first commit

This commit is contained in:
weixin_46229132 2025-06-12 22:11:27 +08:00
commit e8178d191e
18 changed files with 1982 additions and 0 deletions

7
.gitignore vendored Normal file
View File

@ -0,0 +1,7 @@
# 忽略所有__pycache__目录
**/__pycache__/
*.pyc
*.pyo
*.pyd
test/

15
README.md Normal file
View File

@ -0,0 +1,15 @@
# ODM_Pro
无人机三维重建。
## Install
```bash
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
```
## 制作exe
1. odm_monitor.py docker容器名要记得改成uav
2. osgconv的compressed参数去掉会报错
3. 运行指令`pyinstaller main.spec`
4. 测试执行指令`uav.exe --image_dir E:\datasets\UAV\134\project\images --output_dir G:\ODM_output\134`

163
app_plugin.py Normal file
View File

@ -0,0 +1,163 @@
import os
import shutil
from dataclasses import dataclass
from typing import Dict, Tuple
import pandas as pd
from filter.cluster_filter import GPSCluster
from utils.directory_manager import DirectoryManager
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from post_pro.conv_obj2 import ConvertOBJ
@dataclass
class ProcessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
mode: str = "三维模式"
# ODM参数
feature_type: str = "sift"
orthophoto_resolution: float = 5
class ODM_Plugin:
def __init__(self, config):
self.config = config
# 初始化目录管理器
self.dir_manager = DirectoryManager(config)
# 清理并重建输出目录
self.dir_manager.clean_output_dir()
self.dir_manager.setup_output_dirs()
# 检查磁盘空间
self.dir_manager.check_disk_space()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = pd.DataFrame(columns=["file", "lat", "lon"])
self.odm_monitor = ODMProcessMonitor(
config.output_dir, mode=config.mode, config=config)
self.visualizer = FilterVisualizer(config.output_dir)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points,
eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering")
def divide_grids(self) -> Dict[tuple, pd.DataFrame]:
"""划分网格
Returns:
- grid_points: 网格点数据字典
- translations: 网格平移量字典
"""
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
grid_size=self.config.grid_size,
output_dir=self.config.output_dir
)
grids, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points
)
grid_divider.visualize_grids(self.gps_points, grids)
if len(grids) >= 20:
self.logger.warning("网格数量已超过20, 需要人工调整分区")
return grid_points
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
"""复制图像到目标文件夹"""
self.logger.info("开始复制图像文件")
for grid_id, points in grid_points.items():
output_dir = os.path.join(
self.config.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"images"
)
os.makedirs(output_dir, exist_ok=True)
for point in points:
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
def merge_tif(self, grid_lt):
"""合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir)
merger.merge_orthophoto(grid_lt)
def convert_obj(self, grid_lt):
"""转换OBJ模型"""
self.logger.info("开始转换OBJ模型")
converter = ConvertOBJ(self.config.output_dir)
converter.convert_grid_obj(grid_lt)
def post_process(self, successful_grid_lt: list, grid_points: Dict[tuple, pd.DataFrame]):
"""后处理:合并或复制处理结果"""
if len(successful_grid_lt) < len(grid_points):
self.logger.warning(
f"{len(grid_points) - len(successful_grid_lt)} 个网格处理失败,"
f"将只合并成功处理的 {len(successful_grid_lt)} 个网格"
)
self.merge_tif(successful_grid_lt)
if self.config.mode == "三维模式":
self.convert_obj(successful_grid_lt)
else:
pass
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.logger.info("预处理任务完成")
successful_grid_lt = self.odm_monitor.process_all_grids(
grid_points)
self.post_process(successful_grid_lt, grid_points)
self.logger.info("重建任务完成")
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise

24
check_version.py Normal file
View File

@ -0,0 +1,24 @@
import importlib
import pkg_resources
def check_versions(requirements_file):
with open(requirements_file, 'r', encoding='utf-8') as f:
packages = [line.strip() for line in f if line.strip() and not line.startswith('#') and not line.startswith('//')]
print(f"{'Package':<20} {'Installed Version':<20}")
print('-' * 40)
for pkg in packages:
try:
# 有些包名和import名不同优先用pkg_resources
version = pkg_resources.get_distribution(pkg).version
print(f"{pkg:<20} {version:<20}")
except Exception:
try:
mod = importlib.import_module(pkg)
version = getattr(mod, '__version__', 'Unknown')
print(f"{pkg:<20} {version:<20}")
except Exception:
print(f"{pkg:<20} {'Not Installed':<20}")
if __name__ == "__main__":
check_versions("requirements.txt")

77
filter/cluster_filter.py Normal file
View File

@ -0,0 +1,77 @@
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import os
import logging
class GPSCluster:
def __init__(self, gps_points, eps=0.01, min_samples=3):
"""
初始化GPS聚类器
参数:
eps: DBSCAN的邻域半径参数
min_samples: DBSCAN的最小样本数参数
"""
self.eps = eps
self.min_samples = min_samples
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler()
self.gps_points = gps_points
self.logger = logging.getLogger('UAV_Preprocess.GPSCluster')
def fit(self):
"""
对GPS点进行聚类只保留最大的类
参数:
gps_points: 包含'lat''lon'列的DataFrame
返回:
带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1
"""
self.logger.info("开始聚类")
# 提取经纬度数据
X = self.gps_points[["lon", "lat"]].values
# # 数据标准化
# X_scaled = self.scaler.fit_transform(X)
# 执行DBSCAN聚类
labels = self.dbscan.fit_predict(X)
# 找出最大类的标签(排除噪声点-1
unique_labels = [l for l in set(labels) if l != -1]
if unique_labels: # 如果有聚类
label_counts = [(l, sum(labels == l)) for l in unique_labels]
largest_label = max(label_counts, key=lambda x: x[1])[0]
# 将最大类标记为1其他都标记为-1
new_labels = (labels == largest_label).astype(int)
new_labels[new_labels == 0] = -1
else: # 如果没有聚类,全部标记为-1
new_labels = labels
# 将聚类结果添加到原始数据中
result_df = self.gps_points.copy()
result_df["cluster"] = new_labels
return result_df
def get_cluster_stats(self, clustered_points):
"""
获取聚类统计信息
参数:
clustered_points: 带有聚类标签的DataFrame
返回:
聚类统计信息的字典
"""
main_cluster = clustered_points[clustered_points["cluster"] == 1]
noise_cluster = clustered_points[clustered_points["cluster"] == -1]
self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点,"
f"噪声点 {len(noise_cluster)}")
return main_cluster

54
main.py Normal file
View File

@ -0,0 +1,54 @@
import argparse
from app_plugin import ProcessConfig, ODM_Plugin
def parse_args():
parser = argparse.ArgumentParser(description='ODM预处理工具')
# 必需参数
parser.add_argument('--image_dir', required=True, help='输入图片目录路径')
parser.add_argument('--output_dir', required=True, help='输出目录路径')
# parser.add_argument('--image_dir', default=r'G:\UAV_data\test_31\project\images', help='输入图片目录路径')
# parser.add_argument('--output_dir', default=r'G:\ODM_output\test_31', help='输出目录路径')
# 可选参数
parser.add_argument('--mode', default='三维模式',
choices=['快拼模式', '三维模式'], help='处理模式')
parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)')
parser.add_argument('--grid_overlap', type=float,
default=0.1, help='网格重叠率')
# ODM参数
parser.add_argument('--feature_type', default='sift', help='特征类型')
parser.add_argument('--orthophoto_resolution',
type=float, default=5, help='正射影像分辨率')
args = parser.parse_args()
return args
def main():
args = parse_args()
# 创建配置
config = ProcessConfig(
image_dir=args.image_dir,
output_dir=args.output_dir,
mode=args.mode,
grid_size=args.grid_size,
grid_overlap=args.grid_overlap,
feature_type=args.feature_type,
orthophoto_resolution=args.orthophoto_resolution,
# 其他参数使用默认值
cluster_eps=0.01,
cluster_min_samples=5,
)
# 创建处理器并执行
processor = ODM_Plugin(config)
processor.process()
if __name__ == '__main__':
main()

47
main.spec Normal file
View File

@ -0,0 +1,47 @@
# -*- mode: python ; coding: utf-8 -*-
a = Analysis(
['main.py'],
pathex=[],
binaries=[],
datas=[],
hiddenimports=[
'rasterio.sample',
'rasterio.vrt',
'rasterio._io',
'rasterio.enums',
'rasterio.errors',
'rasterio.transform',
'rasterio.crs',
'rasterio.warp'
],
hookspath=[],
hooksconfig={},
runtime_hooks=[],
excludes=[],
noarchive=False,
optimize=0,
)
pyz = PYZ(a.pure)
exe = EXE(
pyz,
a.scripts,
a.binaries,
a.datas,
[],
name='uav',
debug=False,
bootloader_ignore_signals=False,
strip=False,
upx=True,
upx_exclude=[],
runtime_tmpdir=None,
console=True,
disable_windowed_traceback=False,
argv_emulation=False,
target_arch=None,
codesign_identity=None,
entitlements_file=None,
)

265
post_pro/conv_obj.py Normal file
View File

@ -0,0 +1,265 @@
import os
import subprocess
import json
import shutil
import logging
from pyproj import Transformer
import cv2
class ConvertOBJ:
def __init__(self, output_dir: str):
self.output_dir = output_dir
# 用于存储所有grid的UTM范围
self.ref_east = float('inf')
self.ref_north = float('inf')
# 初始化UTM到WGS84的转换器
self.transformer = Transformer.from_crs(
"EPSG:32649", "EPSG:4326", always_xy=True)
self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
def convert_grid_obj(self, grid_lt):
"""转换每个网格的OBJ文件为OSGB格式"""
os.makedirs(os.path.join(self.output_dir,
"osgb", "Data"), exist_ok=True)
# 以第一个grid的UTM坐标作为参照系
first_grid_id = grid_lt[0]
first_grid_dir = os.path.join(
self.output_dir,
f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
"project"
)
log_file = os.path.join(
first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
self.ref_east, self.ref_north = self.read_utm_offset(log_file)
for grid_id in grid_lt:
try:
self._convert_single_grid(grid_id)
except Exception as e:
self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
self._create_merged_metadata()
def _convert_single_grid(self, grid_id):
"""转换单个网格的OBJ文件"""
# 构建相关路径
grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
project_dir = os.path.join(self.output_dir, grid_name, "project")
texturing_dir = os.path.join(project_dir, "odm_texturing")
texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
split_obj_dir = os.path.join(texturing_dst_dir, "split_obj")
log_file = os.path.join(
project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
os.makedirs(texturing_dst_dir, exist_ok=True)
# 修改obj文件z坐标的值
min_25d_z = self.get_min_z_from_obj(os.path.join(
project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
self.modify_z_in_obj(texturing_dir, min_25d_z)
# 在新文件夹下利用UTM偏移量修改obj文件顶点坐标纹理文件下采样
utm_offset = self.read_utm_offset(log_file)
modified_obj = self.modify_obj_coordinates(
texturing_dir, texturing_dst_dir, utm_offset)
self.downsample_texture(texturing_dir, texturing_dst_dir)
# 将obj文件进行切片
self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件")
os.makedirs(split_obj_dir)
cmd = (
f"D:\\software\\Obj2Tiles\\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 2 "
f"{modified_obj} {split_obj_dir}"
)
subprocess.run(cmd, check=True)
# 执行格式转换Linux下osgconv有问题记得注释掉
self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
# 先获取split_obj_dir下的所有obj文件
obj_lod_dir = os.path.join(split_obj_dir, "LOD-0")
obj_files = [f for f in os.listdir(
obj_lod_dir) if f.endswith('.obj')]
for obj_file in obj_files:
obj_path = os.path.join(obj_lod_dir, obj_file)
osgb_file = os.path.splitext(obj_file)[0] + '.osgb'
osgb_path = os.path.join(split_obj_dir, osgb_file)
# 执行 osgconv 命令
subprocess.run(['osgconv', obj_path, osgb_path], check=True)
# 创建OSGB目录结构复制文件
osgb_base_dir = os.path.join(self.output_dir, "osgb")
data_dir = os.path.join(osgb_base_dir, "Data")
for obj_file in obj_files:
obj_file_name = os.path.splitext(obj_file)[0]
tile_dirs = os.path.join(
data_dir, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}")
os.makedirs(tile_dirs, exist_ok=True)
shutil.copy2(os.path.join(
split_obj_dir, obj_file_name+".osgb"), tile_dirs)
os.rename(os.path.join(tile_dirs, obj_file_name+".osgb"),
os.path.join(tile_dirs, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}.osgb"))
def _create_merged_metadata(self):
"""创建合并后的metadata.xml文件"""
# 转换为WGS84经纬度
center_lon, center_lat = self.transformer.transform(
self.ref_east, self.ref_north)
metadata_content = f"""<?xml version="1.0" encoding="utf-8"?>
<ModelMetadata version="1">
<SRS>EPSG:4326</SRS>
<SRSOrigin>{center_lon},{center_lat},0</SRSOrigin>
<Texture>
<ColorSource>Visible</ColorSource>
</Texture>
</ModelMetadata>"""
metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
with open(metadata_file, 'w', encoding='utf-8') as f:
f.write(metadata_content)
def read_utm_offset(self, log_file: str) -> tuple:
"""读取UTM偏移量"""
try:
east_offset = None
north_offset = None
with open(log_file, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
if 'utm_north_offset' in line and i + 1 < len(lines):
north_offset = float(lines[i + 1].strip())
elif 'utm_east_offset' in line and i + 1 < len(lines):
east_offset = float(lines[i + 1].strip())
if east_offset is None or north_offset is None:
raise ValueError("未找到UTM偏移量")
return east_offset, north_offset
except Exception as e:
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
raise
def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
"""修改obj文件中的顶点坐标使用相对坐标系"""
obj_file = os.path.join(
texturing_dir, "odm_textured_model_modified.obj")
obj_dst_file = os.path.join(
texturing_dst_dir, "odm_textured_model_geo_utm.obj")
if not os.path.exists(obj_file):
raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
east_offset, north_offset = utm_offset
self.logger.info(
f"UTM坐标偏移{east_offset - self.ref_east}, {north_offset - self.ref_north}")
try:
with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '):
# 处理顶点坐标行
parts = line.strip().split()
# 使用相对于整体最小UTM坐标的偏移
x = float(parts[1]) + (east_offset - self.ref_east)
y = float(parts[2]) + (north_offset - self.ref_north)
z = float(parts[3])
f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
elif line.startswith('vn '): # 处理法线向量
parts = line.split()
nx = float(parts[1])
ny = float(parts[2])
nz = float(parts[3])
# 同步反转法线的 Y 轴
new_line = f"vn {nx} {nz} {-ny}\n"
f_out.write(new_line)
else:
# 其他行直接写入
f_out.write(line)
return obj_dst_file
except Exception as e:
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
raise
def downsample_texture(self, src_dir: str, dst_dir: str):
"""复制并重命名纹理文件对大于100MB的文件进行多次下采样直到文件小于100MB
Args:
src_dir: 源纹理目录
dst_dir: 目标纹理目录
"""
for file in os.listdir(src_dir):
if file.lower().endswith(('.png')):
src_path = os.path.join(src_dir, file)
dst_path = os.path.join(dst_dir, file)
# 检查文件大小(以字节为单位)
file_size = os.path.getsize(src_path)
if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB直接复制
shutil.copy2(src_path, dst_path)
else:
# 文件大于100MB进行下采样
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
if_first_ds = True
while file_size > 100 * 1024 * 1024: # 大于100MB
self.logger.info(f"纹理文件 {file} 大于100MB进行下采样")
if if_first_ds:
# 计算新的尺寸长宽各变为1/4
new_size = (img.shape[1] // 4,
img.shape[0] // 4) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
if_first_ds = False
else:
# 计算新的尺寸长宽各变为1/2
new_size = (img.shape[1] // 2,
img.shape[0] // 2) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
# 更新文件路径为下采样后的路径
cv2.imwrite(dst_path, resized_img, [
cv2.IMWRITE_PNG_COMPRESSION, 9])
# 更新文件大小和图像
file_size = os.path.getsize(dst_path)
img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
self.logger.info(
f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
def get_min_z_from_obj(self, file_path):
min_z = float('inf') # 初始值设为无穷大
with open(file_path, 'r') as obj_file:
for line in obj_file:
# 检查每一行是否是顶点定义(以 'v ' 开头)
if line.startswith('v '):
# 获取顶点坐标
parts = line.split()
# 将z值转换为浮动数字
z = float(parts[3])
# 更新最小z值
if z < min_z:
min_z = z
return min_z
def modify_z_in_obj(self, texturing_dir, min_25d_z):
obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
output_file = os.path.join(
texturing_dir, 'odm_textured_model_modified.obj')
with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '): # 顶点坐标行
parts = line.strip().split()
x = float(parts[1])
y = float(parts[2])
z = float(parts[3])
if z < min_25d_z:
z = min_25d_z
f_out.write(f"v {x} {y} {z}\n")
else:
f_out.write(line)

253
post_pro/conv_obj2.py Normal file
View File

@ -0,0 +1,253 @@
import os
import subprocess
import json
import shutil
import logging
from pyproj import Transformer
import cv2
class ConvertOBJ:
def __init__(self, output_dir: str):
self.output_dir = output_dir
# 用于存储所有grid的UTM范围
self.ref_east = float('inf')
self.ref_north = float('inf')
# 初始化UTM到WGS84的转换器
self.transformer = Transformer.from_crs(
"EPSG:32649", "EPSG:4326", always_xy=True)
self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
def convert_grid_obj(self, grid_lt):
"""转换每个网格的OBJ文件为OSGB格式"""
os.makedirs(os.path.join(self.output_dir,
"osgb", "Data"), exist_ok=True)
# 以第一个grid的UTM坐标作为参照系
first_grid_id = grid_lt[0]
first_grid_dir = os.path.join(
self.output_dir,
f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
"project"
)
log_file = os.path.join(
first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
self.ref_east, self.ref_north = self.read_utm_offset(log_file)
for grid_id in grid_lt:
try:
self._convert_single_grid(grid_id)
except Exception as e:
self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
self._create_merged_metadata()
def _convert_single_grid(self, grid_id):
"""转换单个网格的OBJ文件"""
# 构建相关路径
grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
project_dir = os.path.join(self.output_dir, grid_name, "project")
texturing_dir = os.path.join(project_dir, "odm_texturing")
texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
opensfm_dir = os.path.join(project_dir, "opensfm")
log_file = os.path.join(
project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
os.makedirs(texturing_dst_dir, exist_ok=True)
# 修改obj文件z坐标的值
min_25d_z = self.get_min_z_from_obj(os.path.join(
project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
self.modify_z_in_obj(texturing_dir, min_25d_z)
# 在新文件夹下利用UTM偏移量修改obj文件顶点坐标纹理文件下采样
utm_offset = self.read_utm_offset(log_file)
modified_obj = self.modify_obj_coordinates(
texturing_dir, texturing_dst_dir, utm_offset)
self.downsample_texture(texturing_dir, texturing_dst_dir)
# 执行格式转换Linux下osgconv有问题记得注释掉
self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
output_osgb = os.path.join(texturing_dst_dir, "Tile.osgb")
cmd = (
f"osgconv {modified_obj} {output_osgb} "
f"--smooth --fix-transparency "
)
self.logger.info(f"执行osgconv命令{cmd}")
try:
subprocess.run(cmd, shell=True, check=True, cwd=texturing_dir)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"OSGB转换失败: {str(e)}")
# 创建OSGB目录结构复制文件
osgb_base_dir = os.path.join(self.output_dir, "osgb")
data_dir = os.path.join(osgb_base_dir, "Data")
tile_dir = os.path.join(data_dir, f"Tile_{grid_id[0]}_{grid_id[1]}")
os.makedirs(tile_dir, exist_ok=True)
target_osgb = os.path.join(
tile_dir, f"Tile_{grid_id[0]}_{grid_id[1]}.osgb")
shutil.copy2(output_osgb, target_osgb)
def _create_merged_metadata(self):
"""创建合并后的metadata.xml文件"""
# 转换为WGS84经纬度
center_lon, center_lat = self.transformer.transform(
self.ref_east, self.ref_north)
metadata_content = f"""<?xml version="1.0" encoding="utf-8"?>
<ModelMetadata version="1">
<SRS>EPSG:4326</SRS>
<SRSOrigin>{center_lon},{center_lat},0</SRSOrigin>
<Texture>
<ColorSource>Visible</ColorSource>
</Texture>
</ModelMetadata>"""
metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
with open(metadata_file, 'w', encoding='utf-8') as f:
f.write(metadata_content)
def read_utm_offset(self, log_file: str) -> tuple:
"""读取UTM偏移量"""
try:
east_offset = None
north_offset = None
with open(log_file, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
if 'utm_north_offset' in line and i + 1 < len(lines):
north_offset = float(lines[i + 1].strip())
elif 'utm_east_offset' in line and i + 1 < len(lines):
east_offset = float(lines[i + 1].strip())
if east_offset is None or north_offset is None:
raise ValueError("未找到UTM偏移量")
return east_offset, north_offset
except Exception as e:
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
raise
def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
"""修改obj文件中的顶点坐标使用相对坐标系"""
obj_file = os.path.join(
texturing_dir, "odm_textured_model_modified.obj")
obj_dst_file = os.path.join(
texturing_dst_dir, "odm_textured_model_geo_utm.obj")
if not os.path.exists(obj_file):
raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
east_offset, north_offset = utm_offset
self.logger.info(
f"UTM坐标偏移{east_offset - self.ref_east}, {north_offset - self.ref_north}")
try:
with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '):
# 处理顶点坐标行
parts = line.strip().split()
# 使用相对于整体最小UTM坐标的偏移
x = float(parts[1]) + (east_offset - self.ref_east)
y = float(parts[2]) + (north_offset - self.ref_north)
z = float(parts[3])
f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
elif line.startswith('vn '): # 处理法线向量
parts = line.split()
nx = float(parts[1])
ny = float(parts[2])
nz = float(parts[3])
# 同步反转法线的 Y 轴
new_line = f"vn {nx} {nz} {-ny}\n"
f_out.write(new_line)
else:
# 其他行直接写入
f_out.write(line)
return obj_dst_file
except Exception as e:
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
raise
def downsample_texture(self, src_dir: str, dst_dir: str):
"""复制并重命名纹理文件对大于100MB的文件进行多次下采样直到文件小于100MB
Args:
src_dir: 源纹理目录
dst_dir: 目标纹理目录
"""
for file in os.listdir(src_dir):
if file.lower().endswith(('.png')):
src_path = os.path.join(src_dir, file)
dst_path = os.path.join(dst_dir, file)
# 检查文件大小(以字节为单位)
file_size = os.path.getsize(src_path)
if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB直接复制
shutil.copy2(src_path, dst_path)
else:
# 文件大于100MB进行下采样
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
if_first_ds = True
while file_size > 100 * 1024 * 1024: # 大于100MB
self.logger.info(f"纹理文件 {file} 大于100MB进行下采样")
if if_first_ds:
# 计算新的尺寸长宽各变为1/4
new_size = (img.shape[1] // 4,
img.shape[0] // 4) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
if_first_ds = False
else:
# 计算新的尺寸长宽各变为1/2
new_size = (img.shape[1] // 2,
img.shape[0] // 2) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
# 更新文件路径为下采样后的路径
cv2.imwrite(dst_path, resized_img, [
cv2.IMWRITE_PNG_COMPRESSION, 9])
# 更新文件大小和图像
file_size = os.path.getsize(dst_path)
img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
self.logger.info(
f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
def get_min_z_from_obj(self, file_path):
min_z = float('inf') # 初始值设为无穷大
with open(file_path, 'r') as obj_file:
for line in obj_file:
# 检查每一行是否是顶点定义(以 'v ' 开头)
if line.startswith('v '):
# 获取顶点坐标
parts = line.split()
# 将z值转换为浮动数字
z = float(parts[3])
# 更新最小z值
if z < min_z:
min_z = z
return min_z
def modify_z_in_obj(self, texturing_dir, min_25d_z):
obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
output_file = os.path.join(
texturing_dir, 'odm_textured_model_modified.obj')
with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '): # 顶点坐标行
parts = line.strip().split()
x = float(parts[1])
y = float(parts[2])
z = float(parts[3])
if z < min_25d_z:
z = min_25d_z
f_out.write(f"v {x} {y} {z}\n")
else:
f_out.write(line)

285
post_pro/merge_tif.py Normal file
View File

@ -0,0 +1,285 @@
import logging
import os
from typing import Dict
import rasterio
from rasterio.mask import mask
from rasterio.transform import Affine, rowcol
import fiona
from edt import edt
import numpy as np
import math
class MergeTif:
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergeTif')
def merge_orthophoto(self, grid_lt):
"""合并网格的正射影像"""
try:
all_orthos_and_ortho_cuts = []
for grid_id in grid_lt:
grid_ortho_dir = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"odm_orthophoto",
)
tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif")
tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg")
output_cut_tif = os.path.join(
grid_ortho_dir, "odm_orthophoto_cut.tif")
output_feathered_tif = os.path.join(
grid_ortho_dir, "odm_orthophoto_feathered.tif")
self.compute_mask_raster(
tif_path, tif_mask, output_cut_tif, blend_distance=20)
self.feather_raster(
tif_path, output_feathered_tif, blend_distance=20)
all_orthos_and_ortho_cuts.append(
[output_feathered_tif, output_cut_tif])
orthophoto_vars = {
'TILED': 'NO',
'COMPRESS': False,
'PREDICTOR': '1',
'BIGTIFF': 'IF_SAFER',
'BLOCKXSIZE': 512,
'BLOCKYSIZE': 512,
'NUM_THREADS': 15
}
self.merge(all_orthos_and_ortho_cuts, os.path.join(
self.output_dir, "orthophoto.tif"), orthophoto_vars)
self.logger.info("所有产品合并完成")
except Exception as e:
self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True)
raise
def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False):
if not os.path.exists(input_raster):
print("Cannot mask raster, %s does not exist" % input_raster)
return
if not os.path.exists(vector_mask):
print("Cannot mask raster, %s does not exist" % vector_mask)
return
print("Computing mask raster: %s" % output_raster)
with rasterio.open(input_raster, 'r') as rast:
with fiona.open(vector_mask) as src:
burn_features = src
if only_max_coords_feature:
max_coords_count = 0
max_coords_feature = None
for feature in src:
if feature is not None:
# No complex shapes
if len(feature['geometry']['coordinates'][0]) > max_coords_count:
max_coords_count = len(
feature['geometry']['coordinates'][0])
max_coords_feature = feature
if max_coords_feature is not None:
burn_features = [max_coords_feature]
shapes = [feature["geometry"] for feature in burn_features]
out_image, out_transform = mask(rast, shapes, nodata=0)
if blend_distance > 0:
if out_image.shape[0] >= 4:
# alpha_band = rast.dataset_mask()
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot blend cutline!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def feather_raster(self, input_raster, output_raster, blend_distance=20):
if not os.path.exists(input_raster):
print("Cannot feather raster, %s does not exist" % input_raster)
return
print("Computing feather raster: %s" % output_raster)
with rasterio.open(input_raster, 'r') as rast:
out_image = rast.read()
if blend_distance > 0:
if out_image.shape[0] >= 4:
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot feather raster!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}):
"""
Based on https://github.com/mapbox/rio-merge-rgba/
Merge orthophotos around cutlines using a blend buffer.
"""
inputs = []
bounds = None
precision = 7
for o, c in input_ortho_and_ortho_cuts:
inputs.append((o, c))
with rasterio.open(inputs[0][0]) as first:
res = first.res
dtype = first.dtypes[0]
profile = first.profile
num_bands = first.meta['count'] - 1 # minus alpha
colorinterp = first.colorinterp
print("%s valid orthophoto rasters to merge" % len(inputs))
sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs]
# scan input files.
# while we're at it, validate assumptions about inputs
xs = []
ys = []
for src, _ in sources:
left, bottom, right, top = src.bounds
xs.extend([left, right])
ys.extend([bottom, top])
if src.profile["count"] < 2:
raise ValueError("Inputs must be at least 2-band rasters")
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
output_transform = Affine.translation(dst_w, dst_n)
output_transform *= Affine.scale(res[0], -res[1])
# Compute output array shape. We guarantee it will cover the output
# bounds completely.
output_width = int(math.ceil((dst_e - dst_w) / res[0]))
output_height = int(math.ceil((dst_n - dst_s) / res[1]))
# Adjust bounds to fit.
dst_e, dst_s = output_transform * (output_width, output_height)
print("Output width: %d, height: %d" %
(output_width, output_height))
print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
profile["transform"] = output_transform
profile["height"] = output_height
profile["width"] = output_width
profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES'
profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512)
profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512)
profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW')
profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2')
profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER')
profile.update()
# create destination file
with rasterio.open(output_orthophoto, "w", **profile) as dstrast:
dstrast.colorinterp = colorinterp
for idx, dst_window in dstrast.block_windows():
left, bottom, right, top = dstrast.window_bounds(dst_window)
blocksize = dst_window.width
dst_rows, dst_cols = (dst_window.height, dst_window.width)
# initialize array destined for the block
dst_count = first.count
dst_shape = (dst_count, dst_rows, dst_cols)
dstarr = np.zeros(dst_shape, dtype=dtype)
# First pass, write all rasters naively without blending
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
# pixels without data yet are available to write
write_region = np.logical_and(
(dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata
)
np.copyto(dstarr, temp, where=write_region)
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Second pass, write all feathered rasters
# blending the edges
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
where = temp[-1] != 0
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=where)
dstarr[-1][where] = 255.0
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Third pass, write cut rasters
# blending the cutlines
for _, cut in sources:
src_window = tuple(zip(rowcol(
cut.transform, left, top, op=round, precision=precision
), rowcol(
cut.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = cut.read(
out=temp, window=src_window, boundless=True, masked=False
)
# For each band, average alpha values between
# destination raster and cut raster
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=temp[-1] != 0)
dstrast.write(dstarr, window=dst_window)
return output_orthophoto

15
requirements.txt Normal file
View File

@ -0,0 +1,15 @@
numpy==1.26.4
pandas==2.2.3
scikit-learn==1.6.1
matplotlib==3.10.0
piexif==1.1.3
geopy==2.4.1
psutil==6.1.1
docker==7.1.0
tqdm==4.66.5
pyproj==3.7.0
rasterio==1.4.3
edt==3.0.0
opencv-python==4.11.0.86
fiona==1.9.5
pyinstaller==6.14.1

33
trans_orthophoto.py Normal file
View File

@ -0,0 +1,33 @@
import logging
import argparse
from osgeo import gdal
class TransOrthophoto:
def __init__(self):
self.logger = logging.getLogger('UAV_Preprocess.TransOrthophoto')
logging.basicConfig(level=logging.INFO,
format='%(asctime)s %(levelname)s %(message)s')
def trans_to_epsg4326(self, input_img, output_img):
# 目标坐标系
dst_srs = 'EPSG:4326'
# 使用 gdal.Warp 进行重投影
gdal.Warp(
destNameOrDestDS=output_img,
srcDSOrSrcDSTab=input_img,
dstSRS=dst_srs,
format='GTiff',
resampleAlg='near' # 最近邻插值
)
self.logger.info(f"文件已成功重投影: {output_img}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="正射影像投影转换为EPSG:4326 (GDAL实现)")
parser.add_argument('--input_img', required=True, help='输入影像路径')
parser.add_argument('--output_img', required=True, help='输出影像路径')
args = parser.parse_args()
trans = TransOrthophoto()
trans.trans_to_epsg4326(args.input_img, args.output_img)

View File

@ -0,0 +1,81 @@
import os
import shutil
import psutil
class DirectoryManager:
def __init__(self, config):
"""
初始化目录管理器
Args:
config: 配置对象包含输入和输出目录等信息
"""
self.config = config
def clean_output_dir(self):
"""清理输出目录"""
try:
if os.path.exists(self.config.output_dir):
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
else:
pass
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def _get_directory_size(self, path):
"""获取目录的总大小(字节)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
try:
total_size += os.path.getsize(file_path)
except (OSError, FileNotFoundError):
continue
return total_size
def check_disk_space(self):
"""检查磁盘空间是否足够"""
# 获取输入目录大小
input_size = self._get_directory_size(self.config.image_dir)
# 获取输出目录所在磁盘的剩余空间
output_drive = os.path.splitdrive(
os.path.abspath(self.config.output_dir))[0]
if not output_drive: # 处理Linux/Unix路径
output_drive = '/home'
disk_usage = psutil.disk_usage(output_drive)
free_space = disk_usage.free
# 计算所需空间输入大小的10倍
required_space = input_size * 8
if free_space < required_space:
error_msg = (
f"磁盘空间不足!\n"
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
f"所需空间: {required_space / (1024**3):.2f} GB\n"
f"可用空间: {free_space / (1024**3):.2f} GB\n"
f"在驱动器 {output_drive}"
)
raise RuntimeError(error_msg)

65
utils/gps_extractor.py Normal file
View File

@ -0,0 +1,65 @@
import os
from PIL import Image
import piexif
import logging
import pandas as pd
class GPSExtractor:
"""从图像文件提取GPS坐标"""
def __init__(self, image_dir):
self.image_dir = image_dir
self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor')
@staticmethod
def _dms_to_decimal(dms):
"""将DMS格式转换为十进制度"""
return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600
def get_gps(self, image_path):
"""提取单张图片的GPS坐标"""
try:
image = Image.open(image_path)
exif_data = piexif.load(image.info['exif'])
# 提取GPS信息
gps_info = exif_data.get("GPS", {})
lat = lon = None
if gps_info:
lat = self._dms_to_decimal(gps_info.get(2, []))
lon = self._dms_to_decimal(gps_info.get(4, []))
self.logger.debug(
f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}")
if not gps_info:
self.logger.warning(f"图片无GPS信息: {image_path}")
return lat, lon
except Exception as e:
self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}")
return None, None, None
def extract_all_gps(self):
"""提取所有图片的GPS坐标"""
self.logger.info(f"开始从目录提取GPS坐标: {self.image_dir}")
gps_data = []
total_images = 0
successful_extractions = 0
for image_file in os.listdir(self.image_dir):
total_images += 1
image_path = os.path.join(self.image_dir, image_file)
lat, lon = self.get_gps(image_path)
if lat and lon: # 仍然以GPS信息作为主要判断依据
successful_extractions += 1
gps_data.append({
'file': image_file,
'lat': lat,
'lon': lon,
})
self.logger.info(
f"GPS坐标提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}")
return pd.DataFrame(gps_data)

