from preprocess.gps_extractor import GPSExtractor from preprocess.time_filter import TimeFilter from preprocess.gps_filter import GPSFilter from preprocess.grid_divider import GridDivider from preprocess.logger import setup_logger from preprocess.command_runner import CommandRunner import os import pandas as pd import shutil import matplotlib.pyplot as plt from typing import List, Dict, Optional from dataclasses import dataclass from tqdm import tqdm import subprocess from concurrent.futures import ThreadPoolExecutor @dataclass class PreprocessConfig: """预处理配置类""" image_dir: str output_dir: str filter_grid_size: float = 0.001 filter_dense_distance_threshold: float = 10 filter_distance_threshold: float = 0.001 filter_min_neighbors: int = 6 grid_overlap: float = 0.05 grid_size: float = 250 enable_filter: bool = True enable_grid_division: bool = True enable_visualization: bool = True enable_copy_images: bool = True class ImagePreprocessor: def __init__(self, config: PreprocessConfig): self.config = config self.logger = setup_logger(config.output_dir) self.gps_points = [] self.command_runner = CommandRunner(config.output_dir) def extract_gps(self) -> pd.DataFrame: """提取GPS数据""" self.logger.info("开始提取GPS数据") extractor = GPSExtractor(self.config.image_dir) self.gps_points = extractor.extract_all_gps() self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") return self.gps_points def time_filter(self) -> pd.DataFrame: """时间过滤""" self.logger.info("开始时间过滤") time_filter = TimeFilter(self.config.output_dir) self.gps_points = time_filter.filter_by_date(self.gps_points) self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points # TODO 过滤密集点算法需要改进 def filter_points(self) -> pd.DataFrame: """过滤GPS点""" if not self.config.enable_filter: return self.gps_points self.logger.info("开始过滤GPS点") filter = GPSFilter(self.config.output_dir) self.logger.info( f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})") self.gps_points = filter.filter_isolated_points( self.gps_points, self.config.filter_distance_threshold, self.config.filter_min_neighbors ) self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") self.logger.info( f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})") self.gps_points = filter.filter_dense_points( self.gps_points, grid_size=self.config.filter_grid_size, distance_threshold=self.config.filter_dense_distance_threshold ) self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points def divide_grids(self) -> Dict[int, pd.DataFrame]: """划分网格""" if not self.config.enable_grid_division: return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中 self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") grid_divider = GridDivider(overlap=self.config.grid_overlap) grids = grid_divider.divide_grids( self.gps_points, grid_size=self.config.grid_size) grid_points = grid_divider.assign_to_grids(self.gps_points, grids) self.logger.info(f"成功划分为 {len(grid_points)} 个网格") return grid_points def copy_images(self, grid_points: Dict[int, pd.DataFrame]): """复制图像到目标文件夹""" if not self.config.enable_copy_images: return self.logger.info("开始复制图像文件") for grid_idx, points in grid_points.items(): if self.config.enable_grid_division: output_dir = os.path.join( self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images') else: output_dir = os.path.join( self.config.output_dir, 'project', 'images') os.makedirs(output_dir, exist_ok=True) for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): src = os.path.join(self.config.image_dir, point['file']) dst = os.path.join(output_dir, point['file']) shutil.copy(src, dst) self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") def visualize_results(self): """可视化处理结果""" if not self.config.enable_visualization: return self.logger.info("开始生成可视化结果") extractor = GPSExtractor(self.config.image_dir) original_points_df = extractor.extract_all_gps() # 读取被过滤的图片列表 with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file: filtered_files = [line.strip() for line in file if line.strip()] # 创建一个新的图形 plt.figure(figsize=(20, 16)) # 绘制所有原始点 plt.scatter(original_points_df['lon'], original_points_df['lat'], color='blue', label="Original Points", alpha=0.6) # 绘制被过滤的点 filtered_points_df = original_points_df[original_points_df['file'].isin( filtered_files)] plt.scatter(filtered_points_df['lon'], filtered_points_df['lat'], color="red", label="Filtered Points", alpha=0.6) # 设置图形属性 plt.title("GPS Coordinates of Images", fontsize=14) plt.xlabel("Longitude", fontsize=12) plt.ylabel("Latitude", fontsize=12) plt.grid(True) plt.legend() # 保存图形 plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png')) plt.close() self.logger.info("预处理结果图已保存") def process(self): """执行完整的预处理流程""" try: self.extract_gps() self.time_filter() self.filter_points() grid_points = self.divide_grids() self.copy_images(grid_points) self.visualize_results() self.logger.info("预处理任务完成") self.command_runner.run_grid_commands( grid_points, self.config.enable_grid_division ) except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) raise if __name__ == '__main__': # 创建配置 config = PreprocessConfig( image_dir=r'E:\datasets\UAV\1815', output_dir=r'E:\datasets\UAV\1815\output', filter_grid_size=0.001, filter_dense_distance_threshold=10, filter_distance_threshold=0.001, filter_min_neighbors=6, grid_overlap=0.05, enable_filter=True, enable_grid_division=True, enable_visualization=True, enable_copy_images=True ) # 创建处理器并执行 processor = ImagePreprocessor(config) processor.process()