修改merge_tif,使用ODM的cut,feather和merge函数

This commit is contained in:
weixin_46229132 2025-04-10 18:57:34 +08:00
parent 697660b5b3
commit a25adbcc31
4 changed files with 262 additions and 586 deletions

View File

@ -254,7 +254,7 @@ class ImagePreprocessor:
"""合并所有网格的影像产品""" """合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品") self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir) merger = MergeTif(self.config.output_dir)
merger.merge_all_tifs(grid_points, mode) merger.merge_orthophoto(grid_points)
def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]): def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并所有网格的PLY点云""" """合并所有网格的PLY点云"""

View File

@ -1,352 +0,0 @@
import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict, Tuple
import psutil
import pandas as pd
from filter.cluster_filter import GPSCluster
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from filter.gps_filter import GPSFilter
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from post_pro.merge_obj import MergeObj
from post_pro.merge_laz import MergePly
from post_pro.conv_obj2 import ConvertOBJ
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
mode: str = "快拼模式"
accuracy: str = "medium"
produce_dem: bool = False
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
# 检查磁盘空间
self._check_disk_space()
# # 清理并重建输出目录
# if os.path.exists(config.output_dir):
# self._clean_output_dir()
# self._setup_output_dirs()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = None
self.odm_monitor = ODMProcessMonitor(
config.output_dir, mode=config.mode)
self.visualizer = FilterVisualizer(config.output_dir)
def _clean_output_dir(self):
"""清理输出目录"""
try:
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def _setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def _get_directory_size(self, path):
"""获取目录的总大小(字节)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
try:
total_size += os.path.getsize(file_path)
except (OSError, FileNotFoundError):
continue
return total_size
def _check_disk_space(self):
"""检查磁盘空间是否足够"""
# 获取输入目录大小
input_size = self._get_directory_size(self.config.image_dir)
# 获取输出目录所在磁盘的剩余空间
output_drive = os.path.splitdrive(
os.path.abspath(self.config.output_dir))[0]
if not output_drive: # 处理Linux/Unix路径
output_drive = '/home'
disk_usage = psutil.disk_usage(output_drive)
free_space = disk_usage.free
# 计算所需空间输入大小的1.5倍)
required_space = input_size * 12
if free_space < required_space:
error_msg = (
f"磁盘空间不足!\n"
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
f"所需空间: {required_space / (1024**3):.2f} GB\n"
f"可用空间: {free_space / (1024**3):.2f} GB\n"
f"在驱动器 {output_drive}"
)
raise RuntimeError(error_msg)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points,
eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering")
def filter_isolated_points(self):
"""过滤孤立点"""
filter = GPSFilter(self.config.output_dir)
previous_points = self.gps_points.copy()
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "2-Isolated Points")
def filter_time_group_overlap(self):
"""过滤重叠的时间组"""
previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
self.gps_points = filter.filter_overlapping_groups(
self.gps_points,
time_threshold=self.config.time_group_interval
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "3-Time Group Overlap")
def calculate_center_coordinates(self):
"""计算剩余点的中心经纬度坐标"""
mean_lat = self.gps_points['lat'].mean()
mean_lon = self.gps_points['lon'].mean()
self.logger.info(f"区域中心坐标:纬度 {mean_lat:.6f}, 经度 {mean_lon:.6f}")
return mean_lat, mean_lon
def filter_alternate_images(self):
"""按时间顺序隔一个删一个图像来降低密度"""
previous_points = self.gps_points.copy()
# 按时间戳排序
self.gps_points = self.gps_points.sort_values('date')
# 保留索引为偶数的行(即隔一个保留一个)
self.gps_points = self.gps_points.iloc[::2].reset_index(drop=True)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "4-Alternate Images")
self.logger.info(f"交替过滤后剩余 {len(self.gps_points)} 个点")
def divide_grids(self) -> Tuple[Dict[tuple, pd.DataFrame], Dict[tuple, tuple]]:
"""划分网格
Returns:
tuple: (grid_points, translations)
- grid_points: 网格点数据字典
- translations: 网格平移量字典
"""
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
grid_size=self.config.grid_size,
output_dir=self.config.output_dir
)
grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points
)
grid_divider.visualize_grids(self.gps_points, grids)
if len(grids) >= 20:
self.logger.warning("网格数量已超过20, 需要人工调整分区")
return grid_points, translations
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
"""复制图像到目标文件夹"""
self.logger.info("开始复制图像文件")
for grid_id, points in grid_points.items():
output_dir = os.path.join(
self.config.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"images"
)
os.makedirs(output_dir, exist_ok=True)
for point in points:
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool):
"""合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir)
merger.merge_all_tifs(grid_points, produce_dem)
def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并所有网格的PLY点云"""
self.logger.info("开始合并PLY点云")
merger = MergePly(self.config.output_dir)
merger.merge_grid_laz(grid_points)
def merge_obj(self, grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""合并所有网格的OBJ模型"""
self.logger.info("开始合并OBJ模型")
merger = MergeObj(self.config.output_dir)
merger.merge_grid_obj(grid_points, translations)
def convert_obj(self, grid_points: Dict[tuple, pd.DataFrame]):
"""转换OBJ模型"""
self.logger.info("开始转换OBJ模型")
converter = ConvertOBJ(self.config.output_dir)
converter.convert_grid_obj(grid_points)
def post_process(self, successful_grid_points: Dict[tuple, pd.DataFrame], grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""后处理:合并或复制处理结果"""
if len(successful_grid_points) < len(grid_points):
self.logger.warning(
f"{len(grid_points) - len(successful_grid_points)} 个网格处理失败,"
f"将只合并成功处理的 {len(successful_grid_points)} 个网格"
)
if self.config.mode == "快拼模式":
self.merge_tif(successful_grid_points, self.config.produce_dem)
elif self.config.mode == "三维模式":
self.merge_tif(successful_grid_points, self.config.produce_dem)
# self.merge_ply(successful_grid_points)
# self.merge_obj(successful_grid_points, translations)
self.convert_obj(successful_grid_points)
else:
self.merge_tif(successful_grid_points, self.config.produce_dem)
# self.merge_ply(successful_grid_points)
# self.merge_obj(successful_grid_points, translations)
self.convert_obj(successful_grid_points)
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
# self.filter_isolated_points()
grid_points, translations = self.divide_grids()
# self.copy_images(grid_points)
# self.logger.info("预处理任务完成")
# successful_grid_points = self.odm_monitor.process_all_grids(
# grid_points, self.config.produce_dem, self.config.accuracy)
# successful_grid_points = self.odm_monitor.process_all_grids(
# grid_points, self.config.produce_dem)
successful_grid_points = grid_points
self.post_process(successful_grid_points,
grid_points, translations)
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"E:\datasets\UAV\134\project\images",
output_dir=r"G:\ODM_output\134",
cluster_eps=0.01,
cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_size=800,
grid_overlap=0.05,
mode="重建模式",
produce_dem=False,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()

View File

@ -5,6 +5,13 @@ from typing import Dict
import pandas as pd import pandas as pd
import time import time
import shutil import shutil
import rasterio
from rasterio.mask import mask
from rasterio.transform import Affine, rowcol
import fiona
from edt import edt
import numpy as np
import math
class MergeTif: class MergeTif:
@ -12,251 +19,271 @@ class MergeTif:
self.output_dir = output_dir self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergeTif') self.logger = logging.getLogger('UAV_Preprocess.MergeTif')
def merge_two_tifs(self, input_tif1: str, input_tif2: str, output_tif: str): def merge_orthophoto(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并两张TIF影像""" """合并网格的正射影像"""
try:
self.logger.info("开始合并TIF影像")
self.logger.info(f"输入影像1: {input_tif1}")
self.logger.info(f"输入影像2: {input_tif2}")
self.logger.info(f"输出影像: {output_tif}")
# 检查输入文件是否存在
if not os.path.exists(input_tif1) or not os.path.exists(input_tif2):
error_msg = "输入影像文件不存在"
self.logger.error(error_msg)
raise FileNotFoundError(error_msg)
# 打开影像,检查投影是否一致
datasets = []
try:
for tif in [input_tif1, input_tif2]:
ds = gdal.Open(tif)
if ds is None:
error_msg = f"无法打开影像文件: {tif}"
self.logger.error(error_msg)
raise ValueError(error_msg)
datasets.append(ds)
projections = [ds.GetProjection() for ds in datasets]
self.logger.debug(f"影像1投影: {projections[0]}")
self.logger.debug(f"影像2投影: {projections[1]}")
# 检查投影是否一致
if len(set(projections)) != 1:
error_msg = "影像的投影不一致,请先进行重投影!"
self.logger.error(error_msg)
raise ValueError(error_msg)
# 如果输出文件已存在,先删除
if os.path.exists(output_tif):
try:
os.remove(output_tif)
except Exception as e:
self.logger.warning(f"删除已存在的输出文件失败: {str(e)}")
# 生成一个新的输出文件名
base, ext = os.path.splitext(output_tif)
output_tif = f"{base}_{int(time.time())}{ext}"
self.logger.info(f"使用新的输出文件名: {output_tif}")
# 创建 GDAL Warp 选项
warp_options = gdal.WarpOptions(
format="GTiff",
resampleAlg="average",
srcNodata=0,
dstNodata=0,
multithread=True
)
self.logger.info("开始执行影像拼接...")
result = gdal.Warp(output_tif, datasets, options=warp_options)
if result is None:
error_msg = "影像拼接失败"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# 获取输出影像的基本信息
output_dataset = gdal.Open(output_tif)
if output_dataset:
width = output_dataset.RasterXSize
height = output_dataset.RasterYSize
bands = output_dataset.RasterCount
self.logger.info(
f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}")
output_dataset = None # 显式关闭数据集
self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}")
finally:
# 确保所有数据集都被正确关闭
for ds in datasets:
if ds:
ds = None
result = None
except Exception as e:
self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_grid_tif(self, grid_points: Dict[tuple, pd.DataFrame], product_info: dict):
"""合并指定产品的所有网格"""
product_name = product_info['name']
product_path = product_info['path']
filename_original = product_info['filename']
filename = filename_original.replace(".original", "")
self.logger.info(f"开始合并{product_name}")
input_tif1, input_tif2 = None, None
merge_count = 0
temp_files = []
try: try:
all_orthos_and_ortho_cuts = []
for grid_id, points in grid_points.items(): for grid_id, points in grid_points.items():
grid_tif_original = os.path.join( grid_ortho_dir = os.path.join(
self.output_dir, self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}", f"grid_{grid_id[0]}_{grid_id[1]}",
"project", "project",
product_path, "odm_orthophoto",
filename_original
) )
grid_tif = os.path.join( tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif")
self.output_dir, tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg")
f"grid_{grid_id[0]}_{grid_id[1]}", output_cut_tif = os.path.join(
"project", grid_ortho_dir, "odm_orthophoto_cut.tif")
product_path, output_feathered_tif = os.path.join(
filename grid_ortho_dir, "odm_orthophoto_feathered.tif")
)
if os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}存在: {grid_tif_original, grid_tif}")
# 如果文件大于600MB则不使用original文件
file_size_mb_original = os.path.getsize(
grid_tif_original) / (1024 * 1024) # 转换为MB
if file_size_mb_original > 600:
to_merge_tif = grid_tif
else:
to_merge_tif = grid_tif_original
elif os.path.exists(grid_tif_original) and not os.path.exists(grid_tif):
to_merge_tif = grid_tif_original
elif not os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
to_merge_tif = grid_tif
else:
self.logger.warning(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}不存在: {grid_tif_original, grid_tif}")
continue
if input_tif1 is None: self.compute_mask_raster(
input_tif1 = to_merge_tif tif_path, tif_mask, output_cut_tif, blend_distance=20)
self.logger.info(f"设置第一个输入{product_name}: {input_tif1}") self.feather_raster(
else: tif_path, output_feathered_tif, blend_distance=20)
input_tif2 = to_merge_tif all_orthos_and_ortho_cuts.append(
# 生成带时间戳的临时输出文件名 [output_feathered_tif, output_cut_tif])
temp_output = os.path.join(
self.output_dir,
f"temp_merged_{int(time.time())}_{product_info['output']}"
)
self.logger.info( orthophoto_vars = {
f"开始合并{product_name}{merge_count + 1} 次:\n" 'TILED': 'NO',
f"输入1: {input_tif1}\n" 'COMPRESS': False,
f"输入2: {input_tif2}\n" 'PREDICTOR': '1',
f"输出: {temp_output}" 'BIGTIFF': 'IF_SAFER',
) 'BLOCKXSIZE': 512,
'BLOCKYSIZE': 512,
self.merge_two_tifs(input_tif1, input_tif2, temp_output) 'NUM_THREADS': 15
merge_count += 1
input_tif1 = temp_output
input_tif2 = None
temp_files.append(temp_output)
final_output = os.path.join(
self.output_dir, product_info['output'])
shutil.copy2(input_tif1, final_output)
# 清理所有临时文件
for temp_file in temp_files:
try:
os.remove(temp_file)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {str(e)}")
self.logger.info(
f"{product_name}合并完成,共执行 {merge_count} 次合并,"
f"最终输出文件: {final_output}"
)
except Exception as e:
self.logger.error(
f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame], mode: str):
"""合并所有产品正射影像、DSM和DTM"""
try:
products = [
{
'name': '正射影像',
'path': 'odm_orthophoto',
'filename': 'odm_orthophoto.original.tif',
'output': 'orthophoto.tif'
},
]
if mode == '三维模式':
products.append(
{
'name': 'DSM',
'path': 'odm_dem',
'filename': 'dsm.original.tif',
'output': 'dsm.tif'
} }
) self.merge(all_orthos_and_ortho_cuts, os.path.join(
products.append( self.output_dir, "orthophoto.tif"), orthophoto_vars)
{
'name': 'DTM',
'path': 'odm_dem',
'filename': 'dtm.original.tif',
'output': 'dtm.tif'
}
)
for product in products:
self.merge_grid_tif(grid_points, product)
self.logger.info("所有产品合并完成") self.logger.info("所有产品合并完成")
except Exception as e: except Exception as e:
self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True) self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True)
raise raise
def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False):
if not os.path.exists(input_raster):
print("Cannot mask raster, %s does not exist" % input_raster)
return
if __name__ == "__main__": if not os.path.exists(vector_mask):
import sys print("Cannot mask raster, %s does not exist" % vector_mask)
sys.path.append(os.path.dirname( return
os.path.dirname(os.path.abspath(__file__))))
from utils.logger import setup_logger
import pandas as pd
# 设置输出目录和日志 print("Computing mask raster: %s" % output_raster)
output_dir = r"G:\ODM_output\1009"
setup_logger(output_dir)
# 构造测试用的grid_points字典 with rasterio.open(input_raster, 'r') as rast:
# 假设我们有两个网格每个网格包含一些GPS点的DataFrame with fiona.open(vector_mask) as src:
grid_points = { burn_features = src
(0, 0): pd.DataFrame({
'latitude': [39.9, 39.91],
'longitude': [116.3, 116.31],
'altitude': [100, 101]
}),
(0, 1): pd.DataFrame({
'latitude': [39.92, 39.93],
'longitude': [116.32, 116.33],
'altitude': [102, 103]
})
}
# 创建MergeTif实例并执行合并 if only_max_coords_feature:
merge_tif = MergeTif(output_dir) max_coords_count = 0
merge_tif.merge_all_tifs(grid_points) max_coords_feature = None
for feature in src:
if feature is not None:
# No complex shapes
if len(feature['geometry']['coordinates'][0]) > max_coords_count:
max_coords_count = len(
feature['geometry']['coordinates'][0])
max_coords_feature = feature
if max_coords_feature is not None:
burn_features = [max_coords_feature]
shapes = [feature["geometry"] for feature in burn_features]
out_image, out_transform = mask(rast, shapes, nodata=0)
if blend_distance > 0:
if out_image.shape[0] >= 4:
# alpha_band = rast.dataset_mask()
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot blend cutline!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def feather_raster(self, input_raster, output_raster, blend_distance=20):
if not os.path.exists(input_raster):
print("Cannot feather raster, %s does not exist" % input_raster)
return
print("Computing feather raster: %s" % output_raster)
with rasterio.open(input_raster, 'r') as rast:
out_image = rast.read()
if blend_distance > 0:
if out_image.shape[0] >= 4:
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot feather raster!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}):
"""
Based on https://github.com/mapbox/rio-merge-rgba/
Merge orthophotos around cutlines using a blend buffer.
"""
inputs = []
bounds = None
precision = 7
for o, c in input_ortho_and_ortho_cuts:
inputs.append((o, c))
with rasterio.open(inputs[0][0]) as first:
res = first.res
dtype = first.dtypes[0]
profile = first.profile
num_bands = first.meta['count'] - 1 # minus alpha
colorinterp = first.colorinterp
print("%s valid orthophoto rasters to merge" % len(inputs))
sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs]
# scan input files.
# while we're at it, validate assumptions about inputs
xs = []
ys = []
for src, _ in sources:
left, bottom, right, top = src.bounds
xs.extend([left, right])
ys.extend([bottom, top])
if src.profile["count"] < 2:
raise ValueError("Inputs must be at least 2-band rasters")
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
output_transform = Affine.translation(dst_w, dst_n)
output_transform *= Affine.scale(res[0], -res[1])
# Compute output array shape. We guarantee it will cover the output
# bounds completely.
output_width = int(math.ceil((dst_e - dst_w) / res[0]))
output_height = int(math.ceil((dst_n - dst_s) / res[1]))
# Adjust bounds to fit.
dst_e, dst_s = output_transform * (output_width, output_height)
print("Output width: %d, height: %d" %
(output_width, output_height))
print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
profile["transform"] = output_transform
profile["height"] = output_height
profile["width"] = output_width
profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES'
profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512)
profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512)
profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW')
profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2')
profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER')
profile.update()
# create destination file
with rasterio.open(output_orthophoto, "w", **profile) as dstrast:
dstrast.colorinterp = colorinterp
for idx, dst_window in dstrast.block_windows():
left, bottom, right, top = dstrast.window_bounds(dst_window)
blocksize = dst_window.width
dst_rows, dst_cols = (dst_window.height, dst_window.width)
# initialize array destined for the block
dst_count = first.count
dst_shape = (dst_count, dst_rows, dst_cols)
dstarr = np.zeros(dst_shape, dtype=dtype)
# First pass, write all rasters naively without blending
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
# pixels without data yet are available to write
write_region = np.logical_and(
(dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata
)
np.copyto(dstarr, temp, where=write_region)
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Second pass, write all feathered rasters
# blending the edges
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
where = temp[-1] != 0
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=where)
dstarr[-1][where] = 255.0
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Third pass, write cut rasters
# blending the cutlines
for _, cut in sources:
src_window = tuple(zip(rowcol(
cut.transform, left, top, op=round, precision=precision
), rowcol(
cut.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = cut.read(
out=temp, window=src_window, boundless=True, masked=False
)
# For each band, average alpha values between
# destination raster and cut raster
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=temp[-1] != 0)
dstrast.write(dstarr, window=dst_window)
return output_orthophoto

View File

@ -157,8 +157,8 @@ class ODMProcessMonitor:
f"--use-exif " f"--use-exif "
f"--use-hybrid-bundle-adjustment " f"--use-hybrid-bundle-adjustment "
f"--optimize-disk-space " f"--optimize-disk-space "
# f"--3d-tiles " f"--orthophoto-cutline "
# f"--feature-type sift " f"--feature-type sift "
# f"--orthophoto-resolution 8 " # f"--orthophoto-resolution 8 "
) )
if accuracy == "high": if accuracy == "high":
@ -178,7 +178,8 @@ class ODMProcessMonitor:
# 根据是否使用lowest quality添加参数 # 根据是否使用lowest quality添加参数
if use_lowest_quality: if use_lowest_quality:
docker_command += f"--feature-quality lowest " # docker_command += f"--feature-quality lowest "
pass
if self.mode == "快拼模式": if self.mode == "快拼模式":
docker_command += ( docker_command += (