266
utils/grid_divider.py Normal file
View File

@ -0,0 +1,266 @@
import logging
from geopy.distance import geodesic
import matplotlib.pyplot as plt
import os
class GridDivider:
"""划分网格,并将图片分配到对应网格"""
def __init__(self, overlap=0.1, grid_size=500, output_dir=None):
self.overlap = overlap
self.grid_size = grid_size
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.GridDivider')
self.logger.info(f"初始化网格划分器,重叠率: {overlap}")
self.num_grids_width = 0 # 添加网格数量属性
self.num_grids_height = 0
def adjust_grid_size_and_overlap(self, points_df):
"""动态调整网格重叠率"""
grids = self.adjust_grid_size(points_df)
self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}")
while True:
# 使用调整好的网格大小划分网格
grids = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(
points_df, grids)
if len(grids) == 1:
self.logger.info(f"网格数量为1跳过重叠率调整")
break
elif multiple_grid_points < 0.1*len(points_df):
self.overlap += 0.02
self.logger.info(f"重叠率增加到: {self.overlap}")
else:
self.logger.info(
f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格")
break
return grids, grid_points
def adjust_grid_size(self, points_df):
"""动态调整网格大小
Args:
points_df: 包含GPS点的DataFrame
Returns:
tuple: grids
"""
self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}")
while True:
# 使用当前grid_size划分网格
grids = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(
points_df, grids)
# 检查每个网格中的点数
max_points = 0
for grid_id, points in grid_points.items():
max_points = max(max_points, len(points))
self.logger.info(
f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}")
# 如果最大点数超过2000减小网格大小
if max_points > 2000:
self.grid_size -= 100
self.logger.info(f"点数超过2000减小网格大小至: {self.grid_size}")
if self.grid_size < 500: # 设置一个最小网格大小限制
self.logger.warning("网格大小已达到最小值500米停止调整")
break
else:
self.logger.info(f"找到合适的网格大小: {self.grid_size}")
break
return grids
def divide_grids(self, points_df):
"""计算边界框并划分网格
Returns:
tuple: grids 网格边界列表
"""
self.logger.info("开始划分网格")
min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max()
min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max()
# 计算区域的实际距离(米)
width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters
self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}")
# 精细调整网格的长宽避免出现2*grid_size-1的情况的影响
grid_size_lt = [self.grid_size - 200, self.grid_size - 100,
self.grid_size, self.grid_size + 100, self.grid_size + 200]
width_modulus_lt = [width % grid_size for grid_size in grid_size_lt]
grid_width = grid_size_lt[width_modulus_lt.index(
min(width_modulus_lt))]
height_modulus_lt = [height % grid_size for grid_size in grid_size_lt]
grid_height = grid_size_lt[height_modulus_lt.index(
min(height_modulus_lt))]
self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}")
# 计算需要划分的网格数量
self.num_grids_width = max(int(width / grid_width), 1)
self.num_grids_height = max(int(height / grid_height), 1)
# 计算每个网格对应的经纬度步长
lat_step = (max_lat - min_lat) / self.num_grids_height
lon_step = (max_lon - min_lon) / self.num_grids_width
grids = []
# 先创建所有网格
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step
grid_max_lat = min_lat + \
(i + 1) * lat_step + self.overlap * lat_step
grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step
grid_max_lon = min_lon + \
(j + 1) * lon_step + self.overlap * lon_step
grid_bounds = (grid_min_lat, grid_max_lat,
grid_min_lon, grid_max_lon)
grids.append(grid_bounds)
self.logger.debug(
f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], "
f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]"
)
self.logger.info(
f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})")
return grids
def assign_to_grids(self, points_df, grids):
"""将点分配到对应网格"""
self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中")
grid_points = {} # 使用字典存储每个网格的点
points_assigned = 0
multiple_grid_points = 0
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_points[(i, j)] = [] # 使用(i,j)元组
for _, point in points_df.iterrows():
point_assigned = False
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_idx = i * self.num_grids_width + j
min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon:
grid_points[(i, j)].append(point.to_dict())
if point_assigned:
multiple_grid_points += 1
else:
points_assigned += 1
point_assigned = True
# 记录每个网格的点数
for grid_id, points in grid_points.items():
self.logger.info(f"网格 {grid_id} 包含 {len(points)} 个点")
self.logger.info(
f"点分配完成: 总点数 {len(points_df)}, "
f"成功分配 {points_assigned} 个点, "
f"{multiple_grid_points} 个点被分配到多个网格"
)
return grid_points, multiple_grid_points
def visualize_grids(self, points_df, grids):
"""可视化网格划分和GPS点的分布"""
self.logger.info("开始可视化网格划分")
plt.figure(figsize=(12, 8))
# 绘制GPS点
plt.scatter(points_df['lon'], points_df['lat'],
c='blue', s=10, alpha=0.6, label='GPS points')
# 绘制网格
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_idx = i * self.num_grids_width + j
min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
# 计算网格的实际长度和宽度(米)
width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
height = geodesic((min_lat, min_lon),
(max_lat, min_lon)).meters
plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon],
[min_lat, min_lat, max_lat, max_lat, min_lat],
'r-', alpha=0.5)
# 在网格中心添加网格编号和尺寸信息
center_lon = (min_lon + max_lon) / 2
center_lat = (min_lat + max_lat) / 2
plt.text(center_lon, center_lat,
f"({i},{j})\n{width:.0f}m×{height:.0f}m", # 显示(i,j)和尺寸
horizontalalignment='center',
verticalalignment='center',
fontsize=8)
plt.title('Grid Division and GPS Point Distribution')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.legend()
plt.grid(True)
# 如果提供了输出目录,保存图像
if self.output_dir:
save_path = os.path.join(
self.output_dir, 'filter_imgs', 'grid_division.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
self.logger.info(f"网格划分可视化图已保存至: {save_path}")
plt.close()
def get_grid_center(self, grid_bounds) -> tuple:
"""计算网格中心点的经纬度
Args:
grid_bounds: (min_lat, max_lat, min_lon, max_lon)
Returns:
(center_lat, center_lon)
"""
min_lat, max_lat, min_lon, max_lon = grid_bounds
return ((min_lat + max_lat) / 2, (min_lon + max_lon) / 2)
def calculate_grid_translation(self, reference_grid: tuple, target_grid: tuple) -> tuple:
"""计算目标网格相对于参考网格的平移距离(米)
Args:
reference_grid: 参考网格的边界 (min_lat, max_lat, min_lon, max_lon)
target_grid: 目标网格的边界 (min_lat, max_lat, min_lon, max_lon)
Returns:
(x_translation, y_translation): 在米制单位下的平移量
"""
ref_center = self.get_grid_center(reference_grid)
target_center = self.get_grid_center(target_grid)
# 计算经度方向的距离x轴
x_distance = geodesic(
(ref_center[0], ref_center[1]),
(ref_center[0], target_center[1])
).meters
# 如果目标在参考点西边,距离为负
if target_center[1] < ref_center[1]:
x_distance = -x_distance
# 计算纬度方向的距离y轴
y_distance = geodesic(
(ref_center[0], ref_center[1]),
(target_center[0], ref_center[1])
).meters
# 如果目标在参考点南边,距离为负
if target_center[0] < ref_center[0]:
y_distance = -y_distance
return (x_distance, y_distance)

36
utils/logger.py Normal file
View File

@ -0,0 +1,36 @@
import logging
import os
from datetime import datetime
def setup_logger(output_dir):
# 创建logs目录
log_dir = os.path.join(output_dir, 'logs')
# 创建日志文件名(包含时间戳)
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log')
# 配置日志格式
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# 配置文件处理器
file_handler = logging.FileHandler(log_file, encoding='utf-8')
file_handler.setFormatter(formatter)
# 配置控制台处理器
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
# 获取根日志记录器
logger = logging.getLogger('UAV_Preprocess')
logger.setLevel(logging.INFO)
# 添加处理器
logger.addHandler(file_handler)
logger.addHandler(console_handler)
return logger

144
utils/odm_monitor.py Normal file
View File

@ -0,0 +1,144 @@
import os
import time
import logging
from typing import Dict, Tuple
import pandas as pd
import subprocess
class ODMProcessMonitor:
"""ODM处理监控器"""
def __init__(self, output_dir: str, mode: str = "三维模式", config=None):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor')
self.mode = mode
self.config = config
def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple) -> Tuple[bool, str]:
"""运行ODM命令"""
self.logger.info(f"开始处理网格 ({grid_id[0]},{grid_id[1]})")
success = False
error_msg = ""
max_retries = 3
current_try = 0
cpu_cores = os.cpu_count()
while current_try < max_retries:
current_try += 1
self.logger.info(
f"{current_try} 次尝试处理网格 ({grid_id[0]},{grid_id[1]})")
# 构建 Docker 容器运行参数
grid_dir_fixed = grid_dir[0].lower() + grid_dir[1:].replace('\\', '/')
command = (
f"--max-concurrency {cpu_cores} "
f"--force-gps "
f"--use-exif "
f"--use-hybrid-bundle-adjustment "
f"--optimize-disk-space "
f"--orthophoto-cutline "
)
command += f"--feature-type {self.config.feature_type} "
command += f"--orthophoto-resolution {self.config.orthophoto_resolution} "
if self.mode == "快拼模式":
command += (
f"--fast-orthophoto "
f"--skip-3dmodel "
)
else: # 三维模式参数
command += (
f"--dsm "
f"--dtm "
)
if current_try == 1:
command += (
f"--feature-quality low "
)
command += "--rerun-all"
docker_cmd = f'python3 /code/run.py --project-path {grid_dir} project {command}'
self.logger.info(f"Docker 命令: {docker_cmd}")
try:
# 执行命令
result = subprocess.run(
docker_cmd,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
encoding="utf-8"
)
if result.returncode != 0:
self.logger.error(
f"容器运行失败,退出状态码: {result.returncode}")
error_lines = result.stderr.splitlines()
self.logger.error("容器运行失败的最后 50 行错误日志:")
for line in error_lines[-50:]:
self.logger.error(line)
error_msg = "\n".join(error_lines[-50:])
time.sleep(5)
else:
logs = result.stdout.splitlines()
self.logger.info("容器运行完成,以下是最后 50 行日志:")
for line in logs[-50:]:
self.logger.info(line)
success = True
error_msg = ""
break
except Exception as e:
error_msg = str(e)
self.logger.error(f"执行docker命令时发生异常: {error_msg}")
time.sleep(5)
return success, error_msg
def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame]) -> list:
"""处理所有网格
Returns:
Dict[tuple, pd.DataFrame]: 成功处理的网格点数据字典
"""
self.logger.info("开始执行网格处理")
successful_grid_lt = []
failed_grids = []
for grid_id, points in grid_points.items():
grid_dir = os.path.join(
self.output_dir, f'grid_{grid_id[0]}_{grid_id[1]}'
)
try:
success, error_msg = self.run_odm_with_monitor(
grid_dir=grid_dir,
grid_id=grid_id,
)
if success:
successful_grid_lt.append(grid_id)
else:
self.logger.error(
f"网格 ({grid_id[0]},{grid_id[1]})")
failed_grids.append(grid_id)
except Exception as e:
error_msg = str(e)
self.logger.error(
f"处理网格 ({grid_id[0]},{grid_id[1]}) 时发生异常: {error_msg}")
failed_grids.append((grid_id, error_msg))
# 汇总处理结果
total_grids = len(grid_points)
failed_count = len(failed_grids)
success_count = len(successful_grid_lt)
self.logger.info(
f"网格处理完成。总计: {total_grids}, 成功: {success_count}, 失败: {failed_count}")
if len(successful_grid_lt) == 0:
raise Exception("所有网格处理都失败,无法继续处理")
return successful_grid_lt

