diff --git a/odm_preprocess.py b/odm_preprocess.py index 40f2edc..d57ac8a 100644 --- a/odm_preprocess.py +++ b/odm_preprocess.py @@ -1,23 +1,25 @@ +import os +import shutil +from dataclasses import dataclass +from typing import Dict + +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm + +from preprocess.cluster import GPSCluster +from preprocess.command_runner import CommandRunner from preprocess.gps_extractor import GPSExtractor -from preprocess.time_filter import TimeFilter from preprocess.gps_filter import GPSFilter from preprocess.grid_divider import GridDivider from preprocess.logger import setup_logger -from preprocess.command_runner import CommandRunner -import os -import pandas as pd -import shutil -import matplotlib.pyplot as plt -from typing import List, Dict, Optional -from dataclasses import dataclass -from tqdm import tqdm -import subprocess -from concurrent.futures import ThreadPoolExecutor +from preprocess.time_filter import TimeFilter @dataclass class PreprocessConfig: """预处理配置类""" + image_dir: str output_dir: str filter_grid_size: float = 0.001 @@ -55,6 +57,21 @@ class ImagePreprocessor: self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points + def cluster(self) -> pd.DataFrame: + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + self.logger.info("开始聚类") + # 创建聚类器并执行聚类 + clusterer = GPSCluster(self.gps_points, eps=0.01, min_samples=5) + # 获取主要类别的点 + self.gps_points = clusterer.get_main_cluster() + # 获取统计信息并记录 + stats = clusterer.get_cluster_stats() + self.logger.info( + f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," + f"噪声点 {stats['noise_points']} 个" + ) + return self.gps_points + # TODO 过滤密集点算法需要改进 def filter_points(self) -> pd.DataFrame: """过滤GPS点""" @@ -65,20 +82,22 @@ class ImagePreprocessor: filter = GPSFilter(self.config.output_dir) self.logger.info( - f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})") + f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})" + ) self.gps_points = filter.filter_isolated_points( self.gps_points, self.config.filter_distance_threshold, - self.config.filter_min_neighbors + self.config.filter_min_neighbors, ) self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") self.logger.info( - f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})") + f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})" + ) self.gps_points = filter.filter_dense_points( self.gps_points, grid_size=self.config.filter_grid_size, - distance_threshold=self.config.filter_dense_distance_threshold + distance_threshold=self.config.filter_dense_distance_threshold, ) self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") return self.gps_points @@ -91,7 +110,8 @@ class ImagePreprocessor: self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") grid_divider = GridDivider(overlap=self.config.grid_overlap) grids = grid_divider.divide_grids( - self.gps_points, grid_size=self.config.grid_size) + self.gps_points, grid_size=self.config.grid_size + ) grid_points = grid_divider.assign_to_grids(self.gps_points, grids) self.logger.info(f"成功划分为 {len(grid_points)} 个网格") return grid_points @@ -106,16 +126,16 @@ class ImagePreprocessor: for grid_idx, points in grid_points.items(): if self.config.enable_grid_division: output_dir = os.path.join( - self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images') + self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" + ) else: - output_dir = os.path.join( - self.config.output_dir, 'project', 'images') + output_dir = os.path.join(self.config.output_dir, "project", "images") os.makedirs(output_dir, exist_ok=True) for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): - src = os.path.join(self.config.image_dir, point['file']) - dst = os.path.join(output_dir, point['file']) + src = os.path.join(self.config.image_dir, point["file"]) + dst = os.path.join(output_dir, point["file"]) shutil.copy(src, dst) self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") @@ -129,27 +149,34 @@ class ImagePreprocessor: original_points_df = extractor.extract_all_gps() # 读取被过滤的图片列表 - with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file: + with open( + os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8" + ) as file: filtered_files = [line.strip() for line in file if line.strip()] # 创建一个新的图形 plt.figure(figsize=(20, 16)) # 绘制所有原始点 - plt.scatter(original_points_df['lon'], - original_points_df['lat'], - color='blue', - label="Original Points", - alpha=0.6) + plt.scatter( + original_points_df["lon"], + original_points_df["lat"], + color="blue", + label="Original Points", + alpha=0.6, + ) # 绘制被过滤的点 - filtered_points_df = original_points_df[original_points_df['file'].isin( - filtered_files)] - plt.scatter(filtered_points_df['lon'], - filtered_points_df['lat'], - color="red", - label="Filtered Points", - alpha=0.6) + filtered_points_df = original_points_df[ + original_points_df["file"].isin(filtered_files) + ] + plt.scatter( + filtered_points_df["lon"], + filtered_points_df["lat"], + color="red", + label="Filtered Points", + alpha=0.6, + ) # 设置图形属性 plt.title("GPS Coordinates of Images", fontsize=14) @@ -159,7 +186,7 @@ class ImagePreprocessor: plt.legend() # 保存图形 - plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png')) + plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png")) plt.close() self.logger.info("预处理结果图已保存") @@ -167,26 +194,27 @@ class ImagePreprocessor: """执行完整的预处理流程""" try: self.extract_gps() - self.time_filter() - self.filter_points() - grid_points = self.divide_grids() - self.copy_images(grid_points) - self.visualize_results() - self.logger.info("预处理任务完成") - self.command_runner.run_grid_commands( - grid_points, - self.config.enable_grid_division - ) + self.cluster() + # self.time_filter() + # self.filter_points() + # grid_points = self.divide_grids() + # self.copy_images(grid_points) + # self.visualize_results() + # self.logger.info("预处理任务完成") + # self.command_runner.run_grid_commands( + # grid_points, + # self.config.enable_grid_division + # ) except Exception as e: self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) raise -if __name__ == '__main__': +if __name__ == "__main__": # 创建配置 config = PreprocessConfig( - image_dir=r'C:\datasets\1815\images', - output_dir=r'C:\datasets\1815\output', + image_dir=r"../code/images", + output_dir=r"../code/output", filter_grid_size=0.001, filter_dense_distance_threshold=10, filter_distance_threshold=0.001, @@ -195,7 +223,7 @@ if __name__ == '__main__': enable_filter=True, enable_grid_division=True, enable_visualization=True, - enable_copy_images=True + enable_copy_images=True, ) # 创建处理器并执行 diff --git a/preprocess/cluster.py b/preprocess/cluster.py new file mode 100644 index 0000000..4837870 --- /dev/null +++ b/preprocess/cluster.py @@ -0,0 +1,80 @@ +from sklearn.cluster import DBSCAN +from sklearn.preprocessing import StandardScaler + + +class GPSCluster: + def __init__(self, gps_points, eps=0.01, min_samples=5): + """ + 初始化GPS聚类器 + + 参数: + eps: DBSCAN的邻域半径参数 + min_samples: DBSCAN的最小样本数参数 + """ + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) + self.scaler = StandardScaler() + self.gps_points = gps_points + self.clustered_points = self.fit() + + def fit(self): + """ + 对GPS点进行聚类,只保留最大的类 + + 参数: + gps_points: 包含'lat'和'lon'列的DataFrame + + 返回: + 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1 + """ + # 提取经纬度数据 + X = self.gps_points[["lon", "lat"]].values + + # # 数据标准化 + # X_scaled = self.scaler.fit_transform(X) + + # 执行DBSCAN聚类 + labels = self.dbscan.fit_predict(X) + + # 找出最大类的标签(排除噪声点-1) + unique_labels = [l for l in set(labels) if l != -1] + if unique_labels: # 如果有聚类 + label_counts = [(l, sum(labels == l)) for l in unique_labels] + largest_label = max(label_counts, key=lambda x: x[1])[0] + + # 将最大类标记为1,其他都标记为-1 + new_labels = (labels == largest_label).astype(int) + new_labels[new_labels == 0] = -1 + else: # 如果没有聚类,全部标记为-1 + new_labels = labels + + # 将聚类结果添加到原始数据中 + result_df = self.gps_points.copy() + result_df["cluster"] = new_labels + + return result_df + + def get_cluster_stats(self): + """ + 获取聚类统计信息 + + 参数: + clustered_points: 带有聚类标签的DataFrame + + 返回: + 聚类统计信息的字典 + """ + main_cluster_points = sum(self.clustered_points["cluster"] == 1) + stats = { + "total_points": len(self.clustered_points), + "main_cluster_points": main_cluster_points, + "noise_points": sum(self.clustered_points["cluster"] == -1), + } + return stats + + def get_main_cluster(self): + return self.clustered_points[self.clustered_points["cluster"] == 1] + + def get_noise_cluster(self): + return self.clustered_points[self.clustered_points["cluster"] == -1]