import os import shutil from dataclasses import dataclass from typing import Dict import matplotlib.pyplot as plt import pandas as pd from tqdm import tqdm from preprocess.cluster import GPSCluster from preprocess.command_runner import CommandRunner from preprocess.gps_extractor import GPSExtractor from preprocess.gps_filter import GPSFilter from preprocess.grid_divider import GridDivider from preprocess.logger import setup_logger from preprocess.time_filter import TimeFilter @dataclass class PreprocessConfig: """预处理配置类""" image_dir: str output_dir: str eps: float = 0.01 min_samples: int = 5 filter_grid_size: float = 0.001 filter_dense_distance_threshold: float = 10 filter_distance_threshold: float = 0.001 filter_min_neighbors: int = 6 grid_overlap: float = 0.05 grid_size: float = 500 enable_filter: bool = True enable_grid_division: bool = True enable_visualization: bool = True enable_copy_images: bool = True class ImagePreprocessor: def __init__(self, config: PreprocessConfig): self.config = config self.logger = setup_logger(config.output_dir) self.gps_points = [] self.command_runner = CommandRunner(config.output_dir) def extract_gps(self) -> pd.DataFrame: """提取GPS数据""" self.logger.info("开始提取GPS数据") extractor = GPSExtractor(self.config.image_dir) self.gps_points = extractor.extract_all_gps() self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") return self.gps_points def time_filter(self) -> pd.DataFrame: """时间过滤""" self.logger.info("开始时间过滤") time_filter = TimeFilter(self.config.output_dir) self.gps_points = time_filter.filter_by_date(self.gps_points) self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points # TODO 添加聚类参数 def cluster(self) -> pd.DataFrame: """使用DBSCAN对GPS点进行聚类,只保留最大的类""" self.logger.info("开始聚类") # 创建聚类器并执行聚类 clusterer = GPSCluster(self.gps_points, output_dir=self.config.output_dir) # 获取主要类别的点 self.gps_points = clusterer.get_main_cluster() # 获取统计信息并记录 stats = clusterer.get_cluster_stats() self.logger.info( f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," f"噪声点 {stats['noise_points']} 个" ) return self.gps_points # TODO 过滤密集点算法需要改进 def filter_points(self) -> pd.DataFrame: """过滤GPS点""" if not self.config.enable_filter: return self.gps_points self.logger.info("开始过滤GPS点") filter = GPSFilter(self.config.output_dir) self.logger.info( f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})" ) self.gps_points = filter.filter_isolated_points( self.gps_points, self.config.filter_distance_threshold, self.config.filter_min_neighbors, ) self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") self.logger.info( f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})" ) self.gps_points = filter.filter_dense_points( self.gps_points, grid_size=self.config.filter_grid_size, distance_threshold=self.config.filter_dense_distance_threshold, ) self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points def divide_grids(self) -> Dict[int, pd.DataFrame]: """划分网格""" if not self.config.enable_grid_division: return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中 self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") grid_divider = GridDivider(overlap=self.config.grid_overlap) grids = grid_divider.divide_grids( self.gps_points, grid_size=self.config.grid_size ) grid_points = grid_divider.assign_to_grids(self.gps_points, grids) self.logger.info(f"成功划分为 {len(grid_points)} 个网格") return grid_points def copy_images(self, grid_points: Dict[int, pd.DataFrame]): """复制图像到目标文件夹""" if not self.config.enable_copy_images: return self.logger.info("开始复制图像文件") for grid_idx, points in grid_points.items(): if self.config.enable_grid_division: output_dir = os.path.join( self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" ) else: output_dir = os.path.join(self.config.output_dir, "project", "images") os.makedirs(output_dir, exist_ok=True) for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): src = os.path.join(self.config.image_dir, point["file"]) dst = os.path.join(output_dir, point["file"]) shutil.copy(src, dst) self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") def visualize_results(self): """可视化处理结果""" if not self.config.enable_visualization: return self.logger.info("开始生成可视化结果") extractor = GPSExtractor(self.config.image_dir) original_points_df = extractor.extract_all_gps() # 读取被过滤的图片列表 with open( os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8" ) as file: filtered_files = [line.strip() for line in file if line.strip()] # 创建一个新的图形 plt.figure(figsize=(20, 16)) # 绘制所有原始点 plt.scatter( original_points_df["lon"], original_points_df["lat"], color="blue", label="Original Points", alpha=0.6, ) # 绘制被过滤的点 filtered_points_df = original_points_df[ original_points_df["file"].isin(filtered_files) ] plt.scatter( filtered_points_df["lon"], filtered_points_df["lat"], color="red", label="Filtered Points", alpha=0.6, ) # 设置图形属性 plt.title("GPS Coordinates of Images", fontsize=14) plt.xlabel("Longitude", fontsize=12) plt.ylabel("Latitude", fontsize=12) plt.grid(True) plt.legend() # 保存图形 plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png")) plt.close() self.logger.info("预处理结果图已保存") def process(self): """执行完整的预处理流程""" try: self.extract_gps() self.cluster() # self.time_filter() # self.filter_points() grid_points = self.divide_grids() self.copy_images(grid_points) self.visualize_results() # self.logger.info("预处理任务完成") self.command_runner.run_grid_commands( grid_points, self.config.enable_grid_division ) except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) raise if __name__ == "__main__": # 创建配置 config = PreprocessConfig( image_dir=r"E:\湖南省第二测绘院\11-06-项目移交文件(王辉给)\无人机二三维节点扩容生产影像\影像数据\199\code\images", output_dir=r"test", filter_grid_size=0.001, filter_dense_distance_threshold=10, filter_distance_threshold=0.001, filter_min_neighbors=6, grid_overlap=0.05, grid_size=500, enable_filter=True, enable_grid_division=True, enable_visualization=True, enable_copy_images=True, ) # 创建处理器并执行 processor = ImagePreprocessor(config) processor.process()