152
utils/visualizer.py Normal file
View File

@ -0,0 +1,152 @@
import os
import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import Optional
from pyproj import Transformer
class FilterVisualizer:
"""过滤结果可视化器"""
def __init__(self, output_dir: str):
"""
初始化可视化器
Args:
output_dir: 输出目录路径
"""
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
# 创建坐标转换器
self.transformer = Transformer.from_crs(
"EPSG:4326", # WGS84经纬度坐标系
"EPSG:32649", # UTM49N
always_xy=True
)
def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple:
"""
将经纬度坐标转换为UTM坐标
Args:
lon: 经度序列
lat: 纬度序列
Returns:
tuple: (x坐标, y坐标)
"""
return self.transformer.transform(lon, lat)
def visualize_filter_step(self,
current_points: pd.DataFrame,
previous_points: pd.DataFrame,
step_name: str,
save_name: Optional[str] = None):
"""
可视化单个过滤步骤的结果
Args:
current_points: 当前步骤后的点
previous_points: 上一步骤的点
step_name: 步骤名称
save_name: 保存文件名默认为step_name
"""
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(
previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(
filtered_files)]
# 转换坐标到UTM
current_x, current_y = self._convert_to_utm(
current_points['lon'], current_points['lat'])
filtered_x, filtered_y = self._convert_to_utm(
filtered_points['lon'], filtered_points['lat'])
# 创建图形
plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(20, 20))
# 绘制保留的点
plt.scatter(current_x, current_y,
color='blue', label='保留的点',
alpha=0.6, s=5)
# 绘制被过滤的点
if not filtered_points.empty:
plt.scatter(filtered_x, filtered_y,
color='red', marker='x', label='过滤的点')
# 设置图形属性
plt.title(f"{step_name}后的GPS点\n"
f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})",
fontsize=14)
plt.xlabel("东向坐标 (米)", fontsize=12)
plt.ylabel("北向坐标 (米)", fontsize=12)
plt.grid(True)
plt.axis('equal')
# 添加统计信息
stats_text = (
f"原始点数: {len(previous_points)}\n"
f"过滤点数: {len(filtered_points)}\n"
f"保留点数: {len(current_points)}\n"
f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%"
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
# 添加图例
plt.legend(loc='upper right', fontsize=10)
# 调整布局
plt.tight_layout()
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(
self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(
f"{step_name}过滤可视化结果已保存至 {save_path}\n"
f"过滤掉 {len(filtered_points)} 个点,"
f"保留 {len(current_points)} 个点,"
f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%"
)
if __name__ == '__main__':
# 测试代码
import numpy as np
from datetime import datetime
# 创建测试数据
np.random.seed(42)
n_points = 1000
# 生成随机点
test_data = pd.DataFrame({
'lon': np.random.uniform(120, 121, n_points),
'lat': np.random.uniform(30, 31, n_points),
'file': [f'img_{i}.jpg' for i in range(n_points)],
'date': [datetime.now() for _ in range(n_points)]
})
# 随机选择点作为过滤后的结果
filtered_data = test_data.sample(n=800)
# 测试可视化
visualizer = FilterVisualizer('test_output')
os.makedirs('test_output', exist_ok=True)
visualizer.visualize_filter_step(
filtered_data,
test_data,
"Test Filter"
)