diff --git a/odm_preprocess.py b/odm_preprocess.py index 6dc68df..3332c91 100644 --- a/odm_preprocess.py +++ b/odm_preprocess.py @@ -15,6 +15,7 @@ from filter.gps_filter import GPSFilter from utils.grid_divider import GridDivider from utils.logger import setup_logger from filter.time_group_overlap_filter import TimeGroupOverlapFilter +from utils.visualizer import FilterVisualizer @dataclass @@ -52,9 +53,12 @@ class ImagePreprocessor: def __init__(self, config: PreprocessConfig): self.config = config self.logger = setup_logger(config.output_dir) - self.gps_points = [] - self.command_runner = CommandRunner( - config.output_dir, mode=config.mode) + self.gps_points = None + self.command_runner = CommandRunner(config.output_dir, mode=config.mode) + self.visualizer = FilterVisualizer(config.output_dir) + # 用于存储每个步骤的点数据 + self.points_history = [] + self.step_names = [] def extract_gps(self) -> pd.DataFrame: """提取GPS数据""" @@ -62,24 +66,32 @@ class ImagePreprocessor: extractor = GPSExtractor(self.config.image_dir) self.gps_points = extractor.extract_all_gps() self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + + # 记录初始状态 + self.points_history.append(self.gps_points.copy()) + self.step_names.append("Initial") return self.gps_points def cluster(self) -> pd.DataFrame: - """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + """使用DBSCAN对GPS点进行聚类""" self.logger.info("开始聚类") - # 创建聚类器并执行聚类 + previous_points = self.gps_points.copy() + clusterer = GPSCluster( self.gps_points, output_dir=self.config.output_dir, eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) - # 获取主要类别的点 self.clustered_points = clusterer.fit() self.gps_points = clusterer.get_main_cluster(self.clustered_points) - # 获取统计信息并记录 - stats = clusterer.get_cluster_stats(self.clustered_points) - self.logger.info( - f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," - f"噪声点 {stats['noise_points']} 个" - ) + + # 可视化聚类结果 + if self.config.enable_visualization: + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "Clustering") + + # 记录这一步的结果 + self.points_history.append(self.gps_points.copy()) + self.step_names.append("Clustering") + return self.gps_points def filter_time_group_overlap(self) -> pd.DataFrame: """过滤重叠的时间组""" @@ -87,6 +99,8 @@ class ImagePreprocessor: return self.gps_points self.logger.info("开始过滤重叠时间组") + previous_points = self.gps_points.copy() + filter = TimeGroupOverlapFilter( self.config.image_dir, self.config.output_dir, @@ -97,150 +111,31 @@ class ImagePreprocessor: time_threshold=self.config.time_group_interval ) - # 更新GPS点数据,移除被删除的图像 - self.gps_points = self.gps_points[~self.gps_points['file'].isin( - deleted_files)] - self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点") - + self.gps_points = self.gps_points[~self.gps_points['file'].isin(deleted_files)] + + # 可视化过滤结果 + if self.config.enable_visualization: + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "Time Group Overlap") + + # 记录这一步的结果 + self.points_history.append(self.gps_points.copy()) + self.step_names.append("Time Group Overlap") return self.gps_points - # TODO 过滤算法还需要更新 - def filter_points(self) -> pd.DataFrame: - """过滤GPS点""" - if not self.config.enable_filter: - return self.gps_points - - self.logger.info("开始过滤GPS点") - filter = GPSFilter(self.config.output_dir) - - self.logger.info( - f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})" - ) - self.gps_points = filter.filter_isolated_points( - self.gps_points, - self.config.filter_distance_threshold, - self.config.filter_min_neighbors, - ) - self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") - - self.logger.info( - f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})" - ) - self.gps_points = filter.filter_dense_points( - self.gps_points, - grid_size=self.config.filter_grid_size, - distance_threshold=self.config.filter_dense_distance_threshold, - time_threshold=self.config.filter_time_threshold, - ) - self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") - return self.gps_points - - def divide_grids(self) -> Dict[int, pd.DataFrame]: - """划分网格""" - if not self.config.enable_grid_division: - return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中 - - self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") - grid_divider = GridDivider(overlap=self.config.grid_overlap) - grids = grid_divider.divide_grids( - self.gps_points, grid_size=self.config.grid_size - ) - grid_points = grid_divider.assign_to_grids(self.gps_points, grids) - self.logger.info(f"成功划分为 {len(grid_points)} 个网格") - return grid_points - - def copy_images(self, grid_points: Dict[int, pd.DataFrame]): - """复制图像到目标文件夹""" - if not self.config.enable_copy_images: - return - - self.logger.info("开始复制图像文件") - - for grid_idx, points in grid_points.items(): - if self.config.enable_grid_division: - output_dir = os.path.join( - self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" - ) - else: - output_dir = os.path.join( - self.config.output_dir, "project", "images") - - os.makedirs(output_dir, exist_ok=True) - - for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): - src = os.path.join(self.config.image_dir, point["file"]) - dst = os.path.join(output_dir, point["file"]) - shutil.copy(src, dst) - self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") - - def visualize_results(self): - """可视化处理结果""" - if not self.config.enable_visualization: - return - - self.logger.info("开始生成可视化结果") - extractor = GPSExtractor(self.config.image_dir) - original_points_df = extractor.extract_all_gps() - - # 读取被过滤的图片列表 - with open( - os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8" - ) as file: - filtered_files = [line.strip() for line in file if line.strip()] - - # 创建一个新的图形 - plt.figure(figsize=(20, 16)) - - # 绘制所有原始点 - plt.scatter( - original_points_df["lon"], - original_points_df["lat"], - color="blue", - label="Original Points", - alpha=0.6, - ) - - # 绘制被过滤的点 - filtered_points_df = original_points_df[ - original_points_df["file"].isin(filtered_files) - ] - plt.scatter( - filtered_points_df["lon"], - filtered_points_df["lat"], - color="red", - marker="x", - label="Filtered Points", - alpha=0.6, - ) - - # 设置图形属性 - plt.title("GPS Coordinates of Images", fontsize=14) - plt.xlabel("Longitude", fontsize=12) - plt.ylabel("Latitude", fontsize=12) - plt.grid(True) - plt.legend() - - # 保存图形 - plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png")) - plt.close() - self.logger.info("预处理结果图已保存") - def process(self): """执行完整的预处理流程""" try: self.extract_gps() self.cluster() self.filter_time_group_overlap() - # self.filter_points() - # grid_points = self.divide_grids() - # self.copy_images(grid_points) - # self.visualize_results() - # self.logger.info("预处理任务完成") - # self.command_runner.run_grid_commands( - # grid_points, - # self.config.enable_grid_division, - # ) - # TODO 拼图 + + # 在处理结束时生成所有步骤的可视化 + if self.config.enable_visualization: + self.visualizer.visualize_all_steps( + self.points_history, self.step_names) + + self.logger.info("预处理任务完成") except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) raise diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000..015a1a5 --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,115 @@ +import os +import matplotlib.pyplot as plt +import pandas as pd +import logging +from typing import List, Optional + +class FilterVisualizer: + """过滤结果可视化器""" + + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.Visualizer') + + def visualize_filter_step(self, + current_points: pd.DataFrame, + previous_points: pd.DataFrame, + step_name: str, + save_name: Optional[str] = None): + """ + 可视化单个过滤步骤的结果 + + Args: + current_points: 当前步骤后的点 + previous_points: 上一步骤的点 + step_name: 步骤名称 + save_name: 保存文件名,默认为step_name + """ + self.logger.info(f"开始生成{step_name}的可视化结果") + + # 找出被过滤掉的点 + filtered_files = set(previous_points['file']) - set(current_points['file']) + filtered_points = previous_points[previous_points['file'].isin(filtered_files)] + + # 创建图形 + plt.figure(figsize=(20, 16)) + + # 绘制保留的点 + plt.scatter(current_points['lon'], current_points['lat'], + color='blue', label='Retained Points', + alpha=0.6, s=50) + + # 绘制被过滤的点 + plt.scatter(filtered_points['lon'], filtered_points['lat'], + color='red', marker='x', label='Filtered Points', + alpha=0.6, s=100) + + # 设置图形属性 + plt.title(f"GPS Points After {step_name}", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 添加统计信息 + stats_text = (f"Total Points: {len(previous_points)}\n" + f"Filtered Points: {len(filtered_points)}\n" + f"Remaining Points: {len(current_points)}") + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + plt.legend() + + # 保存图形 + save_name = save_name or step_name.lower().replace(' ', '_') + save_path = os.path.join(self.output_dir, f'filter_{save_name}.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info(f"{step_name}过滤可视化结果已保存至 {save_path}") + + def visualize_all_steps(self, + points_history: List[pd.DataFrame], + step_names: List[str]): + """ + 可视化所有过滤步骤的结果 + + Args: + points_history: 每个步骤的点数据列表 + step_names: 步骤名称列表 + """ + if len(points_history) != len(step_names): + raise ValueError("点数据列表和步骤名称列表长度不匹配") + + plt.figure(figsize=(20, 16)) + + # 使用不同的颜色表示不同步骤 + colors = plt.cm.rainbow(np.linspace(0, 1, len(points_history))) + + # 绘制每个步骤的点 + for i, (points, color) in enumerate(zip(points_history, colors)): + plt.scatter(points['lon'], points['lat'], + color=color, label=f'After {step_names[i]}', + alpha=0.6, s=50) + + # 设置图形属性 + plt.title("GPS Points Filtering Process", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 添加统计��息 + stats_text = "Points Count:\n" + "\n".join( + f"{step_names[i]}: {len(points)}" + for i, points in enumerate(points_history) + ) + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + plt.legend() + + # 保存图形 + save_path = os.path.join(self.output_dir, 'filter_all_steps.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info(f"所有过滤步骤的可视化结果已保存至 {save_path}") \ No newline at end of file