diff --git a/filter/cluster_filter.py b/filter/cluster_filter.py
new file mode 100644
index 0000000..473da7d
--- /dev/null
+++ b/filter/cluster_filter.py
@@ -0,0 +1,77 @@
+from sklearn.cluster import DBSCAN
+from sklearn.preprocessing import StandardScaler
+import os
+import logging
+
+
+class GPSCluster:
+ def __init__(self, gps_points, eps=0.01, min_samples=3):
+ """
+ 初始化GPS聚类器
+
+ 参数:
+ eps: DBSCAN的邻域半径参数
+ min_samples: DBSCAN的最小样本数参数
+ """
+ self.eps = eps
+ self.min_samples = min_samples
+ self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
+ self.scaler = StandardScaler()
+ self.gps_points = gps_points
+ self.logger = logging.getLogger('UAV_Preprocess.GPSCluster')
+
+ def fit(self):
+ """
+ 对GPS点进行聚类,只保留最大的类
+
+ 参数:
+ gps_points: 包含'lat'和'lon'列的DataFrame
+
+ 返回:
+ 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1
+ """
+ self.logger.info("开始聚类")
+ # 提取经纬度数据
+ X = self.gps_points[["lon", "lat"]].values
+
+ # # 数据标准化
+ # X_scaled = self.scaler.fit_transform(X)
+
+ # 执行DBSCAN聚类
+ labels = self.dbscan.fit_predict(X)
+
+ # 找出最大类的标签(排除噪声点-1)
+ unique_labels = [l for l in set(labels) if l != -1]
+ if unique_labels: # 如果有聚类
+ label_counts = [(l, sum(labels == l)) for l in unique_labels]
+ largest_label = max(label_counts, key=lambda x: x[1])[0]
+
+ # 将最大类标记为1,其他都标记为-1
+ new_labels = (labels == largest_label).astype(int)
+ new_labels[new_labels == 0] = -1
+ else: # 如果没有聚类,全部标记为-1
+ new_labels = labels
+
+ # 将聚类结果添加到原始数据中
+ result_df = self.gps_points.copy()
+ result_df["cluster"] = new_labels
+
+ return result_df
+
+ def get_cluster_stats(self, clustered_points):
+ """
+ 获取聚类统计信息
+
+ 参数:
+ clustered_points: 带有聚类标签的DataFrame
+
+ 返回:
+ 聚类统计信息的字典
+ """
+ main_cluster = clustered_points[clustered_points["cluster"] == 1]
+ noise_cluster = clustered_points[clustered_points["cluster"] == -1]
+
+ self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点,"
+ f"噪声点 {len(noise_cluster)} 个")
+
+ return main_cluster
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..aee07fe
--- /dev/null
+++ b/main.py
@@ -0,0 +1,249 @@
+import os
+import shutil
+from datetime import timedelta
+from dataclasses import dataclass
+from typing import Dict, Tuple
+import psutil
+import pandas as pd
+from pathlib import Path
+
+from filter.cluster_filter import GPSCluster
+from utils.gps_extractor import GPSExtractor
+from utils.grid_divider import GridDivider
+from utils.logger import setup_logger
+from utils.visualizer import FilterVisualizer
+from utils.docker_runner import DockerRunner
+from post_pro.conv_obj import ConvertOBJ
+
+
+@dataclass
+class ProcessConfig:
+ """配置类"""
+
+ image_dir: str
+ output_dir: str
+ # 聚类过滤参数
+ cluster_eps: float = 0.01
+ cluster_min_samples: int = 5
+ # 时间组重叠过滤参数
+ time_group_overlap_threshold: float = 0.7
+ time_group_interval: timedelta = timedelta(minutes=5)
+ # 孤立点过滤参数
+ filter_distance_threshold: float = 0.001 # 经纬度距离
+ filter_min_neighbors: int = 6
+ # 密集点过滤参数
+ filter_grid_size: float = 0.001
+ filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
+ filter_time_threshold: timedelta = timedelta(minutes=5)
+ # 网格划分参数
+ grid_overlap: float = 0.05
+ grid_size: float = 500
+ # 几个pipline过程是否开启
+ mode: str = "快拼模式"
+ accuracy: str = "medium"
+ produce_dem: bool = False
+
+
+class ODM_Plugin:
+ def __init__(self, config: ProcessConfig):
+ self.config = config
+
+ # 检查磁盘空间
+ # TODO 现在输入目录的磁盘空间也需要检查
+ self._check_disk_space()
+
+ # 清理并重建输出目录
+ if os.path.exists(config.output_dir):
+ self._clean_output_dir()
+ self._setup_output_dirs()
+
+ # 修改输入目录,符合ODM要求,从这里开始,image_dir就是project_path
+ self._rename_input_dir()
+ self.project_path = self.config.image_dir
+
+ # 初始化其他组件
+ self.logger = setup_logger(config.output_dir)
+ self.gps_points = None
+ self.grid_points = None
+ self.visualizer = FilterVisualizer(config.output_dir)
+
+ def _clean_output_dir(self):
+ """清理输出目录"""
+ try:
+ shutil.rmtree(self.config.output_dir)
+ print(f"已清理输出目录: {self.config.output_dir}")
+ except Exception as e:
+ print(f"清理输出目录时发生错误: {str(e)}")
+ raise
+
+ def _setup_output_dirs(self):
+ """创建必要的输出目录结构"""
+ try:
+ # 创建主输出目录
+ os.makedirs(self.config.output_dir)
+
+ # 创建过滤图像保存目录
+ os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
+
+ # 创建日志目录
+ os.makedirs(os.path.join(self.config.output_dir, 'logs'))
+
+ print(f"已创建输出目录结构: {self.config.output_dir}")
+ except Exception as e:
+ print(f"创建输出目录时发生错误: {str(e)}")
+ raise
+
+ def _get_directory_size(self, path):
+ """获取目录的总大小(字节)"""
+ total_size = 0
+ for dirpath, dirnames, filenames in os.walk(path):
+ for filename in filenames:
+ file_path = os.path.join(dirpath, filename)
+ try:
+ total_size += os.path.getsize(file_path)
+ except (OSError, FileNotFoundError):
+ continue
+ return total_size
+
+ def _check_disk_space(self):
+ """检查磁盘空间是否足够"""
+ # 获取输入目录大小
+ input_size = self._get_directory_size(self.config.image_dir)
+
+ # 获取输出目录所在磁盘的剩余空间
+ output_drive = os.path.splitdrive(
+ os.path.abspath(self.config.output_dir))[0]
+ if not output_drive: # 处理Linux/Unix路径
+ output_drive = '/home'
+
+ disk_usage = psutil.disk_usage(output_drive)
+ free_space = disk_usage.free
+
+ # 计算所需空间(输入大小的1.5倍)
+ required_space = input_size * 12
+
+ if free_space < required_space:
+ error_msg = (
+ f"磁盘空间不足!\n"
+ f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
+ f"所需空间: {required_space / (1024**3):.2f} GB\n"
+ f"可用空间: {free_space / (1024**3):.2f} GB\n"
+ f"在驱动器 {output_drive}"
+ )
+ raise RuntimeError(error_msg)
+
+ def _rename_input_dir(self):
+ image_dir = Path(self.config.image_dir).resolve()
+
+ if not image_dir.exists() or not image_dir.is_dir():
+ raise ValueError(
+ f"Provided path '{image_dir}' is not a valid directory.")
+
+ # 原目录名和父路径
+ parent_dir = image_dir.parent
+ original_name = image_dir.name
+
+ # 新的 images 路径(原目录重命名为 images)
+ images_path = parent_dir / "images"
+
+ # 重命名原目录为 images
+ image_dir.rename(images_path)
+
+ # 创建一个新的、和原目录同名的文件夹
+ new_root = parent_dir / original_name
+ new_root.mkdir(exist_ok=False)
+
+ # 创建 project 子文件夹
+ project_dir = new_root / "project"
+ project_dir.mkdir()
+
+ # 把 images 文件夹移动到 project 下
+ final_images_path = project_dir / "images"
+ shutil.move(str(images_path), str(final_images_path))
+
+ print(f"符合标准输入的文件夹结构已经创建好了,{final_images_path}")
+
+ return final_images_path
+
+ def extract_gps(self) -> pd.DataFrame:
+ """提取GPS数据"""
+ self.logger.info("开始提取GPS数据")
+ extractor = GPSExtractor(self.project_path)
+ self.gps_points = extractor.extract_all_gps()
+ self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
+
+ def cluster(self):
+ """使用DBSCAN对GPS点进行聚类,只保留最大的类"""
+ previous_points = self.gps_points.copy()
+ clusterer = GPSCluster(
+ self.gps_points,
+ eps=self.config.cluster_eps,
+ min_samples=self.config.cluster_min_samples
+ )
+
+ self.clustered_points = clusterer.fit()
+ self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
+
+ self.visualizer.visualize_filter_step(
+ self.gps_points, previous_points, "1-Clustering")
+
+ def divide_grids(self):
+ """划分网格
+ Returns:
+ tuple: (grid_points, translations)
+ - grid_points: 网格点数据字典
+ - translations: 网格平移量字典
+ """
+ grid_divider = GridDivider(
+ overlap=self.config.grid_overlap,
+ grid_size=self.config.grid_size,
+ project_path=self.project_path,
+ output_dir=self.config.output_dir
+ )
+ grids, self.grid_points = grid_divider.adjust_grid_size_and_overlap(
+ self.gps_points
+ )
+ grid_divider.visualize_grids(self.gps_points, grids)
+ grid_divider.save_image_groups(self.grid_points)
+ if len(grids) >= 20:
+ self.logger.warning("网格数量已超过20, 需要人工调整分区")
+
+ def odm_docker_runner(self):
+ """"运行OMD docker容器"""
+ self.logger.info("开始运行Docker容器")
+ # TODO:加一些容错处理
+ docker_runner = DockerRunner(self.project_path)
+ docker_runner.run_odm_container()
+
+ def convert_obj(self):
+ """转换OBJ模型"""
+ self.logger.info("开始转换OBJ模型")
+ converter = ConvertOBJ(self.config.output_dir)
+ converter.convert_grid_obj(self.grid_points)
+
+ def post_process(self):
+ """后处理:合并或复制处理结果"""
+ self.logger.info("开始后处理")
+
+ self.logger.info("拷贝正射影像至输出目录")
+ orthophoto_tif_path = os.path.join(
+ self.project_path, "odm_orthophoto", "odm_orthophoto.tif")
+ shutil.copy(orthophoto_tif_path, self.config.output_dir)
+ # if self.config.mode == "三维模式":
+ # self.convert_obj()
+ # else:
+ # pass
+
+ def process(self):
+ """执行完整的预处理流程"""
+ try:
+ self.extract_gps()
+ self.cluster()
+ self.divide_grids()
+ self.logger.info("==========预处理任务完成==========")
+ self.odm_docker_runner()
+ self.post_process()
+
+ except Exception as e:
+ self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
+ raise
diff --git a/post_pro/conv_obj.py b/post_pro/conv_obj.py
new file mode 100644
index 0000000..41d7ff7
--- /dev/null
+++ b/post_pro/conv_obj.py
@@ -0,0 +1,253 @@
+import os
+import subprocess
+import json
+import shutil
+import logging
+from pyproj import Transformer
+import cv2
+
+
+class ConvertOBJ:
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+ # 用于存储所有grid的UTM范围
+ self.ref_east = float('inf')
+ self.ref_north = float('inf')
+ # 初始化UTM到WGS84的转换器
+ self.transformer = Transformer.from_crs(
+ "EPSG:32649", "EPSG:4326", always_xy=True)
+ self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
+
+ def convert_grid_obj(self, grid_points):
+ """转换每个网格的OBJ文件为OSGB格式"""
+ os.makedirs(os.path.join(self.output_dir,
+ "osgb", "Data"), exist_ok=True)
+
+ # 以第一个grid的UTM坐标作为参照系
+ first_grid_id = list(grid_points.keys())[0]
+ first_grid_dir = os.path.join(
+ self.output_dir,
+ f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
+ "project"
+ )
+ log_file = os.path.join(
+ first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
+ self.ref_east, self.ref_north = self.read_utm_offset(log_file)
+
+ for grid_id in grid_points.keys():
+ try:
+ self._convert_single_grid(grid_id, grid_points)
+ except Exception as e:
+ self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
+
+ self._create_merged_metadata()
+
+ def _convert_single_grid(self, grid_id, grid_points):
+ """转换单个网格的OBJ文件"""
+ # 构建相关路径
+ grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
+ project_dir = os.path.join(self.output_dir, grid_name, "project")
+ texturing_dir = os.path.join(project_dir, "odm_texturing")
+ texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
+ opensfm_dir = os.path.join(project_dir, "opensfm")
+ log_file = os.path.join(
+ project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
+ os.makedirs(texturing_dst_dir, exist_ok=True)
+
+ # 修改obj文件z坐标的值
+ min_25d_z = self.get_min_z_from_obj(os.path.join(
+ project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
+ self.modify_z_in_obj(texturing_dir, min_25d_z)
+
+ # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样
+ utm_offset = self.read_utm_offset(log_file)
+ modified_obj = self.modify_obj_coordinates(
+ texturing_dir, texturing_dst_dir, utm_offset)
+ self.downsample_texture(texturing_dir, texturing_dst_dir)
+
+ # 执行格式转换,Linux下osgconv有问题,记得注释掉
+ self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
+ output_osgb = os.path.join(texturing_dst_dir, "Tile.osgb")
+ cmd = (
+ f"osgconv {modified_obj} {output_osgb} "
+ f"--compressed --smooth --fix-transparency "
+ )
+ self.logger.info(f"执行osgconv命令:{cmd}")
+
+ try:
+ subprocess.run(cmd, shell=True, check=True, cwd=texturing_dir)
+ except subprocess.CalledProcessError as e:
+ raise RuntimeError(f"OSGB转换失败: {str(e)}")
+
+ # 创建OSGB目录结构,复制文件
+ osgb_base_dir = os.path.join(self.output_dir, "osgb")
+ data_dir = os.path.join(osgb_base_dir, "Data")
+ tile_dir = os.path.join(data_dir, f"Tile_{grid_id[0]}_{grid_id[1]}")
+ os.makedirs(tile_dir, exist_ok=True)
+ target_osgb = os.path.join(
+ tile_dir, f"Tile_{grid_id[0]}_{grid_id[1]}.osgb")
+ shutil.copy2(output_osgb, target_osgb)
+
+ def _create_merged_metadata(self):
+ """创建合并后的metadata.xml文件"""
+ # 转换为WGS84经纬度
+ center_lon, center_lat = self.transformer.transform(
+ self.ref_east, self.ref_north)
+ metadata_content = f"""
+
+ EPSG:4326
+ {center_lon},{center_lat},0
+
+ Visible
+
+ """
+
+ metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
+ with open(metadata_file, 'w', encoding='utf-8') as f:
+ f.write(metadata_content)
+
+ def read_utm_offset(self, log_file: str) -> tuple:
+ """读取UTM偏移量"""
+ try:
+ east_offset = None
+ north_offset = None
+
+ with open(log_file, 'r') as f:
+ lines = f.readlines()
+ for i, line in enumerate(lines):
+ if 'utm_north_offset' in line and i + 1 < len(lines):
+ north_offset = float(lines[i + 1].strip())
+ elif 'utm_east_offset' in line and i + 1 < len(lines):
+ east_offset = float(lines[i + 1].strip())
+
+ if east_offset is None or north_offset is None:
+ raise ValueError("未找到UTM偏移量")
+
+ return east_offset, north_offset
+ except Exception as e:
+ self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
+ raise
+
+ def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
+ """修改obj文件中的顶点坐标,使用相对坐标系"""
+ obj_file = os.path.join(
+ texturing_dir, "odm_textured_model_modified.obj")
+ obj_dst_file = os.path.join(
+ texturing_dst_dir, "odm_textured_model_geo_utm.obj")
+ if not os.path.exists(obj_file):
+ raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
+ shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
+ os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
+ east_offset, north_offset = utm_offset
+ self.logger.info(
+ f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}")
+
+ try:
+ with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
+ for line in f_in:
+ if line.startswith('v '):
+ # 处理顶点坐标行
+ parts = line.strip().split()
+ # 使用相对于整体最小UTM坐标的偏移
+ x = float(parts[1]) + (east_offset - self.ref_east)
+ y = float(parts[2]) + (north_offset - self.ref_north)
+ z = float(parts[3])
+ f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
+ elif line.startswith('vn '): # 处理法线向量
+ parts = line.split()
+ nx = float(parts[1])
+ ny = float(parts[2])
+ nz = float(parts[3])
+ # 同步反转法线的 Y 轴
+ new_line = f"vn {nx} {nz} {-ny}\n"
+ f_out.write(new_line)
+ else:
+ # 其他行直接写入
+ f_out.write(line)
+
+ return obj_dst_file
+ except Exception as e:
+ self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
+ raise
+
+ def downsample_texture(self, src_dir: str, dst_dir: str):
+ """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB
+ Args:
+ src_dir: 源纹理目录
+ dst_dir: 目标纹理目录
+ """
+ for file in os.listdir(src_dir):
+ if file.lower().endswith(('.png')):
+ src_path = os.path.join(src_dir, file)
+ dst_path = os.path.join(dst_dir, file)
+
+ # 检查文件大小(以字节为单位)
+ file_size = os.path.getsize(src_path)
+ if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制
+ shutil.copy2(src_path, dst_path)
+ else:
+ # 文件大于100MB,进行下采样
+ img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
+ if_first_ds = True
+ while file_size > 100 * 1024 * 1024: # 大于100MB
+ self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样")
+
+ if if_first_ds:
+ # 计算新的尺寸(长宽各变为1/4)
+ new_size = (img.shape[1] // 4,
+ img.shape[0] // 4) # 逐步减小尺寸
+ # 使用双三次插值进行下采样
+ resized_img = cv2.resize(
+ img, new_size, interpolation=cv2.INTER_CUBIC)
+ if_first_ds = False
+ else:
+ # 计算新的尺寸(长宽各变为1/2)
+ new_size = (img.shape[1] // 2,
+ img.shape[0] // 2) # 逐步减小尺寸
+ # 使用双三次插值进行下采样
+ resized_img = cv2.resize(
+ img, new_size, interpolation=cv2.INTER_CUBIC)
+
+ # 更新文件路径为下采样后的路径
+ cv2.imwrite(dst_path, resized_img, [
+ cv2.IMWRITE_PNG_COMPRESSION, 9])
+
+ # 更新文件大小和图像
+ file_size = os.path.getsize(dst_path)
+ img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
+ self.logger.info(
+ f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
+
+ def get_min_z_from_obj(self, file_path):
+ min_z = float('inf') # 初始值设为无穷大
+ with open(file_path, 'r') as obj_file:
+ for line in obj_file:
+ # 检查每一行是否是顶点定义(以 'v ' 开头)
+ if line.startswith('v '):
+ # 获取顶点坐标
+ parts = line.split()
+ # 将z值转换为浮动数字
+ z = float(parts[3])
+ # 更新最小z值
+ if z < min_z:
+ min_z = z
+ return min_z
+
+ def modify_z_in_obj(self, texturing_dir, min_25d_z):
+ obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
+ output_file = os.path.join(
+ texturing_dir, 'odm_textured_model_modified.obj')
+ with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
+ for line in f_in:
+ if line.startswith('v '): # 顶点坐标行
+ parts = line.strip().split()
+ x = float(parts[1])
+ y = float(parts[2])
+ z = float(parts[3])
+
+ if z < min_25d_z:
+ z = min_25d_z
+
+ f_out.write(f"v {x} {y} {z}\n")
+ else:
+ f_out.write(line)
diff --git a/post_pro/conv_obj2.py b/post_pro/conv_obj2.py
new file mode 100644
index 0000000..a155f16
--- /dev/null
+++ b/post_pro/conv_obj2.py
@@ -0,0 +1,263 @@
+import os
+import subprocess
+import json
+import shutil
+import logging
+from pyproj import Transformer
+import cv2
+
+
+class ConvertOBJ:
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+ # 用于存储所有grid的UTM范围
+ self.ref_east = float('inf')
+ self.ref_north = float('inf')
+ # 初始化UTM到WGS84的转换器
+ self.transformer = Transformer.from_crs(
+ "EPSG:32649", "EPSG:4326", always_xy=True)
+ self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
+
+ def convert_grid_obj(self, grid_points):
+ """转换每个网格的OBJ文件为OSGB格式"""
+ os.makedirs(os.path.join(self.output_dir,
+ "osgb", "Data"), exist_ok=True)
+
+ # 以第一个grid的UTM坐标作为参照系
+ first_grid_id = list(grid_points.keys())[0]
+ first_grid_dir = os.path.join(
+ self.output_dir,
+ f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
+ "project"
+ )
+ log_file = os.path.join(
+ first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
+ self.ref_east, self.ref_north = self.read_utm_offset(log_file)
+
+ for grid_id in grid_points.keys():
+ try:
+ self._convert_single_grid(grid_id, grid_points)
+ except Exception as e:
+ self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
+
+ self._create_merged_metadata()
+
+ def _convert_single_grid(self, grid_id, grid_points):
+ """转换单个网格的OBJ文件"""
+ # 构建相关路径
+ grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
+ project_dir = os.path.join(self.output_dir, grid_name, "project")
+ texturing_dir = os.path.join(project_dir, "odm_texturing")
+ texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
+ split_obj_dir = os.path.join(texturing_dst_dir, "split_obj")
+ opensfm_dir = os.path.join(project_dir, "opensfm")
+ log_file = os.path.join(
+ project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
+ os.makedirs(texturing_dst_dir, exist_ok=True)
+
+ # 修改obj文件z坐标的值
+ min_25d_z = self.get_min_z_from_obj(os.path.join(
+ project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
+ self.modify_z_in_obj(texturing_dir, min_25d_z)
+
+ # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样
+ utm_offset = self.read_utm_offset(log_file)
+ modified_obj = self.modify_obj_coordinates(
+ texturing_dir, texturing_dst_dir, utm_offset)
+ self.downsample_texture(texturing_dir, texturing_dst_dir)
+
+ # 将obj文件进行切片
+ self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件")
+ os.makedirs(split_obj_dir)
+ cmd = (
+ f"D:\software\Obj2Tiles\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 3 "
+ f"{modified_obj} {split_obj_dir}"
+ )
+ subprocess.run(cmd, check=True)
+
+ # 执行格式转换,Linux下osgconv有问题,记得注释掉
+ self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
+ # 先获取split_obj_dir下的所有obj文件
+ obj_lod_dir = os.path.join(split_obj_dir, "LOD-0")
+ obj_files = [f for f in os.listdir(
+ obj_lod_dir) if f.endswith('.obj')]
+ for obj_file in obj_files:
+ obj_path = os.path.join(obj_lod_dir, obj_file)
+ osgb_file = os.path.splitext(obj_file)[0] + '.osgb'
+ osgb_path = os.path.join(split_obj_dir, osgb_file)
+ # 执行 osgconv 命令
+ subprocess.run(['osgconv', obj_path, osgb_path], check=True)
+
+ # 创建OSGB目录结构,复制文件
+ osgb_base_dir = os.path.join(self.output_dir, "osgb")
+ data_dir = os.path.join(osgb_base_dir, "Data")
+ for obj_file in obj_files:
+ obj_file_name = os.path.splitext(obj_file)[0]
+ tile_dirs = os.path.join(data_dir, f"{obj_file_name}")
+ os.makedirs(tile_dirs, exist_ok=True)
+ shutil.copy2(os.path.join(
+ split_obj_dir, obj_file_name+".osgb"), tile_dirs)
+
+ def _create_merged_metadata(self):
+ """创建合并后的metadata.xml文件"""
+ # 转换为WGS84经纬度
+ center_lon, center_lat = self.transformer.transform(
+ self.ref_east, self.ref_north)
+ metadata_content = f"""
+
+ EPSG:4326
+ {center_lon},{center_lat},0
+
+ Visible
+
+ """
+
+ metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
+ with open(metadata_file, 'w', encoding='utf-8') as f:
+ f.write(metadata_content)
+
+ def read_utm_offset(self, log_file: str) -> tuple:
+ """读取UTM偏移量"""
+ try:
+ east_offset = None
+ north_offset = None
+
+ with open(log_file, 'r') as f:
+ lines = f.readlines()
+ for i, line in enumerate(lines):
+ if 'utm_north_offset' in line and i + 1 < len(lines):
+ north_offset = float(lines[i + 1].strip())
+ elif 'utm_east_offset' in line and i + 1 < len(lines):
+ east_offset = float(lines[i + 1].strip())
+
+ if east_offset is None or north_offset is None:
+ raise ValueError("未找到UTM偏移量")
+
+ return east_offset, north_offset
+ except Exception as e:
+ self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
+ raise
+
+ def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
+ """修改obj文件中的顶点坐标,使用相对坐标系"""
+ obj_file = os.path.join(
+ texturing_dir, "odm_textured_model_modified.obj")
+ obj_dst_file = os.path.join(
+ texturing_dst_dir, "odm_textured_model_geo_utm.obj")
+ if not os.path.exists(obj_file):
+ raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
+ shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
+ os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
+ east_offset, north_offset = utm_offset
+ self.logger.info(
+ f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}")
+
+ try:
+ with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
+ for line in f_in:
+ if line.startswith('v '):
+ # 处理顶点坐标行
+ parts = line.strip().split()
+ # 使用相对于整体最小UTM坐标的偏移
+ x = float(parts[1]) + (east_offset - self.ref_east)
+ y = float(parts[2]) + (north_offset - self.ref_north)
+ z = float(parts[3])
+ f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
+ elif line.startswith('vn '): # 处理法线向量
+ parts = line.split()
+ nx = float(parts[1])
+ ny = float(parts[2])
+ nz = float(parts[3])
+ # 同步反转法线的 Y 轴
+ new_line = f"vn {nx} {nz} {-ny}\n"
+ f_out.write(new_line)
+ else:
+ # 其他行直接写入
+ f_out.write(line)
+
+ return obj_dst_file
+ except Exception as e:
+ self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
+ raise
+
+ def downsample_texture(self, src_dir: str, dst_dir: str):
+ """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB
+ Args:
+ src_dir: 源纹理目录
+ dst_dir: 目标纹理目录
+ """
+ for file in os.listdir(src_dir):
+ if file.lower().endswith(('.png')):
+ src_path = os.path.join(src_dir, file)
+ dst_path = os.path.join(dst_dir, file)
+
+ # 检查文件大小(以字节为单位)
+ file_size = os.path.getsize(src_path)
+ if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制
+ shutil.copy2(src_path, dst_path)
+ else:
+ # 文件大于100MB,进行下采样
+ img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
+ if_first_ds = True
+ while file_size > 100 * 1024 * 1024: # 大于100MB
+ self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样")
+
+ if if_first_ds:
+ # 计算新的尺寸(长宽各变为1/4)
+ new_size = (img.shape[1] // 4,
+ img.shape[0] // 4) # 逐步减小尺寸
+ # 使用双三次插值进行下采样
+ resized_img = cv2.resize(
+ img, new_size, interpolation=cv2.INTER_CUBIC)
+ if_first_ds = False
+ else:
+ # 计算新的尺寸(长宽各变为1/2)
+ new_size = (img.shape[1] // 2,
+ img.shape[0] // 2) # 逐步减小尺寸
+ # 使用双三次插值进行下采样
+ resized_img = cv2.resize(
+ img, new_size, interpolation=cv2.INTER_CUBIC)
+
+ # 更新文件路径为下采样后的路径
+ cv2.imwrite(dst_path, resized_img, [
+ cv2.IMWRITE_PNG_COMPRESSION, 9])
+
+ # 更新文件大小和图像
+ file_size = os.path.getsize(dst_path)
+ img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
+ self.logger.info(
+ f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
+
+ def get_min_z_from_obj(self, file_path):
+ min_z = float('inf') # 初始值设为无穷大
+ with open(file_path, 'r') as obj_file:
+ for line in obj_file:
+ # 检查每一行是否是顶点定义(以 'v ' 开头)
+ if line.startswith('v '):
+ # 获取顶点坐标
+ parts = line.split()
+ # 将z值转换为浮动数字
+ z = float(parts[3])
+ # 更新最小z值
+ if z < min_z:
+ min_z = z
+ return min_z
+
+ def modify_z_in_obj(self, texturing_dir, min_25d_z):
+ obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
+ output_file = os.path.join(
+ texturing_dir, 'odm_textured_model_modified.obj')
+ with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
+ for line in f_in:
+ if line.startswith('v '): # 顶点坐标行
+ parts = line.strip().split()
+ x = float(parts[1])
+ y = float(parts[2])
+ z = float(parts[3])
+
+ if z < min_25d_z:
+ z = min_25d_z
+
+ f_out.write(f"v {x} {y} {z}\n")
+ else:
+ f_out.write(line)
diff --git a/run.py b/run.py
new file mode 100644
index 0000000..d583018
--- /dev/null
+++ b/run.py
@@ -0,0 +1,59 @@
+import argparse
+from datetime import timedelta
+from main import ProcessConfig, ODM_Plugin
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='ODM预处理工具')
+
+ # 必需参数
+ # parser.add_argument('--image_dir', required=True, help='输入图片目录路径')
+ # parser.add_argument('--output_dir', required=True, help='输出目录路径')
+ parser.add_argument(
+ '--image_dir', default=r'E:\datasets\UAV\199', help='输入图片目录路径')
+ parser.add_argument(
+ '--output_dir', default=r'G:\ODM_output\test2', help='输出目录路径')
+ # 可选参数
+ parser.add_argument('--mode', default='三维模式',
+ choices=['快拼模式', '三维模式'], help='处理模式')
+ parser.add_argument('--accuracy', default='medium',
+ choices=['high', 'medium', 'low'], help='精度')
+ parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)')
+ parser.add_argument('--grid_overlap', type=float,
+ default=0.05, help='网格重叠率')
+
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ args = parse_args()
+
+ # 创建配置
+ config = ProcessConfig(
+ image_dir=args.image_dir,
+ output_dir=args.output_dir,
+ mode=args.mode,
+ accuracy=args.accuracy,
+ grid_size=args.grid_size,
+ grid_overlap=args.grid_overlap,
+
+ # 其他参数使用默认值
+ cluster_eps=0.01,
+ cluster_min_samples=5,
+ time_group_overlap_threshold=0.7,
+ time_group_interval=timedelta(minutes=5),
+ filter_distance_threshold=0.001,
+ filter_min_neighbors=6,
+ filter_grid_size=0.001,
+ filter_dense_distance_threshold=10,
+ filter_time_threshold=timedelta(minutes=5),
+ )
+
+ # 创建处理器并执行
+ processor = ODM_Plugin(config)
+ processor.process()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/utils/docker_runner.py b/utils/docker_runner.py
new file mode 100644
index 0000000..de74a73
--- /dev/null
+++ b/utils/docker_runner.py
@@ -0,0 +1,89 @@
+import docker
+import os
+import logging
+from collections import deque
+
+
+class DockerRunner:
+ def __init__(self, project_path: str):
+ """
+ 初始化 DockerRunner
+
+ Args:
+ project_path (str): 项目路径,将挂载到 Docker 容器中
+ """
+ self.project_path = project_path
+ self.logger = logging.getLogger("DockerRunner")
+ self.docker_client = docker.from_env()
+
+ def run_odm_container(self):
+ """
+ 使用 Docker SDK 运行 OpenDroneMap 容器
+ """
+ try:
+ self.logger.info("开始运行docker run指令")
+ # 挂载路径
+ volume_mapping = {
+ self.project_path: {
+ 'bind': '/datasets',
+ 'mode': 'rw'
+ }
+ }
+
+ # Docker 命令参数
+ command = [
+ "--project-path", "/datasets",
+ "project",
+ "--max-concurrency", "15",
+ "--force-gps",
+ "--split-overlap", "0",
+ ]
+
+ # 运行容器
+ container = self.docker_client.containers.run(
+ image="opendronemap/odm:gpu",
+ command=command,
+ volumes=volume_mapping,
+ device_requests=[
+ docker.types.DeviceRequest(
+ count=-1, capabilities=[["gpu"]])
+ ], # 添加 GPU 支持
+ remove=False, # 容器运行结束后不自动删除,便于获取日志
+ tty=True,
+ detach=True # 后台运行
+ )
+
+ # 等待容器运行完成
+ exit_status = container.wait()
+ if exit_status["StatusCode"] != 0:
+ self.logger.error(f"容器运行失败,退出状态码: {exit_status['StatusCode']}")
+
+ # 获取容器的错误日志
+ error_logs = container.logs(
+ stderr=True).decode("utf-8").splitlines()
+ self.logger.error("容器运行失败的详细错误日志:")
+ for line in error_logs:
+ self.logger.error(line)
+
+ else:
+ # 获取所有日志
+ logs = container.logs().decode("utf-8").splitlines()
+
+ # 输出最后 50 行日志
+ self.logger.info("容器运行完成,以下是最后 50 行日志:")
+ for line in logs[-50:]:
+ self.logger.info(line)
+
+ # 删除容器
+ container.remove()
+
+ except Exception as e:
+ self.logger.error(f"运行 Docker 容器时发生错误: {str(e)}", exc_info=True)
+ raise
+
+
+if __name__ == "__main__":
+ # 示例用法
+ project_path = r"E:\datasets\UAV\199"
+ docker_runner = DockerRunner(project_path)
+ docker_runner.run_odm_container()
diff --git a/utils/gps_extractor.py b/utils/gps_extractor.py
new file mode 100644
index 0000000..7d1c5c2
--- /dev/null
+++ b/utils/gps_extractor.py
@@ -0,0 +1,96 @@
+import os
+from PIL import Image
+import piexif
+import logging
+import pandas as pd
+from datetime import datetime
+
+
+class GPSExtractor:
+ """从图像文件提取GPS坐标和拍摄日期"""
+
+ def __init__(self, project_path):
+ self.image_dir = os.path.join(project_path, 'project', 'images')
+ self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor')
+
+ @staticmethod
+ def _dms_to_decimal(dms):
+ """将DMS格式转换为十进制度"""
+ return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600
+
+ @staticmethod
+ def _parse_datetime(datetime_str):
+ """解析EXIF中的日期时间字符串"""
+ try:
+ # EXIF日期格式通常为 'YYYY:MM:DD HH:MM:SS'
+ return datetime.strptime(datetime_str.decode(), '%Y:%m:%d %H:%M:%S')
+ except Exception:
+ return None
+
+ def get_gps_and_date(self, image_path):
+ """提取单张图片的GPS坐标和拍摄日期"""
+ try:
+ image = Image.open(image_path)
+ exif_data = piexif.load(image.info['exif'])
+
+ # 提取GPS信息
+ gps_info = exif_data.get("GPS", {})
+ lat = lon = None
+ if gps_info:
+ lat = self._dms_to_decimal(gps_info.get(2, []))
+ lon = self._dms_to_decimal(gps_info.get(4, []))
+ self.logger.debug(
+ f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}")
+
+ # 提取拍摄日期
+ date_info = None
+ if "Exif" in exif_data:
+ # 优先使用DateTimeOriginal
+ date_str = exif_data["Exif"].get(36867) # DateTimeOriginal
+ if not date_str:
+ # 备选DateTime
+ date_str = exif_data["Exif"].get(
+ 36868) # DateTimeDigitized
+ if not date_str:
+ # 最后使用基本DateTime
+ date_str = exif_data["0th"].get(306) # DateTime
+
+ if date_str:
+ date_info = self._parse_datetime(date_str)
+ self.logger.debug(
+ f"成功提取图片拍摄日期: {image_path} - {date_info}")
+
+ if not gps_info:
+ self.logger.warning(f"图片无GPS信息: {image_path}")
+ if not date_info:
+ self.logger.warning(f"图片无拍摄日期信息: {image_path}")
+
+ return lat, lon, date_info
+
+ except Exception as e:
+ self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}")
+ return None, None, None
+
+ def extract_all_gps(self):
+ """提取所有图片的GPS坐标和拍摄日期"""
+ self.logger.info(f"开始从目录提取GPS坐标和拍摄日期: {self.image_dir}")
+ gps_data = []
+ total_images = 0
+ successful_extractions = 0
+
+ for image_file in os.listdir(self.image_dir):
+ total_images += 1
+ image_path = os.path.join(self.image_dir, image_file)
+ lat, lon, date = self.get_gps_and_date(image_path)
+ if lat and lon: # 仍然以GPS信息作为主要判断依据
+ successful_extractions += 1
+ gps_data.append({
+ 'file': image_file,
+ 'lat': lat,
+ 'lon': lon,
+ 'date': date
+ })
+
+ self.logger.info(
+ f"GPS坐标和拍摄日期提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}")
+ return pd.DataFrame(gps_data)
diff --git a/utils/grid_divider.py b/utils/grid_divider.py
new file mode 100644
index 0000000..04c3fd8
--- /dev/null
+++ b/utils/grid_divider.py
@@ -0,0 +1,249 @@
+import logging
+from geopy.distance import geodesic
+import matplotlib.pyplot as plt
+import os
+
+
+class GridDivider:
+ """划分网格,并将图片分配到对应网格"""
+
+ def __init__(self, overlap, grid_size, project_path, output_dir):
+ self.overlap = overlap
+ self.grid_size = grid_size
+ self.project_path = project_path
+ self.output_dir = output_dir
+ self.logger = logging.getLogger('UAV_Preprocess.GridDivider')
+ self.logger.info(f"初始化网格划分器,重叠率: {overlap}")
+ self.num_grids_width = 0 # 添加网格数量属性
+ self.num_grids_height = 0
+
+ def adjust_grid_size_and_overlap(self, points_df):
+ """动态调整网格重叠率"""
+ grids = self.adjust_grid_size(points_df)
+ self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}")
+ while True:
+ # 使用调整好的网格大小划分网格
+ grids = self.divide_grids(points_df)
+ grid_points, multiple_grid_points = self.assign_to_grids(
+ points_df, grids)
+
+ if len(grids) == 1:
+ self.logger.info(f"网格数量为1,跳过重叠率调整")
+ break
+ elif multiple_grid_points < 0.1*len(points_df):
+ self.overlap += 0.02
+ self.logger.info(f"重叠率增加到: {self.overlap}")
+ else:
+ self.logger.info(
+ f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格")
+ break
+ return grids, grid_points
+
+ def adjust_grid_size(self, points_df):
+ """动态调整网格大小
+
+ Args:
+ points_df: 包含GPS点的DataFrame
+
+ Returns:
+ tuple: (grids, translations, grid_points, final_grid_size)
+ """
+ self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}米")
+
+ while True:
+ # 使用当前grid_size划分网格
+ grids = self.divide_grids(points_df)
+ grid_points, multiple_grid_points = self.assign_to_grids(
+ points_df, grids)
+
+ # 检查每个网格中的点数
+ max_points = 0
+ for grid_id, points in grid_points.items():
+ max_points = max(max_points, len(points))
+
+ self.logger.info(
+ f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}")
+
+ # 如果最大点数超过1600,减小网格大小
+ if max_points > 1600:
+ self.grid_size -= 100
+ self.logger.info(f"点数超过1500,减小网格大小至: {self.grid_size}米")
+ if self.grid_size < 500: # 设置一个最小网格大小限制
+ self.logger.warning("网格大小已达到最小值500米,停止调整")
+ break
+ else:
+ self.logger.info(f"找到合适的网格大小: {self.grid_size}米")
+ break
+ return grids
+
+ def divide_grids(self, points_df):
+ """计算边界框并划分网格
+ Returns:
+ tuple: (grids, translations)
+ - grids: 网格边界列表
+ - translations: 网格平移量字典
+ """
+ self.logger.info("开始划分网格")
+
+ min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max()
+ min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max()
+
+ # 计算区域的实际距离(米)
+ width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
+ height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters
+
+ self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}米")
+
+ # 精细调整网格的长宽,避免出现2*grid_size-1的情况的影响
+ grid_size_lt = [self.grid_size - 200, self.grid_size - 100,
+ self.grid_size, self.grid_size + 100, self.grid_size + 200]
+
+ width_modulus_lt = [width % grid_size for grid_size in grid_size_lt]
+ grid_width = grid_size_lt[width_modulus_lt.index(
+ min(width_modulus_lt))]
+ height_modulus_lt = [height % grid_size for grid_size in grid_size_lt]
+ grid_height = grid_size_lt[height_modulus_lt.index(
+ min(height_modulus_lt))]
+ self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}米")
+
+ # 计算需要划分的网格数量
+ self.num_grids_width = max(int(width / grid_width), 1)
+ self.num_grids_height = max(int(height / grid_height), 1)
+
+ # 计算每个网格对应的经纬度步长
+ lat_step = (max_lat - min_lat) / self.num_grids_height
+ lon_step = (max_lon - min_lon) / self.num_grids_width
+
+ grids = []
+
+ # 先创建所有网格
+ for i in range(self.num_grids_height):
+ for j in range(self.num_grids_width):
+ grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step
+ grid_max_lat = min_lat + \
+ (i + 1) * lat_step + self.overlap * lat_step
+ grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step
+ grid_max_lon = min_lon + \
+ (j + 1) * lon_step + self.overlap * lon_step
+
+ grid_bounds = (grid_min_lat, grid_max_lat,
+ grid_min_lon, grid_max_lon)
+ grids.append(grid_bounds)
+
+ self.logger.debug(
+ f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], "
+ f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]"
+ )
+
+ self.logger.info(
+ f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})")
+
+ return grids
+
+ def assign_to_grids(self, points_df, grids):
+ """将点分配到对应网格"""
+ self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中")
+
+ grid_points = {} # 使用字典存储每个网格的点
+ points_assigned = 0
+ multiple_grid_points = 0
+
+ for i in range(self.num_grids_height):
+ for j in range(self.num_grids_width):
+ grid_points[(i, j)] = [] # 使用(i,j)元组
+
+ for _, point in points_df.iterrows():
+ point_assigned = False
+ for i in range(self.num_grids_height):
+ for j in range(self.num_grids_width):
+ grid_idx = i * self.num_grids_width + j
+ min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
+
+ if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon:
+ grid_points[(i, j)].append(point.to_dict())
+ if point_assigned:
+ multiple_grid_points += 1
+ else:
+ points_assigned += 1
+ point_assigned = True
+
+ # 记录每个网格的点数
+ for grid_id, points in grid_points.items():
+ self.logger.info(f"网格 {grid_id} 包含 {len(points)} 个点")
+
+ self.logger.info(
+ f"点分配完成: 总点数 {len(points_df)}, "
+ f"成功分配 {points_assigned} 个点, "
+ f"{multiple_grid_points} 个点被分配到多个网格"
+ )
+
+ return grid_points, multiple_grid_points
+
+ def visualize_grids(self, points_df, grids):
+ """可视化网格划分和GPS点的分布"""
+ self.logger.info("开始可视化网格划分")
+
+ plt.figure(figsize=(12, 8))
+
+ # 绘制GPS点
+ plt.scatter(points_df['lon'], points_df['lat'],
+ c='blue', s=10, alpha=0.6, label='GPS points')
+
+ # 绘制网格
+ for i in range(self.num_grids_height):
+ for j in range(self.num_grids_width):
+ grid_idx = i * self.num_grids_width + j
+ min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
+
+ # 计算网格的实际长度和宽度(米)
+ width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
+ height = geodesic((min_lat, min_lon),
+ (max_lat, min_lon)).meters
+
+ plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon],
+ [min_lat, min_lat, max_lat, max_lat, min_lat],
+ 'r-', alpha=0.5)
+ # 在网格中心添加网格编号和尺寸信息
+ center_lon = (min_lon + max_lon) / 2
+ center_lat = (min_lat + max_lat) / 2
+ plt.text(center_lon, center_lat,
+ f"({i},{j})\n{width:.0f}m×{height:.0f}m", # 显示(i,j)和尺寸
+ horizontalalignment='center',
+ verticalalignment='center',
+ fontsize=8)
+
+ plt.title('Grid Division and GPS Point Distribution')
+ plt.xlabel('Longitude')
+ plt.ylabel('Latitude')
+ plt.legend()
+ plt.grid(True)
+
+ # 如果提供了输出目录,保存图像
+ if self.output_dir:
+ save_path = os.path.join(
+ self.output_dir, 'filter_imgs', 'grid_division.png')
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ self.logger.info(f"网格划分可视化图已保存至: {save_path}")
+
+ plt.close()
+
+ def save_image_groups(self, grid_points, output_file_name="image_groups.txt"):
+ """保存图像分组信息到文件
+
+ Args:
+ grid_points (dict): 每个网格的点信息,键为(i, j),值为点的列表
+ output_file (str): 输出文件路径
+ """
+ self.logger.info(f"开始保存图像分组信息到 {output_file_name}")
+
+ output_file = os.path.join(
+ self.project_path, 'project', output_file_name)
+ with open(output_file, 'w') as f:
+ for (i, j), points in grid_points.items():
+ # 计算组编号(按行展开的顺序)
+ group_id = i * self.num_grids_width + j + 1
+ for point in points:
+ image_name = point.get('file', 'unknown')
+ f.write(f"{image_name} {group_id}\n")
+
+ self.logger.info(f"图像分组信息已保存到 {output_file}")
diff --git a/utils/logger.py b/utils/logger.py
new file mode 100644
index 0000000..80c1f10
--- /dev/null
+++ b/utils/logger.py
@@ -0,0 +1,36 @@
+import logging
+import os
+from datetime import datetime
+
+
+def setup_logger(output_dir):
+ # 创建logs目录
+ log_dir = os.path.join(output_dir, 'logs')
+
+ # 创建日志文件名(包含时间戳)
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
+ log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log')
+
+ # 配置日志格式
+ formatter = logging.Formatter(
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S'
+ )
+
+ # 配置文件处理器
+ file_handler = logging.FileHandler(log_file, encoding='utf-8')
+ file_handler.setFormatter(formatter)
+
+ # 配置控制台处理器
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(formatter)
+
+ # 获取根日志记录器
+ logger = logging.getLogger('UAV_Preprocess')
+ logger.setLevel(logging.INFO)
+
+ # 添加处理器
+ logger.addHandler(file_handler)
+ logger.addHandler(console_handler)
+
+ return logger
diff --git a/utils/visualizer.py b/utils/visualizer.py
new file mode 100644
index 0000000..964bd53
--- /dev/null
+++ b/utils/visualizer.py
@@ -0,0 +1,152 @@
+import os
+import matplotlib.pyplot as plt
+import pandas as pd
+import logging
+from typing import Optional
+from pyproj import Transformer
+
+
+class FilterVisualizer:
+ """过滤结果可视化器"""
+
+ def __init__(self, output_dir: str):
+ """
+ 初始化可视化器
+
+ Args:
+ output_dir: 输出目录路径
+ """
+ self.output_dir = output_dir
+ self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
+ # 创建坐标转换器
+ self.transformer = Transformer.from_crs(
+ "EPSG:4326", # WGS84经纬度坐标系
+ "EPSG:32649", # UTM49N
+ always_xy=True
+ )
+
+ def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple:
+ """
+ 将经纬度坐标转换为UTM坐标
+
+ Args:
+ lon: 经度序列
+ lat: 纬度序列
+
+ Returns:
+ tuple: (x坐标, y坐标)
+ """
+ return self.transformer.transform(lon, lat)
+
+ def visualize_filter_step(self,
+ current_points: pd.DataFrame,
+ previous_points: pd.DataFrame,
+ step_name: str,
+ save_name: Optional[str] = None):
+ """
+ 可视化单个过滤步骤的结果
+
+ Args:
+ current_points: 当前步骤后的点
+ previous_points: 上一步骤的点
+ step_name: 步骤名称
+ save_name: 保存文件名,默认为step_name
+ """
+ self.logger.info(f"开始生成{step_name}的可视化结果")
+
+ # 找出被过滤掉的点
+ filtered_files = set(
+ previous_points['file']) - set(current_points['file'])
+ filtered_points = previous_points[previous_points['file'].isin(
+ filtered_files)]
+
+ # 转换坐标到UTM
+ current_x, current_y = self._convert_to_utm(
+ current_points['lon'], current_points['lat'])
+ filtered_x, filtered_y = self._convert_to_utm(
+ filtered_points['lon'], filtered_points['lat'])
+
+ # 创建图形
+ plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体
+ plt.rcParams['axes.unicode_minus'] = False
+ plt.figure(figsize=(20, 16))
+
+ # 绘制保留的点
+ plt.scatter(current_x, current_y,
+ color='blue', label='保留的点',
+ alpha=0.6, s=50)
+
+ # 绘制被过滤的点
+ if not filtered_points.empty:
+ plt.scatter(filtered_x, filtered_y,
+ color='red', marker='x', label='过滤的点',
+ alpha=0.6, s=100)
+
+ # 设置图形属性
+ plt.title(f"{step_name}后的GPS点\n"
+ f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})",
+ fontsize=14)
+ plt.xlabel("东向坐标 (米)", fontsize=12)
+ plt.ylabel("北向坐标 (米)", fontsize=12)
+ plt.grid(True)
+
+ # 添加统计信息
+ stats_text = (
+ f"原始点数: {len(previous_points)}\n"
+ f"过滤点数: {len(filtered_points)}\n"
+ f"保留点数: {len(current_points)}\n"
+ f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%"
+ )
+ plt.figtext(0.02, 0.02, stats_text, fontsize=10,
+ bbox=dict(facecolor='white', alpha=0.8))
+
+ # 添加图例
+ plt.legend(loc='upper right', fontsize=10)
+
+ # 调整布局
+ plt.tight_layout()
+
+ # 保存图形
+ save_name = save_name or step_name.lower().replace(' ', '_')
+ save_path = os.path.join(
+ self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
+ plt.close()
+
+ self.logger.info(
+ f"{step_name}过滤可视化结果已保存至 {save_path}\n"
+ f"过滤掉 {len(filtered_points)} 个点,"
+ f"保留 {len(current_points)} 个点,"
+ f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%"
+ )
+
+
+if __name__ == '__main__':
+ # 测试代码
+ import numpy as np
+ from datetime import datetime
+
+ # 创建测试数据
+ np.random.seed(42)
+ n_points = 1000
+
+ # 生成随机点
+ test_data = pd.DataFrame({
+ 'lon': np.random.uniform(120, 121, n_points),
+ 'lat': np.random.uniform(30, 31, n_points),
+ 'file': [f'img_{i}.jpg' for i in range(n_points)],
+ 'date': [datetime.now() for _ in range(n_points)]
+ })
+
+ # 随机选择点作为过滤后的结果
+ filtered_data = test_data.sample(n=800)
+
+ # 测试可视化
+ visualizer = FilterVisualizer('test_output')
+ os.makedirs('test_output', exist_ok=True)
+
+ visualizer.visualize_filter_step(
+ filtered_data,
+ test_data,
+ "Test Filter"
+ )