From 17deadb25e2615a35747a4fc9ea8d553f34c0009 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=99=E6=BE=B3?= Date: Mon, 23 Dec 2024 11:31:20 +0800 Subject: [PATCH] first commit --- .gitignore | 7 + README.md | 2 + filter/cluster_filter.py | 82 +++++++ filter/gps_filter.py | 248 +++++++++++++++++++++ filter/time_group_overlap_filter.py | 212 ++++++++++++++++++ odm_preprocess.py | 334 ++++++++++++++++++++++++++++ post_pro/merge_tif.py | 99 +++++++++ tools/odm_pip_time.py | 55 +++++ tools/show_GPS.py | 51 +++++ tools/show_GPS_by_time.py | 138 ++++++++++++ utils/command_runner.py | 71 ++++++ utils/gps_extractor.py | 93 ++++++++ utils/grid_divider.py | 87 ++++++++ utils/logger.py | 35 +++ utils/odm_monitor.py | 139 ++++++++++++ utils/visualizer.py | 121 ++++++++++ 16 files changed, 1774 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 filter/cluster_filter.py create mode 100644 filter/gps_filter.py create mode 100644 filter/time_group_overlap_filter.py create mode 100644 odm_preprocess.py create mode 100644 post_pro/merge_tif.py create mode 100644 tools/odm_pip_time.py create mode 100644 tools/show_GPS.py create mode 100644 tools/show_GPS_by_time.py create mode 100644 utils/command_runner.py create mode 100644 utils/gps_extractor.py create mode 100644 utils/grid_divider.py create mode 100644 utils/logger.py create mode 100644 utils/odm_monitor.py create mode 100644 utils/visualizer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..775aa57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +# 忽略所有__pycache__目录 +**/__pycache__/ +*.pyc +*.pyo +*.pyd + +test/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..2ed8f52 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# ODM_Pro +无人机三维重建 diff --git a/filter/cluster_filter.py b/filter/cluster_filter.py new file mode 100644 index 0000000..eadebf1 --- /dev/null +++ b/filter/cluster_filter.py @@ -0,0 +1,82 @@ +from sklearn.cluster import DBSCAN +from sklearn.preprocessing import StandardScaler +import os + + +class GPSCluster: + def __init__(self, gps_points, output_dir: str, eps=0.01, min_samples=5): + """ + 初始化GPS聚类器 + + 参数: + eps: DBSCAN的邻域半径参数 + min_samples: DBSCAN的最小样本数参数 + """ + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) + self.scaler = StandardScaler() + self.gps_points = gps_points + + def fit(self): + """ + 对GPS点进行聚类,只保留最大的类 + + 参数: + gps_points: 包含'lat'和'lon'列的DataFrame + + 返回: + 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1 + """ + # 提取经纬度数据 + X = self.gps_points[["lon", "lat"]].values + + # # 数据标准化 + # X_scaled = self.scaler.fit_transform(X) + + # 执行DBSCAN聚类 + labels = self.dbscan.fit_predict(X) + + # 找出最大类的标签(排除噪声点-1) + unique_labels = [l for l in set(labels) if l != -1] + if unique_labels: # 如果有聚类 + label_counts = [(l, sum(labels == l)) for l in unique_labels] + largest_label = max(label_counts, key=lambda x: x[1])[0] + + # 将最大类标记为1,其他都标记为-1 + new_labels = (labels == largest_label).astype(int) + new_labels[new_labels == 0] = -1 + else: # 如果没有聚类,全部标记为-1 + new_labels = labels + + # 将聚类结果添加到原始数据中 + result_df = self.gps_points.copy() + result_df["cluster"] = new_labels + + return result_df + + def get_cluster_stats(self, clustered_points): + """ + 获取聚类统计信息 + + 参数: + clustered_points: 带有聚类标签的DataFrame + + 返回: + 聚类统计信息的字典 + """ + main_cluster_points = sum(clustered_points["cluster"] == 1) + stats = { + "total_points": len(clustered_points), + "main_cluster_points": main_cluster_points, + "noise_points": sum(clustered_points["cluster"] == -1), + } + + noise_cluster = self.get_noise_cluster(clustered_points) + return stats + + def get_main_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == 1] + + def get_noise_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == -1] diff --git a/filter/gps_filter.py b/filter/gps_filter.py new file mode 100644 index 0000000..95d1e1b --- /dev/null +++ b/filter/gps_filter.py @@ -0,0 +1,248 @@ +import os +import math +from itertools import combinations +import numpy as np +from scipy.spatial import KDTree +import logging +import pandas as pd +from datetime import datetime, timedelta + + +class GPSFilter: + """过滤密集点及孤立点""" + + def __init__(self, output_dir): + self.logger = logging.getLogger('UAV_Preprocess.GPSFilter') + + @staticmethod + def _haversine(lat1, lon1, lat2, lon2): + """计算两点之间的地理距离(单位:米)""" + R = 6371000 # 地球平均半径,单位:米 + phi1, phi2 = math.radians(lat1), math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = math.sin(delta_phi / 2) ** 2 + math.cos(phi1) * \ + math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + return R * c + + @staticmethod + def _assign_to_grid(lat, lon, grid_size, min_lat, min_lon): + """根据经纬度和网格大小,将点分配到网格""" + grid_x = int((lat - min_lat) // grid_size) + grid_y = int((lon - min_lon) // grid_size) + return grid_x, grid_y + + def _get_distances(self, points_df, grid_size): + """读取图片 GPS 坐标,计算点对之间的距离并排序""" + # 确定经纬度范围 + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + self.logger.info( + f"经纬度范围:纬度[{min_lat:.6f}, {max_lat:.6f}],纬度范围[{max_lat-min_lat:.6f}]," + f"经度[{min_lon:.6f}, {max_lon:.6f}],经度范围[{max_lon-min_lon:.6f}]") + + # 分配到网格 + grid_map = {} + for _, row in points_df.iterrows(): + grid = self._assign_to_grid( + row['lat'], row['lon'], grid_size, min_lat, min_lon) + if grid not in grid_map: + grid_map[grid] = [] + grid_map[grid].append((row['file'], row['lat'], row['lon'])) + + self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中") + + # 在每个网格中计算两两距离并排序 + sorted_distances = {} + for grid, images in grid_map.items(): + distances = [] + for (img1, lat1, lon1), (img2, lat2, lon2) in combinations(images, 2): + dist = self._haversine(lat1, lon1, lat2, lon2) + distances.append((img1, img2, dist)) + distances.sort(key=lambda x: x[2]) # 按距离升序排序 + sorted_distances[grid] = distances + self.logger.debug(f"网格 {grid} 中计算了 {len(distances)} 个距离对") + + return sorted_distances + + def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list: + """根据拍摄时间分组图片 + + 如果相邻两张图片的拍摄时间差超过5分钟,则进行切分 + + Args: + points_df: 包含图片信息的DataFrame,必须包含'file'和'date'列 + time_threshold: 时间间隔阈值,默认5分钟 + + Returns: + list: 每个元素为时间组内的点数据 + """ + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [points_df] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + self.logger.info( + f"有效时间范围: {valid_date_points['date'].min()} 到 {valid_date_points['date'].max()}") + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + + # 记录断点信息 + break_time = valid_date_points.iloc[idx]['date'] + group_start_time = current_group.iloc[0]['date'] + group_end_time = current_group.iloc[-1]['date'] + + self.logger.info( + f"时间组 {len(time_groups)}: {len(current_group)} 个点, " + f"时间范围 [{group_start_time} - {group_end_time}]" + ) + self.logger.info( + f"在时间 {break_time} 处发现断点,时间差为 {time_diff}") + + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + self.logger.info( + f"时间组 {len(time_groups)}: {len(last_group)} 个点, " + f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]" + ) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点") + + self.logger.info(f"共分为 {len(time_groups)} 个时间组") + return time_groups + + def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)): + """ + 过滤密集点,先按时间分组,再在每个时间组内过滤。 + 空时间戳的点不进行过滤。 + + Args: + points_df: 点数据 + grid_size: 网格大小 + distance_threshold: 距离阈值(米) + time_interval: 时间间隔(秒) + """ + self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, " + f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)") + + # 按时间分组 + time_groups = self._group_by_time(points_df, time_threshold) + + # 存储所有要删除的图片 + all_to_del_imgs = [] + + # 对每个时间组进行密集点过滤 + for group_idx, group_points in enumerate(time_groups): + # 检查是否为空时间戳组(最后一组) + if group_idx == len(time_groups) - 1 and group_points['date'].isna().any(): + self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)") + continue + + self.logger.info( + f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)") + + # 计算该组内的点间距离 + sorted_distances = self._get_distances(group_points, grid_size) + group_to_del_imgs = [] + + # 在每个网格中过滤密集点 + for grid, distances in sorted_distances.items(): + grid_del_count = 0 + while distances: + candidate_img1, candidate_img2, dist = distances[0] + if dist < distance_threshold: + distances.pop(0) + + # 获取候选图片的其他最短距离 + candidate_img1_dist = None + candidate_img2_dist = None + for distance in distances: + if candidate_img1 in distance: + candidate_img1_dist = distance[2] + break + for distance in distances: + if candidate_img2 in distance: + candidate_img2_dist = distance[2] + break + + # 选择要删除的点 + if candidate_img1_dist and candidate_img2_dist: + to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2 + group_to_del_imgs.append(to_del_img) + grid_del_count += 1 + self.logger.debug( + f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)") + distances = [ + d for d in distances if to_del_img not in d] + else: + break + + if grid_del_count > 0: + self.logger.info( + f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点") + + all_to_del_imgs.extend(group_to_del_imgs) + self.logger.info( + f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点") + + + # 过滤数据 + filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)] + self.logger.info( + f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") + + return filtered_df + + def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6): + """过滤孤立点""" + self.logger.info( + f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})") + + coords = points_df[['lat', 'lon']].values + kdtree = KDTree(coords) + neighbors_count = [len(kdtree.query_ball_point( + coord, threshold_distance)) for coord in coords] + + isolated_points = [] + for i, (_, row) in enumerate(points_df.iterrows()): + if neighbors_count[i] < min_neighbors: + isolated_points.append(row['file']) + self.logger.debug( + f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})") + + filtered_df = points_df[~points_df['file'].isin(isolated_points)] + self.logger.info( + f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点") + return filtered_df diff --git a/filter/time_group_overlap_filter.py b/filter/time_group_overlap_filter.py new file mode 100644 index 0000000..15ece6f --- /dev/null +++ b/filter/time_group_overlap_filter.py @@ -0,0 +1,212 @@ +import shutil +import pandas as pd +from shapely.geometry import box +from utils.logger import setup_logger +from utils.gps_extractor import GPSExtractor +import numpy as np +import logging +from datetime import timedelta +import matplotlib.pyplot as plt +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class TimeGroupOverlapFilter: + """基于时间组重叠度的图像过滤器""" + + def __init__(self, image_dir: str, output_dir: str, overlap_threshold: float = 0.7): + """ + 初始化过滤器 + + Args: + image_dir: 图像目录 + output_dir: 输出目录 + overlap_threshold: 重叠阈值,默认0.7 + """ + self.image_dir = image_dir + self.output_dir = output_dir + self.overlap_threshold = overlap_threshold + self.logger = logging.getLogger('UAV_Preprocess.TimeGroupFilter') + + def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)): + """按时间间隔对点进行分组""" + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + + return time_groups + + def _get_group_bbox(self, group_df): + """获取组内点的边界框""" + min_lon = group_df['lon'].min() + max_lon = group_df['lon'].max() + min_lat = group_df['lat'].min() + max_lat = group_df['lat'].max() + return box(min_lon, min_lat, max_lon, max_lat) + + def _calculate_overlap(self, box1, box2): + """计算两个边界框的重叠率""" + if box1.intersects(box2): + intersection_area = box1.intersection(box2).area + smaller_area = min(box1.area, box2.area) + return intersection_area / smaller_area + return 0 + + def filter_overlapping_groups(self, time_threshold=timedelta(minutes=5)): + """过滤重叠的时间组""" + # 提取GPS数据 + extractor = GPSExtractor(self.image_dir) + gps_points = extractor.extract_all_gps() + + # 按时间分组 + time_groups = self._group_by_time(gps_points, time_threshold) + + # 计算每个组的边界框 + group_boxes = [] + for idx, group in enumerate(time_groups): + if not group['date'].isna().any(): # 只处理有时间戳的组 + bbox = self._get_group_bbox(group) + group_boxes.append((idx, group, bbox)) + + # 找出需要删除的组 + groups_to_delete = set() + for i in range(len(group_boxes)): + if i in groups_to_delete: + continue + + idx1, group1, box1 = group_boxes[i] + area1 = box1.area + + for j in range(i + 1, len(group_boxes)): + if j in groups_to_delete: + continue + + idx2, group2, box2 = group_boxes[j] + area2 = box2.area + + overlap_ratio = self._calculate_overlap(box1, box2) + + if overlap_ratio > self.overlap_threshold: + # 删除面积较小的组 + if area1 < area2: + group_to_delete = idx1 + smaller_area = area1 + larger_area = area2 + else: + group_to_delete = idx2 + smaller_area = area2 + larger_area = area1 + + groups_to_delete.add(group_to_delete) + self.logger.info( + f"时间组 {group_to_delete + 1} 与时间组 " + f"{idx2 + 1 if group_to_delete == idx1 else idx1 + 1} " + f"重叠率为 {overlap_ratio:.2f}," + f"面积比为 {smaller_area/larger_area:.2f}," + f"将删除较小面积的组 {group_to_delete + 1}" + ) + + # 删除重复组的图像 + deleted_files = [] + for group_idx in groups_to_delete: + group_files = time_groups[group_idx]['file'].tolist() + deleted_files.extend(group_files) + + self.logger.info(f"共删除 {len(groups_to_delete)} 个重复时间组," + f"{len(deleted_files)} 张图像") + + # 可视化结果 + self._visualize_results(time_groups, groups_to_delete) + + return deleted_files + + def _visualize_results(self, time_groups, groups_to_delete): + """可视化过滤结果""" + plt.figure(figsize=(15, 10)) + + # 生成不同的颜色 + colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups))) + + # 绘制所有组的边界框 + for idx, (group, color) in enumerate(zip(time_groups, colors)): + if not group['date'].isna().any(): # 只处理有时间戳的组 + bbox = self._get_group_bbox(group) + x, y = bbox.exterior.xy + + if idx in groups_to_delete: + # 被删除的组用虚线表示 + plt.plot(x, y, '--', color=color, alpha=0.6, + label=f'Deleted Group {idx + 1}') + else: + # 保留的组用实线表示 + plt.plot(x, y, '-', color=color, alpha=0.6, + label=f'Group {idx + 1}') + + # 绘制该组的GPS点 + plt.scatter(group['lon'], group['lat'], color=color, + s=30, alpha=0.6) + + plt.title("Time Groups and Their Bounding Boxes", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10) + plt.tight_layout() + + # 保存图片 + plt.savefig(os.path.join(self.output_dir, 'filter_imgs', 'time_groups_overlap_bbox.png'), + dpi=300, bbox_inches='tight') + plt.close() + + +if __name__ == '__main__': + # 设置路径 + DATASET = r'F:\error_data\20241108134711\3D' + output_dir = r'E:\studio2\ODM_pro\test' + os.makedirs(output_dir, exist_ok=True) + + # 设置日志 + setup_logger(os.path.dirname(output_dir)) + + # 创建过滤器并执行过滤 + filter = TimeGroupOverlapFilter(DATASET, output_dir, overlap_threshold=0.7) + deleted_files = filter.filter_overlapping_groups( + time_threshold=timedelta(minutes=5)) diff --git a/odm_preprocess.py b/odm_preprocess.py new file mode 100644 index 0000000..c62d076 --- /dev/null +++ b/odm_preprocess.py @@ -0,0 +1,334 @@ +import os +import shutil +from datetime import timedelta +from dataclasses import dataclass +from typing import Dict + +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm + +from filter.cluster_filter import GPSCluster +from filter.time_group_overlap_filter import TimeGroupOverlapFilter +from filter.gps_filter import GPSFilter +from utils.command_runner import CommandRunner +from utils.gps_extractor import GPSExtractor +from utils.grid_divider import GridDivider +from utils.logger import setup_logger +from utils.visualizer import FilterVisualizer +from post_pro.merge_tif import MergeTif + + +@dataclass +class PreprocessConfig: + """预处理配置类""" + + image_dir: str + output_dir: str + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + # 时间组重叠过滤参数 + time_group_overlap_threshold: float = 0.7 + time_group_interval: timedelta = timedelta(minutes=5) + # 孤立点过滤参数 + filter_distance_threshold: float = 0.001 # 经纬度距离 + filter_min_neighbors: int = 6 + # 密集点过滤参数 + filter_grid_size: float = 0.001 + filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 + filter_time_threshold: timedelta = timedelta(minutes=5) + # 网格划分参数 + grid_overlap: float = 0.05 + grid_size: float = 500 + # 几个pipline过程是否开启 + mode: str = "快拼模式" + + +class ImagePreprocessor: + def __init__(self, config: PreprocessConfig): + self.config = config + + # 清理并重建输出目录 + if os.path.exists(config.output_dir): + self._clean_output_dir() + self._setup_output_dirs() + + # 初始化其他组件 + self.logger = setup_logger(config.output_dir) + self.gps_points = None + self.command_runner = CommandRunner( + config.output_dir, mode=config.mode) + self.visualizer = FilterVisualizer(config.output_dir) + + def _clean_output_dir(self): + """清理输出目录""" + try: + shutil.rmtree(self.config.output_dir) + print(f"已清理输出目录: {self.config.output_dir}") + except Exception as e: + print(f"清理输出目录时发生错误: {str(e)}") + raise + + def _setup_output_dirs(self): + """创建必要的输出目录结构""" + try: + # 创建主输出目录 + os.makedirs(self.config.output_dir) + + # 创建过滤图像保存目录 + os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) + + # 创建日志目录 + os.makedirs(os.path.join(self.config.output_dir, 'logs')) + + print(f"已创建输出目录结构: {self.config.output_dir}") + except Exception as e: + print(f"创建输出目录时发生错误: {str(e)}") + raise + + def extract_gps(self) -> pd.DataFrame: + """提取GPS数据""" + self.logger.info("开始提取GPS数据") + extractor = GPSExtractor(self.config.image_dir) + self.gps_points = extractor.extract_all_gps() + self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + return self.gps_points + + def cluster(self) -> pd.DataFrame: + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + self.logger.info("开始聚类") + previous_points = self.gps_points.copy() + + # 创建聚类器并执行聚类 + clusterer = GPSCluster( + self.gps_points, output_dir=self.config.output_dir, + eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) + + # 获取主要类别的点 + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_main_cluster(self.clustered_points) + + # 获取统计信息并记录 + stats = clusterer.get_cluster_stats(self.clustered_points) + self.logger.info( + f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," + f"噪声点 {stats['noise_points']} 个" + ) + + # 可视化聚类结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "1-Clustering") + + return self.gps_points + + def filter_time_group_overlap(self) -> pd.DataFrame: + """过滤重叠的时间组""" + self.logger.info("开始过滤重叠时间组") + + self.logger.info("开始过滤重叠时间组") + previous_points = self.gps_points.copy() + + filter = TimeGroupOverlapFilter( + self.config.image_dir, + self.config.output_dir, + overlap_threshold=self.config.time_group_overlap_threshold + ) + + deleted_files = filter.filter_overlapping_groups( + time_threshold=self.config.time_group_interval + ) + + # 更新GPS点数据,移除被删除的图像 + self.gps_points = self.gps_points[~self.gps_points['file'].isin( + deleted_files)] + self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "2-Time Group Overlap") + + return self.gps_points + + # TODO 过滤算法还需要更新 + def filter_points(self) -> pd.DataFrame: + """过滤GPS点""" + + self.logger.info("开始过滤GPS点") + filter = GPSFilter(self.config.output_dir) + + # 过滤孤立点 + previous_points = self.gps_points.copy() + self.logger.info( + f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, " + f"最小邻居数: {self.config.filter_min_neighbors})" + ) + self.gps_points = filter.filter_isolated_points( + self.gps_points, + self.config.filter_distance_threshold, + self.config.filter_min_neighbors, + ) + self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化孤立点过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "3-Isolated Points") + + # 过滤密集点 + previous_points = self.gps_points.copy() + self.logger.info( + f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, " + f"距离阈值: {self.config.filter_dense_distance_threshold})" + ) + self.gps_points = filter.filter_dense_points( + self.gps_points, + grid_size=self.config.filter_grid_size, + distance_threshold=self.config.filter_dense_distance_threshold, + time_threshold=self.config.filter_time_threshold, + ) + self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化密集点过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "4-Dense Points") + + return self.gps_points + + def divide_grids(self) -> Dict[int, pd.DataFrame]: + """划分网格""" + self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") + self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") + grid_divider = GridDivider(overlap=self.config.grid_overlap) + grids = grid_divider.divide_grids( + self.gps_points, grid_size=self.config.grid_size + ) + grid_points = grid_divider.assign_to_grids(self.gps_points, grids) + self.logger.info(f"成功划分为 {len(grid_points)} 个网格") + return grid_points + + def copy_images(self, grid_points: Dict[int, pd.DataFrame]): + """复制图像到目标文件夹""" + self.logger.info("开始复制图像文件") + self.logger.info("开始复制图像文件") + + for grid_idx, points in grid_points.items(): + output_dir = os.path.join( + self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" + ) + + os.makedirs(output_dir, exist_ok=True) + + for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): + src = os.path.join(self.config.image_dir, point["file"]) + dst = os.path.join(output_dir, point["file"]) + shutil.copy(src, dst) + self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") + + def merge_tif(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的TIF影像""" + self.logger.info("开始合并TIF影像") + + # 检查是否有多个网格需要合并 + if len(grid_points) < 2: + self.logger.info("只有一个网格,无需合并TIF影像") + return + + input_tif1, input_tif2 = None, None + merge_count = 0 + + try: + for grid_idx, points in grid_points.items(): + grid_tif = os.path.join( + self.config.output_dir, + f"grid_{grid_idx + 1}", + "project", + "odm_orthophoto", + "odm_orthophoto.tif" + ) + + # 检查TIF文件是否存在 + if not os.path.exists(grid_tif): + self.logger.error( + f"网格 {grid_idx + 1} 的TIF文件不存在: {grid_tif}") + continue + + if input_tif1 is None: + input_tif1 = grid_tif + self.logger.info(f"设置第一个输入TIF: {input_tif1}") + else: + input_tif2 = grid_tif + output_tif = os.path.join( + self.config.output_dir, "merged_orthophoto.tif") + + self.logger.info( + f"开始合并第 {merge_count + 1} 次:\n" + f"输入1: {input_tif1}\n" + f"输入2: {input_tif2}\n" + f"输出: {output_tif}" + ) + + merge_tif = MergeTif(input_tif1, input_tif2, output_tif) + merge_tif.merge() + merge_count += 1 + + input_tif1 = output_tif + input_tif2 = None + + self.logger.info( + f"TIF影像合并完成,共执行 {merge_count} 次合并," + f"最终输出文件: {input_tif1}" + ) + + except Exception as e: + self.logger.error(f"TIF影像合并过程中发生错误: {str(e)}", exc_info=True) + raise + + def process(self): + """执行完整的预处理流程""" + try: + self.extract_gps() + self.cluster() + self.filter_time_group_overlap() + self.filter_points() + grid_points = self.divide_grids() + self.copy_images(grid_points) + self.logger.info("预处理任务完成") + self.command_runner.run_grid_commands( + grid_points, + ) + # 添加TIF合并步骤 + self.merge_tif(grid_points) + except Exception as e: + self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + # 创建配置 + config = PreprocessConfig( + image_dir=r"F:\error_data\20240930091614\project\images", + output_dir=r"G:\20240930091614\output", + + cluster_eps=0.01, + cluster_min_samples=5, + + # 添加时间组重叠过滤参数 + time_group_overlap_threshold=0.7, + time_group_interval=timedelta(minutes=5), + + filter_distance_threshold=0.001, + filter_min_neighbors=6, + + filter_grid_size=0.001, + filter_dense_distance_threshold=10, + filter_time_threshold=timedelta(minutes=5), + + grid_size=1000, + + + mode="快拼模式", + ) + + # 创建处理器并执行 + processor = ImagePreprocessor(config) + processor.process() diff --git a/post_pro/merge_tif.py b/post_pro/merge_tif.py new file mode 100644 index 0000000..f9a1ce6 --- /dev/null +++ b/post_pro/merge_tif.py @@ -0,0 +1,99 @@ +from osgeo import gdal +import logging +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class MergeTif: + def __init__(self, input_tif1, input_tif2, output_tif): + self.input_tif1 = input_tif1 + self.input_tif2 = input_tif2 + self.output_tif = output_tif + self.logger = logging.getLogger('UAV_Preprocess.MergeTif') + + def merge(self): + """合并两张TIF影像""" + try: + self.logger.info("开始合并TIF影像") + self.logger.info(f"输入影像1: {self.input_tif1}") + self.logger.info(f"输入影像2: {self.input_tif2}") + self.logger.info(f"输出影像: {self.output_tif}") + + # 检查输入文件是否存在 + if not os.path.exists(self.input_tif1) or not os.path.exists(self.input_tif2): + error_msg = "输入影像文件不存在" + self.logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # 打开影像,检查投影是否一致 + datasets = [gdal.Open(tif) + for tif in [self.input_tif1, self.input_tif2]] + if None in datasets: + error_msg = "无法打开输入影像文件" + self.logger.error(error_msg) + raise ValueError(error_msg) + + projections = [dataset.GetProjection() for dataset in datasets] + self.logger.debug(f"影像1投影: {projections[0]}") + self.logger.debug(f"影像2投影: {projections[1]}") + + # 检查投影是否一致 + if len(set(projections)) != 1: + error_msg = "影像的投影不一致,请先进行重投影!" + self.logger.error(error_msg) + raise ValueError(error_msg) + + # 创建 GDAL Warp 选项 + warp_options = gdal.WarpOptions( + format="GTiff", + resampleAlg="average", # 设置重采样方法为平均值 + srcNodata=0, # 输入影像中的无效值 + dstNodata=0, # 输出影像中的无效值 + multithread=True # 启用多线程优化 + ) + + self.logger.info("开始执行影像拼接...") + + # 使用 GDAL 的 Warp 方法进行拼接 + result = gdal.Warp( + self.output_tif, + [self.input_tif1, self.input_tif2], # 输入多张影像 + options=warp_options + ) + + if result is None: + error_msg = "影像拼接失败" + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + # 获取输出影像的基本信息 + output_dataset = gdal.Open(self.output_tif) + if output_dataset: + width = output_dataset.RasterXSize + height = output_dataset.RasterYSize + bands = output_dataset.RasterCount + self.logger.info(f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}") + + self.logger.info(f"影像拼接成功,输出文件保存至: {self.output_tif}") + + except Exception as e: + self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + from utils.logger import setup_logger + + # 定义影像路径 + input_tif1 = r"G:\20240930091614\output\grid_1\project\odm_orthophoto\odm_orthophoto.tif" + input_tif2 = r"G:\20240930091614\output\grid_2\project\odm_orthophoto\odm_orthophoto.tif" + output_tif = r"G:\20240930091614\output\merged_orthophoto.tif" + + # 设置日志 + output_dir = r"E:\studio2\ODM_pro\test" + setup_logger(output_dir) + + # 执行拼接 + merge_tif = MergeTif(input_tif1, input_tif2, output_tif) + merge_tif.merge() diff --git a/tools/odm_pip_time.py b/tools/odm_pip_time.py new file mode 100644 index 0000000..2dba550 --- /dev/null +++ b/tools/odm_pip_time.py @@ -0,0 +1,55 @@ +from datetime import datetime +import json + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="ODM log time") + + parser.add_argument( + "--path", default=r"E:\datasets\UAV\134\project\log.json") + args = parser.parse_args() + return args + + +def main(args): + # 读取 JSON 文件 + with open(args.path, 'r') as file: + data = json.load(file) + + # 提取 "stages" 中每个步骤的开始时间和持续时间 + stage_timings = [] + for i, stage in enumerate(data.get("stages", [])): + stage_name = stage.get("name", "Unnamed Stage") + start_time = stage.get("startTime") + + # 获取当前阶段的开始时间 + if start_time: + start_dt = datetime.fromisoformat(start_time) + + # 获取阶段的结束时间:可以是下一个阶段的开始时间,或当前阶段的 `endTime`(如果存在) + if i + 1 < len(data["stages"]): + end_time = data["stages"][i + 1].get("startTime") + else: + end_time = stage.get("endTime") or data.get("endTime") + + if end_time: + end_dt = datetime.fromisoformat(end_time) + duration = (end_dt - start_dt).total_seconds() + stage_timings.append((stage_name, duration)) + + # 输出每个阶段的持续时间,调整为对齐格式 + total_time = 0 + print(f"{'Stage Name':<25} {'Duration (seconds)':>15}") + print("=" * 45) + for stage_name, duration in stage_timings: + print(f"{stage_name:<25} {duration:>15.2f}") + total_time += duration + + print('Total Time:', total_time) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/tools/show_GPS.py b/tools/show_GPS.py new file mode 100644 index 0000000..f4690a6 --- /dev/null +++ b/tools/show_GPS.py @@ -0,0 +1,51 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +from preprocess.gps_extractor import GPSExtractor + +DATASET = r'F:\error_data\20240930091614\project\images' + +if __name__ == '__main__': + extractor = GPSExtractor(DATASET) + gps_points = extractor.extract_all_gps() + + # 创建两个子图 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) + + # 左图:原始散点图 + ax1.scatter(gps_points['lon'], gps_points['lat'], + color='blue', marker='o', label='GPS Points') + ax1.set_title("GPS Coordinates of Images", fontsize=14) + ax1.set_xlabel("Longitude", fontsize=12) + ax1.set_ylabel("Latitude", fontsize=12) + ax1.grid(True) + ax1.legend() + + # # 右图:按时间排序的轨迹图 + # gps_points_sorted = gps_points.sort_values('date') + + # # 绘制飞行轨迹线 + # ax2.plot(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600], + # color='blue', linestyle='-', linewidth=1, alpha=0.6) + + # # 绘制GPS点 + # ax2.scatter(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600], + # color='red', marker='o', s=30, label='GPS Points') + + # 标记起点和终点 + # ax2.scatter(gps_points_sorted['lon'].iloc[0], gps_points_sorted['lat'].iloc[0], + # color='green', marker='^', s=100, label='Start') + # ax2.scatter(gps_points_sorted['lon'].iloc[-1], gps_points_sorted['lat'].iloc[-1], + # color='purple', marker='s', s=100, label='End') + + ax2.set_title("UAV Flight Trajectory", fontsize=14) + ax2.set_xlabel("Longitude", fontsize=12) + ax2.set_ylabel("Latitude", fontsize=12) + ax2.grid(True) + ax2.legend() + + # 调整子图之间的间距 + plt.tight_layout() + plt.show() diff --git a/tools/show_GPS_by_time.py b/tools/show_GPS_by_time.py new file mode 100644 index 0000000..f771275 --- /dev/null +++ b/tools/show_GPS_by_time.py @@ -0,0 +1,138 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +from datetime import timedelta +import logging +import numpy as np +from preprocess.gps_extractor import GPSExtractor +from preprocess.logger import setup_logger + +class GPSTimeVisualizer: + """按时间组可视化GPS点""" + + def __init__(self, image_dir: str, output_dir: str): + self.image_dir = image_dir + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.GPSVisualizer') + + def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)): + """按时间间隔对点进行分组""" + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [points_df] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + + return time_groups + + def visualize_time_groups(self, time_threshold=timedelta(minutes=5)): + """在同一张图上显示所有时间组,用不同颜色区分""" + # 提取GPS数据 + extractor = GPSExtractor(self.image_dir) + gps_points = extractor.extract_all_gps() + + # 按时间分组 + time_groups = self._group_by_time(gps_points, time_threshold) + + # 创建图形 + plt.figure(figsize=(15, 10)) + + # 生成不同的颜色 + colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups))) + + # 为每个时间组绘制点和轨迹 + for idx, (group, color) in enumerate(zip(time_groups, colors)): + if not group['date'].isna().any(): + # 有时间戳的组 + sorted_group = group.sort_values('date') + + # 绘制轨迹线 + plt.plot(sorted_group['lon'], sorted_group['lat'], + color=color, linestyle='-', linewidth=1.5, alpha=0.6, + label=f'Flight Path {idx + 1}') + + # 绘制GPS点 + plt.scatter(sorted_group['lon'], sorted_group['lat'], + color=color, marker='o', s=30, alpha=0.6) + + # 标记起点和终点 + plt.scatter(sorted_group['lon'].iloc[0], sorted_group['lat'].iloc[0], + color=color, marker='^', s=100, + label=f'Start {idx + 1} ({sorted_group["date"].min().strftime("%H:%M:%S")})') + plt.scatter(sorted_group['lon'].iloc[-1], sorted_group['lat'].iloc[-1], + color=color, marker='s', s=100, + label=f'End {idx + 1} ({sorted_group["date"].max().strftime("%H:%M:%S")})') + else: + # 无时间戳的组 + plt.scatter(group['lon'], group['lat'], + color=color, marker='x', s=50, alpha=0.6, + label='No Timestamp Points') + + plt.title("GPS Points by Time Groups", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 调整图例位置和大小 + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10) + + # 调整布局以适应图例 + plt.tight_layout() + + # 保存图片 + plt.savefig(os.path.join(self.output_dir, 'gps_time_groups_combined.png'), + dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info(f"已生成包含 {len(time_groups)} 个时间组的组合可视化图形") + + +if __name__ == '__main__': + # 设置数据集路径 + DATASET = r'F:\error_data\20241108134711\3D' + output_dir = r'E:\studio2\ODM_pro\test' + os.makedirs(output_dir, exist_ok=True) + + # 设置日志 + setup_logger(os.path.dirname(output_dir)) + + # 创建可视化器并生成图形 + visualizer = GPSTimeVisualizer(DATASET, output_dir) + visualizer.visualize_time_groups(time_threshold=timedelta(minutes=5)) \ No newline at end of file diff --git a/utils/command_runner.py b/utils/command_runner.py new file mode 100644 index 0000000..26c4f44 --- /dev/null +++ b/utils/command_runner.py @@ -0,0 +1,71 @@ +import os +import logging +import subprocess +import time +from typing import Dict +import pandas as pd +from utils.odm_monitor import ODMProcessMonitor + + +class CommandRunner: + """执行网格处理命令的类""" + + def __init__(self, output_dir: str, max_retries: int = 3, mode: str = "快拼模式"): + """ + 初始化命令执行器 +i + Args: + output_dir: 输出目录路径 + max_retries: 最大重试次数 + """ + self.output_dir = output_dir + self.max_retries = max_retries + self.logger = logging.getLogger('UAV_Preprocess.CommandRunner') + self.monitor = ODMProcessMonitor(max_retries=max_retries, mode=mode) + self.mode = mode + + def _run_command(self, grid_idx: int): + """ + 执行单个网格的命令 + + Args: + grid_idx: 网格索引 + + Raises: + Exception: 当命令执行失败时抛出异常 + """ + try: + grid_dir = os.path.join(self.output_dir, f'grid_{grid_idx + 1}') + grid_dir = grid_dir[0].lower() + grid_dir[1:].replace('\\', '/') + if self.mode == "快拼模式": + command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --fast-orthophoto --skip-3dmodel" + else: + command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps" + + self.logger.info(f"开始执行命令: {command}") + success, error_msg = self.monitor.run_odm_with_monitor( + command, grid_dir, grid_idx) + + if not success: + raise Exception(error_msg) + + except Exception as e: + self.logger.error(f"网格 {grid_idx + 1} 处理失败: {str(e)}") + raise + + def run_grid_commands(self, grid_points: Dict[int, pd.DataFrame]): + """ + 为每个网格顺序运行指定命令 + + Args: + grid_points: 网格点数据字典,键为网格索引,值为该网格的点数据 + """ + + self.logger.info("开始执行网格处理命令") + + for grid_idx in grid_points.keys(): + try: + self._run_command(grid_idx) + except Exception as e: + self.logger.error(f"网格 {grid_idx + 1} 处理失败,停止后续执行: {str(e)}") + raise diff --git a/utils/gps_extractor.py b/utils/gps_extractor.py new file mode 100644 index 0000000..9bb2e57 --- /dev/null +++ b/utils/gps_extractor.py @@ -0,0 +1,93 @@ +import os +from PIL import Image +import piexif +import logging +import pandas as pd +from datetime import datetime + + +class GPSExtractor: + """从图像文件提取GPS坐标和拍摄日期""" + + def __init__(self, image_dir): + self.image_dir = image_dir + self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor') + + @staticmethod + def _dms_to_decimal(dms): + """将DMS格式转换为十进制度""" + return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600 + + @staticmethod + def _parse_datetime(datetime_str): + """解析EXIF中的日期时间字符串""" + try: + # EXIF日期格式通常为 'YYYY:MM:DD HH:MM:SS' + return datetime.strptime(datetime_str.decode(), '%Y:%m:%d %H:%M:%S') + except Exception: + return None + + def get_gps_and_date(self, image_path): + """提取单张图片的GPS坐标和拍摄日期""" + try: + image = Image.open(image_path) + exif_data = piexif.load(image.info['exif']) + + # 提取GPS信息 + gps_info = exif_data.get("GPS", {}) + lat = lon = None + if gps_info: + lat = self._dms_to_decimal(gps_info.get(2, [])) + lon = self._dms_to_decimal(gps_info.get(4, [])) + self.logger.debug(f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}") + + # 提取拍摄日期 + date_info = None + if "Exif" in exif_data: + # 优先使用DateTimeOriginal + date_str = exif_data["Exif"].get(36867) # DateTimeOriginal + if not date_str: + # 备选DateTime + date_str = exif_data["Exif"].get(36868) # DateTimeDigitized + if not date_str: + # 最后使用基本DateTime + date_str = exif_data["0th"].get(306) # DateTime + + if date_str: + date_info = self._parse_datetime(date_str) + self.logger.debug(f"成功提取图片拍摄日期: {image_path} - {date_info}") + + if not gps_info: + self.logger.warning(f"图片无GPS信息: {image_path}") + if not date_info: + self.logger.warning(f"图片无拍摄日期信息: {image_path}") + + return lat, lon, date_info + + except Exception as e: + self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}") + return None, None, None + + def extract_all_gps(self): + """提取所有图片的GPS坐标和拍摄日期""" + self.logger.info(f"开始从目录提取GPS坐标和拍摄日期: {self.image_dir}") + gps_data = [] + total_images = 0 + successful_extractions = 0 + + for image_file in os.listdir(self.image_dir): + if image_file.lower().endswith('.jpg'): + total_images += 1 + image_path = os.path.join(self.image_dir, image_file) + lat, lon, date = self.get_gps_and_date(image_path) + if lat and lon: # 仍然以GPS信息作为主要判断依据 + successful_extractions += 1 + gps_data.append({ + 'file': image_file, + 'lat': lat, + 'lon': lon, + 'date': date + }) + + self.logger.info(f"GPS坐标和拍摄日期提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}") + return pd.DataFrame(gps_data) diff --git a/utils/grid_divider.py b/utils/grid_divider.py new file mode 100644 index 0000000..e51332c --- /dev/null +++ b/utils/grid_divider.py @@ -0,0 +1,87 @@ +import logging +from geopy.distance import geodesic + +class GridDivider: + """划分九宫格,并将图片分配到对应网格""" + + def __init__(self, overlap=0.1): + self.overlap = overlap + self.logger = logging.getLogger('UAV_Preprocess.GridDivider') + self.logger.info(f"初始化网格划分器,重叠率: {overlap}") + + def divide_grids(self, points_df, grid_size=500): + """计算边界框并划分九宫格""" + self.logger.info("开始划分九宫格") + + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + + # 计算区域的实际距离(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters + + self.logger.info( + f"区域宽度: {width:.2f}米, 高度: {height:.2f}米" + ) + + # 计算需要划分的网格数量 + num_grids_width = int(width / grid_size) if int(width / grid_size) > 0 else 1 + num_grids_height = int(height / grid_size) if int(height / grid_size) > 0 else 1 + + # 计算每个网格对应的经纬度步长 + lat_step = (max_lat - min_lat) / num_grids_height + lon_step = (max_lon - min_lon) / num_grids_width + + + grids = [] + for i in range(num_grids_height): + for j in range(num_grids_width): + grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step + grid_max_lat = min_lat + (i + 1) * lat_step + self.overlap * lat_step + grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step + grid_max_lon = min_lon + (j + 1) * lon_step + self.overlap * lon_step + grids.append((grid_min_lat, grid_max_lat, grid_min_lon, grid_max_lon)) + + self.logger.debug( + f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], " + f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]" + ) + + self.logger.info(f"成功划分为 {len(grids)} 个网格 ({num_grids_width}x{num_grids_height})") + return grids + + def assign_to_grids(self, points_df, grids): + """将点分配到对应网格""" + self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中") + + grid_points = {i: [] for i in range(len(grids))} + points_assigned = 0 + multiple_grid_points = 0 + + for _, point in points_df.iterrows(): + point_assigned = False + for i, (min_lat, max_lat, min_lon, max_lon) in enumerate(grids): + if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon: + grid_points[i].append(point.to_dict()) + if point_assigned: + multiple_grid_points += 1 + else: + points_assigned += 1 + point_assigned = True + + self.logger.debug( + f"点 {point['file']} (纬度: {point['lat']:.6f}, 经度: {point['lon']:.6f}) " + f"被分配到网格" + ) + + # 记录每个网格的点数 + for grid_idx, points in grid_points.items(): + self.logger.info(f"网格 {grid_idx} 包含 {len(points)} 个点") + + self.logger.info( + f"点分配完成: 总点数 {len(points_df)}, " + f"成功分配 {points_assigned} 个点, " + f"{multiple_grid_points} 个点被分配到多个网格" + ) + + return grid_points diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..9aee1a6 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,35 @@ +import logging +import os +from datetime import datetime + +def setup_logger(output_dir): + # 创建logs目录 + log_dir = os.path.join(output_dir, 'logs') + + # 创建日志文件名(包含时间戳) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log') + + # 配置日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 配置文件处理器 + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setFormatter(formatter) + + # 配置控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # 获取根日志记录器 + logger = logging.getLogger('UAV_Preprocess') + logger.setLevel(logging.INFO) + + # 添加处理器 + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger \ No newline at end of file diff --git a/utils/odm_monitor.py b/utils/odm_monitor.py new file mode 100644 index 0000000..edaad8f --- /dev/null +++ b/utils/odm_monitor.py @@ -0,0 +1,139 @@ +import os +import time +import psutil +import logging +import subprocess +from typing import Optional, Tuple + + +class ODMProcessMonitor: + """ODM进程监控器""" + + def __init__(self, max_retries: int = 3, check_interval: int = 10, mode: str = "快拼模式"): + """ + 初始化监控器 + + Args: + max_retries: 最大重试次数 + check_interval: 检查间隔(秒) + mode: 模式 + """ + self.max_retries = max_retries + self.check_interval = check_interval + self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor') + self.mode = mode + + def _check_docker_container(self, process_name: str = "opendronemap/odm") -> bool: + """检查是否有指定的Docker容器在运行""" + try: + result = subprocess.run( + ["docker", "ps", "--filter", + f"ancestor={process_name}", "--format", "{{.ID}}"], + capture_output=True, + text=True + ) + return bool(result.stdout.strip()) + except Exception as e: + self.logger.error(f"检查Docker容器状态时发生错误: {str(e)}") + return False + + def _check_success(self, grid_dir: str) -> bool: + """检查ODM是否执行成功""" + if self.mode == "快拼模式": + success_markers = ['odm_orthophoto', 'odm_georeferencing'] + else: + success_markers = ['odm_orthophoto', + 'odm_georeferencing', 'odm_texturing'] + return all(os.path.exists(os.path.join(grid_dir, 'project', marker)) for marker in success_markers) + + def run_odm_with_monitor(self, command: str, grid_dir: str, grid_idx: int) -> Tuple[bool, str]: + """ + 运行ODM命令并监控进程 + """ + attempt = 0 + while attempt < self.max_retries: + try: + self.logger.info(f"网格 {grid_idx + 1} 第 {attempt + 1} 次尝试执行ODM") + + # 创建日志文件 + log_file = os.path.join(grid_dir, f'odm_attempt_{attempt + 1}.log') + with open(log_file, 'w', encoding='utf-8') as f: + f.write(f"=== ODM处理日志 ===\n开始时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + # 启动ODM进程,实时获取输出 + process = subprocess.Popen( + command, + shell=True, + cwd=grid_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + bufsize=1, # 行缓冲 + universal_newlines=True + ) + + self.logger.info("ODM进程已启动,开始监控Docker容器") + + # 等待进程启动 + time.sleep(10) + + # 实时读取输出并写入日志 + def log_output(pipe, log_file, prefix=""): + with open(log_file, 'a', encoding='utf-8') as f: + for line in pipe: + f.write(f"{prefix}{line}") + f.flush() # 确保立即写入 + + # 创建线程读取输出 + from threading import Thread + stdout_thread = Thread(target=log_output, + args=(process.stdout, log_file)) + stderr_thread = Thread(target=log_output, + args=(process.stderr, log_file, "ERROR: ")) + + stdout_thread.daemon = True + stderr_thread.daemon = True + stdout_thread.start() + stderr_thread.start() + + # 监控Docker容器 + while True: + if not self._check_docker_container(): + # Docker容器已结束 + process.wait() # 等待进程完全结束 + + # 等待输出线程结束 + stdout_thread.join(timeout=5) + stderr_thread.join(timeout=5) + + # 记录结束时间 + with open(log_file, 'a', encoding='utf-8') as f: + f.write(f"\n=== 处理结束 ===\n结束时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") + + # 检查是否成功完成 + if self._check_success(grid_dir): + self.logger.info(f"网格 {grid_idx + 1} ODM处理成功") + return True, "" + else: + self.logger.warning( + f"网格 {grid_idx + 1} 第 {attempt + 1} 次尝试失败") + break + + time.sleep(self.check_interval) + + # 如果不是最后一次尝试,等待后重试 + if attempt < self.max_retries - 1: + wait_time = (attempt + 1) * 30 + self.logger.info(f"等待 {wait_time} 秒后重试...") + time.sleep(wait_time) + + attempt += 1 + + except Exception as e: + error_msg = f"监控进程发生异常: {str(e)}" + self.logger.error(error_msg) + return False, error_msg + + error_msg = f"网格 {grid_idx + 1} 在 {self.max_retries} 次尝试后仍然失败,需要人工查看" + self.logger.error(error_msg) + return False, error_msg diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000..f74f5a9 --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,121 @@ +import os +import matplotlib.pyplot as plt +import pandas as pd +import logging +from typing import Optional + + +class FilterVisualizer: + """过滤结果可视化器""" + + def __init__(self, output_dir: str): + """ + 初始化可视化器 + + Args: + output_dir: 输出目录路径 + """ + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.Visualizer') + + def visualize_filter_step(self, + current_points: pd.DataFrame, + previous_points: pd.DataFrame, + step_name: str, + save_name: Optional[str] = None): + """ + 可视化单个过滤步骤的结果 + + Args: + current_points: 当前步骤后的点 + previous_points: 上一步骤的点 + step_name: 步骤名称 + save_name: 保存文件名,默认为step_name + """ + self.logger.info(f"开始生成{step_name}的可视化结果") + + # 找出被过滤掉的点 + filtered_files = set(previous_points['file']) - set(current_points['file']) + filtered_points = previous_points[previous_points['file'].isin(filtered_files)] + + # 创建图形 + plt.figure(figsize=(20, 16)) + + # 绘制保留的点 + plt.scatter(current_points['lon'], current_points['lat'], + color='blue', label='Retained Points', + alpha=0.6, s=50) + + # 绘制被过滤的点 + if not filtered_points.empty: + plt.scatter(filtered_points['lon'], filtered_points['lat'], + color='red', marker='x', label='Filtered Points', + alpha=0.6, s=100) + + # 设置图形属性 + plt.title(f"GPS Points After {step_name}\n" + f"(Filtered: {len(filtered_points)}, Retained: {len(current_points)})", + fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 添加统计信息 + stats_text = ( + f"Original Points: {len(previous_points)}\n" + f"Filtered Points: {len(filtered_points)}\n" + f"Remaining Points: {len(current_points)}\n" + f"Filter Rate: {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + # 添加图例 + plt.legend(loc='upper right', fontsize=10) + + # 调整布局 + plt.tight_layout() + + # 保存图形 + save_name = save_name or step_name.lower().replace(' ', '_') + save_path = os.path.join(self.output_dir, 'filter_imgs', f'filter_{save_name}.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info( + f"{step_name}过滤可视化结果已保存至 {save_path}\n" + f"过滤掉 {len(filtered_points)} 个点," + f"保留 {len(current_points)} 个点," + f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + + +if __name__ == '__main__': + # 测试代码 + import numpy as np + from datetime import datetime + + # 创建测试数据 + np.random.seed(42) + n_points = 1000 + + # 生成随机点 + test_data = pd.DataFrame({ + 'lon': np.random.uniform(120, 121, n_points), + 'lat': np.random.uniform(30, 31, n_points), + 'file': [f'img_{i}.jpg' for i in range(n_points)], + 'date': [datetime.now() for _ in range(n_points)] + }) + + # 随机选择点作为过滤后的结果 + filtered_data = test_data.sample(n=800) + + # 测试可视化 + visualizer = FilterVisualizer('test_output') + os.makedirs('test_output', exist_ok=True) + + visualizer.visualize_filter_step( + filtered_data, + test_data, + "Test Filter" + ) \ No newline at end of file