UAV/post_pro/merge_tif.py
2025-01-17 10:46:39 +08:00

263 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from osgeo import gdal
import logging
import os
from typing import Dict
import pandas as pd
import time
import shutil
class MergeTif:
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergeTif')
def merge_two_tifs(self, input_tif1: str, input_tif2: str, output_tif: str):
"""合并两张TIF影像"""
try:
self.logger.info("开始合并TIF影像")
self.logger.info(f"输入影像1: {input_tif1}")
self.logger.info(f"输入影像2: {input_tif2}")
self.logger.info(f"输出影像: {output_tif}")
# 检查输入文件是否存在
if not os.path.exists(input_tif1) or not os.path.exists(input_tif2):
error_msg = "输入影像文件不存在"
self.logger.error(error_msg)
raise FileNotFoundError(error_msg)
# 打开影像,检查投影是否一致
datasets = []
try:
for tif in [input_tif1, input_tif2]:
ds = gdal.Open(tif)
if ds is None:
error_msg = f"无法打开影像文件: {tif}"
self.logger.error(error_msg)
raise ValueError(error_msg)
datasets.append(ds)
projections = [ds.GetProjection() for ds in datasets]
self.logger.debug(f"影像1投影: {projections[0]}")
self.logger.debug(f"影像2投影: {projections[1]}")
# 检查投影是否一致
if len(set(projections)) != 1:
error_msg = "影像的投影不一致,请先进行重投影!"
self.logger.error(error_msg)
raise ValueError(error_msg)
# 如果输出文件已存在,先删除
if os.path.exists(output_tif):
try:
os.remove(output_tif)
except Exception as e:
self.logger.warning(f"删除已存在的输出文件失败: {str(e)}")
# 生成一个新的输出文件名
base, ext = os.path.splitext(output_tif)
output_tif = f"{base}_{int(time.time())}{ext}"
self.logger.info(f"使用新的输出文件名: {output_tif}")
# 创建 GDAL Warp 选项
warp_options = gdal.WarpOptions(
format="GTiff",
resampleAlg="average",
srcNodata=0,
dstNodata=0,
multithread=True
)
self.logger.info("开始执行影像拼接...")
result = gdal.Warp(output_tif, datasets, options=warp_options)
if result is None:
error_msg = "影像拼接失败"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# 获取输出影像的基本信息
output_dataset = gdal.Open(output_tif)
if output_dataset:
width = output_dataset.RasterXSize
height = output_dataset.RasterYSize
bands = output_dataset.RasterCount
self.logger.info(
f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}")
output_dataset = None # 显式关闭数据集
self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}")
finally:
# 确保所有数据集都被正确关闭
for ds in datasets:
if ds:
ds = None
result = None
except Exception as e:
self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_grid_tif(self, grid_points: Dict[tuple, pd.DataFrame], product_info: dict):
"""合并指定产品的所有网格"""
product_name = product_info['name']
product_path = product_info['path']
filename_original = product_info['filename']
filename = filename_original.replace(".original", "")
self.logger.info(f"开始合并{product_name}")
input_tif1, input_tif2 = None, None
merge_count = 0
temp_files = []
try:
for grid_id, points in grid_points.items():
grid_tif_original = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
product_path,
filename_original
)
grid_tif = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
product_path,
filename
)
if os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}存在: {grid_tif_original, grid_tif}")
# 如果文件大于600MB则不使用original文件
file_size_mb_original = os.path.getsize(
grid_tif_original) / (1024 * 1024) # 转换为MB
if file_size_mb_original > 600:
to_merge_tif = grid_tif
else:
to_merge_tif = grid_tif_original
elif os.path.exists(grid_tif_original) and not os.path.exists(grid_tif):
to_merge_tif = grid_tif_original
elif not os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
to_merge_tif = grid_tif
else:
self.logger.warning(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}不存在: {grid_tif_original, grid_tif}")
continue
if input_tif1 is None:
input_tif1 = to_merge_tif
self.logger.info(f"设置第一个输入{product_name}: {input_tif1}")
else:
input_tif2 = to_merge_tif
# 生成带时间戳的临时输出文件名
temp_output = os.path.join(
self.output_dir,
f"temp_merged_{int(time.time())}_{product_info['output']}"
)
self.logger.info(
f"开始合并{product_name}{merge_count + 1} 次:\n"
f"输入1: {input_tif1}\n"
f"输入2: {input_tif2}\n"
f"输出: {temp_output}"
)
self.merge_two_tifs(input_tif1, input_tif2, temp_output)
merge_count += 1
input_tif1 = temp_output
input_tif2 = None
temp_files.append(temp_output)
final_output = os.path.join(
self.output_dir, product_info['output'])
shutil.copy2(input_tif1, final_output)
# 清理所有临时文件
for temp_file in temp_files:
try:
os.remove(temp_file)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {str(e)}")
self.logger.info(
f"{product_name}合并完成,共执行 {merge_count} 次合并,"
f"最终输出文件: {final_output}"
)
except Exception as e:
self.logger.error(
f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool):
"""合并所有产品正射影像、DSM和DTM"""
try:
products = [
{
'name': '正射影像',
'path': 'odm_orthophoto',
'filename': 'odm_orthophoto.original.tif',
'output': 'orthophoto.tif'
},
]
if produce_dem:
products.append(
{
'name': 'DSM',
'path': 'odm_dem',
'filename': 'dsm.original.tif',
'output': 'dsm.tif'
}
)
products.append(
{
'name': 'DTM',
'path': 'odm_dem',
'filename': 'dtm.original.tif',
'output': 'dtm.tif'
}
)
for product in products:
self.merge_grid_tif(grid_points, product)
self.logger.info("所有产品合并完成")
except Exception as e:
self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))
from utils.logger import setup_logger
import pandas as pd
# 设置输出目录和日志
output_dir = r"G:\ODM_output\1009"
setup_logger(output_dir)
# 构造测试用的grid_points字典
# 假设我们有两个网格每个网格包含一些GPS点的DataFrame
grid_points = {
(0, 0): pd.DataFrame({
'latitude': [39.9, 39.91],
'longitude': [116.3, 116.31],
'altitude': [100, 101]
}),
(0, 1): pd.DataFrame({
'latitude': [39.92, 39.93],
'longitude': [116.32, 116.33],
'altitude': [102, 103]
})
}
# 创建MergeTif实例并执行合并
merge_tif = MergeTif(output_dir)
merge_tif.merge_all_tifs(grid_points)