From e8178d191e0d3e4063639bceacd2e0e18ebf6096 Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Thu, 12 Jun 2025 22:11:27 +0800 Subject: [PATCH] first commit --- .gitignore | 7 + README.md | 15 ++ app_plugin.py | 163 +++++++++++++++++++++ check_version.py | 24 ++++ filter/cluster_filter.py | 77 ++++++++++ main.py | 54 +++++++ main.spec | 47 ++++++ post_pro/conv_obj.py | 265 ++++++++++++++++++++++++++++++++++ post_pro/conv_obj2.py | 253 ++++++++++++++++++++++++++++++++ post_pro/merge_tif.py | 285 +++++++++++++++++++++++++++++++++++++ requirements.txt | 15 ++ trans_orthophoto.py | 33 +++++ utils/directory_manager.py | 81 +++++++++++ utils/gps_extractor.py | 65 +++++++++ utils/grid_divider.py | 266 ++++++++++++++++++++++++++++++++++ utils/logger.py | 36 +++++ utils/odm_monitor.py | 144 +++++++++++++++++++ utils/visualizer.py | 152 ++++++++++++++++++++ 18 files changed, 1982 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 app_plugin.py create mode 100644 check_version.py create mode 100644 filter/cluster_filter.py create mode 100644 main.py create mode 100644 main.spec create mode 100644 post_pro/conv_obj.py create mode 100644 post_pro/conv_obj2.py create mode 100644 post_pro/merge_tif.py create mode 100644 requirements.txt create mode 100644 trans_orthophoto.py create mode 100644 utils/directory_manager.py create mode 100644 utils/gps_extractor.py create mode 100644 utils/grid_divider.py create mode 100644 utils/logger.py create mode 100644 utils/odm_monitor.py create mode 100644 utils/visualizer.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..775aa57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +# 忽略所有__pycache__目录 +**/__pycache__/ +*.pyc +*.pyo +*.pyd + +test/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..4441e42 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# ODM_Pro +无人机三维重建。 + +## Install + +```bash +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +## 制作exe + +1. odm_monitor.py docker容器名要记得改成uav +2. osgconv的compressed参数去掉,会报错 +3. 运行指令`pyinstaller main.spec` +4. 测试执行指令`uav.exe --image_dir E:\datasets\UAV\134\project\images --output_dir G:\ODM_output\134` diff --git a/app_plugin.py b/app_plugin.py new file mode 100644 index 0000000..b73aca6 --- /dev/null +++ b/app_plugin.py @@ -0,0 +1,163 @@ +import os +import shutil +from dataclasses import dataclass +from typing import Dict, Tuple +import pandas as pd + +from filter.cluster_filter import GPSCluster +from utils.directory_manager import DirectoryManager +from utils.odm_monitor import ODMProcessMonitor +from utils.gps_extractor import GPSExtractor +from utils.grid_divider import GridDivider +from utils.logger import setup_logger +from utils.visualizer import FilterVisualizer +from post_pro.merge_tif import MergeTif +from post_pro.conv_obj2 import ConvertOBJ + + +@dataclass +class ProcessConfig: + """预处理配置类""" + image_dir: str + output_dir: str + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + + # 网格划分参数 + grid_overlap: float = 0.05 + grid_size: float = 500 + + mode: str = "三维模式" + + # ODM参数 + feature_type: str = "sift" + orthophoto_resolution: float = 5 + + +class ODM_Plugin: + def __init__(self, config): + self.config = config + + # 初始化目录管理器 + self.dir_manager = DirectoryManager(config) + # 清理并重建输出目录 + self.dir_manager.clean_output_dir() + self.dir_manager.setup_output_dirs() + # 检查磁盘空间 + self.dir_manager.check_disk_space() + + # 初始化其他组件 + self.logger = setup_logger(config.output_dir) + self.gps_points = pd.DataFrame(columns=["file", "lat", "lon"]) + self.odm_monitor = ODMProcessMonitor( + config.output_dir, mode=config.mode, config=config) + self.visualizer = FilterVisualizer(config.output_dir) + + def extract_gps(self) -> pd.DataFrame: + """提取GPS数据""" + self.logger.info("开始提取GPS数据") + extractor = GPSExtractor(self.config.image_dir) + self.gps_points = extractor.extract_all_gps() + self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + + def cluster(self): + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + previous_points = self.gps_points.copy() + clusterer = GPSCluster( + self.gps_points, + eps=self.config.cluster_eps, + min_samples=self.config.cluster_min_samples + ) + + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_cluster_stats(self.clustered_points) + + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "1-Clustering") + + def divide_grids(self) -> Dict[tuple, pd.DataFrame]: + """划分网格 + Returns: + - grid_points: 网格点数据字典 + - translations: 网格平移量字典 + """ + grid_divider = GridDivider( + overlap=self.config.grid_overlap, + grid_size=self.config.grid_size, + output_dir=self.config.output_dir + ) + grids, grid_points = grid_divider.adjust_grid_size_and_overlap( + self.gps_points + ) + grid_divider.visualize_grids(self.gps_points, grids) + if len(grids) >= 20: + self.logger.warning("网格数量已超过20, 需要人工调整分区") + + return grid_points + + def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]): + """复制图像到目标文件夹""" + self.logger.info("开始复制图像文件") + + for grid_id, points in grid_points.items(): + output_dir = os.path.join( + self.config.output_dir, + f"grid_{grid_id[0]}_{grid_id[1]}", + "project", + "images" + ) + + os.makedirs(output_dir, exist_ok=True) + + for point in points: + src = os.path.join(self.config.image_dir, point["file"]) + dst = os.path.join(output_dir, point["file"]) + shutil.copy(src, dst) + self.logger.info( + f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像") + + def merge_tif(self, grid_lt): + """合并所有网格的影像产品""" + self.logger.info("开始合并所有影像产品") + merger = MergeTif(self.config.output_dir) + merger.merge_orthophoto(grid_lt) + + def convert_obj(self, grid_lt): + """转换OBJ模型""" + self.logger.info("开始转换OBJ模型") + converter = ConvertOBJ(self.config.output_dir) + converter.convert_grid_obj(grid_lt) + + def post_process(self, successful_grid_lt: list, grid_points: Dict[tuple, pd.DataFrame]): + """后处理:合并或复制处理结果""" + if len(successful_grid_lt) < len(grid_points): + self.logger.warning( + f"有 {len(grid_points) - len(successful_grid_lt)} 个网格处理失败," + f"将只合并成功处理的 {len(successful_grid_lt)} 个网格" + ) + + self.merge_tif(successful_grid_lt) + if self.config.mode == "三维模式": + self.convert_obj(successful_grid_lt) + else: + pass + + def process(self): + """执行完整的预处理流程""" + try: + self.extract_gps() + self.cluster() + grid_points = self.divide_grids() + self.copy_images(grid_points) + self.logger.info("预处理任务完成") + + successful_grid_lt = self.odm_monitor.process_all_grids( + grid_points) + + self.post_process(successful_grid_lt, grid_points) + self.logger.info("重建任务完成") + + except Exception as e: + self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + raise diff --git a/check_version.py b/check_version.py new file mode 100644 index 0000000..f0652fa --- /dev/null +++ b/check_version.py @@ -0,0 +1,24 @@ +import importlib +import pkg_resources + +def check_versions(requirements_file): + with open(requirements_file, 'r', encoding='utf-8') as f: + packages = [line.strip() for line in f if line.strip() and not line.startswith('#') and not line.startswith('//')] + + print(f"{'Package':<20} {'Installed Version':<20}") + print('-' * 40) + for pkg in packages: + try: + # 有些包名和import名不同,优先用pkg_resources + version = pkg_resources.get_distribution(pkg).version + print(f"{pkg:<20} {version:<20}") + except Exception: + try: + mod = importlib.import_module(pkg) + version = getattr(mod, '__version__', 'Unknown') + print(f"{pkg:<20} {version:<20}") + except Exception: + print(f"{pkg:<20} {'Not Installed':<20}") + +if __name__ == "__main__": + check_versions("requirements.txt") \ No newline at end of file diff --git a/filter/cluster_filter.py b/filter/cluster_filter.py new file mode 100644 index 0000000..473da7d --- /dev/null +++ b/filter/cluster_filter.py @@ -0,0 +1,77 @@ +from sklearn.cluster import DBSCAN +from sklearn.preprocessing import StandardScaler +import os +import logging + + +class GPSCluster: + def __init__(self, gps_points, eps=0.01, min_samples=3): + """ + 初始化GPS聚类器 + + 参数: + eps: DBSCAN的邻域半径参数 + min_samples: DBSCAN的最小样本数参数 + """ + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) + self.scaler = StandardScaler() + self.gps_points = gps_points + self.logger = logging.getLogger('UAV_Preprocess.GPSCluster') + + def fit(self): + """ + 对GPS点进行聚类,只保留最大的类 + + 参数: + gps_points: 包含'lat'和'lon'列的DataFrame + + 返回: + 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1 + """ + self.logger.info("开始聚类") + # 提取经纬度数据 + X = self.gps_points[["lon", "lat"]].values + + # # 数据标准化 + # X_scaled = self.scaler.fit_transform(X) + + # 执行DBSCAN聚类 + labels = self.dbscan.fit_predict(X) + + # 找出最大类的标签(排除噪声点-1) + unique_labels = [l for l in set(labels) if l != -1] + if unique_labels: # 如果有聚类 + label_counts = [(l, sum(labels == l)) for l in unique_labels] + largest_label = max(label_counts, key=lambda x: x[1])[0] + + # 将最大类标记为1,其他都标记为-1 + new_labels = (labels == largest_label).astype(int) + new_labels[new_labels == 0] = -1 + else: # 如果没有聚类,全部标记为-1 + new_labels = labels + + # 将聚类结果添加到原始数据中 + result_df = self.gps_points.copy() + result_df["cluster"] = new_labels + + return result_df + + def get_cluster_stats(self, clustered_points): + """ + 获取聚类统计信息 + + 参数: + clustered_points: 带有聚类标签的DataFrame + + 返回: + 聚类统计信息的字典 + """ + main_cluster = clustered_points[clustered_points["cluster"] == 1] + noise_cluster = clustered_points[clustered_points["cluster"] == -1] + + self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点," + f"噪声点 {len(noise_cluster)} 个") + + return main_cluster diff --git a/main.py b/main.py new file mode 100644 index 0000000..ccf4919 --- /dev/null +++ b/main.py @@ -0,0 +1,54 @@ +import argparse +from app_plugin import ProcessConfig, ODM_Plugin + + +def parse_args(): + parser = argparse.ArgumentParser(description='ODM预处理工具') + + # 必需参数 + parser.add_argument('--image_dir', required=True, help='输入图片目录路径') + parser.add_argument('--output_dir', required=True, help='输出目录路径') + # parser.add_argument('--image_dir', default=r'G:\UAV_data\test_31\project\images', help='输入图片目录路径') + # parser.add_argument('--output_dir', default=r'G:\ODM_output\test_31', help='输出目录路径') + + # 可选参数 + parser.add_argument('--mode', default='三维模式', + choices=['快拼模式', '三维模式'], help='处理模式') + parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)') + parser.add_argument('--grid_overlap', type=float, + default=0.1, help='网格重叠率') + + # ODM参数 + parser.add_argument('--feature_type', default='sift', help='特征类型') + parser.add_argument('--orthophoto_resolution', + type=float, default=5, help='正射影像分辨率') + + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + # 创建配置 + config = ProcessConfig( + image_dir=args.image_dir, + output_dir=args.output_dir, + mode=args.mode, + grid_size=args.grid_size, + grid_overlap=args.grid_overlap, + feature_type=args.feature_type, + orthophoto_resolution=args.orthophoto_resolution, + + # 其他参数使用默认值 + cluster_eps=0.01, + cluster_min_samples=5, + ) + + # 创建处理器并执行 + processor = ODM_Plugin(config) + processor.process() + + +if __name__ == '__main__': + main() diff --git a/main.spec b/main.spec new file mode 100644 index 0000000..126e1d5 --- /dev/null +++ b/main.spec @@ -0,0 +1,47 @@ +# -*- mode: python ; coding: utf-8 -*- + + +a = Analysis( + ['main.py'], + pathex=[], + binaries=[], + datas=[], + hiddenimports=[ + 'rasterio.sample', + 'rasterio.vrt', + 'rasterio._io', + 'rasterio.enums', + 'rasterio.errors', + 'rasterio.transform', + 'rasterio.crs', + 'rasterio.warp' + ], + hookspath=[], + hooksconfig={}, + runtime_hooks=[], + excludes=[], + noarchive=False, + optimize=0, +) +pyz = PYZ(a.pure) + +exe = EXE( + pyz, + a.scripts, + a.binaries, + a.datas, + [], + name='uav', + debug=False, + bootloader_ignore_signals=False, + strip=False, + upx=True, + upx_exclude=[], + runtime_tmpdir=None, + console=True, + disable_windowed_traceback=False, + argv_emulation=False, + target_arch=None, + codesign_identity=None, + entitlements_file=None, +) diff --git a/post_pro/conv_obj.py b/post_pro/conv_obj.py new file mode 100644 index 0000000..a7e9975 --- /dev/null +++ b/post_pro/conv_obj.py @@ -0,0 +1,265 @@ +import os +import subprocess +import json +import shutil +import logging +from pyproj import Transformer +import cv2 + + +class ConvertOBJ: + def __init__(self, output_dir: str): + self.output_dir = output_dir + # 用于存储所有grid的UTM范围 + self.ref_east = float('inf') + self.ref_north = float('inf') + # 初始化UTM到WGS84的转换器 + self.transformer = Transformer.from_crs( + "EPSG:32649", "EPSG:4326", always_xy=True) + self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ') + + def convert_grid_obj(self, grid_lt): + """转换每个网格的OBJ文件为OSGB格式""" + os.makedirs(os.path.join(self.output_dir, + "osgb", "Data"), exist_ok=True) + + # 以第一个grid的UTM坐标作为参照系 + first_grid_id = grid_lt[0] + first_grid_dir = os.path.join( + self.output_dir, + f"grid_{first_grid_id[0]}_{first_grid_id[1]}", + "project" + ) + log_file = os.path.join( + first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + self.ref_east, self.ref_north = self.read_utm_offset(log_file) + + for grid_id in grid_lt: + try: + self._convert_single_grid(grid_id) + except Exception as e: + self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}") + + self._create_merged_metadata() + + def _convert_single_grid(self, grid_id): + """转换单个网格的OBJ文件""" + # 构建相关路径 + grid_name = f"grid_{grid_id[0]}_{grid_id[1]}" + project_dir = os.path.join(self.output_dir, grid_name, "project") + texturing_dir = os.path.join(project_dir, "odm_texturing") + texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst") + split_obj_dir = os.path.join(texturing_dst_dir, "split_obj") + log_file = os.path.join( + project_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + os.makedirs(texturing_dst_dir, exist_ok=True) + + # 修改obj文件z坐标的值 + min_25d_z = self.get_min_z_from_obj(os.path.join( + project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj')) + self.modify_z_in_obj(texturing_dir, min_25d_z) + + # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样 + utm_offset = self.read_utm_offset(log_file) + modified_obj = self.modify_obj_coordinates( + texturing_dir, texturing_dst_dir, utm_offset) + self.downsample_texture(texturing_dir, texturing_dst_dir) + + # 将obj文件进行切片 + self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件") + os.makedirs(split_obj_dir) + cmd = ( + f"D:\\software\\Obj2Tiles\\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 2 " + f"{modified_obj} {split_obj_dir}" + ) + subprocess.run(cmd, check=True) + + # 执行格式转换,Linux下osgconv有问题,记得注释掉 + self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件") + # 先获取split_obj_dir下的所有obj文件 + obj_lod_dir = os.path.join(split_obj_dir, "LOD-0") + obj_files = [f for f in os.listdir( + obj_lod_dir) if f.endswith('.obj')] + for obj_file in obj_files: + obj_path = os.path.join(obj_lod_dir, obj_file) + osgb_file = os.path.splitext(obj_file)[0] + '.osgb' + osgb_path = os.path.join(split_obj_dir, osgb_file) + # 执行 osgconv 命令 + subprocess.run(['osgconv', obj_path, osgb_path], check=True) + + # 创建OSGB目录结构,复制文件 + osgb_base_dir = os.path.join(self.output_dir, "osgb") + data_dir = os.path.join(osgb_base_dir, "Data") + for obj_file in obj_files: + obj_file_name = os.path.splitext(obj_file)[0] + tile_dirs = os.path.join( + data_dir, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}") + os.makedirs(tile_dirs, exist_ok=True) + shutil.copy2(os.path.join( + split_obj_dir, obj_file_name+".osgb"), tile_dirs) + os.rename(os.path.join(tile_dirs, obj_file_name+".osgb"), + os.path.join(tile_dirs, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}.osgb")) + + def _create_merged_metadata(self): + """创建合并后的metadata.xml文件""" + # 转换为WGS84经纬度 + center_lon, center_lat = self.transformer.transform( + self.ref_east, self.ref_north) + metadata_content = f""" + + EPSG:4326 + {center_lon},{center_lat},0 + + Visible + + """ + + metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml") + with open(metadata_file, 'w', encoding='utf-8') as f: + f.write(metadata_content) + + def read_utm_offset(self, log_file: str) -> tuple: + """读取UTM偏移量""" + try: + east_offset = None + north_offset = None + + with open(log_file, 'r') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if 'utm_north_offset' in line and i + 1 < len(lines): + north_offset = float(lines[i + 1].strip()) + elif 'utm_east_offset' in line and i + 1 < len(lines): + east_offset = float(lines[i + 1].strip()) + + if east_offset is None or north_offset is None: + raise ValueError("未找到UTM偏移量") + + return east_offset, north_offset + except Exception as e: + self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}") + raise + + def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str: + """修改obj文件中的顶点坐标,使用相对坐标系""" + obj_file = os.path.join( + texturing_dir, "odm_textured_model_modified.obj") + obj_dst_file = os.path.join( + texturing_dst_dir, "odm_textured_model_geo_utm.obj") + if not os.path.exists(obj_file): + raise FileNotFoundError(f"找不到OBJ文件: {obj_file}") + shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"), + os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl")) + east_offset, north_offset = utm_offset + self.logger.info( + f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}") + + try: + with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): + # 处理顶点坐标行 + parts = line.strip().split() + # 使用相对于整体最小UTM坐标的偏移 + x = float(parts[1]) + (east_offset - self.ref_east) + y = float(parts[2]) + (north_offset - self.ref_north) + z = float(parts[3]) + f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n') + elif line.startswith('vn '): # 处理法线向量 + parts = line.split() + nx = float(parts[1]) + ny = float(parts[2]) + nz = float(parts[3]) + # 同步反转法线的 Y 轴 + new_line = f"vn {nx} {nz} {-ny}\n" + f_out.write(new_line) + else: + # 其他行直接写入 + f_out.write(line) + + return obj_dst_file + except Exception as e: + self.logger.error(f"修改obj坐标时发生错误: {str(e)}") + raise + + def downsample_texture(self, src_dir: str, dst_dir: str): + """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB + Args: + src_dir: 源纹理目录 + dst_dir: 目标纹理目录 + """ + for file in os.listdir(src_dir): + if file.lower().endswith(('.png')): + src_path = os.path.join(src_dir, file) + dst_path = os.path.join(dst_dir, file) + + # 检查文件大小(以字节为单位) + file_size = os.path.getsize(src_path) + if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制 + shutil.copy2(src_path, dst_path) + else: + # 文件大于100MB,进行下采样 + img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED) + if_first_ds = True + while file_size > 100 * 1024 * 1024: # 大于100MB + self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样") + + if if_first_ds: + # 计算新的尺寸(长宽各变为1/4) + new_size = (img.shape[1] // 4, + img.shape[0] // 4) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + if_first_ds = False + else: + # 计算新的尺寸(长宽各变为1/2) + new_size = (img.shape[1] // 2, + img.shape[0] // 2) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + + # 更新文件路径为下采样后的路径 + cv2.imwrite(dst_path, resized_img, [ + cv2.IMWRITE_PNG_COMPRESSION, 9]) + + # 更新文件大小和图像 + file_size = os.path.getsize(dst_path) + img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED) + self.logger.info( + f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB") + + def get_min_z_from_obj(self, file_path): + min_z = float('inf') # 初始值设为无穷大 + with open(file_path, 'r') as obj_file: + for line in obj_file: + # 检查每一行是否是顶点定义(以 'v ' 开头) + if line.startswith('v '): + # 获取顶点坐标 + parts = line.split() + # 将z值转换为浮动数字 + z = float(parts[3]) + # 更新最小z值 + if z < min_z: + min_z = z + return min_z + + def modify_z_in_obj(self, texturing_dir, min_25d_z): + obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj') + output_file = os.path.join( + texturing_dir, 'odm_textured_model_modified.obj') + with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): # 顶点坐标行 + parts = line.strip().split() + x = float(parts[1]) + y = float(parts[2]) + z = float(parts[3]) + + if z < min_25d_z: + z = min_25d_z + + f_out.write(f"v {x} {y} {z}\n") + else: + f_out.write(line) diff --git a/post_pro/conv_obj2.py b/post_pro/conv_obj2.py new file mode 100644 index 0000000..d0188f2 --- /dev/null +++ b/post_pro/conv_obj2.py @@ -0,0 +1,253 @@ +import os +import subprocess +import json +import shutil +import logging +from pyproj import Transformer +import cv2 + + +class ConvertOBJ: + def __init__(self, output_dir: str): + self.output_dir = output_dir + # 用于存储所有grid的UTM范围 + self.ref_east = float('inf') + self.ref_north = float('inf') + # 初始化UTM到WGS84的转换器 + self.transformer = Transformer.from_crs( + "EPSG:32649", "EPSG:4326", always_xy=True) + self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ') + + def convert_grid_obj(self, grid_lt): + """转换每个网格的OBJ文件为OSGB格式""" + os.makedirs(os.path.join(self.output_dir, + "osgb", "Data"), exist_ok=True) + + # 以第一个grid的UTM坐标作为参照系 + first_grid_id = grid_lt[0] + first_grid_dir = os.path.join( + self.output_dir, + f"grid_{first_grid_id[0]}_{first_grid_id[1]}", + "project" + ) + log_file = os.path.join( + first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + self.ref_east, self.ref_north = self.read_utm_offset(log_file) + + for grid_id in grid_lt: + try: + self._convert_single_grid(grid_id) + except Exception as e: + self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}") + + self._create_merged_metadata() + + def _convert_single_grid(self, grid_id): + """转换单个网格的OBJ文件""" + # 构建相关路径 + grid_name = f"grid_{grid_id[0]}_{grid_id[1]}" + project_dir = os.path.join(self.output_dir, grid_name, "project") + texturing_dir = os.path.join(project_dir, "odm_texturing") + texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst") + opensfm_dir = os.path.join(project_dir, "opensfm") + log_file = os.path.join( + project_dir, "odm_orthophoto", "odm_orthophoto_log.txt") + os.makedirs(texturing_dst_dir, exist_ok=True) + + # 修改obj文件z坐标的值 + min_25d_z = self.get_min_z_from_obj(os.path.join( + project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj')) + self.modify_z_in_obj(texturing_dir, min_25d_z) + + # 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样 + utm_offset = self.read_utm_offset(log_file) + modified_obj = self.modify_obj_coordinates( + texturing_dir, texturing_dst_dir, utm_offset) + self.downsample_texture(texturing_dir, texturing_dst_dir) + + # 执行格式转换,Linux下osgconv有问题,记得注释掉 + self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件") + output_osgb = os.path.join(texturing_dst_dir, "Tile.osgb") + cmd = ( + f"osgconv {modified_obj} {output_osgb} " + f"--smooth --fix-transparency " + ) + self.logger.info(f"执行osgconv命令:{cmd}") + + try: + subprocess.run(cmd, shell=True, check=True, cwd=texturing_dir) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"OSGB转换失败: {str(e)}") + + # 创建OSGB目录结构,复制文件 + osgb_base_dir = os.path.join(self.output_dir, "osgb") + data_dir = os.path.join(osgb_base_dir, "Data") + tile_dir = os.path.join(data_dir, f"Tile_{grid_id[0]}_{grid_id[1]}") + os.makedirs(tile_dir, exist_ok=True) + target_osgb = os.path.join( + tile_dir, f"Tile_{grid_id[0]}_{grid_id[1]}.osgb") + shutil.copy2(output_osgb, target_osgb) + + def _create_merged_metadata(self): + """创建合并后的metadata.xml文件""" + # 转换为WGS84经纬度 + center_lon, center_lat = self.transformer.transform( + self.ref_east, self.ref_north) + metadata_content = f""" + + EPSG:4326 + {center_lon},{center_lat},0 + + Visible + + """ + + metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml") + with open(metadata_file, 'w', encoding='utf-8') as f: + f.write(metadata_content) + + def read_utm_offset(self, log_file: str) -> tuple: + """读取UTM偏移量""" + try: + east_offset = None + north_offset = None + + with open(log_file, 'r') as f: + lines = f.readlines() + for i, line in enumerate(lines): + if 'utm_north_offset' in line and i + 1 < len(lines): + north_offset = float(lines[i + 1].strip()) + elif 'utm_east_offset' in line and i + 1 < len(lines): + east_offset = float(lines[i + 1].strip()) + + if east_offset is None or north_offset is None: + raise ValueError("未找到UTM偏移量") + + return east_offset, north_offset + except Exception as e: + self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}") + raise + + def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str: + """修改obj文件中的顶点坐标,使用相对坐标系""" + obj_file = os.path.join( + texturing_dir, "odm_textured_model_modified.obj") + obj_dst_file = os.path.join( + texturing_dst_dir, "odm_textured_model_geo_utm.obj") + if not os.path.exists(obj_file): + raise FileNotFoundError(f"找不到OBJ文件: {obj_file}") + shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"), + os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl")) + east_offset, north_offset = utm_offset + self.logger.info( + f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}") + + try: + with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): + # 处理顶点坐标行 + parts = line.strip().split() + # 使用相对于整体最小UTM坐标的偏移 + x = float(parts[1]) + (east_offset - self.ref_east) + y = float(parts[2]) + (north_offset - self.ref_north) + z = float(parts[3]) + f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n') + elif line.startswith('vn '): # 处理法线向量 + parts = line.split() + nx = float(parts[1]) + ny = float(parts[2]) + nz = float(parts[3]) + # 同步反转法线的 Y 轴 + new_line = f"vn {nx} {nz} {-ny}\n" + f_out.write(new_line) + else: + # 其他行直接写入 + f_out.write(line) + + return obj_dst_file + except Exception as e: + self.logger.error(f"修改obj坐标时发生错误: {str(e)}") + raise + + def downsample_texture(self, src_dir: str, dst_dir: str): + """复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB + Args: + src_dir: 源纹理目录 + dst_dir: 目标纹理目录 + """ + for file in os.listdir(src_dir): + if file.lower().endswith(('.png')): + src_path = os.path.join(src_dir, file) + dst_path = os.path.join(dst_dir, file) + + # 检查文件大小(以字节为单位) + file_size = os.path.getsize(src_path) + if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制 + shutil.copy2(src_path, dst_path) + else: + # 文件大于100MB,进行下采样 + img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED) + if_first_ds = True + while file_size > 100 * 1024 * 1024: # 大于100MB + self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样") + + if if_first_ds: + # 计算新的尺寸(长宽各变为1/4) + new_size = (img.shape[1] // 4, + img.shape[0] // 4) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + if_first_ds = False + else: + # 计算新的尺寸(长宽各变为1/2) + new_size = (img.shape[1] // 2, + img.shape[0] // 2) # 逐步减小尺寸 + # 使用双三次插值进行下采样 + resized_img = cv2.resize( + img, new_size, interpolation=cv2.INTER_CUBIC) + + # 更新文件路径为下采样后的路径 + cv2.imwrite(dst_path, resized_img, [ + cv2.IMWRITE_PNG_COMPRESSION, 9]) + + # 更新文件大小和图像 + file_size = os.path.getsize(dst_path) + img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED) + self.logger.info( + f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB") + + def get_min_z_from_obj(self, file_path): + min_z = float('inf') # 初始值设为无穷大 + with open(file_path, 'r') as obj_file: + for line in obj_file: + # 检查每一行是否是顶点定义(以 'v ' 开头) + if line.startswith('v '): + # 获取顶点坐标 + parts = line.split() + # 将z值转换为浮动数字 + z = float(parts[3]) + # 更新最小z值 + if z < min_z: + min_z = z + return min_z + + def modify_z_in_obj(self, texturing_dir, min_25d_z): + obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj') + output_file = os.path.join( + texturing_dir, 'odm_textured_model_modified.obj') + with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out: + for line in f_in: + if line.startswith('v '): # 顶点坐标行 + parts = line.strip().split() + x = float(parts[1]) + y = float(parts[2]) + z = float(parts[3]) + + if z < min_25d_z: + z = min_25d_z + + f_out.write(f"v {x} {y} {z}\n") + else: + f_out.write(line) diff --git a/post_pro/merge_tif.py b/post_pro/merge_tif.py new file mode 100644 index 0000000..c02d32b --- /dev/null +++ b/post_pro/merge_tif.py @@ -0,0 +1,285 @@ +import logging +import os +from typing import Dict +import rasterio +from rasterio.mask import mask +from rasterio.transform import Affine, rowcol +import fiona +from edt import edt +import numpy as np +import math + + +class MergeTif: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.MergeTif') + + def merge_orthophoto(self, grid_lt): + """合并网格的正射影像""" + try: + all_orthos_and_ortho_cuts = [] + for grid_id in grid_lt: + grid_ortho_dir = os.path.join( + self.output_dir, + f"grid_{grid_id[0]}_{grid_id[1]}", + "project", + "odm_orthophoto", + ) + tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif") + tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg") + output_cut_tif = os.path.join( + grid_ortho_dir, "odm_orthophoto_cut.tif") + output_feathered_tif = os.path.join( + grid_ortho_dir, "odm_orthophoto_feathered.tif") + + self.compute_mask_raster( + tif_path, tif_mask, output_cut_tif, blend_distance=20) + self.feather_raster( + tif_path, output_feathered_tif, blend_distance=20) + all_orthos_and_ortho_cuts.append( + [output_feathered_tif, output_cut_tif]) + + orthophoto_vars = { + 'TILED': 'NO', + 'COMPRESS': False, + 'PREDICTOR': '1', + 'BIGTIFF': 'IF_SAFER', + 'BLOCKXSIZE': 512, + 'BLOCKYSIZE': 512, + 'NUM_THREADS': 15 + } + self.merge(all_orthos_and_ortho_cuts, os.path.join( + self.output_dir, "orthophoto.tif"), orthophoto_vars) + self.logger.info("所有产品合并完成") + + except Exception as e: + self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True) + raise + + def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False): + if not os.path.exists(input_raster): + print("Cannot mask raster, %s does not exist" % input_raster) + return + + if not os.path.exists(vector_mask): + print("Cannot mask raster, %s does not exist" % vector_mask) + return + + print("Computing mask raster: %s" % output_raster) + + with rasterio.open(input_raster, 'r') as rast: + with fiona.open(vector_mask) as src: + burn_features = src + + if only_max_coords_feature: + max_coords_count = 0 + max_coords_feature = None + for feature in src: + if feature is not None: + # No complex shapes + if len(feature['geometry']['coordinates'][0]) > max_coords_count: + max_coords_count = len( + feature['geometry']['coordinates'][0]) + max_coords_feature = feature + if max_coords_feature is not None: + burn_features = [max_coords_feature] + + shapes = [feature["geometry"] for feature in burn_features] + out_image, out_transform = mask(rast, shapes, nodata=0) + + if blend_distance > 0: + if out_image.shape[0] >= 4: + # alpha_band = rast.dataset_mask() + alpha_band = out_image[-1] + dist_t = edt(alpha_band, black_border=True, parallel=0) + dist_t[dist_t <= blend_distance] /= blend_distance + dist_t[dist_t > blend_distance] = 1 + np.multiply(alpha_band, dist_t, + out=alpha_band, casting="unsafe") + else: + print( + "%s does not have an alpha band, cannot blend cutline!" % input_raster) + + with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst: + dst.colorinterp = rast.colorinterp + dst.write(out_image) + + return output_raster + + def feather_raster(self, input_raster, output_raster, blend_distance=20): + if not os.path.exists(input_raster): + print("Cannot feather raster, %s does not exist" % input_raster) + return + + print("Computing feather raster: %s" % output_raster) + + with rasterio.open(input_raster, 'r') as rast: + out_image = rast.read() + if blend_distance > 0: + if out_image.shape[0] >= 4: + alpha_band = out_image[-1] + dist_t = edt(alpha_band, black_border=True, parallel=0) + dist_t[dist_t <= blend_distance] /= blend_distance + dist_t[dist_t > blend_distance] = 1 + np.multiply(alpha_band, dist_t, + out=alpha_band, casting="unsafe") + else: + print( + "%s does not have an alpha band, cannot feather raster!" % input_raster) + + with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst: + dst.colorinterp = rast.colorinterp + dst.write(out_image) + + return output_raster + + def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}): + """ + Based on https://github.com/mapbox/rio-merge-rgba/ + Merge orthophotos around cutlines using a blend buffer. + """ + inputs = [] + bounds = None + precision = 7 + + for o, c in input_ortho_and_ortho_cuts: + inputs.append((o, c)) + + with rasterio.open(inputs[0][0]) as first: + res = first.res + dtype = first.dtypes[0] + profile = first.profile + num_bands = first.meta['count'] - 1 # minus alpha + colorinterp = first.colorinterp + + print("%s valid orthophoto rasters to merge" % len(inputs)) + sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs] + + # scan input files. + # while we're at it, validate assumptions about inputs + xs = [] + ys = [] + for src, _ in sources: + left, bottom, right, top = src.bounds + xs.extend([left, right]) + ys.extend([bottom, top]) + if src.profile["count"] < 2: + raise ValueError("Inputs must be at least 2-band rasters") + dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys) + print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n)) + + output_transform = Affine.translation(dst_w, dst_n) + output_transform *= Affine.scale(res[0], -res[1]) + + # Compute output array shape. We guarantee it will cover the output + # bounds completely. + output_width = int(math.ceil((dst_e - dst_w) / res[0])) + output_height = int(math.ceil((dst_n - dst_s) / res[1])) + + # Adjust bounds to fit. + dst_e, dst_s = output_transform * (output_width, output_height) + print("Output width: %d, height: %d" % + (output_width, output_height)) + print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n)) + + profile["transform"] = output_transform + profile["height"] = output_height + profile["width"] = output_width + profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES' + profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512) + profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512) + profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW') + profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2') + profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER') + profile.update() + + # create destination file + with rasterio.open(output_orthophoto, "w", **profile) as dstrast: + dstrast.colorinterp = colorinterp + for idx, dst_window in dstrast.block_windows(): + left, bottom, right, top = dstrast.window_bounds(dst_window) + + blocksize = dst_window.width + dst_rows, dst_cols = (dst_window.height, dst_window.width) + + # initialize array destined for the block + dst_count = first.count + dst_shape = (dst_count, dst_rows, dst_cols) + + dstarr = np.zeros(dst_shape, dtype=dtype) + + # First pass, write all rasters naively without blending + for src, _ in sources: + src_window = tuple(zip(rowcol( + src.transform, left, top, op=round, precision=precision + ), rowcol( + src.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = src.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + # pixels without data yet are available to write + write_region = np.logical_and( + (dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata + ) + np.copyto(dstarr, temp, where=write_region) + + # check if dest has any nodata pixels available + if np.count_nonzero(dstarr[-1]) == blocksize: + break + + # Second pass, write all feathered rasters + # blending the edges + for src, _ in sources: + src_window = tuple(zip(rowcol( + src.transform, left, top, op=round, precision=precision + ), rowcol( + src.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = src.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + where = temp[-1] != 0 + for b in range(0, num_bands): + blended = temp[-1] / 255.0 * temp[b] + \ + (1 - temp[-1] / 255.0) * dstarr[b] + np.copyto(dstarr[b], blended, + casting='unsafe', where=where) + dstarr[-1][where] = 255.0 + + # check if dest has any nodata pixels available + if np.count_nonzero(dstarr[-1]) == blocksize: + break + + # Third pass, write cut rasters + # blending the cutlines + for _, cut in sources: + src_window = tuple(zip(rowcol( + cut.transform, left, top, op=round, precision=precision + ), rowcol( + cut.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = cut.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + # For each band, average alpha values between + # destination raster and cut raster + for b in range(0, num_bands): + blended = temp[-1] / 255.0 * temp[b] + \ + (1 - temp[-1] / 255.0) * dstarr[b] + np.copyto(dstarr[b], blended, + casting='unsafe', where=temp[-1] != 0) + + dstrast.write(dstarr, window=dst_window) + + return output_orthophoto diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1b79686 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,15 @@ +numpy==1.26.4 +pandas==2.2.3 +scikit-learn==1.6.1 +matplotlib==3.10.0 +piexif==1.1.3 +geopy==2.4.1 +psutil==6.1.1 +docker==7.1.0 +tqdm==4.66.5 +pyproj==3.7.0 +rasterio==1.4.3 +edt==3.0.0 +opencv-python==4.11.0.86 +fiona==1.9.5 +pyinstaller==6.14.1 \ No newline at end of file diff --git a/trans_orthophoto.py b/trans_orthophoto.py new file mode 100644 index 0000000..370c0d1 --- /dev/null +++ b/trans_orthophoto.py @@ -0,0 +1,33 @@ +import logging +import argparse +from osgeo import gdal + + +class TransOrthophoto: + def __init__(self): + self.logger = logging.getLogger('UAV_Preprocess.TransOrthophoto') + logging.basicConfig(level=logging.INFO, + format='%(asctime)s %(levelname)s %(message)s') + + def trans_to_epsg4326(self, input_img, output_img): + # 目标坐标系 + dst_srs = 'EPSG:4326' + # 使用 gdal.Warp 进行重投影 + gdal.Warp( + destNameOrDestDS=output_img, + srcDSOrSrcDSTab=input_img, + dstSRS=dst_srs, + format='GTiff', + resampleAlg='near' # 最近邻插值 + ) + self.logger.info(f"文件已成功重投影: {output_img}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="正射影像投影转换为EPSG:4326 (GDAL实现)") + parser.add_argument('--input_img', required=True, help='输入影像路径') + parser.add_argument('--output_img', required=True, help='输出影像路径') + args = parser.parse_args() + + trans = TransOrthophoto() + trans.trans_to_epsg4326(args.input_img, args.output_img) diff --git a/utils/directory_manager.py b/utils/directory_manager.py new file mode 100644 index 0000000..1d5802d --- /dev/null +++ b/utils/directory_manager.py @@ -0,0 +1,81 @@ +import os +import shutil +import psutil + + +class DirectoryManager: + def __init__(self, config): + """ + 初始化目录管理器 + Args: + config: 配置对象,包含输入和输出目录等信息 + """ + self.config = config + + def clean_output_dir(self): + """清理输出目录""" + try: + if os.path.exists(self.config.output_dir): + shutil.rmtree(self.config.output_dir) + print(f"已清理输出目录: {self.config.output_dir}") + else: + pass + except Exception as e: + print(f"清理输出目录时发生错误: {str(e)}") + raise + + def setup_output_dirs(self): + """创建必要的输出目录结构""" + try: + # 创建主输出目录 + os.makedirs(self.config.output_dir) + + # 创建过滤图像保存目录 + os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) + + # 创建日志目录 + os.makedirs(os.path.join(self.config.output_dir, 'logs')) + + print(f"已创建输出目录结构: {self.config.output_dir}") + except Exception as e: + print(f"创建输出目录时发生错误: {str(e)}") + raise + + def _get_directory_size(self, path): + """获取目录的总大小(字节)""" + total_size = 0 + for dirpath, dirnames, filenames in os.walk(path): + for filename in filenames: + file_path = os.path.join(dirpath, filename) + try: + total_size += os.path.getsize(file_path) + except (OSError, FileNotFoundError): + continue + return total_size + + def check_disk_space(self): + """检查磁盘空间是否足够""" + # 获取输入目录大小 + input_size = self._get_directory_size(self.config.image_dir) + + # 获取输出目录所在磁盘的剩余空间 + output_drive = os.path.splitdrive( + os.path.abspath(self.config.output_dir))[0] + if not output_drive: # 处理Linux/Unix路径 + output_drive = '/home' + + disk_usage = psutil.disk_usage(output_drive) + free_space = disk_usage.free + + # 计算所需空间(输入大小的10倍) + required_space = input_size * 8 + + if free_space < required_space: + error_msg = ( + f"磁盘空间不足!\n" + f"输入目录大小: {input_size / (1024**3):.2f} GB\n" + f"所需空间: {required_space / (1024**3):.2f} GB\n" + f"可用空间: {free_space / (1024**3):.2f} GB\n" + f"在驱动器 {output_drive}" + ) + raise RuntimeError(error_msg) diff --git a/utils/gps_extractor.py b/utils/gps_extractor.py new file mode 100644 index 0000000..3e4090a --- /dev/null +++ b/utils/gps_extractor.py @@ -0,0 +1,65 @@ +import os +from PIL import Image +import piexif +import logging +import pandas as pd + + +class GPSExtractor: + """从图像文件提取GPS坐标""" + + def __init__(self, image_dir): + self.image_dir = image_dir + self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor') + + @staticmethod + def _dms_to_decimal(dms): + """将DMS格式转换为十进制度""" + return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600 + + def get_gps(self, image_path): + """提取单张图片的GPS坐标""" + try: + image = Image.open(image_path) + exif_data = piexif.load(image.info['exif']) + + # 提取GPS信息 + gps_info = exif_data.get("GPS", {}) + lat = lon = None + if gps_info: + lat = self._dms_to_decimal(gps_info.get(2, [])) + lon = self._dms_to_decimal(gps_info.get(4, [])) + self.logger.debug( + f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}") + + if not gps_info: + self.logger.warning(f"图片无GPS信息: {image_path}") + + return lat, lon + + except Exception as e: + self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}") + return None, None, None + + def extract_all_gps(self): + """提取所有图片的GPS坐标""" + self.logger.info(f"开始从目录提取GPS坐标: {self.image_dir}") + gps_data = [] + total_images = 0 + successful_extractions = 0 + + for image_file in os.listdir(self.image_dir): + total_images += 1 + image_path = os.path.join(self.image_dir, image_file) + lat, lon = self.get_gps(image_path) + if lat and lon: # 仍然以GPS信息作为主要判断依据 + successful_extractions += 1 + gps_data.append({ + 'file': image_file, + 'lat': lat, + 'lon': lon, + }) + + self.logger.info( + f"GPS坐标提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}") + return pd.DataFrame(gps_data) diff --git a/utils/grid_divider.py b/utils/grid_divider.py new file mode 100644 index 0000000..60e42e3 --- /dev/null +++ b/utils/grid_divider.py @@ -0,0 +1,266 @@ +import logging +from geopy.distance import geodesic +import matplotlib.pyplot as plt +import os + + +class GridDivider: + """划分网格,并将图片分配到对应网格""" + + def __init__(self, overlap=0.1, grid_size=500, output_dir=None): + self.overlap = overlap + self.grid_size = grid_size + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.GridDivider') + self.logger.info(f"初始化网格划分器,重叠率: {overlap}") + self.num_grids_width = 0 # 添加网格数量属性 + self.num_grids_height = 0 + + def adjust_grid_size_and_overlap(self, points_df): + """动态调整网格重叠率""" + grids = self.adjust_grid_size(points_df) + self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}") + while True: + # 使用调整好的网格大小划分网格 + grids = self.divide_grids(points_df) + grid_points, multiple_grid_points = self.assign_to_grids( + points_df, grids) + + if len(grids) == 1: + self.logger.info(f"网格数量为1,跳过重叠率调整") + break + elif multiple_grid_points < 0.1*len(points_df): + self.overlap += 0.02 + self.logger.info(f"重叠率增加到: {self.overlap}") + else: + self.logger.info( + f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格") + break + return grids, grid_points + + def adjust_grid_size(self, points_df): + """动态调整网格大小 + + Args: + points_df: 包含GPS点的DataFrame + + Returns: + tuple: grids + """ + self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}米") + + while True: + # 使用当前grid_size划分网格 + grids = self.divide_grids(points_df) + grid_points, multiple_grid_points = self.assign_to_grids( + points_df, grids) + + # 检查每个网格中的点数 + max_points = 0 + for grid_id, points in grid_points.items(): + max_points = max(max_points, len(points)) + + self.logger.info( + f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}") + + # 如果最大点数超过2000,减小网格大小 + if max_points > 2000: + self.grid_size -= 100 + self.logger.info(f"点数超过2000,减小网格大小至: {self.grid_size}米") + if self.grid_size < 500: # 设置一个最小网格大小限制 + self.logger.warning("网格大小已达到最小值500米,停止调整") + break + else: + self.logger.info(f"找到合适的网格大小: {self.grid_size}米") + break + return grids + + def divide_grids(self, points_df): + """计算边界框并划分网格 + Returns: + tuple: grids 网格边界列表 + """ + self.logger.info("开始划分网格") + + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + + # 计算区域的实际距离(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters + + self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}米") + + # 精细调整网格的长宽,避免出现2*grid_size-1的情况的影响 + grid_size_lt = [self.grid_size - 200, self.grid_size - 100, + self.grid_size, self.grid_size + 100, self.grid_size + 200] + + width_modulus_lt = [width % grid_size for grid_size in grid_size_lt] + grid_width = grid_size_lt[width_modulus_lt.index( + min(width_modulus_lt))] + height_modulus_lt = [height % grid_size for grid_size in grid_size_lt] + grid_height = grid_size_lt[height_modulus_lt.index( + min(height_modulus_lt))] + self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}米") + + # 计算需要划分的网格数量 + self.num_grids_width = max(int(width / grid_width), 1) + self.num_grids_height = max(int(height / grid_height), 1) + + # 计算每个网格对应的经纬度步长 + lat_step = (max_lat - min_lat) / self.num_grids_height + lon_step = (max_lon - min_lon) / self.num_grids_width + + grids = [] + + # 先创建所有网格 + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step + grid_max_lat = min_lat + \ + (i + 1) * lat_step + self.overlap * lat_step + grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step + grid_max_lon = min_lon + \ + (j + 1) * lon_step + self.overlap * lon_step + + grid_bounds = (grid_min_lat, grid_max_lat, + grid_min_lon, grid_max_lon) + grids.append(grid_bounds) + + self.logger.debug( + f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], " + f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]" + ) + + self.logger.info( + f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})") + + return grids + + def assign_to_grids(self, points_df, grids): + """将点分配到对应网格""" + self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中") + + grid_points = {} # 使用字典存储每个网格的点 + points_assigned = 0 + multiple_grid_points = 0 + + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_points[(i, j)] = [] # 使用(i,j)元组 + + for _, point in points_df.iterrows(): + point_assigned = False + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_idx = i * self.num_grids_width + j + min_lat, max_lat, min_lon, max_lon = grids[grid_idx] + + if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon: + grid_points[(i, j)].append(point.to_dict()) + if point_assigned: + multiple_grid_points += 1 + else: + points_assigned += 1 + point_assigned = True + + # 记录每个网格的点数 + for grid_id, points in grid_points.items(): + self.logger.info(f"网格 {grid_id} 包含 {len(points)} 个点") + + self.logger.info( + f"点分配完成: 总点数 {len(points_df)}, " + f"成功分配 {points_assigned} 个点, " + f"{multiple_grid_points} 个点被分配到多个网格" + ) + + return grid_points, multiple_grid_points + + def visualize_grids(self, points_df, grids): + """可视化网格划分和GPS点的分布""" + self.logger.info("开始可视化网格划分") + + plt.figure(figsize=(12, 8)) + + # 绘制GPS点 + plt.scatter(points_df['lon'], points_df['lat'], + c='blue', s=10, alpha=0.6, label='GPS points') + + # 绘制网格 + for i in range(self.num_grids_height): + for j in range(self.num_grids_width): + grid_idx = i * self.num_grids_width + j + min_lat, max_lat, min_lon, max_lon = grids[grid_idx] + + # 计算网格的实际长度和宽度(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), + (max_lat, min_lon)).meters + + plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon], + [min_lat, min_lat, max_lat, max_lat, min_lat], + 'r-', alpha=0.5) + # 在网格中心添加网格编号和尺寸信息 + center_lon = (min_lon + max_lon) / 2 + center_lat = (min_lat + max_lat) / 2 + plt.text(center_lon, center_lat, + f"({i},{j})\n{width:.0f}m×{height:.0f}m", # 显示(i,j)和尺寸 + horizontalalignment='center', + verticalalignment='center', + fontsize=8) + + plt.title('Grid Division and GPS Point Distribution') + plt.xlabel('Longitude') + plt.ylabel('Latitude') + plt.legend() + plt.grid(True) + + # 如果提供了输出目录,保存图像 + if self.output_dir: + save_path = os.path.join( + self.output_dir, 'filter_imgs', 'grid_division.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + self.logger.info(f"网格划分可视化图已保存至: {save_path}") + + plt.close() + + def get_grid_center(self, grid_bounds) -> tuple: + """计算网格中心点的经纬度 + Args: + grid_bounds: (min_lat, max_lat, min_lon, max_lon) + Returns: + (center_lat, center_lon) + """ + min_lat, max_lat, min_lon, max_lon = grid_bounds + return ((min_lat + max_lat) / 2, (min_lon + max_lon) / 2) + + def calculate_grid_translation(self, reference_grid: tuple, target_grid: tuple) -> tuple: + """计算目标网格相对于参考网格的平移距离(米) + Args: + reference_grid: 参考网格的边界 (min_lat, max_lat, min_lon, max_lon) + target_grid: 目标网格的边界 (min_lat, max_lat, min_lon, max_lon) + Returns: + (x_translation, y_translation): 在米制单位下的平移量 + """ + ref_center = self.get_grid_center(reference_grid) + target_center = self.get_grid_center(target_grid) + + # 计算经度方向的距离(x轴) + x_distance = geodesic( + (ref_center[0], ref_center[1]), + (ref_center[0], target_center[1]) + ).meters + # 如果目标在参考点西边,距离为负 + if target_center[1] < ref_center[1]: + x_distance = -x_distance + + # 计算纬度方向的距离(y轴) + y_distance = geodesic( + (ref_center[0], ref_center[1]), + (target_center[0], ref_center[1]) + ).meters + # 如果目标在参考点南边,距离为负 + if target_center[0] < ref_center[0]: + y_distance = -y_distance + + return (x_distance, y_distance) diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..80c1f10 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,36 @@ +import logging +import os +from datetime import datetime + + +def setup_logger(output_dir): + # 创建logs目录 + log_dir = os.path.join(output_dir, 'logs') + + # 创建日志文件名(包含时间戳) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log') + + # 配置日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 配置文件处理器 + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setFormatter(formatter) + + # 配置控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # 获取根日志记录器 + logger = logging.getLogger('UAV_Preprocess') + logger.setLevel(logging.INFO) + + # 添加处理器 + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger diff --git a/utils/odm_monitor.py b/utils/odm_monitor.py new file mode 100644 index 0000000..02ea4fd --- /dev/null +++ b/utils/odm_monitor.py @@ -0,0 +1,144 @@ +import os +import time +import logging +from typing import Dict, Tuple +import pandas as pd +import subprocess + + +class ODMProcessMonitor: + """ODM处理监控器""" + + def __init__(self, output_dir: str, mode: str = "三维模式", config=None): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor') + self.mode = mode + self.config = config + + def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple) -> Tuple[bool, str]: + """运行ODM命令""" + self.logger.info(f"开始处理网格 ({grid_id[0]},{grid_id[1]})") + success = False + error_msg = "" + max_retries = 3 + current_try = 0 + cpu_cores = os.cpu_count() + + while current_try < max_retries: + current_try += 1 + self.logger.info( + f"第 {current_try} 次尝试处理网格 ({grid_id[0]},{grid_id[1]})") + + # 构建 Docker 容器运行参数 + grid_dir_fixed = grid_dir[0].lower() + grid_dir[1:].replace('\\', '/') + command = ( + f"--max-concurrency {cpu_cores} " + f"--force-gps " + f"--use-exif " + f"--use-hybrid-bundle-adjustment " + f"--optimize-disk-space " + f"--orthophoto-cutline " + ) + + command += f"--feature-type {self.config.feature_type} " + command += f"--orthophoto-resolution {self.config.orthophoto_resolution} " + + if self.mode == "快拼模式": + command += ( + f"--fast-orthophoto " + f"--skip-3dmodel " + ) + else: # 三维模式参数 + command += ( + f"--dsm " + f"--dtm " + ) + if current_try == 1: + command += ( + f"--feature-quality low " + ) + + command += "--rerun-all" + + docker_cmd = f'python3 /code/run.py --project-path {grid_dir} project {command}' + self.logger.info(f"Docker 命令: {docker_cmd}") + + try: + # 执行命令 + result = subprocess.run( + docker_cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + encoding="utf-8" + ) + if result.returncode != 0: + self.logger.error( + f"容器运行失败,退出状态码: {result.returncode}") + error_lines = result.stderr.splitlines() + self.logger.error("容器运行失败的最后 50 行错误日志:") + for line in error_lines[-50:]: + self.logger.error(line) + error_msg = "\n".join(error_lines[-50:]) + time.sleep(5) + else: + logs = result.stdout.splitlines() + self.logger.info("容器运行完成,以下是最后 50 行日志:") + for line in logs[-50:]: + self.logger.info(line) + success = True + error_msg = "" + break + except Exception as e: + error_msg = str(e) + self.logger.error(f"执行docker命令时发生异常: {error_msg}") + time.sleep(5) + + return success, error_msg + + def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame]) -> list: + """处理所有网格 + + Returns: + Dict[tuple, pd.DataFrame]: 成功处理的网格点数据字典 + """ + self.logger.info("开始执行网格处理") + successful_grid_lt = [] + failed_grids = [] + + for grid_id, points in grid_points.items(): + grid_dir = os.path.join( + self.output_dir, f'grid_{grid_id[0]}_{grid_id[1]}' + ) + + try: + success, error_msg = self.run_odm_with_monitor( + grid_dir=grid_dir, + grid_id=grid_id, + ) + + if success: + successful_grid_lt.append(grid_id) + else: + self.logger.error( + f"网格 ({grid_id[0]},{grid_id[1]})") + failed_grids.append(grid_id) + + except Exception as e: + error_msg = str(e) + self.logger.error( + f"处理网格 ({grid_id[0]},{grid_id[1]}) 时发生异常: {error_msg}") + failed_grids.append((grid_id, error_msg)) + + # 汇总处理结果 + total_grids = len(grid_points) + failed_count = len(failed_grids) + success_count = len(successful_grid_lt) + + self.logger.info( + f"网格处理完成。总计: {total_grids}, 成功: {success_count}, 失败: {failed_count}") + + if len(successful_grid_lt) == 0: + raise Exception("所有网格处理都失败,无法继续处理") + + return successful_grid_lt diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000..531a0ac --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,152 @@ +import os +import matplotlib.pyplot as plt +import pandas as pd +import logging +from typing import Optional +from pyproj import Transformer + + +class FilterVisualizer: + """过滤结果可视化器""" + + def __init__(self, output_dir: str): + """ + 初始化可视化器 + + Args: + output_dir: 输出目录路径 + """ + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.Visualizer') + # 创建坐标转换器 + self.transformer = Transformer.from_crs( + "EPSG:4326", # WGS84经纬度坐标系 + "EPSG:32649", # UTM49N + always_xy=True + ) + + def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple: + """ + 将经纬度坐标转换为UTM坐标 + + Args: + lon: 经度序列 + lat: 纬度序列 + + Returns: + tuple: (x坐标, y坐标) + """ + return self.transformer.transform(lon, lat) + + def visualize_filter_step(self, + current_points: pd.DataFrame, + previous_points: pd.DataFrame, + step_name: str, + save_name: Optional[str] = None): + """ + 可视化单个过滤步骤的结果 + + Args: + current_points: 当前步骤后的点 + previous_points: 上一步骤的点 + step_name: 步骤名称 + save_name: 保存文件名,默认为step_name + """ + self.logger.info(f"开始生成{step_name}的可视化结果") + + # 找出被过滤掉的点 + filtered_files = set( + previous_points['file']) - set(current_points['file']) + filtered_points = previous_points[previous_points['file'].isin( + filtered_files)] + + # 转换坐标到UTM + current_x, current_y = self._convert_to_utm( + current_points['lon'], current_points['lat']) + filtered_x, filtered_y = self._convert_to_utm( + filtered_points['lon'], filtered_points['lat']) + + # 创建图形 + plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体 + plt.rcParams['axes.unicode_minus'] = False + plt.figure(figsize=(20, 20)) + + # 绘制保留的点 + plt.scatter(current_x, current_y, + color='blue', label='保留的点', + alpha=0.6, s=5) + + # 绘制被过滤的点 + if not filtered_points.empty: + plt.scatter(filtered_x, filtered_y, + color='red', marker='x', label='过滤的点') + + # 设置图形属性 + plt.title(f"{step_name}后的GPS点\n" + f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})", + fontsize=14) + plt.xlabel("东向坐标 (米)", fontsize=12) + plt.ylabel("北向坐标 (米)", fontsize=12) + plt.grid(True) + plt.axis('equal') + + # 添加统计信息 + stats_text = ( + f"原始点数: {len(previous_points)}\n" + f"过滤点数: {len(filtered_points)}\n" + f"保留点数: {len(current_points)}\n" + f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + # 添加图例 + plt.legend(loc='upper right', fontsize=10) + + # 调整布局 + plt.tight_layout() + + # 保存图形 + save_name = save_name or step_name.lower().replace(' ', '_') + save_path = os.path.join( + self.output_dir, 'filter_imgs', f'filter_{save_name}.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info( + f"{step_name}过滤可视化结果已保存至 {save_path}\n" + f"过滤掉 {len(filtered_points)} 个点," + f"保留 {len(current_points)} 个点," + f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + + +if __name__ == '__main__': + # 测试代码 + import numpy as np + from datetime import datetime + + # 创建测试数据 + np.random.seed(42) + n_points = 1000 + + # 生成随机点 + test_data = pd.DataFrame({ + 'lon': np.random.uniform(120, 121, n_points), + 'lat': np.random.uniform(30, 31, n_points), + 'file': [f'img_{i}.jpg' for i in range(n_points)], + 'date': [datetime.now() for _ in range(n_points)] + }) + + # 随机选择点作为过滤后的结果 + filtered_data = test_data.sample(n=800) + + # 测试可视化 + visualizer = FilterVisualizer('test_output') + os.makedirs('test_output', exist_ok=True) + + visualizer.visualize_filter_step( + filtered_data, + test_data, + "Test Filter" + )