diff --git a/filter/cluster_filter.py b/filter/cluster_filter.py new file mode 100644 index 0000000..473da7d --- /dev/null +++ b/filter/cluster_filter.py @@ -0,0 +1,77 @@ +from sklearn.cluster import DBSCAN +from sklearn.preprocessing import StandardScaler +import os +import logging + + +class GPSCluster: + def __init__(self, gps_points, eps=0.01, min_samples=3): + """ + 初始化GPS聚类器 + + 参数: + eps: DBSCAN的邻域半径参数 + min_samples: DBSCAN的最小样本数参数 + """ + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) + self.scaler = StandardScaler() + self.gps_points = gps_points + self.logger = logging.getLogger('UAV_Preprocess.GPSCluster') + + def fit(self): + """ + 对GPS点进行聚类,只保留最大的类 + + 参数: + gps_points: 包含'lat'和'lon'列的DataFrame + + 返回: + 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1 + """ + self.logger.info("开始聚类") + # 提取经纬度数据 + X = self.gps_points[["lon", "lat"]].values + + # # 数据标准化 + # X_scaled = self.scaler.fit_transform(X) + + # 执行DBSCAN聚类 + labels = self.dbscan.fit_predict(X) + + # 找出最大类的标签(排除噪声点-1) + unique_labels = [l for l in set(labels) if l != -1] + if unique_labels: # 如果有聚类 + label_counts = [(l, sum(labels == l)) for l in unique_labels] + largest_label = max(label_counts, key=lambda x: x[1])[0] + + # 将最大类标记为1,其他都标记为-1 + new_labels = (labels == largest_label).astype(int) + new_labels[new_labels == 0] = -1 + else: # 如果没有聚类,全部标记为-1 + new_labels = labels + + # 将聚类结果添加到原始数据中 + result_df = self.gps_points.copy() + result_df["cluster"] = new_labels + + return result_df + + def get_cluster_stats(self, clustered_points): + """ + 获取聚类统计信息 + + 参数: + clustered_points: 带有聚类标签的DataFrame + + 返回: + 聚类统计信息的字典 + """ + main_cluster = clustered_points[clustered_points["cluster"] == 1] + noise_cluster = clustered_points[clustered_points["cluster"] == -1] + + self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点," + f"噪声点 {len(noise_cluster)} 个") + + return main_cluster diff --git a/main.py b/main.py new file mode 100644 index 0000000..aee07fe --- /dev/null +++ b/main.py @@ -0,0 +1,249 @@ +import os +import shutil +from datetime import timedelta +from dataclasses import dataclass +from typing import Dict, Tuple +import psutil +import pandas as pd +from pathlib import Path + +from filter.cluster_filter import GPSCluster +from utils.gps_extractor import GPSExtractor +from utils.grid_divider import GridDivider +from utils.logger import setup_logger +from utils.visualizer import FilterVisualizer +from utils.docker_runner import DockerRunner +from post_pro.conv_obj import ConvertOBJ + + +@dataclass +class ProcessConfig: + """配置类""" + + image_dir: str + output_dir: str + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + # 时间组重叠过滤参数 + time_group_overlap_threshold: float = 0.7 + time_group_interval: timedelta = timedelta(minutes=5) + # 孤立点过滤参数 + filter_distance_threshold: float = 0.001 # 经纬度距离 + filter_min_neighbors: int = 6 + # 密集点过滤参数 + filter_grid_size: float = 0.001 + filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 + filter_time_threshold: timedelta = timedelta(minutes=5) + # 网格划分参数 + grid_overlap: float = 0.05 + grid_size: float = 500 + # 几个pipline过程是否开启 + mode: str = "快拼模式" + accuracy: str = "medium" + produce_dem: bool = False + + +class ODM_Plugin: + def __init__(self, config: ProcessConfig): + self.config = config + + # 检查磁盘空间 + # TODO 现在输入目录的磁盘空间也需要检查 + self._check_disk_space() + + # 清理并重建输出目录 + if os.path.exists(config.output_dir): + self._clean_output_dir() + self._setup_output_dirs() + + # 修改输入目录,符合ODM要求,从这里开始,image_dir就是project_path + self._rename_input_dir() + self.project_path = self.config.image_dir + + # 初始化其他组件 + self.logger = setup_logger(config.output_dir) + self.gps_points = None + self.grid_points = None + self.visualizer = FilterVisualizer(config.output_dir) + + def _clean_output_dir(self): + """清理输出目录""" + try: + shutil.rmtree(self.config.output_dir) + print(f"已清理输出目录: {self.config.output_dir}") + except Exception as e: + print(f"清理输出目录时发生错误: {str(e)}") + raise + + def _setup_output_dirs(self): + """创建必要的输出目录结构""" + try: + # 创建主输出目录 + os.makedirs(self.config.output_dir) + + # 创建过滤图像保存目录 + os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) + + # 创建日志目录 + os.makedirs(os.path.join(self.config.output_dir, 'logs')) + + print(f"已创建输出目录结构: {self.config.output_dir}") + except Exception as e: + print(f"创建输出目录时发生错误: {str(e)}") + raise + + def _get_directory_size(self, path): + """获取目录的总大小(字节)""" + total_size = 0 + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + file_path = os.path.join(dirpath, filename) + try: + total_size += os.path.getsize(file_path) + except (OSError, FileNotFoundError): + continue + return total_size + + def _check_disk_space(self): + """检查磁盘空间是否足够""" + # 获取输入目录大小 + input_size = self._get_directory_size(self.config.image_dir) + + # 获取输出目录所在磁盘的剩余空间 + output_drive = os.path.splitdrive( + os.path.abspath(self.config.output_dir))[0] + if not output_drive: # 处理Linux/Unix路径 + output_drive = '/home' + + disk_usage = psutil.disk_usage(output_drive) + free_space = disk_usage.free + + # 计算所需空间(输入大小的1.5倍) + required_space = input_size * 12 + + if free_space < required_space: + error_msg = ( + f"磁盘空间不足!\n" + f"输入目录大小: {input_size / (1024**3):.2f} GB\n" + f"所需空间: {required_space / (1024**3):.2f} GB\n" + f"可用空间: {free_space / (1024**3):.2f} GB\n" + f"在驱动器 {output_drive}" + ) + raise RuntimeError(error_msg) + + def _rename_input_dir(self): + image_dir = Path(self.config.image_dir).resolve() + + if not image_dir.exists() or not image_dir.is_dir(): + raise ValueError( + f"Provided path '{image_dir}' is not a valid directory.") + + # 原目录名和父路径 + parent_dir = image_dir.parent + original_name = image_dir.name + + # 新的 images 路径(原目录重命名为 images) + images_path = parent_dir / "images" + + # 重命名原目录为 images + image_dir.rename(images_path) + + # 创建一个新的、和原目录同名的文件夹 + new_root = parent_dir / original_name + new_root.mkdir(exist_ok=False) + + # 创建 project 子文件夹 + project_dir = new_root / "project" + project_dir.mkdir() + + # 把 images 文件夹移动到 project 下 + final_images_path = project_dir / "images" + shutil.move(str(images_path), str(final_images_path)) + + print(f"符合标准输入的文件夹结构已经创建好了,{final_images_path}") + + return final_images_path + + def extract_gps(self) -> pd.DataFrame: + """提取GPS数据""" + self.logger.info("开始提取GPS数据") + extractor = GPSExtractor(self.project_path) + self.gps_points = extractor.extract_all_gps() + self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + + def cluster(self): + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + previous_points = self.gps_points.copy() + clusterer = GPSCluster( + self.gps_points, + eps=self.config.cluster_eps, + min_samples=self.config.cluster_min_samples + ) + + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_cluster_stats(self.clustered_points) + + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "1-Clustering") + + def divide_grids(self): + """划分网格 + Returns: + tuple: (grid_points, translations) + - grid_points: 网格点数据字典 + - translations: 网格平移量字典 + """ + grid_divider = GridDivider( + overlap=self.config.grid_overlap, + grid_size=self.config.grid_size, + project_path=self.project_path, + output_dir=self.config.output_dir + ) + grids, self.grid_points = grid_divider.adjust_grid_size_and_overlap( + self.gps_points + ) + grid_divider.visualize_grids(self.gps_points, grids) + grid_divider.save_image_groups(self.grid_points) + if len(grids) >= 20: + self.logger.warning("网格数量已超过20, 需要人工调整分区") + + def odm_docker_runner(self): + """"运行OMD docker容器""" + self.logger.info("开始运行Docker容器") + # TODO:加一些容错处理 + docker_runner = DockerRunner(self.project_path) + docker_runner.run_odm_container() + + def convert_obj(self): + """转换OBJ模型""" + self.logger.info("开始转换OBJ模型") + converter = ConvertOBJ(self.config.output_dir) + converter.convert_grid_obj(self.grid_points) + + def post_process(self): + """后处理:合并或复制处理结果""" + self.logger.info("开始后处理") + + self.logger.info("拷贝正射影像至输出目录") + orthophoto_tif_path = os.path.join( + self.project_path, "odm_orthophoto", "odm_orthophoto.tif") + shutil.copy(orthophoto_tif_path, self.config.output_dir) + # if self.config.mode == "三维模式": + # self.convert_obj() + # else: + # pass + + def process(self): + """执行完整的预处理流程""" + try: + self.extract_gps() + self.cluster() + self.divide_grids() + self.logger.info("==========预处理任务完成==========") + self.odm_docker_runner() + self.post_process() + + except Exception as e: + self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + raise diff --git a/post_pro/conv_obj.py b/post_pro/conv_obj.py new file mode 100644 index 0000000..41d7ff7 --- /dev/null +++ b/post_pro/conv_obj.py @@ -0,0 +1,253 @@ +import os +import subprocess +import json +import shutil +import logging +from pyproj import Transformer +import cv2 + + +class ConvertOBJ: + def __init__(self, output_dir: str): + self.output_dir = output_dir + # 用于存储所有grid的UTM范围 + self.ref_east = float('inf') + self.ref_north = float('inf') + # 初始化UTM到WGS84的转换器 + self.transformer = Transformer.from_crs( + "EPSG:32649", "EPSG:4326", always_xy=True) + self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ') + + def convert_grid_obj(self, grid_points): + """转换每个网格的OBJ文件为OSGB格式""" + os.makedirs(os.path.join(self.output_dir, + "osgb", "Data"), exist_ok=True) + + # 以第一个grid的UTM坐标作为参照系 + first_grid_id = list(grid_points.keys())[0] + first_grid_dir = os.path.join( + self.output_dir, + f"grid_{first_grid_id[0]}_{first_grid_id[1]}", + "project" + ) + log_file = os.path.join( + first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + self.ref_east, self.ref_north = self.read_utm_offset(log_file) + + for grid_id in grid_points.keys(): + try: + self._convert_single_grid(grid_id, grid_points) + except Exception as e: + self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}") + + self._create_merged_metadata() + + def _convert_single_grid(self, grid_id, grid_points): + """转换单个网格的OBJ文件""" + # 构建相关路径 + grid_name = f"grid_{grid_id[0]}_{grid_id[1]}" + project_dir = os.path.join(self.output_dir, grid_name, "project") + texturing_dir = os.path.join(project_dir, "odm_texturing") + texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst") + opensfm_dir = os.path.join(project_dir, "opensfm") + log_file = os.path.join( + project_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + os.makedirs(texturing_dst_dir, exist_ok=True) + + # 修改obj文件z坐标的值 + min_25d_z = self.get_min_z_from_obj(os.path.join( + project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj')) + self.modify_z_in_obj(texturing_dir, min_25d_z) + + # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样 + utm_offset = self.read_utm_offset(log_file) + modified_obj = self.modify_obj_coordinates( + texturing_dir, texturing_dst_dir, utm_offset) + self.downsample_texture(texturing_dir, texturing_dst_dir) + + # 执行格式转换,Linux下osgconv有问题,记得注释掉 + self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件") + output_osgb = os.path.join(texturing_dst_dir, "Tile.osgb") + cmd = ( + f"osgconv {modified_obj} {output_osgb} " + f"--compressed --smooth --fix-transparency " + ) + self.logger.info(f"执行osgconv命令:{cmd}") + + try: + subprocess.run(cmd, shell=True, check=True, cwd=texturing_dir) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"OSGB转换失败: {str(e)}") + + # 创建OSGB目录结构,复制文件 + osgb_base_dir = os.path.join(self.output_dir, "osgb") + data_dir = os.path.join(osgb_base_dir, "Data") + tile_dir = os.path.join(data_dir, f"Tile_{grid_id[0]}_{grid_id[1]}") + os.makedirs(tile_dir, exist_ok=True) + target_osgb = os.path.join( + tile_dir, f"Tile_{grid_id[0]}_{grid_id[1]}.osgb") + shutil.copy2(output_osgb, target_osgb) + + def _create_merged_metadata(self): + """创建合并后的metadata.xml文件""" + # 转换为WGS84经纬度 + center_lon, center_lat = self.transformer.transform( + self.ref_east, self.ref_north) + metadata_content = f""" + + EPSG:4326 + {center_lon},{center_lat},0 + + Visible + + """ + + metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml") + with open(metadata_file, 'w', encoding='utf-8') as f: + f.write(metadata_content) + + def read_utm_offset(self, log_file: str) -> tuple: + """读取UTM偏移量""" + try: + east_offset = None + north_offset = None + + with open(log_file, 'r') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if 'utm_north_offset' in line and i + 1 < len(lines): + north_offset = float(lines[i + 1].strip()) + elif 'utm_east_offset' in line and i + 1 < len(lines): + east_offset = float(lines[i + 1].strip()) + + if east_offset is None or north_offset is None: + raise ValueError("未找到UTM偏移量") + + return east_offset, north_offset + except Exception as e: + self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}") + raise + + def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str: + """修改obj文件中的顶点坐标,使用相对坐标系""" + obj_file = os.path.join( + texturing_dir, "odm_textured_model_modified.obj") + obj_dst_file = os.path.join( + texturing_dst_dir, "odm_textured_model_geo_utm.obj") + if not os.path.exists(obj_file): + raise FileNotFoundError(f"找不到OBJ文件: {obj_file}") + shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"), + os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl")) + east_offset, north_offset = utm_offset + self.logger.info( + f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}") + + try: + with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): + # 处理顶点坐标行 + parts = line.strip().split() + # 使用相对于整体最小UTM坐标的偏移 + x = float(parts[1]) + (east_offset - self.ref_east) + y = float(parts[2]) + (north_offset - self.ref_north) + z = float(parts[3]) + f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n') + elif line.startswith('vn '): # 处理法线向量 + parts = line.split() + nx = float(parts[1]) + ny = float(parts[2]) + nz = float(parts[3]) + # 同步反转法线的 Y 轴 + new_line = f"vn {nx} {nz} {-ny}\n" + f_out.write(new_line) + else: + # 其他行直接写入 + f_out.write(line) + + return obj_dst_file + except Exception as e: + self.logger.error(f"修改obj坐标时发生错误: {str(e)}") + raise + + def downsample_texture(self, src_dir: str, dst_dir: str): + """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB + Args: + src_dir: 源纹理目录 + dst_dir: 目标纹理目录 + """ + for file in os.listdir(src_dir): + if file.lower().endswith(('.png')): + src_path = os.path.join(src_dir, file) + dst_path = os.path.join(dst_dir, file) + + # 检查文件大小(以字节为单位) + file_size = os.path.getsize(src_path) + if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制 + shutil.copy2(src_path, dst_path) + else: + # 文件大于100MB,进行下采样 + img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED) + if_first_ds = True + while file_size > 100 * 1024 * 1024: # 大于100MB + self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样") + + if if_first_ds: + # 计算新的尺寸(长宽各变为1/4) + new_size = (img.shape[1] // 4, + img.shape[0] // 4) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + if_first_ds = False + else: + # 计算新的尺寸(长宽各变为1/2) + new_size = (img.shape[1] // 2, + img.shape[0] // 2) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + + # 更新文件路径为下采样后的路径 + cv2.imwrite(dst_path, resized_img, [ + cv2.IMWRITE_PNG_COMPRESSION, 9]) + + # 更新文件大小和图像 + file_size = os.path.getsize(dst_path) + img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED) + self.logger.info( + f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB") + + def get_min_z_from_obj(self, file_path): + min_z = float('inf') # 初始值设为无穷大 + with open(file_path, 'r') as obj_file: + for line in obj_file: + # 检查每一行是否是顶点定义(以 'v ' 开头) + if line.startswith('v '): + # 获取顶点坐标 + parts = line.split() + # 将z值转换为浮动数字 + z = float(parts[3]) + # 更新最小z值 + if z < min_z: + min_z = z + return min_z + + def modify_z_in_obj(self, texturing_dir, min_25d_z): + obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj') + output_file = os.path.join( + texturing_dir, 'odm_textured_model_modified.obj') + with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): # 顶点坐标行 + parts = line.strip().split() + x = float(parts[1]) + y = float(parts[2]) + z = float(parts[3]) + + if z < min_25d_z: + z = min_25d_z + + f_out.write(f"v {x} {y} {z}\n") + else: + f_out.write(line) diff --git a/post_pro/conv_obj2.py b/post_pro/conv_obj2.py new file mode 100644 index 0000000..a155f16 --- /dev/null +++ b/post_pro/conv_obj2.py @@ -0,0 +1,263 @@ +import os +import subprocess +import json +import shutil +import logging +from pyproj import Transformer +import cv2 + + +class ConvertOBJ: + def __init__(self, output_dir: str): + self.output_dir = output_dir + # 用于存储所有grid的UTM范围 + self.ref_east = float('inf') + self.ref_north = float('inf') + # 初始化UTM到WGS84的转换器 + self.transformer = Transformer.from_crs( + "EPSG:32649", "EPSG:4326", always_xy=True) + self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ') + + def convert_grid_obj(self, grid_points): + """转换每个网格的OBJ文件为OSGB格式""" + os.makedirs(os.path.join(self.output_dir, + "osgb", "Data"), exist_ok=True) + + # 以第一个grid的UTM坐标作为参照系 + first_grid_id = list(grid_points.keys())[0] + first_grid_dir = os.path.join( + self.output_dir, + f"grid_{first_grid_id[0]}_{first_grid_id[1]}", + "project" + ) + log_file = os.path.join( + first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + self.ref_east, self.ref_north = self.read_utm_offset(log_file) + + for grid_id in grid_points.keys(): + try: + self._convert_single_grid(grid_id, grid_points) + except Exception as e: + self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}") + + self._create_merged_metadata() + + def _convert_single_grid(self, grid_id, grid_points): + """转换单个网格的OBJ文件""" + # 构建相关路径 + grid_name = f"grid_{grid_id[0]}_{grid_id[1]}" + project_dir = os.path.join(self.output_dir, grid_name, "project") + texturing_dir = os.path.join(project_dir, "odm_texturing") + texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst") + split_obj_dir = os.path.join(texturing_dst_dir, "split_obj") + opensfm_dir = os.path.join(project_dir, "opensfm") + log_file = os.path.join( + project_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + os.makedirs(texturing_dst_dir, exist_ok=True) + + # 修改obj文件z坐标的值 + min_25d_z = self.get_min_z_from_obj(os.path.join( + project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj')) + self.modify_z_in_obj(texturing_dir, min_25d_z) + + # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样 + utm_offset = self.read_utm_offset(log_file) + modified_obj = self.modify_obj_coordinates( + texturing_dir, texturing_dst_dir, utm_offset) + self.downsample_texture(texturing_dir, texturing_dst_dir) + + # 将obj文件进行切片 + self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件") + os.makedirs(split_obj_dir) + cmd = ( + f"D:\software\Obj2Tiles\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 3 " + f"{modified_obj} {split_obj_dir}" + ) + subprocess.run(cmd, check=True) + + # 执行格式转换,Linux下osgconv有问题,记得注释掉 + self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件") + # 先获取split_obj_dir下的所有obj文件 + obj_lod_dir = os.path.join(split_obj_dir, "LOD-0") + obj_files = [f for f in os.listdir( + obj_lod_dir) if f.endswith('.obj')] + for obj_file in obj_files: + obj_path = os.path.join(obj_lod_dir, obj_file) + osgb_file = os.path.splitext(obj_file)[0] + '.osgb' + osgb_path = os.path.join(split_obj_dir, osgb_file) + # 执行 osgconv 命令 + subprocess.run(['osgconv', obj_path, osgb_path], check=True) + + # 创建OSGB目录结构,复制文件 + osgb_base_dir = os.path.join(self.output_dir, "osgb") + data_dir = os.path.join(osgb_base_dir, "Data") + for obj_file in obj_files: + obj_file_name = os.path.splitext(obj_file)[0] + tile_dirs = os.path.join(data_dir, f"{obj_file_name}") + os.makedirs(tile_dirs, exist_ok=True) + shutil.copy2(os.path.join( + split_obj_dir, obj_file_name+".osgb"), tile_dirs) + + def _create_merged_metadata(self): + """创建合并后的metadata.xml文件""" + # 转换为WGS84经纬度 + center_lon, center_lat = self.transformer.transform( + self.ref_east, self.ref_north) + metadata_content = f""" + + EPSG:4326 + {center_lon},{center_lat},0 + + Visible + + """ + + metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml") + with open(metadata_file, 'w', encoding='utf-8') as f: + f.write(metadata_content) + + def read_utm_offset(self, log_file: str) -> tuple: + """读取UTM偏移量""" + try: + east_offset = None + north_offset = None + + with open(log_file, 'r') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if 'utm_north_offset' in line and i + 1 < len(lines): + north_offset = float(lines[i + 1].strip()) + elif 'utm_east_offset' in line and i + 1 < len(lines): + east_offset = float(lines[i + 1].strip()) + + if east_offset is None or north_offset is None: + raise ValueError("未找到UTM偏移量") + + return east_offset, north_offset + except Exception as e: + self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}") + raise + + def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str: + """修改obj文件中的顶点坐标,使用相对坐标系""" + obj_file = os.path.join( + texturing_dir, "odm_textured_model_modified.obj") + obj_dst_file = os.path.join( + texturing_dst_dir, "odm_textured_model_geo_utm.obj") + if not os.path.exists(obj_file): + raise FileNotFoundError(f"找不到OBJ文件: {obj_file}") + shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"), + os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl")) + east_offset, north_offset = utm_offset + self.logger.info( + f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}") + + try: + with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): + # 处理顶点坐标行 + parts = line.strip().split() + # 使用相对于整体最小UTM坐标的偏移 + x = float(parts[1]) + (east_offset - self.ref_east) + y = float(parts[2]) + (north_offset - self.ref_north) + z = float(parts[3]) + f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n') + elif line.startswith('vn '): # 处理法线向量 + parts = line.split() + nx = float(parts[1]) + ny = float(parts[2]) + nz = float(parts[3]) + # 同步反转法线的 Y 轴 + new_line = f"vn {nx} {nz} {-ny}\n" + f_out.write(new_line) + else: + # 其他行直接写入 + f_out.write(line) + + return obj_dst_file + except Exception as e: + self.logger.error(f"修改obj坐标时发生错误: {str(e)}") + raise + + def downsample_texture(self, src_dir: str, dst_dir: str): + """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB + Args: + src_dir: 源纹理目录 + dst_dir: 目标纹理目录 + """ + for file in os.listdir(src_dir): + if file.lower().endswith(('.png')): + src_path = os.path.join(src_dir, file) + dst_path = os.path.join(dst_dir, file) + + # 检查文件大小(以字节为单位) + file_size = os.path.getsize(src_path) + if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制 + shutil.copy2(src_path, dst_path) + else: + # 文件大于100MB,进行下采样 + img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED) + if_first_ds = True + while file_size > 100 * 1024 * 1024: # 大于100MB + self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样") + + if if_first_ds: + # 计算新的尺寸(长宽各变为1/4) + new_size = (img.shape[1] // 4, + img.shape[0] // 4) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + if_first_ds = False + else: + # 计算新的尺寸(长宽各变为1/2) + new_size = (img.shape[1] // 2, + img.shape[0] // 2) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + + # 更新文件路径为下采样后的路径 + cv2.imwrite(dst_path, resized_img, [ + cv2.IMWRITE_PNG_COMPRESSION, 9]) + + # 更新文件大小和图像 + file_size = os.path.getsize(dst_path) + img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED) + self.logger.info( + f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB") + + def get_min_z_from_obj(self, file_path): + min_z = float('inf') # 初始值设为无穷大 + with open(file_path, 'r') as obj_file: + for line in obj_file: + # 检查每一行是否是顶点定义(以 'v ' 开头) + if line.startswith('v '): + # 获取顶点坐标 + parts = line.split() + # 将z值转换为浮动数字 + z = float(parts[3]) + # 更新最小z值 + if z < min_z: + min_z = z + return min_z + + def modify_z_in_obj(self, texturing_dir, min_25d_z): + obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj') + output_file = os.path.join( + texturing_dir, 'odm_textured_model_modified.obj') + with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): # 顶点坐标行 + parts = line.strip().split() + x = float(parts[1]) + y = float(parts[2]) + z = float(parts[3]) + + if z < min_25d_z: + z = min_25d_z + + f_out.write(f"v {x} {y} {z}\n") + else: + f_out.write(line) diff --git a/run.py b/run.py new file mode 100644 index 0000000..d583018 --- /dev/null +++ b/run.py @@ -0,0 +1,59 @@ +import argparse +from datetime import timedelta +from main import ProcessConfig, ODM_Plugin + + +def parse_args(): + parser = argparse.ArgumentParser(description='ODM预处理工具') + + # 必需参数 + # parser.add_argument('--image_dir', required=True, help='输入图片目录路径') + # parser.add_argument('--output_dir', required=True, help='输出目录路径') + parser.add_argument( + '--image_dir', default=r'E:\datasets\UAV\199', help='输入图片目录路径') + parser.add_argument( + '--output_dir', default=r'G:\ODM_output\test2', help='输出目录路径') + # 可选参数 + parser.add_argument('--mode', default='三维模式', + choices=['快拼模式', '三维模式'], help='处理模式') + parser.add_argument('--accuracy', default='medium', + choices=['high', 'medium', 'low'], help='精度') + parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)') + parser.add_argument('--grid_overlap', type=float, + default=0.05, help='网格重叠率') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # 创建配置 + config = ProcessConfig( + image_dir=args.image_dir, + output_dir=args.output_dir, + mode=args.mode, + accuracy=args.accuracy, + grid_size=args.grid_size, + grid_overlap=args.grid_overlap, + + # 其他参数使用默认值 + cluster_eps=0.01, + cluster_min_samples=5, + time_group_overlap_threshold=0.7, + time_group_interval=timedelta(minutes=5), + filter_distance_threshold=0.001, + filter_min_neighbors=6, + filter_grid_size=0.001, + filter_dense_distance_threshold=10, + filter_time_threshold=timedelta(minutes=5), + ) + + # 创建处理器并执行 + processor = ODM_Plugin(config) + processor.process() + + +if __name__ == '__main__': + main() diff --git a/utils/docker_runner.py b/utils/docker_runner.py new file mode 100644 index 0000000..de74a73 --- /dev/null +++ b/utils/docker_runner.py @@ -0,0 +1,89 @@ +import docker +import os +import logging +from collections import deque + + +class DockerRunner: + def __init__(self, project_path: str): + """ + 初始化 DockerRunner + + Args: + project_path (str): 项目路径,将挂载到 Docker 容器中 + """ + self.project_path = project_path + self.logger = logging.getLogger("DockerRunner") + self.docker_client = docker.from_env() + + def run_odm_container(self): + """ + 使用 Docker SDK 运行 OpenDroneMap 容器 + """ + try: + self.logger.info("开始运行docker run指令") + # 挂载路径 + volume_mapping = { + self.project_path: { + 'bind': '/datasets', + 'mode': 'rw' + } + } + + # Docker 命令参数 + command = [ + "--project-path", "/datasets", + "project", + "--max-concurrency", "15", + "--force-gps", + "--split-overlap", "0", + ] + + # 运行容器 + container = self.docker_client.containers.run( + image="opendronemap/odm:gpu", + command=command, + volumes=volume_mapping, + device_requests=[ + docker.types.DeviceRequest( + count=-1, capabilities=[["gpu"]]) + ], # 添加 GPU 支持 + remove=False, # 容器运行结束后不自动删除,便于获取日志 + tty=True, + detach=True # 后台运行 + ) + + # 等待容器运行完成 + exit_status = container.wait() + if exit_status["StatusCode"] != 0: + self.logger.error(f"容器运行失败,退出状态码: {exit_status['StatusCode']}") + + # 获取容器的错误日志 + error_logs = container.logs( + stderr=True).decode("utf-8").splitlines() + self.logger.error("容器运行失败的详细错误日志:") + for line in error_logs: + self.logger.error(line) + + else: + # 获取所有日志 + logs = container.logs().decode("utf-8").splitlines() + + # 输出最后 50 行日志 + self.logger.info("容器运行完成,以下是最后 50 行日志:") + for line in logs[-50:]: + self.logger.info(line) + + # 删除容器 + container.remove() + + except Exception as e: + self.logger.error(f"运行 Docker 容器时发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + # 示例用法 + project_path = r"E:\datasets\UAV\199" + docker_runner = DockerRunner(project_path) + docker_runner.run_odm_container() diff --git a/utils/gps_extractor.py b/utils/gps_extractor.py new file mode 100644 index 0000000..7d1c5c2 --- /dev/null +++ b/utils/gps_extractor.py @@ -0,0 +1,96 @@ +import os +from PIL import Image +import piexif +import logging +import pandas as pd +from datetime import datetime + + +class GPSExtractor: + """从图像文件提取GPS坐标和拍摄日期""" + + def __init__(self, project_path): + self.image_dir = os.path.join(project_path, 'project', 'images') + self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor') + + @staticmethod + def _dms_to_decimal(dms): + """将DMS格式转换为十进制度""" + return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600 + + @staticmethod + def _parse_datetime(datetime_str): + """解析EXIF中的日期时间字符串""" + try: + # EXIF日期格式通常为 'YYYY:MM:DD HH:MM:SS' + return datetime.strptime(datetime_str.decode(), '%Y:%m:%d %H:%M:%S') + except Exception: + return None + + def get_gps_and_date(self, image_path): + """提取单张图片的GPS坐标和拍摄日期""" + try: + image = Image.open(image_path) + exif_data = piexif.load(image.info['exif']) + + # 提取GPS信息 + gps_info = exif_data.get("GPS", {}) + lat = lon = None + if gps_info: + lat = self._dms_to_decimal(gps_info.get(2, [])) + lon = self._dms_to_decimal(gps_info.get(4, [])) + self.logger.debug( + f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}") + + # 提取拍摄日期 + date_info = None + if "Exif" in exif_data: + # 优先使用DateTimeOriginal + date_str = exif_data["Exif"].get(36867) # DateTimeOriginal + if not date_str: + # 备选DateTime + date_str = exif_data["Exif"].get( + 36868) # DateTimeDigitized + if not date_str: + # 最后使用基本DateTime + date_str = exif_data["0th"].get(306) # DateTime + + if date_str: + date_info = self._parse_datetime(date_str) + self.logger.debug( + f"成功提取图片拍摄日期: {image_path} - {date_info}") + + if not gps_info: + self.logger.warning(f"图片无GPS信息: {image_path}") + if not date_info: + self.logger.warning(f"图片无拍摄日期信息: {image_path}") + + return lat, lon, date_info + + except Exception as e: + self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}") + return None, None, None + + def extract_all_gps(self): + """提取所有图片的GPS坐标和拍摄日期""" + self.logger.info(f"开始从目录提取GPS坐标和拍摄日期: {self.image_dir}") + gps_data = [] + total_images = 0 + successful_extractions = 0 + + for image_file in os.listdir(self.image_dir): + total_images += 1 + image_path = os.path.join(self.image_dir, image_file) + lat, lon, date = self.get_gps_and_date(image_path) + if lat and lon: # 仍然以GPS信息作为主要判断依据 + successful_extractions += 1 + gps_data.append({ + 'file': image_file, + 'lat': lat, + 'lon': lon, + 'date': date + }) + + self.logger.info( + f"GPS坐标和拍摄日期提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}") + return pd.DataFrame(gps_data) diff --git a/utils/grid_divider.py b/utils/grid_divider.py new file mode 100644 index 0000000..04c3fd8 --- /dev/null +++ b/utils/grid_divider.py @@ -0,0 +1,249 @@ +import logging +from geopy.distance import geodesic +import matplotlib.pyplot as plt +import os + + +class GridDivider: + """划分网格,并将图片分配到对应网格""" + + def __init__(self, overlap, grid_size, project_path, output_dir): + self.overlap = overlap + self.grid_size = grid_size + self.project_path = project_path + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.GridDivider') + self.logger.info(f"初始化网格划分器,重叠率: {overlap}") + self.num_grids_width = 0 # 添加网格数量属性 + self.num_grids_height = 0 + + def adjust_grid_size_and_overlap(self, points_df): + """动态调整网格重叠率""" + grids = self.adjust_grid_size(points_df) + self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}") + while True: + # 使用调整好的网格大小划分网格 + grids = self.divide_grids(points_df) + grid_points, multiple_grid_points = self.assign_to_grids( + points_df, grids) + + if len(grids) == 1: + self.logger.info(f"网格数量为1,跳过重叠率调整") + break + elif multiple_grid_points < 0.1*len(points_df): + self.overlap += 0.02 + self.logger.info(f"重叠率增加到: {self.overlap}") + else: + self.logger.info( + f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格") + break + return grids, grid_points + + def adjust_grid_size(self, points_df): + """动态调整网格大小 + + Args: + points_df: 包含GPS点的DataFrame + + Returns: + tuple: (grids, translations, grid_points, final_grid_size) + """ + self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}米") + + while True: + # 使用当前grid_size划分网格 + grids = self.divide_grids(points_df) + grid_points, multiple_grid_points = self.assign_to_grids( + points_df, grids) + + # 检查每个网格中的点数 + max_points = 0 + for grid_id, points in grid_points.items(): + max_points = max(max_points, len(points)) + + self.logger.info( + f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}") + + # 如果最大点数超过1600,减小网格大小 + if max_points > 1600: + self.grid_size -= 100 + self.logger.info(f"点数超过1500,减小网格大小至: {self.grid_size}米") + if self.grid_size < 500: # 设置一个最小网格大小限制 + self.logger.warning("网格大小已达到最小值500米,停止调整") + break + else: + self.logger.info(f"找到合适的网格大小: {self.grid_size}米") + break + return grids + + def divide_grids(self, points_df): + """计算边界框并划分网格 + Returns: + tuple: (grids, translations) + - grids: 网格边界列表 + - translations: 网格平移量字典 + """ + self.logger.info("开始划分网格") + + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + + # 计算区域的实际距离(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters + + self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}米") + + # 精细调整网格的长宽,避免出现2*grid_size-1的情况的影响 + grid_size_lt = [self.grid_size - 200, self.grid_size - 100, + self.grid_size, self.grid_size + 100, self.grid_size + 200] + + width_modulus_lt = [width % grid_size for grid_size in grid_size_lt] + grid_width = grid_size_lt[width_modulus_lt.index( + min(width_modulus_lt))] + height_modulus_lt = [height % grid_size for grid_size in grid_size_lt] + grid_height = grid_size_lt[height_modulus_lt.index( + min(height_modulus_lt))] + self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}米") + + # 计算需要划分的网格数量 + self.num_grids_width = max(int(width / grid_width), 1) + self.num_grids_height = max(int(height / grid_height), 1) + + # 计算每个网格对应的经纬度步长 + lat_step = (max_lat - min_lat) / self.num_grids_height + lon_step = (max_lon - min_lon) / self.num_grids_width + + grids = [] + + # 先创建所有网格 + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step + grid_max_lat = min_lat + \ + (i + 1) * lat_step + self.overlap * lat_step + grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step + grid_max_lon = min_lon + \ + (j + 1) * lon_step + self.overlap * lon_step + + grid_bounds = (grid_min_lat, grid_max_lat, + grid_min_lon, grid_max_lon) + grids.append(grid_bounds) + + self.logger.debug( + f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], " + f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]" + ) + + self.logger.info( + f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})") + + return grids + + def assign_to_grids(self, points_df, grids): + """将点分配到对应网格""" + self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中") + + grid_points = {} # 使用字典存储每个网格的点 + points_assigned = 0 + multiple_grid_points = 0 + + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_points[(i, j)] = [] # 使用(i,j)元组 + + for _, point in points_df.iterrows(): + point_assigned = False + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_idx = i * self.num_grids_width + j + min_lat, max_lat, min_lon, max_lon = grids[grid_idx] + + if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon: + grid_points[(i, j)].append(point.to_dict()) + if point_assigned: + multiple_grid_points += 1 + else: + points_assigned += 1 + point_assigned = True + + # 记录每个网格的点数 + for grid_id, points in grid_points.items(): + self.logger.info(f"网格 {grid_id} 包含 {len(points)} 个点") + + self.logger.info( + f"点分配完成: 总点数 {len(points_df)}, " + f"成功分配 {points_assigned} 个点, " + f"{multiple_grid_points} 个点被分配到多个网格" + ) + + return grid_points, multiple_grid_points + + def visualize_grids(self, points_df, grids): + """可视化网格划分和GPS点的分布""" + self.logger.info("开始可视化网格划分") + + plt.figure(figsize=(12, 8)) + + # 绘制GPS点 + plt.scatter(points_df['lon'], points_df['lat'], + c='blue', s=10, alpha=0.6, label='GPS points') + + # 绘制网格 + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_idx = i * self.num_grids_width + j + min_lat, max_lat, min_lon, max_lon = grids[grid_idx] + + # 计算网格的实际长度和宽度(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), + (max_lat, min_lon)).meters + + plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon], + [min_lat, min_lat, max_lat, max_lat, min_lat], + 'r-', alpha=0.5) + # 在网格中心添加网格编号和尺寸信息 + center_lon = (min_lon + max_lon) / 2 + center_lat = (min_lat + max_lat) / 2 + plt.text(center_lon, center_lat, + f"({i},{j})\n{width:.0f}m×{height:.0f}m", # 显示(i,j)和尺寸 + horizontalalignment='center', + verticalalignment='center', + fontsize=8) + + plt.title('Grid Division and GPS Point Distribution') + plt.xlabel('Longitude') + plt.ylabel('Latitude') + plt.legend() + plt.grid(True) + + # 如果提供了输出目录,保存图像 + if self.output_dir: + save_path = os.path.join( + self.output_dir, 'filter_imgs', 'grid_division.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + self.logger.info(f"网格划分可视化图已保存至: {save_path}") + + plt.close() + + def save_image_groups(self, grid_points, output_file_name="image_groups.txt"): + """保存图像分组信息到文件 + + Args: + grid_points (dict): 每个网格的点信息,键为(i, j),值为点的列表 + output_file (str): 输出文件路径 + """ + self.logger.info(f"开始保存图像分组信息到 {output_file_name}") + + output_file = os.path.join( + self.project_path, 'project', output_file_name) + with open(output_file, 'w') as f: + for (i, j), points in grid_points.items(): + # 计算组编号(按行展开的顺序) + group_id = i * self.num_grids_width + j + 1 + for point in points: + image_name = point.get('file', 'unknown') + f.write(f"{image_name} {group_id}\n") + + self.logger.info(f"图像分组信息已保存到 {output_file}") diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..80c1f10 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,36 @@ +import logging +import os +from datetime import datetime + + +def setup_logger(output_dir): + # 创建logs目录 + log_dir = os.path.join(output_dir, 'logs') + + # 创建日志文件名(包含时间戳) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log') + + # 配置日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 配置文件处理器 + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setFormatter(formatter) + + # 配置控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # 获取根日志记录器 + logger = logging.getLogger('UAV_Preprocess') + logger.setLevel(logging.INFO) + + # 添加处理器 + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000..964bd53 --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,152 @@ +import os +import matplotlib.pyplot as plt +import pandas as pd +import logging +from typing import Optional +from pyproj import Transformer + + +class FilterVisualizer: + """过滤结果可视化器""" + + def __init__(self, output_dir: str): + """ + 初始化可视化器 + + Args: + output_dir: 输出目录路径 + """ + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.Visualizer') + # 创建坐标转换器 + self.transformer = Transformer.from_crs( + "EPSG:4326", # WGS84经纬度坐标系 + "EPSG:32649", # UTM49N + always_xy=True + ) + + def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple: + """ + 将经纬度坐标转换为UTM坐标 + + Args: + lon: 经度序列 + lat: 纬度序列 + + Returns: + tuple: (x坐标, y坐标) + """ + return self.transformer.transform(lon, lat) + + def visualize_filter_step(self, + current_points: pd.DataFrame, + previous_points: pd.DataFrame, + step_name: str, + save_name: Optional[str] = None): + """ + 可视化单个过滤步骤的结果 + + Args: + current_points: 当前步骤后的点 + previous_points: 上一步骤的点 + step_name: 步骤名称 + save_name: 保存文件名,默认为step_name + """ + self.logger.info(f"开始生成{step_name}的可视化结果") + + # 找出被过滤掉的点 + filtered_files = set( + previous_points['file']) - set(current_points['file']) + filtered_points = previous_points[previous_points['file'].isin( + filtered_files)] + + # 转换坐标到UTM + current_x, current_y = self._convert_to_utm( + current_points['lon'], current_points['lat']) + filtered_x, filtered_y = self._convert_to_utm( + filtered_points['lon'], filtered_points['lat']) + + # 创建图形 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体 + plt.rcParams['axes.unicode_minus'] = False + plt.figure(figsize=(20, 16)) + + # 绘制保留的点 + plt.scatter(current_x, current_y, + color='blue', label='保留的点', + alpha=0.6, s=50) + + # 绘制被过滤的点 + if not filtered_points.empty: + plt.scatter(filtered_x, filtered_y, + color='red', marker='x', label='过滤的点', + alpha=0.6, s=100) + + # 设置图形属性 + plt.title(f"{step_name}后的GPS点\n" + f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})", + fontsize=14) + plt.xlabel("东向坐标 (米)", fontsize=12) + plt.ylabel("北向坐标 (米)", fontsize=12) + plt.grid(True) + + # 添加统计信息 + stats_text = ( + f"原始点数: {len(previous_points)}\n" + f"过滤点数: {len(filtered_points)}\n" + f"保留点数: {len(current_points)}\n" + f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + # 添加图例 + plt.legend(loc='upper right', fontsize=10) + + # 调整布局 + plt.tight_layout() + + # 保存图形 + save_name = save_name or step_name.lower().replace(' ', '_') + save_path = os.path.join( + self.output_dir, 'filter_imgs', f'filter_{save_name}.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info( + f"{step_name}过滤可视化结果已保存至 {save_path}\n" + f"过滤掉 {len(filtered_points)} 个点," + f"保留 {len(current_points)} 个点," + f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + + +if __name__ == '__main__': + # 测试代码 + import numpy as np + from datetime import datetime + + # 创建测试数据 + np.random.seed(42) + n_points = 1000 + + # 生成随机点 + test_data = pd.DataFrame({ + 'lon': np.random.uniform(120, 121, n_points), + 'lat': np.random.uniform(30, 31, n_points), + 'file': [f'img_{i}.jpg' for i in range(n_points)], + 'date': [datetime.now() for _ in range(n_points)] + }) + + # 随机选择点作为过滤后的结果 + filtered_data = test_data.sample(n=800) + + # 测试可视化 + visualizer = FilterVisualizer('test_output') + os.makedirs('test_output', exist_ok=True) + + visualizer.visualize_filter_step( + filtered_data, + test_data, + "Test Filter" + )