diff --git a/odm_preprocess.py b/odm_preprocess.py index 4d6b1a6..4cd8162 100644 --- a/odm_preprocess.py +++ b/odm_preprocess.py @@ -1,5 +1,6 @@ import os import shutil +from datetime import timedelta from dataclasses import dataclass from typing import Dict @@ -21,19 +22,25 @@ class PreprocessConfig: image_dir: str output_dir: str - eps: float = 0.01 - min_samples: int = 5 - filter_grid_size: float = 0.001 - filter_dense_distance_threshold: float = 10 - filter_distance_threshold: float = 0.001 + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + # 孤立点过滤参数 + filter_distance_threshold: float = 0.001 # 经纬度距离 filter_min_neighbors: int = 6 + # 密集点过滤参数 + filter_grid_size: float = 0.001 + filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 + filter_time_threshold: timedelta = timedelta(minutes=5) + # 网格划分参数 grid_overlap: float = 0.05 grid_size: float = 500 + # 几个pipline过程是否开启 enable_filter: bool = True enable_grid_division: bool = True enable_visualization: bool = True enable_copy_images: bool = True - + mode: str = "快拼模式" class ImagePreprocessor: def __init__(self, config: PreprocessConfig): @@ -55,11 +62,13 @@ class ImagePreprocessor: self.logger.info("开始聚类") # 创建聚类器并执行聚类 clusterer = GPSCluster( - self.gps_points, output_dir=self.config.output_dir) + self.gps_points, output_dir=self.config.output_dir, + eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) # 获取主要类别的点 - self.gps_points = clusterer.get_main_cluster() + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_main_cluster(self.clustered_points) # 获取统计信息并记录 - stats = clusterer.get_cluster_stats() + stats = clusterer.get_cluster_stats(self.clustered_points) self.logger.info( f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," f"噪声点 {stats['noise_points']} 个" @@ -90,6 +99,7 @@ class ImagePreprocessor: self.gps_points, grid_size=self.config.filter_grid_size, distance_threshold=self.config.filter_dense_distance_threshold, + time_threshold=self.config.filter_time_threshold, ) self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points @@ -188,15 +198,15 @@ class ImagePreprocessor: try: self.extract_gps() self.cluster() - # self.time_filter() - # self.filter_points() + self.filter_points() grid_points = self.divide_grids() self.copy_images(grid_points) self.visualize_results() - # self.logger.info("预处理任务完成") + self.logger.info("预处理任务完成") self.command_runner.run_grid_commands( grid_points, - self.config.enable_grid_division + self.config.enable_grid_division, + self.mode ) except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) @@ -206,14 +216,22 @@ class ImagePreprocessor: if __name__ == "__main__": # 创建配置 config = PreprocessConfig( - image_dir=r"E:\湖南省第二测绘院\11-06-项目移交文件(王辉给)\无人机二三维节点扩容生产影像\影像数据\199\code\images", + image_dir=r"E:\datasets\UAV\1815\images", output_dir=r"test", - filter_grid_size=0.001, - filter_dense_distance_threshold=10, + + cluster_eps=0.01, + cluster_min_samples=5, + filter_distance_threshold=0.001, filter_min_neighbors=6, + + filter_grid_size=0.001, + filter_dense_distance_threshold=10, + filter_time_threshold=timedelta(minutes=5), + grid_overlap=0.05, grid_size=500, + enable_filter=True, enable_grid_division=True, enable_visualization=True, diff --git a/preprocess/cluster.py b/preprocess/cluster.py index ae32033..85ed6aa 100644 --- a/preprocess/cluster.py +++ b/preprocess/cluster.py @@ -17,7 +17,6 @@ class GPSCluster: self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) self.scaler = StandardScaler() self.gps_points = gps_points - self.clustered_points = self.fit() self.log_file = os.path.join(output_dir, 'del_imgs.txt') def fit(self): @@ -57,7 +56,7 @@ class GPSCluster: return result_df - def get_cluster_stats(self): + def get_cluster_stats(self, clustered_points): """ 获取聚类统计信息 @@ -67,22 +66,22 @@ class GPSCluster: 返回: 聚类统计信息的字典 """ - main_cluster_points = sum(self.clustered_points["cluster"] == 1) + main_cluster_points = sum(clustered_points["cluster"] == 1) stats = { - "total_points": len(self.clustered_points), + "total_points": len(clustered_points), "main_cluster_points": main_cluster_points, - "noise_points": sum(self.clustered_points["cluster"] == -1), + "noise_points": sum(clustered_points["cluster"] == -1), } - noise_cluster = self.get_noise_cluster() + noise_cluster = self.get_noise_cluster(clustered_points) with open(self.log_file, 'a', encoding='utf-8') as f: for i, (_, row) in enumerate(noise_cluster.iterrows()): f.write(row['file']+'\n') f.write('\n') return stats - def get_main_cluster(self): - return self.clustered_points[self.clustered_points["cluster"] == 1] + def get_main_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == 1] - def get_noise_cluster(self): - return self.clustered_points[self.clustered_points["cluster"] == -1] + def get_noise_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == -1] diff --git a/preprocess/command_runner.py b/preprocess/command_runner.py index 75f9f51..8d3616b 100644 --- a/preprocess/command_runner.py +++ b/preprocess/command_runner.py @@ -35,11 +35,15 @@ i """ try: grid_dir = os.path.join(self.output_dir, f'grid_{grid_idx + 1}') - command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh" + if self.mode == "快拼模式": + command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh --fast-orthophoto --skip-3dmodel" + else: + command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh" self.logger.info(f"开始执行命令: {command}") - success, error_msg = self.monitor.run_odm_with_monitor(command, grid_dir, grid_idx) - + success, error_msg = self.monitor.run_odm_with_monitor( + command, grid_dir, grid_idx) + if not success: raise Exception(error_msg) diff --git a/preprocess/gps_filter.py b/preprocess/gps_filter.py index a88977c..706d793 100644 --- a/preprocess/gps_filter.py +++ b/preprocess/gps_filter.py @@ -52,7 +52,7 @@ class GPSFilter: if grid not in grid_map: grid_map[grid] = [] grid_map[grid].append((row['file'], row['lat'], row['lon'])) - + self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中") # 在每个网格中计算两两距离并排序 @@ -68,11 +68,11 @@ class GPSFilter: return sorted_distances - def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta = timedelta(minutes=5)) -> list: + def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list: """根据拍摄时间分组图片 - + 如果相邻两张图片的拍摄时间差超过5分钟,则进行切分 - + Args: points_df: 包含图片信息的DataFrame,必须包含'file'和'date'列 time_threshold: 时间间隔阈值,默认5分钟 @@ -97,32 +97,34 @@ class GPSFilter: # 按时间排序 valid_date_points = valid_date_points.sort_values('date') - self.logger.info(f"有效时间范围: {valid_date_points['date'].min()} 到 {valid_date_points['date'].max()}") + self.logger.info( + f"有效时间范围: {valid_date_points['date'].min()} 到 {valid_date_points['date'].max()}") # 计算时间差 time_diffs = valid_date_points['date'].diff() - + # 找到时间差超过阈值的位置 time_groups = [] current_group_start = 0 - + for idx, time_diff in enumerate(time_diffs): if time_diff and time_diff > time_threshold: # 添加当前组 current_group = valid_date_points.iloc[current_group_start:idx] time_groups.append(current_group) - + # 记录断点信息 break_time = valid_date_points.iloc[idx]['date'] group_start_time = current_group.iloc[0]['date'] group_end_time = current_group.iloc[-1]['date'] - + self.logger.info( f"时间组 {len(time_groups)}: {len(current_group)} 个点, " f"时间范围 [{group_start_time} - {group_end_time}]" ) - self.logger.info(f"在时间 {break_time} 处发现断点,时间差为 {time_diff}") - + self.logger.info( + f"在时间 {break_time} 处发现断点,时间差为 {time_diff}") + current_group_start = idx # 添加最后一组 @@ -142,11 +144,11 @@ class GPSFilter: self.logger.info(f"共分为 {len(time_groups)} 个时间组") return time_groups - def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_interval=60): + def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)): """ 过滤密集点,先按时间分组,再在每个时间组内过滤。 空时间戳的点不进行过滤。 - + Args: points_df: 点数据 grid_size: 网格大小 @@ -154,27 +156,28 @@ class GPSFilter: time_interval: 时间间隔(秒) """ self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, " - f"距离阈值: {distance_threshold}米, 时间间隔: {time_interval}秒)") - + f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)") + # 按时间分组 - time_groups = self._group_by_time(points_df, time_interval) - + time_groups = self._group_by_time(points_df, time_threshold) + # 存储所有要删除的图片 all_to_del_imgs = [] - + # 对每个时间组进行密集点过滤 for group_idx, group_points in enumerate(time_groups): # 检查是否为空时间戳组(最后一组) if group_idx == len(time_groups) - 1 and group_points['date'].isna().any(): self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)") continue - - self.logger.info(f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)") - + + self.logger.info( + f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)") + # 计算该组内的点间距离 sorted_distances = self._get_distances(group_points, grid_size) group_to_del_imgs = [] - + # 在每个网格中过滤密集点 for grid, distances in sorted_distances.items(): grid_del_count = 0 @@ -182,7 +185,7 @@ class GPSFilter: candidate_img1, candidate_img2, dist = distances[0] if dist < distance_threshold: distances.pop(0) - + # 获取候选图片的其他最短距离 candidate_img1_dist = None candidate_img2_dist = None @@ -194,52 +197,60 @@ class GPSFilter: if candidate_img2 in distance: candidate_img2_dist = distance[2] break - + # 选择要删除的点 if candidate_img1_dist and candidate_img2_dist: to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2 group_to_del_imgs.append(to_del_img) grid_del_count += 1 - self.logger.debug(f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)") - distances = [d for d in distances if to_del_img not in d] + self.logger.debug( + f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)") + distances = [ + d for d in distances if to_del_img not in d] else: break - + if grid_del_count > 0: - self.logger.info(f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点") - + self.logger.info( + f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点") + all_to_del_imgs.extend(group_to_del_imgs) - self.logger.info(f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点") + self.logger.info( + f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点") # 写入删除日志 with open(self.log_file, 'a', encoding='utf-8') as f: for img in all_to_del_imgs: f.write(img + '\n') - + # 过滤数据 filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)] - self.logger.info(f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") - + self.logger.info( + f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") + return filtered_df def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6): """过滤孤立点""" - self.logger.info(f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})") - + self.logger.info( + f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})") + coords = points_df[['lat', 'lon']].values kdtree = KDTree(coords) neighbors_count = [len(kdtree.query_ball_point( coord, threshold_distance)) for coord in coords] - + isolated_points = [] with open(self.log_file, 'a', encoding='utf-8') as f: for i, (_, row) in enumerate(points_df.iterrows()): if neighbors_count[i] < min_neighbors: isolated_points.append(row['file']) f.write(row['file']+'\n') - self.logger.debug(f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})") + self.logger.debug( + f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})") f.write('\n') - + filtered_df = points_df[~points_df['file'].isin(isolated_points)] - self.logger.info(f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点") + self.logger.info( + f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点") return filtered_df diff --git a/preprocess/odm_monitor.py b/preprocess/odm_monitor.py index 2fd766d..24761d0 100644 --- a/preprocess/odm_monitor.py +++ b/preprocess/odm_monitor.py @@ -8,7 +8,7 @@ from typing import Optional, Tuple class ODMProcessMonitor: """ODM进程监控器""" - def __init__(self, max_retries: int = 3, check_interval: int = 5): + def __init__(self, max_retries: int = 3, check_interval: int = 300): """ 初始化监控器 @@ -31,7 +31,10 @@ class ODMProcessMonitor: def _check_success(self, grid_dir: str) -> bool: """检查ODM是否执行成功""" # ODM成功完成时会生成这些文件夹 - success_markers = ['odm_orthophoto', 'odm_georeferencing', 'odm_texturing'] + if self.mode == "快拼模式": + success_markers = ['odm_orthophoto', 'odm_georeferencing'] + else: + success_markers = ['odm_orthophoto', 'odm_georeferencing', 'odm_texturing'] return all(os.path.exists(os.path.join(grid_dir, marker)) for marker in success_markers) def run_odm_with_monitor(self, command: str, grid_dir: str, grid_idx: int) -> Tuple[bool, str]: