diff --git a/odm_preprocess.py b/odm_preprocess.py index 74b0b7e..2f3762d 100644 --- a/odm_preprocess.py +++ b/odm_preprocess.py @@ -254,7 +254,7 @@ class ImagePreprocessor: """合并所有网格的影像产品""" self.logger.info("开始合并所有影像产品") merger = MergeTif(self.config.output_dir) - merger.merge_all_tifs(grid_points, mode) + merger.merge_orthophoto(grid_points) def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]): """合并所有网格的PLY点云""" diff --git a/odm_preprocess_fast.py b/odm_preprocess_fast.py deleted file mode 100644 index 11ad0aa..0000000 --- a/odm_preprocess_fast.py +++ /dev/null @@ -1,352 +0,0 @@ -import os -import shutil -from datetime import timedelta -from dataclasses import dataclass -from typing import Dict, Tuple -import psutil -import pandas as pd - -from filter.cluster_filter import GPSCluster -from filter.time_group_overlap_filter import TimeGroupOverlapFilter -from filter.gps_filter import GPSFilter -from utils.odm_monitor import ODMProcessMonitor -from utils.gps_extractor import GPSExtractor -from utils.grid_divider import GridDivider -from utils.logger import setup_logger -from utils.visualizer import FilterVisualizer -from post_pro.merge_tif import MergeTif -from post_pro.merge_obj import MergeObj -from post_pro.merge_laz import MergePly -from post_pro.conv_obj2 import ConvertOBJ - - -@dataclass -class PreprocessConfig: - """预处理配置类""" - - image_dir: str - output_dir: str - # 聚类过滤参数 - cluster_eps: float = 0.01 - cluster_min_samples: int = 5 - # 时间组重叠过滤参数 - time_group_overlap_threshold: float = 0.7 - time_group_interval: timedelta = timedelta(minutes=5) - # 孤立点过滤参数 - filter_distance_threshold: float = 0.001 # 经纬度距离 - filter_min_neighbors: int = 6 - # 密集点过滤参数 - filter_grid_size: float = 0.001 - filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 - filter_time_threshold: timedelta = timedelta(minutes=5) - # 网格划分参数 - grid_overlap: float = 0.05 - grid_size: float = 500 - # 几个pipline过程是否开启 - mode: str = "快拼模式" - accuracy: str = "medium" - produce_dem: bool = False - - -class ImagePreprocessor: - def __init__(self, config: PreprocessConfig): - self.config = config - - # 检查磁盘空间 - self._check_disk_space() - - # # 清理并重建输出目录 - # if os.path.exists(config.output_dir): - # self._clean_output_dir() - # self._setup_output_dirs() - - # 初始化其他组件 - self.logger = setup_logger(config.output_dir) - self.gps_points = None - self.odm_monitor = ODMProcessMonitor( - config.output_dir, mode=config.mode) - self.visualizer = FilterVisualizer(config.output_dir) - - def _clean_output_dir(self): - """清理输出目录""" - try: - shutil.rmtree(self.config.output_dir) - print(f"已清理输出目录: {self.config.output_dir}") - except Exception as e: - print(f"清理输出目录时发生错误: {str(e)}") - raise - - def _setup_output_dirs(self): - """创建必要的输出目录结构""" - try: - # 创建主输出目录 - os.makedirs(self.config.output_dir) - - # 创建过滤图像保存目录 - os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) - - # 创建日志目录 - os.makedirs(os.path.join(self.config.output_dir, 'logs')) - - print(f"已创建输出目录结构: {self.config.output_dir}") - except Exception as e: - print(f"创建输出目录时发生错误: {str(e)}") - raise - - def _get_directory_size(self, path): - """获取目录的总大小(字节)""" - total_size = 0 - for dirpath, dirnames, filenames in os.walk(path): - for filename in filenames: - file_path = os.path.join(dirpath, filename) - try: - total_size += os.path.getsize(file_path) - except (OSError, FileNotFoundError): - continue - return total_size - - def _check_disk_space(self): - """检查磁盘空间是否足够""" - # 获取输入目录大小 - input_size = self._get_directory_size(self.config.image_dir) - - # 获取输出目录所在磁盘的剩余空间 - output_drive = os.path.splitdrive( - os.path.abspath(self.config.output_dir))[0] - if not output_drive: # 处理Linux/Unix路径 - output_drive = '/home' - - disk_usage = psutil.disk_usage(output_drive) - free_space = disk_usage.free - - # 计算所需空间(输入大小的1.5倍) - required_space = input_size * 12 - - if free_space < required_space: - error_msg = ( - f"磁盘空间不足!\n" - f"输入目录大小: {input_size / (1024**3):.2f} GB\n" - f"所需空间: {required_space / (1024**3):.2f} GB\n" - f"可用空间: {free_space / (1024**3):.2f} GB\n" - f"在驱动器 {output_drive}" - ) - raise RuntimeError(error_msg) - - def extract_gps(self) -> pd.DataFrame: - """提取GPS数据""" - self.logger.info("开始提取GPS数据") - extractor = GPSExtractor(self.config.image_dir) - self.gps_points = extractor.extract_all_gps() - self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") - - def cluster(self): - """使用DBSCAN对GPS点进行聚类,只保留最大的类""" - previous_points = self.gps_points.copy() - clusterer = GPSCluster( - self.gps_points, - eps=self.config.cluster_eps, - min_samples=self.config.cluster_min_samples - ) - - self.clustered_points = clusterer.fit() - self.gps_points = clusterer.get_cluster_stats(self.clustered_points) - - self.visualizer.visualize_filter_step( - self.gps_points, previous_points, "1-Clustering") - - def filter_isolated_points(self): - """过滤孤立点""" - filter = GPSFilter(self.config.output_dir) - previous_points = self.gps_points.copy() - - self.gps_points = filter.filter_isolated_points( - self.gps_points, - self.config.filter_distance_threshold, - self.config.filter_min_neighbors, - ) - - self.visualizer.visualize_filter_step( - self.gps_points, previous_points, "2-Isolated Points") - - def filter_time_group_overlap(self): - """过滤重叠的时间组""" - previous_points = self.gps_points.copy() - - filter = TimeGroupOverlapFilter( - self.config.image_dir, - self.config.output_dir, - overlap_threshold=self.config.time_group_overlap_threshold - ) - - self.gps_points = filter.filter_overlapping_groups( - self.gps_points, - time_threshold=self.config.time_group_interval - ) - - self.visualizer.visualize_filter_step( - self.gps_points, previous_points, "3-Time Group Overlap") - - def calculate_center_coordinates(self): - """计算剩余点的中心经纬度坐标""" - mean_lat = self.gps_points['lat'].mean() - mean_lon = self.gps_points['lon'].mean() - self.logger.info(f"区域中心坐标:纬度 {mean_lat:.6f}, 经度 {mean_lon:.6f}") - return mean_lat, mean_lon - - def filter_alternate_images(self): - """按时间顺序隔一个删一个图像来降低密度""" - previous_points = self.gps_points.copy() - - # 按时间戳排序 - self.gps_points = self.gps_points.sort_values('date') - - # 保留索引为偶数的行(即隔一个保留一个) - self.gps_points = self.gps_points.iloc[::2].reset_index(drop=True) - - self.visualizer.visualize_filter_step( - self.gps_points, previous_points, "4-Alternate Images") - - self.logger.info(f"交替过滤后剩余 {len(self.gps_points)} 个点") - - def divide_grids(self) -> Tuple[Dict[tuple, pd.DataFrame], Dict[tuple, tuple]]: - """划分网格 - Returns: - tuple: (grid_points, translations) - - grid_points: 网格点数据字典 - - translations: 网格平移量字典 - """ - grid_divider = GridDivider( - overlap=self.config.grid_overlap, - grid_size=self.config.grid_size, - output_dir=self.config.output_dir - ) - grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap( - self.gps_points - ) - grid_divider.visualize_grids(self.gps_points, grids) - if len(grids) >= 20: - self.logger.warning("网格数量已超过20, 需要人工调整分区") - - return grid_points, translations - - def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]): - """复制图像到目标文件夹""" - self.logger.info("开始复制图像文件") - - for grid_id, points in grid_points.items(): - output_dir = os.path.join( - self.config.output_dir, - f"grid_{grid_id[0]}_{grid_id[1]}", - "project", - "images" - ) - - os.makedirs(output_dir, exist_ok=True) - - for point in points: - src = os.path.join(self.config.image_dir, point["file"]) - dst = os.path.join(output_dir, point["file"]) - shutil.copy(src, dst) - self.logger.info( - f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像") - - def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool): - """合并所有网格的影像产品""" - self.logger.info("开始合并所有影像产品") - merger = MergeTif(self.config.output_dir) - merger.merge_all_tifs(grid_points, produce_dem) - - def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]): - """合并所有网格的PLY点云""" - self.logger.info("开始合并PLY点云") - merger = MergePly(self.config.output_dir) - merger.merge_grid_laz(grid_points) - - def merge_obj(self, grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]): - """合并所有网格的OBJ模型""" - self.logger.info("开始合并OBJ模型") - merger = MergeObj(self.config.output_dir) - merger.merge_grid_obj(grid_points, translations) - - def convert_obj(self, grid_points: Dict[tuple, pd.DataFrame]): - """转换OBJ模型""" - self.logger.info("开始转换OBJ模型") - converter = ConvertOBJ(self.config.output_dir) - converter.convert_grid_obj(grid_points) - - def post_process(self, successful_grid_points: Dict[tuple, pd.DataFrame], grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]): - """后处理:合并或复制处理结果""" - if len(successful_grid_points) < len(grid_points): - self.logger.warning( - f"有 {len(grid_points) - len(successful_grid_points)} 个网格处理失败," - f"将只合并成功处理的 {len(successful_grid_points)} 个网格" - ) - - if self.config.mode == "快拼模式": - self.merge_tif(successful_grid_points, self.config.produce_dem) - elif self.config.mode == "三维模式": - self.merge_tif(successful_grid_points, self.config.produce_dem) - # self.merge_ply(successful_grid_points) - # self.merge_obj(successful_grid_points, translations) - self.convert_obj(successful_grid_points) - else: - self.merge_tif(successful_grid_points, self.config.produce_dem) - # self.merge_ply(successful_grid_points) - # self.merge_obj(successful_grid_points, translations) - self.convert_obj(successful_grid_points) - - def process(self): - """执行完整的预处理流程""" - try: - self.extract_gps() - self.cluster() - # self.filter_isolated_points() - grid_points, translations = self.divide_grids() - # self.copy_images(grid_points) - # self.logger.info("预处理任务完成") - - # successful_grid_points = self.odm_monitor.process_all_grids( - # grid_points, self.config.produce_dem, self.config.accuracy) - - # successful_grid_points = self.odm_monitor.process_all_grids( - # grid_points, self.config.produce_dem) - successful_grid_points = grid_points - self.post_process(successful_grid_points, - grid_points, translations) - - except Exception as e: - self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) - raise - - -if __name__ == "__main__": - # 创建配置 - config = PreprocessConfig( - image_dir=r"E:\datasets\UAV\134\project\images", - output_dir=r"G:\ODM_output\134", - - cluster_eps=0.01, - cluster_min_samples=5, - - # 添加时间组重叠过滤参数 - time_group_overlap_threshold=0.7, - time_group_interval=timedelta(minutes=5), - - filter_distance_threshold=0.001, - filter_min_neighbors=6, - - filter_grid_size=0.001, - filter_dense_distance_threshold=10, - filter_time_threshold=timedelta(minutes=5), - - grid_size=800, - grid_overlap=0.05, - - - mode="重建模式", - produce_dem=False, - ) - - # 创建处理器并执行 - processor = ImagePreprocessor(config) - processor.process() diff --git a/post_pro/merge_tif.py b/post_pro/merge_tif.py index 613e535..6e004ef 100644 --- a/post_pro/merge_tif.py +++ b/post_pro/merge_tif.py @@ -5,6 +5,13 @@ from typing import Dict import pandas as pd import time import shutil +import rasterio +from rasterio.mask import mask +from rasterio.transform import Affine, rowcol +import fiona +from edt import edt +import numpy as np +import math class MergeTif: @@ -12,251 +19,271 @@ class MergeTif: self.output_dir = output_dir self.logger = logging.getLogger('UAV_Preprocess.MergeTif') - def merge_two_tifs(self, input_tif1: str, input_tif2: str, output_tif: str): - """合并两张TIF影像""" - try: - self.logger.info("开始合并TIF影像") - self.logger.info(f"输入影像1: {input_tif1}") - self.logger.info(f"输入影像2: {input_tif2}") - self.logger.info(f"输出影像: {output_tif}") - - # 检查输入文件是否存在 - if not os.path.exists(input_tif1) or not os.path.exists(input_tif2): - error_msg = "输入影像文件不存在" - self.logger.error(error_msg) - raise FileNotFoundError(error_msg) - - # 打开影像,检查投影是否一致 - datasets = [] - try: - for tif in [input_tif1, input_tif2]: - ds = gdal.Open(tif) - if ds is None: - error_msg = f"无法打开影像文件: {tif}" - self.logger.error(error_msg) - raise ValueError(error_msg) - datasets.append(ds) - - projections = [ds.GetProjection() for ds in datasets] - self.logger.debug(f"影像1投影: {projections[0]}") - self.logger.debug(f"影像2投影: {projections[1]}") - - # 检查投影是否一致 - if len(set(projections)) != 1: - error_msg = "影像的投影不一致,请先进行重投影!" - self.logger.error(error_msg) - raise ValueError(error_msg) - - # 如果输出文件已存在,先删除 - if os.path.exists(output_tif): - try: - os.remove(output_tif) - except Exception as e: - self.logger.warning(f"删除已存在的输出文件失败: {str(e)}") - # 生成一个新的输出文件名 - base, ext = os.path.splitext(output_tif) - output_tif = f"{base}_{int(time.time())}{ext}" - self.logger.info(f"使用新的输出文件名: {output_tif}") - - # 创建 GDAL Warp 选项 - warp_options = gdal.WarpOptions( - format="GTiff", - resampleAlg="average", - srcNodata=0, - dstNodata=0, - multithread=True - ) - - self.logger.info("开始执行影像拼接...") - result = gdal.Warp(output_tif, datasets, options=warp_options) - - if result is None: - error_msg = "影像拼接失败" - self.logger.error(error_msg) - raise RuntimeError(error_msg) - - # 获取输出影像的基本信息 - output_dataset = gdal.Open(output_tif) - if output_dataset: - width = output_dataset.RasterXSize - height = output_dataset.RasterYSize - bands = output_dataset.RasterCount - self.logger.info( - f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}") - output_dataset = None # 显式关闭数据集 - - self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}") - - finally: - # 确保所有数据集都被正确关闭 - for ds in datasets: - if ds: - ds = None - result = None - - except Exception as e: - self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True) - raise - - def merge_grid_tif(self, grid_points: Dict[tuple, pd.DataFrame], product_info: dict): - """合并指定产品的所有网格""" - product_name = product_info['name'] - product_path = product_info['path'] - filename_original = product_info['filename'] - filename = filename_original.replace(".original", "") - - self.logger.info(f"开始合并{product_name}") - - input_tif1, input_tif2 = None, None - merge_count = 0 - temp_files = [] - + def merge_orthophoto(self, grid_points: Dict[tuple, pd.DataFrame]): + """合并网格的正射影像""" try: + all_orthos_and_ortho_cuts = [] for grid_id, points in grid_points.items(): - grid_tif_original = os.path.join( + grid_ortho_dir = os.path.join( self.output_dir, f"grid_{grid_id[0]}_{grid_id[1]}", "project", - product_path, - filename_original + "odm_orthophoto", ) - grid_tif = os.path.join( - self.output_dir, - f"grid_{grid_id[0]}_{grid_id[1]}", - "project", - product_path, - filename - ) - if os.path.exists(grid_tif_original) and os.path.exists(grid_tif): - self.logger.info( - f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}存在: {grid_tif_original, grid_tif}") - # 如果文件大于600MB,则不使用original文件 - file_size_mb_original = os.path.getsize( - grid_tif_original) / (1024 * 1024) # 转换为MB - if file_size_mb_original > 600: - to_merge_tif = grid_tif - else: - to_merge_tif = grid_tif_original - elif os.path.exists(grid_tif_original) and not os.path.exists(grid_tif): - to_merge_tif = grid_tif_original - elif not os.path.exists(grid_tif_original) and os.path.exists(grid_tif): - to_merge_tif = grid_tif - else: - self.logger.warning( - f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}不存在: {grid_tif_original, grid_tif}") - continue + tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif") + tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg") + output_cut_tif = os.path.join( + grid_ortho_dir, "odm_orthophoto_cut.tif") + output_feathered_tif = os.path.join( + grid_ortho_dir, "odm_orthophoto_feathered.tif") - if input_tif1 is None: - input_tif1 = to_merge_tif - self.logger.info(f"设置第一个输入{product_name}: {input_tif1}") - else: - input_tif2 = to_merge_tif - # 生成带时间戳的临时输出文件名 - temp_output = os.path.join( - self.output_dir, - f"temp_merged_{int(time.time())}_{product_info['output']}" - ) + self.compute_mask_raster( + tif_path, tif_mask, output_cut_tif, blend_distance=20) + self.feather_raster( + tif_path, output_feathered_tif, blend_distance=20) + all_orthos_and_ortho_cuts.append( + [output_feathered_tif, output_cut_tif]) - self.logger.info( - f"开始合并{product_name}第 {merge_count + 1} 次:\n" - f"输入1: {input_tif1}\n" - f"输入2: {input_tif2}\n" - f"输出: {temp_output}" - ) - - self.merge_two_tifs(input_tif1, input_tif2, temp_output) - merge_count += 1 - - input_tif1 = temp_output - input_tif2 = None - temp_files.append(temp_output) - - final_output = os.path.join( - self.output_dir, product_info['output']) - shutil.copy2(input_tif1, final_output) - - # 清理所有临时文件 - for temp_file in temp_files: - try: - os.remove(temp_file) - except Exception as e: - self.logger.warning(f"删除临时文件失败: {str(e)}") - - self.logger.info( - f"{product_name}合并完成,共执行 {merge_count} 次合并," - f"最终输出文件: {final_output}" - ) - - except Exception as e: - self.logger.error( - f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True) - raise - - def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame], mode: str): - """合并所有产品(正射影像、DSM和DTM)""" - try: - products = [ - { - 'name': '正射影像', - 'path': 'odm_orthophoto', - 'filename': 'odm_orthophoto.original.tif', - 'output': 'orthophoto.tif' - }, - - ] - if mode == '三维模式': - products.append( - { - 'name': 'DSM', - 'path': 'odm_dem', - 'filename': 'dsm.original.tif', - 'output': 'dsm.tif' - } - ) - products.append( - { - 'name': 'DTM', - 'path': 'odm_dem', - 'filename': 'dtm.original.tif', - 'output': 'dtm.tif' - } - ) - - for product in products: - self.merge_grid_tif(grid_points, product) + orthophoto_vars = { + 'TILED': 'NO', + 'COMPRESS': False, + 'PREDICTOR': '1', + 'BIGTIFF': 'IF_SAFER', + 'BLOCKXSIZE': 512, + 'BLOCKYSIZE': 512, + 'NUM_THREADS': 15 + } + self.merge(all_orthos_and_ortho_cuts, os.path.join( + self.output_dir, "orthophoto.tif"), orthophoto_vars) self.logger.info("所有产品合并完成") except Exception as e: self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True) raise + def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False): + if not os.path.exists(input_raster): + print("Cannot mask raster, %s does not exist" % input_raster) + return -if __name__ == "__main__": - import sys - sys.path.append(os.path.dirname( - os.path.dirname(os.path.abspath(__file__)))) - from utils.logger import setup_logger - import pandas as pd + if not os.path.exists(vector_mask): + print("Cannot mask raster, %s does not exist" % vector_mask) + return - # 设置输出目录和日志 - output_dir = r"G:\ODM_output\1009" - setup_logger(output_dir) + print("Computing mask raster: %s" % output_raster) - # 构造测试用的grid_points字典 - # 假设我们有两个网格,每个网格包含一些GPS点的DataFrame - grid_points = { - (0, 0): pd.DataFrame({ - 'latitude': [39.9, 39.91], - 'longitude': [116.3, 116.31], - 'altitude': [100, 101] - }), - (0, 1): pd.DataFrame({ - 'latitude': [39.92, 39.93], - 'longitude': [116.32, 116.33], - 'altitude': [102, 103] - }) - } + with rasterio.open(input_raster, 'r') as rast: + with fiona.open(vector_mask) as src: + burn_features = src - # 创建MergeTif实例并执行合并 - merge_tif = MergeTif(output_dir) - merge_tif.merge_all_tifs(grid_points) + if only_max_coords_feature: + max_coords_count = 0 + max_coords_feature = None + for feature in src: + if feature is not None: + # No complex shapes + if len(feature['geometry']['coordinates'][0]) > max_coords_count: + max_coords_count = len( + feature['geometry']['coordinates'][0]) + max_coords_feature = feature + if max_coords_feature is not None: + burn_features = [max_coords_feature] + + shapes = [feature["geometry"] for feature in burn_features] + out_image, out_transform = mask(rast, shapes, nodata=0) + + if blend_distance > 0: + if out_image.shape[0] >= 4: + # alpha_band = rast.dataset_mask() + alpha_band = out_image[-1] + dist_t = edt(alpha_band, black_border=True, parallel=0) + dist_t[dist_t <= blend_distance] /= blend_distance + dist_t[dist_t > blend_distance] = 1 + np.multiply(alpha_band, dist_t, + out=alpha_band, casting="unsafe") + else: + print( + "%s does not have an alpha band, cannot blend cutline!" % input_raster) + + with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst: + dst.colorinterp = rast.colorinterp + dst.write(out_image) + + return output_raster + + def feather_raster(self, input_raster, output_raster, blend_distance=20): + if not os.path.exists(input_raster): + print("Cannot feather raster, %s does not exist" % input_raster) + return + + print("Computing feather raster: %s" % output_raster) + + with rasterio.open(input_raster, 'r') as rast: + out_image = rast.read() + if blend_distance > 0: + if out_image.shape[0] >= 4: + alpha_band = out_image[-1] + dist_t = edt(alpha_band, black_border=True, parallel=0) + dist_t[dist_t <= blend_distance] /= blend_distance + dist_t[dist_t > blend_distance] = 1 + np.multiply(alpha_band, dist_t, + out=alpha_band, casting="unsafe") + else: + print( + "%s does not have an alpha band, cannot feather raster!" % input_raster) + + with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst: + dst.colorinterp = rast.colorinterp + dst.write(out_image) + + return output_raster + + def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}): + """ + Based on https://github.com/mapbox/rio-merge-rgba/ + Merge orthophotos around cutlines using a blend buffer. + """ + inputs = [] + bounds = None + precision = 7 + + for o, c in input_ortho_and_ortho_cuts: + inputs.append((o, c)) + + with rasterio.open(inputs[0][0]) as first: + res = first.res + dtype = first.dtypes[0] + profile = first.profile + num_bands = first.meta['count'] - 1 # minus alpha + colorinterp = first.colorinterp + + print("%s valid orthophoto rasters to merge" % len(inputs)) + sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs] + + # scan input files. + # while we're at it, validate assumptions about inputs + xs = [] + ys = [] + for src, _ in sources: + left, bottom, right, top = src.bounds + xs.extend([left, right]) + ys.extend([bottom, top]) + if src.profile["count"] < 2: + raise ValueError("Inputs must be at least 2-band rasters") + dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys) + print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n)) + + output_transform = Affine.translation(dst_w, dst_n) + output_transform *= Affine.scale(res[0], -res[1]) + + # Compute output array shape. We guarantee it will cover the output + # bounds completely. + output_width = int(math.ceil((dst_e - dst_w) / res[0])) + output_height = int(math.ceil((dst_n - dst_s) / res[1])) + + # Adjust bounds to fit. + dst_e, dst_s = output_transform * (output_width, output_height) + print("Output width: %d, height: %d" % + (output_width, output_height)) + print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n)) + + profile["transform"] = output_transform + profile["height"] = output_height + profile["width"] = output_width + profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES' + profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512) + profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512) + profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW') + profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2') + profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER') + profile.update() + + # create destination file + with rasterio.open(output_orthophoto, "w", **profile) as dstrast: + dstrast.colorinterp = colorinterp + for idx, dst_window in dstrast.block_windows(): + left, bottom, right, top = dstrast.window_bounds(dst_window) + + blocksize = dst_window.width + dst_rows, dst_cols = (dst_window.height, dst_window.width) + + # initialize array destined for the block + dst_count = first.count + dst_shape = (dst_count, dst_rows, dst_cols) + + dstarr = np.zeros(dst_shape, dtype=dtype) + + # First pass, write all rasters naively without blending + for src, _ in sources: + src_window = tuple(zip(rowcol( + src.transform, left, top, op=round, precision=precision + ), rowcol( + src.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = src.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + # pixels without data yet are available to write + write_region = np.logical_and( + (dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata + ) + np.copyto(dstarr, temp, where=write_region) + + # check if dest has any nodata pixels available + if np.count_nonzero(dstarr[-1]) == blocksize: + break + + # Second pass, write all feathered rasters + # blending the edges + for src, _ in sources: + src_window = tuple(zip(rowcol( + src.transform, left, top, op=round, precision=precision + ), rowcol( + src.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = src.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + where = temp[-1] != 0 + for b in range(0, num_bands): + blended = temp[-1] / 255.0 * temp[b] + \ + (1 - temp[-1] / 255.0) * dstarr[b] + np.copyto(dstarr[b], blended, + casting='unsafe', where=where) + dstarr[-1][where] = 255.0 + + # check if dest has any nodata pixels available + if np.count_nonzero(dstarr[-1]) == blocksize: + break + + # Third pass, write cut rasters + # blending the cutlines + for _, cut in sources: + src_window = tuple(zip(rowcol( + cut.transform, left, top, op=round, precision=precision + ), rowcol( + cut.transform, right, bottom, op=round, precision=precision + ))) + + temp = np.zeros(dst_shape, dtype=dtype) + temp = cut.read( + out=temp, window=src_window, boundless=True, masked=False + ) + + # For each band, average alpha values between + # destination raster and cut raster + for b in range(0, num_bands): + blended = temp[-1] / 255.0 * temp[b] + \ + (1 - temp[-1] / 255.0) * dstarr[b] + np.copyto(dstarr[b], blended, + casting='unsafe', where=temp[-1] != 0) + + dstrast.write(dstarr, window=dst_window) + + return output_orthophoto diff --git a/utils/odm_monitor.py b/utils/odm_monitor.py index 8a5cf2a..17d0e2e 100644 --- a/utils/odm_monitor.py +++ b/utils/odm_monitor.py @@ -157,8 +157,8 @@ class ODMProcessMonitor: f"--use-exif " f"--use-hybrid-bundle-adjustment " f"--optimize-disk-space " - # f"--3d-tiles " - # f"--feature-type sift " + f"--orthophoto-cutline " + f"--feature-type sift " # f"--orthophoto-resolution 8 " ) if accuracy == "high": @@ -178,7 +178,8 @@ class ODMProcessMonitor: # 根据是否使用lowest quality添加参数 if use_lowest_quality: - docker_command += f"--feature-quality lowest " + # docker_command += f"--feature-quality lowest " + pass if self.mode == "快拼模式": docker_command += (