diff --git a/odm_preprocess.py b/odm_preprocess.py index 1b1a076..3d42411 100644 --- a/odm_preprocess.py +++ b/odm_preprocess.py @@ -46,6 +46,7 @@ class PreprocessConfig: grid_size: float = 500 # 几个pipline过程是否开启 mode: str = "快拼模式" + produce_dem: bool = False class ImagePreprocessor: @@ -237,11 +238,11 @@ class ImagePreprocessor: shutil.copy(src, dst) self.logger.info(f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像") - def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame]): + def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool): """合并所有网格的影像产品""" self.logger.info("开始合并所有影像产品") merger = MergeTif(self.config.output_dir) - merger.merge_all_tifs(grid_points) + merger.merge_all_tifs(grid_points, produce_dem) def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]): """合并所有网格的PLY点云""" @@ -266,8 +267,8 @@ class ImagePreprocessor: self.copy_images(grid_points) self.logger.info("预处理任务完成") - self.odm_monitor.process_all_grids(grid_points) - self.merge_tif(grid_points) + self.odm_monitor.process_all_grids(grid_points, self.config.produce_dem) + self.merge_tif(grid_points, self.config.produce_dem) self.merge_ply(grid_points) self.merge_obj(grid_points, translations) except Exception as e: @@ -278,8 +279,8 @@ class ImagePreprocessor: if __name__ == "__main__": # 创建配置 config = PreprocessConfig( - image_dir=r"E:\datasets\UAV\1815\project\images", - output_dir=r"G:\ODM_output\1815", + image_dir=r"E:\datasets\UAV\283\project\images", + output_dir=r"G:\ODM_output\283", cluster_eps=0.01, cluster_min_samples=5, @@ -295,11 +296,12 @@ if __name__ == "__main__": filter_dense_distance_threshold=10, filter_time_threshold=timedelta(minutes=5), - grid_size=500, + grid_size=400, grid_overlap=0.1, mode="重建模式", + produce_dem=False, ) # 创建处理器并执行 diff --git a/post_pro/merge_obj.py b/post_pro/merge_obj.py index a192605..7ed94ce 100644 --- a/post_pro/merge_obj.py +++ b/post_pro/merge_obj.py @@ -4,6 +4,7 @@ import numpy as np from typing import Dict import pandas as pd import shutil +import time class MergeObj: @@ -124,8 +125,23 @@ class MergeObj: materials2 = self.read_mtl(mtl2_path) # 创建材质名称映射(使用与MTL文件相同的命名格式) - material_map1 = {old_name: f"material_{grid_id1[0]}_{grid_id1[1]}_{old_name}" for old_name in materials1.keys()} - material_map2 = {old_name: f"material_{grid_id2[0]}_{grid_id2[1]}_{old_name}" for old_name in materials2.keys()} + material_map1 = {} + material_map2 = {} + + # 处理第一个模型的材质映射 + for old_name in materials1.keys(): + # 如果材质名称已经包含了网格ID前缀,就不再添加 + if old_name.startswith(f"material_{grid_id1[0]}_{grid_id1[1]}_"): + material_map1[old_name] = old_name + else: + material_map1[old_name] = f"material_{grid_id1[0]}_{grid_id1[1]}_{old_name}" + + # 处理第二个模型的材质映射 + for old_name in materials2.keys(): + if old_name.startswith(f"material_{grid_id2[0]}_{grid_id2[1]}_"): + material_map2[old_name] = old_name + else: + material_map2[old_name] = f"material_{grid_id2[0]}_{grid_id2[1]}_{old_name}" # 平移第二个模型的顶点 vertices2_translated = self.translate_vertices(vertices2, translation) @@ -315,24 +331,51 @@ class MergeObj: # 合并OBJ文件 reference_id = list(grid_files.keys())[0] merged_obj = grid_files[reference_id]['obj'] + temp_files = [] # 记录所有中间文件 for grid_id, files in list(grid_files.items())[1:]: translation = translations[grid_id] translation = (translation[0], translation[1], 0) - output_obj = os.path.join( + # 生成临时输出文件名 + temp_output = os.path.join( output_model_dir, - f"merged_model_{reference_id[0]}_{reference_id[1]}_{grid_id[0]}_{grid_id[1]}.obj" + f"temp_merged_{int(time.time())}.obj" ) + temp_files.append(temp_output) # 添加到临时文件列表 - self.merge_two_objs(merged_obj, files['obj'], output_obj, translation, reference_id, grid_id) - merged_obj = output_obj + self.merge_two_objs(merged_obj, files['obj'], temp_output, translation, reference_id, grid_id) + + # 如果上一个merged_obj是临时文件,则删除它 + if merged_obj != grid_files[reference_id]['obj'] and os.path.exists(merged_obj): + try: + os.remove(merged_obj) + except Exception as e: + self.logger.warning(f"删除临时文件失败: {merged_obj}, 错误: {str(e)}") + + merged_obj = temp_output # 最终结果 final_obj = os.path.join(output_model_dir, "merged_model.obj") - if os.path.exists(merged_obj) and merged_obj != final_obj: + try: + if os.path.exists(final_obj): + os.remove(final_obj) + os.rename(merged_obj, final_obj) + except Exception as e: + self.logger.warning(f"重命名最终文件失败: {str(e)}") shutil.copy2(merged_obj, final_obj) - os.remove(merged_obj) + try: + os.remove(merged_obj) + except: + pass + + # 清理所有临时文件 + for temp_file in temp_files: + if os.path.exists(temp_file): + try: + os.remove(temp_file) + except Exception as e: + self.logger.warning(f"删除临时文件失败: {temp_file}, 错误: {str(e)}") self.logger.info( f"模型合并完成,输出目录: {output_model_dir}\n" diff --git a/post_pro/merge_tif.py b/post_pro/merge_tif.py index ee3f048..c1e005c 100644 --- a/post_pro/merge_tif.py +++ b/post_pro/merge_tif.py @@ -3,6 +3,8 @@ import logging import os from typing import Dict import pandas as pd +import time +import shutil class MergeTif: @@ -25,49 +27,72 @@ class MergeTif: raise FileNotFoundError(error_msg) # 打开影像,检查投影是否一致 - datasets = [gdal.Open(tif) for tif in [input_tif1, input_tif2]] - if None in datasets: - error_msg = "无法打开输入影像文件" - self.logger.error(error_msg) - raise ValueError(error_msg) + datasets = [] + try: + for tif in [input_tif1, input_tif2]: + ds = gdal.Open(tif) + if ds is None: + error_msg = f"无法打开影像文件: {tif}" + self.logger.error(error_msg) + raise ValueError(error_msg) + datasets.append(ds) - projections = [dataset.GetProjection() for dataset in datasets] - self.logger.debug(f"影像1投影: {projections[0]}") - self.logger.debug(f"影像2投影: {projections[1]}") + projections = [ds.GetProjection() for ds in datasets] + self.logger.debug(f"影像1投影: {projections[0]}") + self.logger.debug(f"影像2投影: {projections[1]}") - # 检查投影是否一致 - if len(set(projections)) != 1: - error_msg = "影像的投影不一致,请先进行重投影!" - self.logger.error(error_msg) - raise ValueError(error_msg) + # 检查投影是否一致 + if len(set(projections)) != 1: + error_msg = "影像的投影不一致,请先进行重投影!" + self.logger.error(error_msg) + raise ValueError(error_msg) - # 创建 GDAL Warp 选项 - warp_options = gdal.WarpOptions( - format="GTiff", - resampleAlg="average", - srcNodata=0, - dstNodata=0, - multithread=True - ) + # 如果输出文件已存在,先删除 + if os.path.exists(output_tif): + try: + os.remove(output_tif) + except Exception as e: + self.logger.warning(f"删除已存在的输出文件失败: {str(e)}") + # 生成一个新的输出文件名 + base, ext = os.path.splitext(output_tif) + output_tif = f"{base}_{int(time.time())}{ext}" + self.logger.info(f"使用新的输出文件名: {output_tif}") - self.logger.info("开始执行影像拼接...") - result = gdal.Warp( - output_tif, [input_tif1, input_tif2], options=warp_options) + # 创建 GDAL Warp 选项 + warp_options = gdal.WarpOptions( + format="GTiff", + resampleAlg="average", + srcNodata=0, + dstNodata=0, + multithread=True + ) - if result is None: - error_msg = "影像拼接失败" - self.logger.error(error_msg) - raise RuntimeError(error_msg) + self.logger.info("开始执行影像拼接...") + result = gdal.Warp(output_tif, datasets, options=warp_options) - # 获取输出影像的基本信息 - output_dataset = gdal.Open(output_tif) - if output_dataset: - width = output_dataset.RasterXSize - height = output_dataset.RasterYSize - bands = output_dataset.RasterCount - self.logger.info(f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}") + if result is None: + error_msg = "影像拼接失败" + self.logger.error(error_msg) + raise RuntimeError(error_msg) - self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}") + # 获取输出影像的基本信息 + output_dataset = gdal.Open(output_tif) + if output_dataset: + width = output_dataset.RasterXSize + height = output_dataset.RasterYSize + bands = output_dataset.RasterCount + self.logger.info( + f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}") + output_dataset = None # 显式关闭数据集 + + self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}") + + finally: + # 确保所有数据集都被正确关闭 + for ds in datasets: + if ds: + ds = None + result = None except Exception as e: self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True) @@ -108,25 +133,58 @@ class MergeTif: self.logger.info(f"设置第一个输入{product_name}: {input_tif1}") else: input_tif2 = grid_tif - output_tif = os.path.join( - self.output_dir, f"merged_{product_info['output']}") + # 生成带时间戳的临时输出文件名 + temp_output = os.path.join( + self.output_dir, + f"temp_merged_{int(time.time())}_{product_info['output']}" + ) self.logger.info( f"开始合并{product_name}第 {merge_count + 1} 次:\n" f"输入1: {input_tif1}\n" f"输入2: {input_tif2}\n" - f"输出: {output_tif}" + f"输出: {temp_output}" ) - self.merge_two_tifs(input_tif1, input_tif2, output_tif) + self.merge_two_tifs(input_tif1, input_tif2, temp_output) merge_count += 1 - input_tif1 = output_tif + # 如果之前的输入文件是临时文件,则删除它 + if 'temp_merged_' in input_tif1: + try: + os.remove(input_tif1) + except Exception as e: + self.logger.warning(f"删除临时文件失败: {str(e)}") + + input_tif1 = temp_output input_tif2 = None + # 重命名最终的临时文件为目标文件名 + final_output = os.path.join( + self.output_dir, product_info['output']) + if os.path.exists(final_output): + try: + os.remove(final_output) + except Exception as e: + self.logger.warning(f"删除已存在的最终输出文件失败: {str(e)}") + final_output = os.path.join( + self.output_dir, + f"merged_{int(time.time())}_{product_info['output']}" + ) + + try: + os.rename(input_tif1, final_output) + except Exception as e: + self.logger.warning(f"重命名最终文件失败: {str(e)}") + shutil.copy2(input_tif1, final_output) + try: + os.remove(input_tif1) + except: + pass + self.logger.info( f"{product_name}合并完成,共执行 {merge_count} 次合并," - f"最终输出文件: {input_tif1}" + f"最终输出文件: {final_output}" ) except Exception as e: @@ -134,7 +192,7 @@ class MergeTif: f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True) raise - def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame]): + def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool): """合并所有产品(正射影像、DSM和DTM)""" try: products = [ @@ -144,19 +202,23 @@ class MergeTif: 'filename': 'odm_orthophoto.original.tif', 'output': 'orthophoto.tif' }, - { - 'name': 'DSM', - 'path': 'odm_dem', - 'filename': 'dsm.original.tif', - 'output': 'dsm.tif' - }, - { - 'name': 'DTM', - 'path': 'odm_dem', - 'filename': 'dtm.original.tif', - 'output': 'dtm.tif' - } + ] + if produce_dem: + products.append( + { + 'name': 'DSM', + 'path': 'odm_dem', + 'filename': 'dsm.original.tif', + 'output': 'dsm.tif' + }, + { + 'name': 'DTM', + 'path': 'odm_dem', + 'filename': 'dtm.original.tif', + 'output': 'dtm.tif' + } + ) for product in products: self.merge_grid_tif(grid_points, product) @@ -173,7 +235,7 @@ if __name__ == "__main__": os.path.dirname(os.path.abspath(__file__)))) from utils.logger import setup_logger import pandas as pd - + # 设置输出目录和日志 output_dir = r"G:\ODM_output\1009" setup_logger(output_dir) diff --git a/utils/grid_divider.py b/utils/grid_divider.py index c871b57..b7f79f5 100644 --- a/utils/grid_divider.py +++ b/utils/grid_divider.py @@ -149,9 +149,9 @@ class GridDivider: plt.text(center_lon, center_lat, f"({i},{j})", # 显示(i,j) horizontalalignment='center', verticalalignment='center') - plt.title('网格划分与GPS点分布图') - plt.xlabel('经度') - plt.ylabel('纬度') + plt.title('Grid Division and GPS Point Distribution') + plt.xlabel('Longitude') + plt.ylabel('Latitude') plt.legend() plt.grid(True) diff --git a/utils/odm_monitor.py b/utils/odm_monitor.py index 12e6773..d053f5d 100644 --- a/utils/odm_monitor.py +++ b/utils/odm_monitor.py @@ -20,8 +20,12 @@ class ODMProcessMonitor: success_markers.append('odm_texturing') return all(os.path.exists(os.path.join(grid_dir, 'project', marker)) for marker in success_markers) - def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple, fast_mode: bool = True) -> Tuple[bool, str]: + def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple, fast_mode: bool = True, produce_dem: bool = False) -> Tuple[bool, str]: """运行ODM命令""" + if produce_dem and fast_mode: + self.logger.error("快拼模式下无法生成DEM,请调整生产参数") + return False, "快拼模式下无法生成DEM,请调整生产参数" + self.logger.info(f"开始处理网格 ({grid_id[0]},{grid_id[1]})") # 构建Docker命令 @@ -37,6 +41,12 @@ class ODMProcessMonitor: f"--orthophoto-resolution 10 " ) + if produce_dem: + docker_command += ( + f"--dsm " + f"--dtm " + ) + if fast_mode: docker_command += ( f"--fast-orthophoto " @@ -60,7 +70,7 @@ class ODMProcessMonitor: self.logger.error(f"网格 ({grid_id[0]},{grid_id[1]}) 处理失败") return False, f"网格 ({grid_id[0]},{grid_id[1]}) 处理失败" - def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame]): + def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool): """处理所有网格""" self.logger.info("开始执行网格处理") for grid_id in grid_points.keys(): @@ -71,7 +81,8 @@ class ODMProcessMonitor: success, error_msg = self.run_odm_with_monitor( grid_dir=grid_dir, grid_id=grid_id, - fast_mode=(self.mode == "快拼模式") + fast_mode=(self.mode == "快拼模式"), + produce_dem=produce_dem ) if not success: