import os import shutil from datetime import timedelta from dataclasses import dataclass from typing import Dict, Tuple import psutil import matplotlib.pyplot as plt import pandas as pd from tqdm import tqdm from filter.cluster_filter import GPSCluster from filter.time_group_overlap_filter import TimeGroupOverlapFilter from filter.gps_filter import GPSFilter from utils.odm_monitor import ODMProcessMonitor from utils.gps_extractor import GPSExtractor from utils.grid_divider import GridDivider from utils.logger import setup_logger from utils.visualizer import FilterVisualizer from post_pro.merge_tif import MergeTif from post_pro.merge_obj import MergeObj from post_pro.merge_laz import MergePly @dataclass class PreprocessConfig: """预处理配置类""" image_dir: str output_dir: str # 聚类过滤参数 cluster_eps: float = 0.01 cluster_min_samples: int = 5 # 时间组重叠过滤参数 time_group_overlap_threshold: float = 0.7 time_group_interval: timedelta = timedelta(minutes=5) # 孤立点过滤参数 filter_distance_threshold: float = 0.001 # 经纬度距离 filter_min_neighbors: int = 6 # 密集点过滤参数 filter_grid_size: float = 0.001 filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 filter_time_threshold: timedelta = timedelta(minutes=5) # 网格划分参数 grid_overlap: float = 0.05 grid_size: float = 500 # 几个pipline过程是否开启 mode: str = "快拼模式" produce_dem: bool = False class ImagePreprocessor: def __init__(self, config: PreprocessConfig): self.config = config # 检查磁盘空间 self._check_disk_space() # 清理并重建输出目录 if os.path.exists(config.output_dir): self._clean_output_dir() self._setup_output_dirs() # 初始化其他组件 self.logger = setup_logger(config.output_dir) self.gps_points = None self.odm_monitor = ODMProcessMonitor( config.output_dir, mode=config.mode) self.visualizer = FilterVisualizer(config.output_dir) def _clean_output_dir(self): """清理输出目录""" try: shutil.rmtree(self.config.output_dir) print(f"已清理输出目录: {self.config.output_dir}") except Exception as e: print(f"清理输出目录时发生错误: {str(e)}") raise def _setup_output_dirs(self): """创建必要的输出目录结构""" try: # 创建主输出目录 os.makedirs(self.config.output_dir) # 创建过滤图像保存目录 os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) # 创建日志目录 os.makedirs(os.path.join(self.config.output_dir, 'logs')) print(f"已创建输出目录结构: {self.config.output_dir}") except Exception as e: print(f"创建输出目录时发生错误: {str(e)}") raise def _get_directory_size(self, path): """获取目录的总大小(字节)""" total_size = 0 for dirpath, dirnames, filenames in os.walk(path): for filename in filenames: file_path = os.path.join(dirpath, filename) try: total_size += os.path.getsize(file_path) except (OSError, FileNotFoundError): continue return total_size def _check_disk_space(self): """检查磁盘空间是否足够""" # 获取输入目录大小 input_size = self._get_directory_size(self.config.image_dir) # 获取输出目录所在磁盘的剩余空间 output_drive = os.path.splitdrive( os.path.abspath(self.config.output_dir))[0] if not output_drive: # 处理Linux/Unix路径 output_drive = '/home' disk_usage = psutil.disk_usage(output_drive) free_space = disk_usage.free # 计算所需空间(输入大小的1.5倍) required_space = input_size * 12 if free_space < required_space: error_msg = ( f"磁盘空间不足!\n" f"输入目录大小: {input_size / (1024**3):.2f} GB\n" f"所需空间: {required_space / (1024**3):.2f} GB\n" f"可用空间: {free_space / (1024**3):.2f} GB\n" f"在驱动器 {output_drive}" ) raise RuntimeError(error_msg) def extract_gps(self) -> pd.DataFrame: """提取GPS数据""" self.logger.info("开始提取GPS数据") extractor = GPSExtractor(self.config.image_dir) self.gps_points = extractor.extract_all_gps() self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") def cluster(self): """使用DBSCAN对GPS点进行聚类,只保留最大的类""" previous_points = self.gps_points.copy() clusterer = GPSCluster( self.gps_points, eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples ) self.clustered_points = clusterer.fit() self.gps_points = clusterer.get_cluster_stats(self.clustered_points) self.visualizer.visualize_filter_step( self.gps_points, previous_points, "1-Clustering") def filter_isolated_points(self): """过滤孤立点""" filter = GPSFilter(self.config.output_dir) previous_points = self.gps_points.copy() self.gps_points = filter.filter_isolated_points( self.gps_points, self.config.filter_distance_threshold, self.config.filter_min_neighbors, ) self.visualizer.visualize_filter_step( self.gps_points, previous_points, "2-Isolated Points") def filter_time_group_overlap(self): """过滤重叠的时间组""" previous_points = self.gps_points.copy() filter = TimeGroupOverlapFilter( self.config.image_dir, self.config.output_dir, overlap_threshold=self.config.time_group_overlap_threshold ) self.gps_points = filter.filter_overlapping_groups( self.gps_points, time_threshold=self.config.time_group_interval ) self.visualizer.visualize_filter_step( self.gps_points, previous_points, "3-Time Group Overlap") def divide_grids(self) -> Tuple[Dict[tuple, pd.DataFrame], Dict[tuple, tuple]]: """划分网格 Returns: tuple: (grid_points, translations) - grid_points: 网格点数据字典 - translations: 网格平移量字典 """ grid_divider = GridDivider( overlap=self.config.grid_overlap, grid_size=self.config.grid_size, output_dir=self.config.output_dir ) grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap( self.gps_points ) grid_divider.visualize_grids(self.gps_points, grids) return grid_points, translations def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]): """复制图像到目标文件夹""" self.logger.info("开始复制图像文件") for grid_id, points in grid_points.items(): output_dir = os.path.join( self.config.output_dir, f"grid_{grid_id[0]}_{grid_id[1]}", "project", "images" ) os.makedirs(output_dir, exist_ok=True) for point in tqdm(points, desc=f"复制网格 ({grid_id[0]},{grid_id[1]}) 的图像"): src = os.path.join(self.config.image_dir, point["file"]) dst = os.path.join(output_dir, point["file"]) shutil.copy(src, dst) self.logger.info( f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像") def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool): """合并所有网格的影像产品""" self.logger.info("开始合并所有影像产品") merger = MergeTif(self.config.output_dir) merger.merge_all_tifs(grid_points, produce_dem) def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]): """合并所有网格的PLY点云""" self.logger.info("开始合并PLY点云") merger = MergePly(self.config.output_dir) merger.merge_grid_laz(grid_points) def merge_obj(self, grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]): """合并所有网格的OBJ模型""" self.logger.info("开始合并OBJ模型") merger = MergeObj(self.config.output_dir) merger.merge_grid_obj(grid_points, translations) def post_process(self, successful_grid_points: Dict[tuple, pd.DataFrame], grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]): """后处理:合并或复制处理结果""" if len(successful_grid_points) < len(grid_points): self.logger.warning( f"有 {len(grid_points) - len(successful_grid_points)} 个网格处理失败," f"将只合并成功处理的 {len(successful_grid_points)} 个网格" ) if self.config.mode == "快拼模式": self.merge_tif(successful_grid_points, self.config.produce_dem) elif self.config.mode == "三维模式": self.merge_ply(successful_grid_points) self.merge_obj(successful_grid_points, translations) else: self.merge_tif(successful_grid_points, self.config.produce_dem) self.merge_ply(successful_grid_points) self.merge_obj(successful_grid_points, translations) def process(self): """执行完整的预处理流程""" try: self.extract_gps() self.cluster() self.filter_isolated_points() self.filter_time_group_overlap() grid_points, translations = self.divide_grids() self.copy_images(grid_points) self.logger.info("预处理任务完成") successful_grid_points = self.odm_monitor.process_all_grids( grid_points, self.config.produce_dem) self.post_process(successful_grid_points, grid_points, translations) except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) raise if __name__ == "__main__": # 创建配置 config = PreprocessConfig( image_dir=r"E:\datasets\UAV\134\project\images", output_dir=r"G:\ODM_output\134", cluster_eps=0.01, cluster_min_samples=5, # 添加时间组重叠过滤参数 time_group_overlap_threshold=0.7, time_group_interval=timedelta(minutes=5), filter_distance_threshold=0.001, filter_min_neighbors=6, filter_grid_size=0.001, filter_dense_distance_threshold=10, filter_time_threshold=timedelta(minutes=5), grid_size=800, grid_overlap=0.05, mode="快拼模式", produce_dem=False, ) # 创建处理器并执行 processor = ImagePreprocessor(config) processor.process()