first commit
This commit is contained in:
commit
e8178d191e
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# 忽略所有__pycache__目录
|
||||||
|
**/__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
|
||||||
|
test/
|
15
README.md
Normal file
15
README.md
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
# ODM_Pro
|
||||||
|
无人机三维重建。
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
```
|
||||||
|
|
||||||
|
## 制作exe
|
||||||
|
|
||||||
|
1. odm_monitor.py docker容器名要记得改成uav
|
||||||
|
2. osgconv的compressed参数去掉,会报错
|
||||||
|
3. 运行指令`pyinstaller main.spec`
|
||||||
|
4. 测试执行指令`uav.exe --image_dir E:\datasets\UAV\134\project\images --output_dir G:\ODM_output\134`
|
163
app_plugin.py
Normal file
163
app_plugin.py
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from filter.cluster_filter import GPSCluster
|
||||||
|
from utils.directory_manager import DirectoryManager
|
||||||
|
from utils.odm_monitor import ODMProcessMonitor
|
||||||
|
from utils.gps_extractor import GPSExtractor
|
||||||
|
from utils.grid_divider import GridDivider
|
||||||
|
from utils.logger import setup_logger
|
||||||
|
from utils.visualizer import FilterVisualizer
|
||||||
|
from post_pro.merge_tif import MergeTif
|
||||||
|
from post_pro.conv_obj2 import ConvertOBJ
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessConfig:
|
||||||
|
"""预处理配置类"""
|
||||||
|
image_dir: str
|
||||||
|
output_dir: str
|
||||||
|
# 聚类过滤参数
|
||||||
|
cluster_eps: float = 0.01
|
||||||
|
cluster_min_samples: int = 5
|
||||||
|
|
||||||
|
# 网格划分参数
|
||||||
|
grid_overlap: float = 0.05
|
||||||
|
grid_size: float = 500
|
||||||
|
|
||||||
|
mode: str = "三维模式"
|
||||||
|
|
||||||
|
# ODM参数
|
||||||
|
feature_type: str = "sift"
|
||||||
|
orthophoto_resolution: float = 5
|
||||||
|
|
||||||
|
|
||||||
|
class ODM_Plugin:
|
||||||
|
def __init__(self, config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# 初始化目录管理器
|
||||||
|
self.dir_manager = DirectoryManager(config)
|
||||||
|
# 清理并重建输出目录
|
||||||
|
self.dir_manager.clean_output_dir()
|
||||||
|
self.dir_manager.setup_output_dirs()
|
||||||
|
# 检查磁盘空间
|
||||||
|
self.dir_manager.check_disk_space()
|
||||||
|
|
||||||
|
# 初始化其他组件
|
||||||
|
self.logger = setup_logger(config.output_dir)
|
||||||
|
self.gps_points = pd.DataFrame(columns=["file", "lat", "lon"])
|
||||||
|
self.odm_monitor = ODMProcessMonitor(
|
||||||
|
config.output_dir, mode=config.mode, config=config)
|
||||||
|
self.visualizer = FilterVisualizer(config.output_dir)
|
||||||
|
|
||||||
|
def extract_gps(self) -> pd.DataFrame:
|
||||||
|
"""提取GPS数据"""
|
||||||
|
self.logger.info("开始提取GPS数据")
|
||||||
|
extractor = GPSExtractor(self.config.image_dir)
|
||||||
|
self.gps_points = extractor.extract_all_gps()
|
||||||
|
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
|
||||||
|
|
||||||
|
def cluster(self):
|
||||||
|
"""使用DBSCAN对GPS点进行聚类,只保留最大的类"""
|
||||||
|
previous_points = self.gps_points.copy()
|
||||||
|
clusterer = GPSCluster(
|
||||||
|
self.gps_points,
|
||||||
|
eps=self.config.cluster_eps,
|
||||||
|
min_samples=self.config.cluster_min_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
self.clustered_points = clusterer.fit()
|
||||||
|
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
|
||||||
|
|
||||||
|
self.visualizer.visualize_filter_step(
|
||||||
|
self.gps_points, previous_points, "1-Clustering")
|
||||||
|
|
||||||
|
def divide_grids(self) -> Dict[tuple, pd.DataFrame]:
|
||||||
|
"""划分网格
|
||||||
|
Returns:
|
||||||
|
- grid_points: 网格点数据字典
|
||||||
|
- translations: 网格平移量字典
|
||||||
|
"""
|
||||||
|
grid_divider = GridDivider(
|
||||||
|
overlap=self.config.grid_overlap,
|
||||||
|
grid_size=self.config.grid_size,
|
||||||
|
output_dir=self.config.output_dir
|
||||||
|
)
|
||||||
|
grids, grid_points = grid_divider.adjust_grid_size_and_overlap(
|
||||||
|
self.gps_points
|
||||||
|
)
|
||||||
|
grid_divider.visualize_grids(self.gps_points, grids)
|
||||||
|
if len(grids) >= 20:
|
||||||
|
self.logger.warning("网格数量已超过20, 需要人工调整分区")
|
||||||
|
|
||||||
|
return grid_points
|
||||||
|
|
||||||
|
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
|
||||||
|
"""复制图像到目标文件夹"""
|
||||||
|
self.logger.info("开始复制图像文件")
|
||||||
|
|
||||||
|
for grid_id, points in grid_points.items():
|
||||||
|
output_dir = os.path.join(
|
||||||
|
self.config.output_dir,
|
||||||
|
f"grid_{grid_id[0]}_{grid_id[1]}",
|
||||||
|
"project",
|
||||||
|
"images"
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for point in points:
|
||||||
|
src = os.path.join(self.config.image_dir, point["file"])
|
||||||
|
dst = os.path.join(output_dir, point["file"])
|
||||||
|
shutil.copy(src, dst)
|
||||||
|
self.logger.info(
|
||||||
|
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
|
||||||
|
|
||||||
|
def merge_tif(self, grid_lt):
|
||||||
|
"""合并所有网格的影像产品"""
|
||||||
|
self.logger.info("开始合并所有影像产品")
|
||||||
|
merger = MergeTif(self.config.output_dir)
|
||||||
|
merger.merge_orthophoto(grid_lt)
|
||||||
|
|
||||||
|
def convert_obj(self, grid_lt):
|
||||||
|
"""转换OBJ模型"""
|
||||||
|
self.logger.info("开始转换OBJ模型")
|
||||||
|
converter = ConvertOBJ(self.config.output_dir)
|
||||||
|
converter.convert_grid_obj(grid_lt)
|
||||||
|
|
||||||
|
def post_process(self, successful_grid_lt: list, grid_points: Dict[tuple, pd.DataFrame]):
|
||||||
|
"""后处理:合并或复制处理结果"""
|
||||||
|
if len(successful_grid_lt) < len(grid_points):
|
||||||
|
self.logger.warning(
|
||||||
|
f"有 {len(grid_points) - len(successful_grid_lt)} 个网格处理失败,"
|
||||||
|
f"将只合并成功处理的 {len(successful_grid_lt)} 个网格"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.merge_tif(successful_grid_lt)
|
||||||
|
if self.config.mode == "三维模式":
|
||||||
|
self.convert_obj(successful_grid_lt)
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def process(self):
|
||||||
|
"""执行完整的预处理流程"""
|
||||||
|
try:
|
||||||
|
self.extract_gps()
|
||||||
|
self.cluster()
|
||||||
|
grid_points = self.divide_grids()
|
||||||
|
self.copy_images(grid_points)
|
||||||
|
self.logger.info("预处理任务完成")
|
||||||
|
|
||||||
|
successful_grid_lt = self.odm_monitor.process_all_grids(
|
||||||
|
grid_points)
|
||||||
|
|
||||||
|
self.post_process(successful_grid_lt, grid_points)
|
||||||
|
self.logger.info("重建任务完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||||
|
raise
|
24
check_version.py
Normal file
24
check_version.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
import importlib
|
||||||
|
import pkg_resources
|
||||||
|
|
||||||
|
def check_versions(requirements_file):
|
||||||
|
with open(requirements_file, 'r', encoding='utf-8') as f:
|
||||||
|
packages = [line.strip() for line in f if line.strip() and not line.startswith('#') and not line.startswith('//')]
|
||||||
|
|
||||||
|
print(f"{'Package':<20} {'Installed Version':<20}")
|
||||||
|
print('-' * 40)
|
||||||
|
for pkg in packages:
|
||||||
|
try:
|
||||||
|
# 有些包名和import名不同,优先用pkg_resources
|
||||||
|
version = pkg_resources.get_distribution(pkg).version
|
||||||
|
print(f"{pkg:<20} {version:<20}")
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
mod = importlib.import_module(pkg)
|
||||||
|
version = getattr(mod, '__version__', 'Unknown')
|
||||||
|
print(f"{pkg:<20} {version:<20}")
|
||||||
|
except Exception:
|
||||||
|
print(f"{pkg:<20} {'Not Installed':<20}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_versions("requirements.txt")
|
77
filter/cluster_filter.py
Normal file
77
filter/cluster_filter.py
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class GPSCluster:
|
||||||
|
def __init__(self, gps_points, eps=0.01, min_samples=3):
|
||||||
|
"""
|
||||||
|
初始化GPS聚类器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
eps: DBSCAN的邻域半径参数
|
||||||
|
min_samples: DBSCAN的最小样本数参数
|
||||||
|
"""
|
||||||
|
self.eps = eps
|
||||||
|
self.min_samples = min_samples
|
||||||
|
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
||||||
|
self.scaler = StandardScaler()
|
||||||
|
self.gps_points = gps_points
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.GPSCluster')
|
||||||
|
|
||||||
|
def fit(self):
|
||||||
|
"""
|
||||||
|
对GPS点进行聚类,只保留最大的类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
gps_points: 包含'lat'和'lon'列的DataFrame
|
||||||
|
|
||||||
|
返回:
|
||||||
|
带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1
|
||||||
|
"""
|
||||||
|
self.logger.info("开始聚类")
|
||||||
|
# 提取经纬度数据
|
||||||
|
X = self.gps_points[["lon", "lat"]].values
|
||||||
|
|
||||||
|
# # 数据标准化
|
||||||
|
# X_scaled = self.scaler.fit_transform(X)
|
||||||
|
|
||||||
|
# 执行DBSCAN聚类
|
||||||
|
labels = self.dbscan.fit_predict(X)
|
||||||
|
|
||||||
|
# 找出最大类的标签(排除噪声点-1)
|
||||||
|
unique_labels = [l for l in set(labels) if l != -1]
|
||||||
|
if unique_labels: # 如果有聚类
|
||||||
|
label_counts = [(l, sum(labels == l)) for l in unique_labels]
|
||||||
|
largest_label = max(label_counts, key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
# 将最大类标记为1,其他都标记为-1
|
||||||
|
new_labels = (labels == largest_label).astype(int)
|
||||||
|
new_labels[new_labels == 0] = -1
|
||||||
|
else: # 如果没有聚类,全部标记为-1
|
||||||
|
new_labels = labels
|
||||||
|
|
||||||
|
# 将聚类结果添加到原始数据中
|
||||||
|
result_df = self.gps_points.copy()
|
||||||
|
result_df["cluster"] = new_labels
|
||||||
|
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def get_cluster_stats(self, clustered_points):
|
||||||
|
"""
|
||||||
|
获取聚类统计信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
clustered_points: 带有聚类标签的DataFrame
|
||||||
|
|
||||||
|
返回:
|
||||||
|
聚类统计信息的字典
|
||||||
|
"""
|
||||||
|
main_cluster = clustered_points[clustered_points["cluster"] == 1]
|
||||||
|
noise_cluster = clustered_points[clustered_points["cluster"] == -1]
|
||||||
|
|
||||||
|
self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点,"
|
||||||
|
f"噪声点 {len(noise_cluster)} 个")
|
||||||
|
|
||||||
|
return main_cluster
|
54
main.py
Normal file
54
main.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import argparse
|
||||||
|
from app_plugin import ProcessConfig, ODM_Plugin
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(description='ODM预处理工具')
|
||||||
|
|
||||||
|
# 必需参数
|
||||||
|
parser.add_argument('--image_dir', required=True, help='输入图片目录路径')
|
||||||
|
parser.add_argument('--output_dir', required=True, help='输出目录路径')
|
||||||
|
# parser.add_argument('--image_dir', default=r'G:\UAV_data\test_31\project\images', help='输入图片目录路径')
|
||||||
|
# parser.add_argument('--output_dir', default=r'G:\ODM_output\test_31', help='输出目录路径')
|
||||||
|
|
||||||
|
# 可选参数
|
||||||
|
parser.add_argument('--mode', default='三维模式',
|
||||||
|
choices=['快拼模式', '三维模式'], help='处理模式')
|
||||||
|
parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)')
|
||||||
|
parser.add_argument('--grid_overlap', type=float,
|
||||||
|
default=0.1, help='网格重叠率')
|
||||||
|
|
||||||
|
# ODM参数
|
||||||
|
parser.add_argument('--feature_type', default='sift', help='特征类型')
|
||||||
|
parser.add_argument('--orthophoto_resolution',
|
||||||
|
type=float, default=5, help='正射影像分辨率')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# 创建配置
|
||||||
|
config = ProcessConfig(
|
||||||
|
image_dir=args.image_dir,
|
||||||
|
output_dir=args.output_dir,
|
||||||
|
mode=args.mode,
|
||||||
|
grid_size=args.grid_size,
|
||||||
|
grid_overlap=args.grid_overlap,
|
||||||
|
feature_type=args.feature_type,
|
||||||
|
orthophoto_resolution=args.orthophoto_resolution,
|
||||||
|
|
||||||
|
# 其他参数使用默认值
|
||||||
|
cluster_eps=0.01,
|
||||||
|
cluster_min_samples=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建处理器并执行
|
||||||
|
processor = ODM_Plugin(config)
|
||||||
|
processor.process()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
47
main.spec
Normal file
47
main.spec
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# -*- mode: python ; coding: utf-8 -*-
|
||||||
|
|
||||||
|
|
||||||
|
a = Analysis(
|
||||||
|
['main.py'],
|
||||||
|
pathex=[],
|
||||||
|
binaries=[],
|
||||||
|
datas=[],
|
||||||
|
hiddenimports=[
|
||||||
|
'rasterio.sample',
|
||||||
|
'rasterio.vrt',
|
||||||
|
'rasterio._io',
|
||||||
|
'rasterio.enums',
|
||||||
|
'rasterio.errors',
|
||||||
|
'rasterio.transform',
|
||||||
|
'rasterio.crs',
|
||||||
|
'rasterio.warp'
|
||||||
|
],
|
||||||
|
hookspath=[],
|
||||||
|
hooksconfig={},
|
||||||
|
runtime_hooks=[],
|
||||||
|
excludes=[],
|
||||||
|
noarchive=False,
|
||||||
|
optimize=0,
|
||||||
|
)
|
||||||
|
pyz = PYZ(a.pure)
|
||||||
|
|
||||||
|
exe = EXE(
|
||||||
|
pyz,
|
||||||
|
a.scripts,
|
||||||
|
a.binaries,
|
||||||
|
a.datas,
|
||||||
|
[],
|
||||||
|
name='uav',
|
||||||
|
debug=False,
|
||||||
|
bootloader_ignore_signals=False,
|
||||||
|
strip=False,
|
||||||
|
upx=True,
|
||||||
|
upx_exclude=[],
|
||||||
|
runtime_tmpdir=None,
|
||||||
|
console=True,
|
||||||
|
disable_windowed_traceback=False,
|
||||||
|
argv_emulation=False,
|
||||||
|
target_arch=None,
|
||||||
|
codesign_identity=None,
|
||||||
|
entitlements_file=None,
|
||||||
|
)
|
265
post_pro/conv_obj.py
Normal file
265
post_pro/conv_obj.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
from pyproj import Transformer
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertOBJ:
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
# 用于存储所有grid的UTM范围
|
||||||
|
self.ref_east = float('inf')
|
||||||
|
self.ref_north = float('inf')
|
||||||
|
# 初始化UTM到WGS84的转换器
|
||||||
|
self.transformer = Transformer.from_crs(
|
||||||
|
"EPSG:32649", "EPSG:4326", always_xy=True)
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
|
||||||
|
|
||||||
|
def convert_grid_obj(self, grid_lt):
|
||||||
|
"""转换每个网格的OBJ文件为OSGB格式"""
|
||||||
|
os.makedirs(os.path.join(self.output_dir,
|
||||||
|
"osgb", "Data"), exist_ok=True)
|
||||||
|
|
||||||
|
# 以第一个grid的UTM坐标作为参照系
|
||||||
|
first_grid_id = grid_lt[0]
|
||||||
|
first_grid_dir = os.path.join(
|
||||||
|
self.output_dir,
|
||||||
|
f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
|
||||||
|
"project"
|
||||||
|
)
|
||||||
|
log_file = os.path.join(
|
||||||
|
first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
|
||||||
|
self.ref_east, self.ref_north = self.read_utm_offset(log_file)
|
||||||
|
|
||||||
|
for grid_id in grid_lt:
|
||||||
|
try:
|
||||||
|
self._convert_single_grid(grid_id)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
|
||||||
|
|
||||||
|
self._create_merged_metadata()
|
||||||
|
|
||||||
|
def _convert_single_grid(self, grid_id):
|
||||||
|
"""转换单个网格的OBJ文件"""
|
||||||
|
# 构建相关路径
|
||||||
|
grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
|
||||||
|
project_dir = os.path.join(self.output_dir, grid_name, "project")
|
||||||
|
texturing_dir = os.path.join(project_dir, "odm_texturing")
|
||||||
|
texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
|
||||||
|
split_obj_dir = os.path.join(texturing_dst_dir, "split_obj")
|
||||||
|
log_file = os.path.join(
|
||||||
|
project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
|
||||||
|
os.makedirs(texturing_dst_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 修改obj文件z坐标的值
|
||||||
|
min_25d_z = self.get_min_z_from_obj(os.path.join(
|
||||||
|
project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
|
||||||
|
self.modify_z_in_obj(texturing_dir, min_25d_z)
|
||||||
|
|
||||||
|
# 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样
|
||||||
|
utm_offset = self.read_utm_offset(log_file)
|
||||||
|
modified_obj = self.modify_obj_coordinates(
|
||||||
|
texturing_dir, texturing_dst_dir, utm_offset)
|
||||||
|
self.downsample_texture(texturing_dir, texturing_dst_dir)
|
||||||
|
|
||||||
|
# 将obj文件进行切片
|
||||||
|
self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件")
|
||||||
|
os.makedirs(split_obj_dir)
|
||||||
|
cmd = (
|
||||||
|
f"D:\\software\\Obj2Tiles\\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 2 "
|
||||||
|
f"{modified_obj} {split_obj_dir}"
|
||||||
|
)
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
|
||||||
|
# 执行格式转换,Linux下osgconv有问题,记得注释掉
|
||||||
|
self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
|
||||||
|
# 先获取split_obj_dir下的所有obj文件
|
||||||
|
obj_lod_dir = os.path.join(split_obj_dir, "LOD-0")
|
||||||
|
obj_files = [f for f in os.listdir(
|
||||||
|
obj_lod_dir) if f.endswith('.obj')]
|
||||||
|
for obj_file in obj_files:
|
||||||
|
obj_path = os.path.join(obj_lod_dir, obj_file)
|
||||||
|
osgb_file = os.path.splitext(obj_file)[0] + '.osgb'
|
||||||
|
osgb_path = os.path.join(split_obj_dir, osgb_file)
|
||||||
|
# 执行 osgconv 命令
|
||||||
|
subprocess.run(['osgconv', obj_path, osgb_path], check=True)
|
||||||
|
|
||||||
|
# 创建OSGB目录结构,复制文件
|
||||||
|
osgb_base_dir = os.path.join(self.output_dir, "osgb")
|
||||||
|
data_dir = os.path.join(osgb_base_dir, "Data")
|
||||||
|
for obj_file in obj_files:
|
||||||
|
obj_file_name = os.path.splitext(obj_file)[0]
|
||||||
|
tile_dirs = os.path.join(
|
||||||
|
data_dir, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}")
|
||||||
|
os.makedirs(tile_dirs, exist_ok=True)
|
||||||
|
shutil.copy2(os.path.join(
|
||||||
|
split_obj_dir, obj_file_name+".osgb"), tile_dirs)
|
||||||
|
os.rename(os.path.join(tile_dirs, obj_file_name+".osgb"),
|
||||||
|
os.path.join(tile_dirs, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}.osgb"))
|
||||||
|
|
||||||
|
def _create_merged_metadata(self):
|
||||||
|
"""创建合并后的metadata.xml文件"""
|
||||||
|
# 转换为WGS84经纬度
|
||||||
|
center_lon, center_lat = self.transformer.transform(
|
||||||
|
self.ref_east, self.ref_north)
|
||||||
|
metadata_content = f"""<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<ModelMetadata version="1">
|
||||||
|
<SRS>EPSG:4326</SRS>
|
||||||
|
<SRSOrigin>{center_lon},{center_lat},0</SRSOrigin>
|
||||||
|
<Texture>
|
||||||
|
<ColorSource>Visible</ColorSource>
|
||||||
|
</Texture>
|
||||||
|
</ModelMetadata>"""
|
||||||
|
|
||||||
|
metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
|
||||||
|
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(metadata_content)
|
||||||
|
|
||||||
|
def read_utm_offset(self, log_file: str) -> tuple:
|
||||||
|
"""读取UTM偏移量"""
|
||||||
|
try:
|
||||||
|
east_offset = None
|
||||||
|
north_offset = None
|
||||||
|
|
||||||
|
with open(log_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if 'utm_north_offset' in line and i + 1 < len(lines):
|
||||||
|
north_offset = float(lines[i + 1].strip())
|
||||||
|
elif 'utm_east_offset' in line and i + 1 < len(lines):
|
||||||
|
east_offset = float(lines[i + 1].strip())
|
||||||
|
|
||||||
|
if east_offset is None or north_offset is None:
|
||||||
|
raise ValueError("未找到UTM偏移量")
|
||||||
|
|
||||||
|
return east_offset, north_offset
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
|
||||||
|
"""修改obj文件中的顶点坐标,使用相对坐标系"""
|
||||||
|
obj_file = os.path.join(
|
||||||
|
texturing_dir, "odm_textured_model_modified.obj")
|
||||||
|
obj_dst_file = os.path.join(
|
||||||
|
texturing_dst_dir, "odm_textured_model_geo_utm.obj")
|
||||||
|
if not os.path.exists(obj_file):
|
||||||
|
raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
|
||||||
|
shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
|
||||||
|
os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
|
||||||
|
east_offset, north_offset = utm_offset
|
||||||
|
self.logger.info(
|
||||||
|
f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
|
||||||
|
for line in f_in:
|
||||||
|
if line.startswith('v '):
|
||||||
|
# 处理顶点坐标行
|
||||||
|
parts = line.strip().split()
|
||||||
|
# 使用相对于整体最小UTM坐标的偏移
|
||||||
|
x = float(parts[1]) + (east_offset - self.ref_east)
|
||||||
|
y = float(parts[2]) + (north_offset - self.ref_north)
|
||||||
|
z = float(parts[3])
|
||||||
|
f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
|
||||||
|
elif line.startswith('vn '): # 处理法线向量
|
||||||
|
parts = line.split()
|
||||||
|
nx = float(parts[1])
|
||||||
|
ny = float(parts[2])
|
||||||
|
nz = float(parts[3])
|
||||||
|
# 同步反转法线的 Y 轴
|
||||||
|
new_line = f"vn {nx} {nz} {-ny}\n"
|
||||||
|
f_out.write(new_line)
|
||||||
|
else:
|
||||||
|
# 其他行直接写入
|
||||||
|
f_out.write(line)
|
||||||
|
|
||||||
|
return obj_dst_file
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def downsample_texture(self, src_dir: str, dst_dir: str):
|
||||||
|
"""复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB
|
||||||
|
Args:
|
||||||
|
src_dir: 源纹理目录
|
||||||
|
dst_dir: 目标纹理目录
|
||||||
|
"""
|
||||||
|
for file in os.listdir(src_dir):
|
||||||
|
if file.lower().endswith(('.png')):
|
||||||
|
src_path = os.path.join(src_dir, file)
|
||||||
|
dst_path = os.path.join(dst_dir, file)
|
||||||
|
|
||||||
|
# 检查文件大小(以字节为单位)
|
||||||
|
file_size = os.path.getsize(src_path)
|
||||||
|
if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制
|
||||||
|
shutil.copy2(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
# 文件大于100MB,进行下采样
|
||||||
|
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if_first_ds = True
|
||||||
|
while file_size > 100 * 1024 * 1024: # 大于100MB
|
||||||
|
self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样")
|
||||||
|
|
||||||
|
if if_first_ds:
|
||||||
|
# 计算新的尺寸(长宽各变为1/4)
|
||||||
|
new_size = (img.shape[1] // 4,
|
||||||
|
img.shape[0] // 4) # 逐步减小尺寸
|
||||||
|
# 使用双三次插值进行下采样
|
||||||
|
resized_img = cv2.resize(
|
||||||
|
img, new_size, interpolation=cv2.INTER_CUBIC)
|
||||||
|
if_first_ds = False
|
||||||
|
else:
|
||||||
|
# 计算新的尺寸(长宽各变为1/2)
|
||||||
|
new_size = (img.shape[1] // 2,
|
||||||
|
img.shape[0] // 2) # 逐步减小尺寸
|
||||||
|
# 使用双三次插值进行下采样
|
||||||
|
resized_img = cv2.resize(
|
||||||
|
img, new_size, interpolation=cv2.INTER_CUBIC)
|
||||||
|
|
||||||
|
# 更新文件路径为下采样后的路径
|
||||||
|
cv2.imwrite(dst_path, resized_img, [
|
||||||
|
cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||||
|
|
||||||
|
# 更新文件大小和图像
|
||||||
|
file_size = os.path.getsize(dst_path)
|
||||||
|
img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
self.logger.info(
|
||||||
|
f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
|
def get_min_z_from_obj(self, file_path):
|
||||||
|
min_z = float('inf') # 初始值设为无穷大
|
||||||
|
with open(file_path, 'r') as obj_file:
|
||||||
|
for line in obj_file:
|
||||||
|
# 检查每一行是否是顶点定义(以 'v ' 开头)
|
||||||
|
if line.startswith('v '):
|
||||||
|
# 获取顶点坐标
|
||||||
|
parts = line.split()
|
||||||
|
# 将z值转换为浮动数字
|
||||||
|
z = float(parts[3])
|
||||||
|
# 更新最小z值
|
||||||
|
if z < min_z:
|
||||||
|
min_z = z
|
||||||
|
return min_z
|
||||||
|
|
||||||
|
def modify_z_in_obj(self, texturing_dir, min_25d_z):
|
||||||
|
obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
|
||||||
|
output_file = os.path.join(
|
||||||
|
texturing_dir, 'odm_textured_model_modified.obj')
|
||||||
|
with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
|
||||||
|
for line in f_in:
|
||||||
|
if line.startswith('v '): # 顶点坐标行
|
||||||
|
parts = line.strip().split()
|
||||||
|
x = float(parts[1])
|
||||||
|
y = float(parts[2])
|
||||||
|
z = float(parts[3])
|
||||||
|
|
||||||
|
if z < min_25d_z:
|
||||||
|
z = min_25d_z
|
||||||
|
|
||||||
|
f_out.write(f"v {x} {y} {z}\n")
|
||||||
|
else:
|
||||||
|
f_out.write(line)
|
253
post_pro/conv_obj2.py
Normal file
253
post_pro/conv_obj2.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
from pyproj import Transformer
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
class ConvertOBJ:
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
# 用于存储所有grid的UTM范围
|
||||||
|
self.ref_east = float('inf')
|
||||||
|
self.ref_north = float('inf')
|
||||||
|
# 初始化UTM到WGS84的转换器
|
||||||
|
self.transformer = Transformer.from_crs(
|
||||||
|
"EPSG:32649", "EPSG:4326", always_xy=True)
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
|
||||||
|
|
||||||
|
def convert_grid_obj(self, grid_lt):
|
||||||
|
"""转换每个网格的OBJ文件为OSGB格式"""
|
||||||
|
os.makedirs(os.path.join(self.output_dir,
|
||||||
|
"osgb", "Data"), exist_ok=True)
|
||||||
|
|
||||||
|
# 以第一个grid的UTM坐标作为参照系
|
||||||
|
first_grid_id = grid_lt[0]
|
||||||
|
first_grid_dir = os.path.join(
|
||||||
|
self.output_dir,
|
||||||
|
f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
|
||||||
|
"project"
|
||||||
|
)
|
||||||
|
log_file = os.path.join(
|
||||||
|
first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
|
||||||
|
self.ref_east, self.ref_north = self.read_utm_offset(log_file)
|
||||||
|
|
||||||
|
for grid_id in grid_lt:
|
||||||
|
try:
|
||||||
|
self._convert_single_grid(grid_id)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
|
||||||
|
|
||||||
|
self._create_merged_metadata()
|
||||||
|
|
||||||
|
def _convert_single_grid(self, grid_id):
|
||||||
|
"""转换单个网格的OBJ文件"""
|
||||||
|
# 构建相关路径
|
||||||
|
grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
|
||||||
|
project_dir = os.path.join(self.output_dir, grid_name, "project")
|
||||||
|
texturing_dir = os.path.join(project_dir, "odm_texturing")
|
||||||
|
texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
|
||||||
|
opensfm_dir = os.path.join(project_dir, "opensfm")
|
||||||
|
log_file = os.path.join(
|
||||||
|
project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
|
||||||
|
os.makedirs(texturing_dst_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 修改obj文件z坐标的值
|
||||||
|
min_25d_z = self.get_min_z_from_obj(os.path.join(
|
||||||
|
project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
|
||||||
|
self.modify_z_in_obj(texturing_dir, min_25d_z)
|
||||||
|
|
||||||
|
# 在新文件夹下,利用UTM偏移量,修改obj文件顶点坐标,纹理文件下采样
|
||||||
|
utm_offset = self.read_utm_offset(log_file)
|
||||||
|
modified_obj = self.modify_obj_coordinates(
|
||||||
|
texturing_dir, texturing_dst_dir, utm_offset)
|
||||||
|
self.downsample_texture(texturing_dir, texturing_dst_dir)
|
||||||
|
|
||||||
|
# 执行格式转换,Linux下osgconv有问题,记得注释掉
|
||||||
|
self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
|
||||||
|
output_osgb = os.path.join(texturing_dst_dir, "Tile.osgb")
|
||||||
|
cmd = (
|
||||||
|
f"osgconv {modified_obj} {output_osgb} "
|
||||||
|
f"--smooth --fix-transparency "
|
||||||
|
)
|
||||||
|
self.logger.info(f"执行osgconv命令:{cmd}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(cmd, shell=True, check=True, cwd=texturing_dir)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
raise RuntimeError(f"OSGB转换失败: {str(e)}")
|
||||||
|
|
||||||
|
# 创建OSGB目录结构,复制文件
|
||||||
|
osgb_base_dir = os.path.join(self.output_dir, "osgb")
|
||||||
|
data_dir = os.path.join(osgb_base_dir, "Data")
|
||||||
|
tile_dir = os.path.join(data_dir, f"Tile_{grid_id[0]}_{grid_id[1]}")
|
||||||
|
os.makedirs(tile_dir, exist_ok=True)
|
||||||
|
target_osgb = os.path.join(
|
||||||
|
tile_dir, f"Tile_{grid_id[0]}_{grid_id[1]}.osgb")
|
||||||
|
shutil.copy2(output_osgb, target_osgb)
|
||||||
|
|
||||||
|
def _create_merged_metadata(self):
|
||||||
|
"""创建合并后的metadata.xml文件"""
|
||||||
|
# 转换为WGS84经纬度
|
||||||
|
center_lon, center_lat = self.transformer.transform(
|
||||||
|
self.ref_east, self.ref_north)
|
||||||
|
metadata_content = f"""<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<ModelMetadata version="1">
|
||||||
|
<SRS>EPSG:4326</SRS>
|
||||||
|
<SRSOrigin>{center_lon},{center_lat},0</SRSOrigin>
|
||||||
|
<Texture>
|
||||||
|
<ColorSource>Visible</ColorSource>
|
||||||
|
</Texture>
|
||||||
|
</ModelMetadata>"""
|
||||||
|
|
||||||
|
metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
|
||||||
|
with open(metadata_file, 'w', encoding='utf-8') as f:
|
||||||
|
f.write(metadata_content)
|
||||||
|
|
||||||
|
def read_utm_offset(self, log_file: str) -> tuple:
|
||||||
|
"""读取UTM偏移量"""
|
||||||
|
try:
|
||||||
|
east_offset = None
|
||||||
|
north_offset = None
|
||||||
|
|
||||||
|
with open(log_file, 'r') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if 'utm_north_offset' in line and i + 1 < len(lines):
|
||||||
|
north_offset = float(lines[i + 1].strip())
|
||||||
|
elif 'utm_east_offset' in line and i + 1 < len(lines):
|
||||||
|
east_offset = float(lines[i + 1].strip())
|
||||||
|
|
||||||
|
if east_offset is None or north_offset is None:
|
||||||
|
raise ValueError("未找到UTM偏移量")
|
||||||
|
|
||||||
|
return east_offset, north_offset
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
|
||||||
|
"""修改obj文件中的顶点坐标,使用相对坐标系"""
|
||||||
|
obj_file = os.path.join(
|
||||||
|
texturing_dir, "odm_textured_model_modified.obj")
|
||||||
|
obj_dst_file = os.path.join(
|
||||||
|
texturing_dst_dir, "odm_textured_model_geo_utm.obj")
|
||||||
|
if not os.path.exists(obj_file):
|
||||||
|
raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
|
||||||
|
shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
|
||||||
|
os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
|
||||||
|
east_offset, north_offset = utm_offset
|
||||||
|
self.logger.info(
|
||||||
|
f"UTM坐标偏移:{east_offset - self.ref_east}, {north_offset - self.ref_north}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
|
||||||
|
for line in f_in:
|
||||||
|
if line.startswith('v '):
|
||||||
|
# 处理顶点坐标行
|
||||||
|
parts = line.strip().split()
|
||||||
|
# 使用相对于整体最小UTM坐标的偏移
|
||||||
|
x = float(parts[1]) + (east_offset - self.ref_east)
|
||||||
|
y = float(parts[2]) + (north_offset - self.ref_north)
|
||||||
|
z = float(parts[3])
|
||||||
|
f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
|
||||||
|
elif line.startswith('vn '): # 处理法线向量
|
||||||
|
parts = line.split()
|
||||||
|
nx = float(parts[1])
|
||||||
|
ny = float(parts[2])
|
||||||
|
nz = float(parts[3])
|
||||||
|
# 同步反转法线的 Y 轴
|
||||||
|
new_line = f"vn {nx} {nz} {-ny}\n"
|
||||||
|
f_out.write(new_line)
|
||||||
|
else:
|
||||||
|
# 其他行直接写入
|
||||||
|
f_out.write(line)
|
||||||
|
|
||||||
|
return obj_dst_file
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def downsample_texture(self, src_dir: str, dst_dir: str):
|
||||||
|
"""复制并重命名纹理文件,对大于100MB的文件进行多次下采样,直到文件小于100MB
|
||||||
|
Args:
|
||||||
|
src_dir: 源纹理目录
|
||||||
|
dst_dir: 目标纹理目录
|
||||||
|
"""
|
||||||
|
for file in os.listdir(src_dir):
|
||||||
|
if file.lower().endswith(('.png')):
|
||||||
|
src_path = os.path.join(src_dir, file)
|
||||||
|
dst_path = os.path.join(dst_dir, file)
|
||||||
|
|
||||||
|
# 检查文件大小(以字节为单位)
|
||||||
|
file_size = os.path.getsize(src_path)
|
||||||
|
if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB,直接复制
|
||||||
|
shutil.copy2(src_path, dst_path)
|
||||||
|
else:
|
||||||
|
# 文件大于100MB,进行下采样
|
||||||
|
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
if_first_ds = True
|
||||||
|
while file_size > 100 * 1024 * 1024: # 大于100MB
|
||||||
|
self.logger.info(f"纹理文件 {file} 大于100MB,进行下采样")
|
||||||
|
|
||||||
|
if if_first_ds:
|
||||||
|
# 计算新的尺寸(长宽各变为1/4)
|
||||||
|
new_size = (img.shape[1] // 4,
|
||||||
|
img.shape[0] // 4) # 逐步减小尺寸
|
||||||
|
# 使用双三次插值进行下采样
|
||||||
|
resized_img = cv2.resize(
|
||||||
|
img, new_size, interpolation=cv2.INTER_CUBIC)
|
||||||
|
if_first_ds = False
|
||||||
|
else:
|
||||||
|
# 计算新的尺寸(长宽各变为1/2)
|
||||||
|
new_size = (img.shape[1] // 2,
|
||||||
|
img.shape[0] // 2) # 逐步减小尺寸
|
||||||
|
# 使用双三次插值进行下采样
|
||||||
|
resized_img = cv2.resize(
|
||||||
|
img, new_size, interpolation=cv2.INTER_CUBIC)
|
||||||
|
|
||||||
|
# 更新文件路径为下采样后的路径
|
||||||
|
cv2.imwrite(dst_path, resized_img, [
|
||||||
|
cv2.IMWRITE_PNG_COMPRESSION, 9])
|
||||||
|
|
||||||
|
# 更新文件大小和图像
|
||||||
|
file_size = os.path.getsize(dst_path)
|
||||||
|
img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
|
||||||
|
self.logger.info(
|
||||||
|
f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
|
||||||
|
|
||||||
|
def get_min_z_from_obj(self, file_path):
|
||||||
|
min_z = float('inf') # 初始值设为无穷大
|
||||||
|
with open(file_path, 'r') as obj_file:
|
||||||
|
for line in obj_file:
|
||||||
|
# 检查每一行是否是顶点定义(以 'v ' 开头)
|
||||||
|
if line.startswith('v '):
|
||||||
|
# 获取顶点坐标
|
||||||
|
parts = line.split()
|
||||||
|
# 将z值转换为浮动数字
|
||||||
|
z = float(parts[3])
|
||||||
|
# 更新最小z值
|
||||||
|
if z < min_z:
|
||||||
|
min_z = z
|
||||||
|
return min_z
|
||||||
|
|
||||||
|
def modify_z_in_obj(self, texturing_dir, min_25d_z):
|
||||||
|
obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
|
||||||
|
output_file = os.path.join(
|
||||||
|
texturing_dir, 'odm_textured_model_modified.obj')
|
||||||
|
with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
|
||||||
|
for line in f_in:
|
||||||
|
if line.startswith('v '): # 顶点坐标行
|
||||||
|
parts = line.strip().split()
|
||||||
|
x = float(parts[1])
|
||||||
|
y = float(parts[2])
|
||||||
|
z = float(parts[3])
|
||||||
|
|
||||||
|
if z < min_25d_z:
|
||||||
|
z = min_25d_z
|
||||||
|
|
||||||
|
f_out.write(f"v {x} {y} {z}\n")
|
||||||
|
else:
|
||||||
|
f_out.write(line)
|
285
post_pro/merge_tif.py
Normal file
285
post_pro/merge_tif.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict
|
||||||
|
import rasterio
|
||||||
|
from rasterio.mask import mask
|
||||||
|
from rasterio.transform import Affine, rowcol
|
||||||
|
import fiona
|
||||||
|
from edt import edt
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class MergeTif:
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.MergeTif')
|
||||||
|
|
||||||
|
def merge_orthophoto(self, grid_lt):
|
||||||
|
"""合并网格的正射影像"""
|
||||||
|
try:
|
||||||
|
all_orthos_and_ortho_cuts = []
|
||||||
|
for grid_id in grid_lt:
|
||||||
|
grid_ortho_dir = os.path.join(
|
||||||
|
self.output_dir,
|
||||||
|
f"grid_{grid_id[0]}_{grid_id[1]}",
|
||||||
|
"project",
|
||||||
|
"odm_orthophoto",
|
||||||
|
)
|
||||||
|
tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif")
|
||||||
|
tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg")
|
||||||
|
output_cut_tif = os.path.join(
|
||||||
|
grid_ortho_dir, "odm_orthophoto_cut.tif")
|
||||||
|
output_feathered_tif = os.path.join(
|
||||||
|
grid_ortho_dir, "odm_orthophoto_feathered.tif")
|
||||||
|
|
||||||
|
self.compute_mask_raster(
|
||||||
|
tif_path, tif_mask, output_cut_tif, blend_distance=20)
|
||||||
|
self.feather_raster(
|
||||||
|
tif_path, output_feathered_tif, blend_distance=20)
|
||||||
|
all_orthos_and_ortho_cuts.append(
|
||||||
|
[output_feathered_tif, output_cut_tif])
|
||||||
|
|
||||||
|
orthophoto_vars = {
|
||||||
|
'TILED': 'NO',
|
||||||
|
'COMPRESS': False,
|
||||||
|
'PREDICTOR': '1',
|
||||||
|
'BIGTIFF': 'IF_SAFER',
|
||||||
|
'BLOCKXSIZE': 512,
|
||||||
|
'BLOCKYSIZE': 512,
|
||||||
|
'NUM_THREADS': 15
|
||||||
|
}
|
||||||
|
self.merge(all_orthos_and_ortho_cuts, os.path.join(
|
||||||
|
self.output_dir, "orthophoto.tif"), orthophoto_vars)
|
||||||
|
self.logger.info("所有产品合并完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False):
|
||||||
|
if not os.path.exists(input_raster):
|
||||||
|
print("Cannot mask raster, %s does not exist" % input_raster)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.exists(vector_mask):
|
||||||
|
print("Cannot mask raster, %s does not exist" % vector_mask)
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Computing mask raster: %s" % output_raster)
|
||||||
|
|
||||||
|
with rasterio.open(input_raster, 'r') as rast:
|
||||||
|
with fiona.open(vector_mask) as src:
|
||||||
|
burn_features = src
|
||||||
|
|
||||||
|
if only_max_coords_feature:
|
||||||
|
max_coords_count = 0
|
||||||
|
max_coords_feature = None
|
||||||
|
for feature in src:
|
||||||
|
if feature is not None:
|
||||||
|
# No complex shapes
|
||||||
|
if len(feature['geometry']['coordinates'][0]) > max_coords_count:
|
||||||
|
max_coords_count = len(
|
||||||
|
feature['geometry']['coordinates'][0])
|
||||||
|
max_coords_feature = feature
|
||||||
|
if max_coords_feature is not None:
|
||||||
|
burn_features = [max_coords_feature]
|
||||||
|
|
||||||
|
shapes = [feature["geometry"] for feature in burn_features]
|
||||||
|
out_image, out_transform = mask(rast, shapes, nodata=0)
|
||||||
|
|
||||||
|
if blend_distance > 0:
|
||||||
|
if out_image.shape[0] >= 4:
|
||||||
|
# alpha_band = rast.dataset_mask()
|
||||||
|
alpha_band = out_image[-1]
|
||||||
|
dist_t = edt(alpha_band, black_border=True, parallel=0)
|
||||||
|
dist_t[dist_t <= blend_distance] /= blend_distance
|
||||||
|
dist_t[dist_t > blend_distance] = 1
|
||||||
|
np.multiply(alpha_band, dist_t,
|
||||||
|
out=alpha_band, casting="unsafe")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"%s does not have an alpha band, cannot blend cutline!" % input_raster)
|
||||||
|
|
||||||
|
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
|
||||||
|
dst.colorinterp = rast.colorinterp
|
||||||
|
dst.write(out_image)
|
||||||
|
|
||||||
|
return output_raster
|
||||||
|
|
||||||
|
def feather_raster(self, input_raster, output_raster, blend_distance=20):
|
||||||
|
if not os.path.exists(input_raster):
|
||||||
|
print("Cannot feather raster, %s does not exist" % input_raster)
|
||||||
|
return
|
||||||
|
|
||||||
|
print("Computing feather raster: %s" % output_raster)
|
||||||
|
|
||||||
|
with rasterio.open(input_raster, 'r') as rast:
|
||||||
|
out_image = rast.read()
|
||||||
|
if blend_distance > 0:
|
||||||
|
if out_image.shape[0] >= 4:
|
||||||
|
alpha_band = out_image[-1]
|
||||||
|
dist_t = edt(alpha_band, black_border=True, parallel=0)
|
||||||
|
dist_t[dist_t <= blend_distance] /= blend_distance
|
||||||
|
dist_t[dist_t > blend_distance] = 1
|
||||||
|
np.multiply(alpha_band, dist_t,
|
||||||
|
out=alpha_band, casting="unsafe")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"%s does not have an alpha band, cannot feather raster!" % input_raster)
|
||||||
|
|
||||||
|
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
|
||||||
|
dst.colorinterp = rast.colorinterp
|
||||||
|
dst.write(out_image)
|
||||||
|
|
||||||
|
return output_raster
|
||||||
|
|
||||||
|
def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}):
|
||||||
|
"""
|
||||||
|
Based on https://github.com/mapbox/rio-merge-rgba/
|
||||||
|
Merge orthophotos around cutlines using a blend buffer.
|
||||||
|
"""
|
||||||
|
inputs = []
|
||||||
|
bounds = None
|
||||||
|
precision = 7
|
||||||
|
|
||||||
|
for o, c in input_ortho_and_ortho_cuts:
|
||||||
|
inputs.append((o, c))
|
||||||
|
|
||||||
|
with rasterio.open(inputs[0][0]) as first:
|
||||||
|
res = first.res
|
||||||
|
dtype = first.dtypes[0]
|
||||||
|
profile = first.profile
|
||||||
|
num_bands = first.meta['count'] - 1 # minus alpha
|
||||||
|
colorinterp = first.colorinterp
|
||||||
|
|
||||||
|
print("%s valid orthophoto rasters to merge" % len(inputs))
|
||||||
|
sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs]
|
||||||
|
|
||||||
|
# scan input files.
|
||||||
|
# while we're at it, validate assumptions about inputs
|
||||||
|
xs = []
|
||||||
|
ys = []
|
||||||
|
for src, _ in sources:
|
||||||
|
left, bottom, right, top = src.bounds
|
||||||
|
xs.extend([left, right])
|
||||||
|
ys.extend([bottom, top])
|
||||||
|
if src.profile["count"] < 2:
|
||||||
|
raise ValueError("Inputs must be at least 2-band rasters")
|
||||||
|
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
|
||||||
|
print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
|
||||||
|
|
||||||
|
output_transform = Affine.translation(dst_w, dst_n)
|
||||||
|
output_transform *= Affine.scale(res[0], -res[1])
|
||||||
|
|
||||||
|
# Compute output array shape. We guarantee it will cover the output
|
||||||
|
# bounds completely.
|
||||||
|
output_width = int(math.ceil((dst_e - dst_w) / res[0]))
|
||||||
|
output_height = int(math.ceil((dst_n - dst_s) / res[1]))
|
||||||
|
|
||||||
|
# Adjust bounds to fit.
|
||||||
|
dst_e, dst_s = output_transform * (output_width, output_height)
|
||||||
|
print("Output width: %d, height: %d" %
|
||||||
|
(output_width, output_height))
|
||||||
|
print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
|
||||||
|
|
||||||
|
profile["transform"] = output_transform
|
||||||
|
profile["height"] = output_height
|
||||||
|
profile["width"] = output_width
|
||||||
|
profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES'
|
||||||
|
profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512)
|
||||||
|
profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512)
|
||||||
|
profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW')
|
||||||
|
profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2')
|
||||||
|
profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER')
|
||||||
|
profile.update()
|
||||||
|
|
||||||
|
# create destination file
|
||||||
|
with rasterio.open(output_orthophoto, "w", **profile) as dstrast:
|
||||||
|
dstrast.colorinterp = colorinterp
|
||||||
|
for idx, dst_window in dstrast.block_windows():
|
||||||
|
left, bottom, right, top = dstrast.window_bounds(dst_window)
|
||||||
|
|
||||||
|
blocksize = dst_window.width
|
||||||
|
dst_rows, dst_cols = (dst_window.height, dst_window.width)
|
||||||
|
|
||||||
|
# initialize array destined for the block
|
||||||
|
dst_count = first.count
|
||||||
|
dst_shape = (dst_count, dst_rows, dst_cols)
|
||||||
|
|
||||||
|
dstarr = np.zeros(dst_shape, dtype=dtype)
|
||||||
|
|
||||||
|
# First pass, write all rasters naively without blending
|
||||||
|
for src, _ in sources:
|
||||||
|
src_window = tuple(zip(rowcol(
|
||||||
|
src.transform, left, top, op=round, precision=precision
|
||||||
|
), rowcol(
|
||||||
|
src.transform, right, bottom, op=round, precision=precision
|
||||||
|
)))
|
||||||
|
|
||||||
|
temp = np.zeros(dst_shape, dtype=dtype)
|
||||||
|
temp = src.read(
|
||||||
|
out=temp, window=src_window, boundless=True, masked=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# pixels without data yet are available to write
|
||||||
|
write_region = np.logical_and(
|
||||||
|
(dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata
|
||||||
|
)
|
||||||
|
np.copyto(dstarr, temp, where=write_region)
|
||||||
|
|
||||||
|
# check if dest has any nodata pixels available
|
||||||
|
if np.count_nonzero(dstarr[-1]) == blocksize:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Second pass, write all feathered rasters
|
||||||
|
# blending the edges
|
||||||
|
for src, _ in sources:
|
||||||
|
src_window = tuple(zip(rowcol(
|
||||||
|
src.transform, left, top, op=round, precision=precision
|
||||||
|
), rowcol(
|
||||||
|
src.transform, right, bottom, op=round, precision=precision
|
||||||
|
)))
|
||||||
|
|
||||||
|
temp = np.zeros(dst_shape, dtype=dtype)
|
||||||
|
temp = src.read(
|
||||||
|
out=temp, window=src_window, boundless=True, masked=False
|
||||||
|
)
|
||||||
|
|
||||||
|
where = temp[-1] != 0
|
||||||
|
for b in range(0, num_bands):
|
||||||
|
blended = temp[-1] / 255.0 * temp[b] + \
|
||||||
|
(1 - temp[-1] / 255.0) * dstarr[b]
|
||||||
|
np.copyto(dstarr[b], blended,
|
||||||
|
casting='unsafe', where=where)
|
||||||
|
dstarr[-1][where] = 255.0
|
||||||
|
|
||||||
|
# check if dest has any nodata pixels available
|
||||||
|
if np.count_nonzero(dstarr[-1]) == blocksize:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Third pass, write cut rasters
|
||||||
|
# blending the cutlines
|
||||||
|
for _, cut in sources:
|
||||||
|
src_window = tuple(zip(rowcol(
|
||||||
|
cut.transform, left, top, op=round, precision=precision
|
||||||
|
), rowcol(
|
||||||
|
cut.transform, right, bottom, op=round, precision=precision
|
||||||
|
)))
|
||||||
|
|
||||||
|
temp = np.zeros(dst_shape, dtype=dtype)
|
||||||
|
temp = cut.read(
|
||||||
|
out=temp, window=src_window, boundless=True, masked=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each band, average alpha values between
|
||||||
|
# destination raster and cut raster
|
||||||
|
for b in range(0, num_bands):
|
||||||
|
blended = temp[-1] / 255.0 * temp[b] + \
|
||||||
|
(1 - temp[-1] / 255.0) * dstarr[b]
|
||||||
|
np.copyto(dstarr[b], blended,
|
||||||
|
casting='unsafe', where=temp[-1] != 0)
|
||||||
|
|
||||||
|
dstrast.write(dstarr, window=dst_window)
|
||||||
|
|
||||||
|
return output_orthophoto
|
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
numpy==1.26.4
|
||||||
|
pandas==2.2.3
|
||||||
|
scikit-learn==1.6.1
|
||||||
|
matplotlib==3.10.0
|
||||||
|
piexif==1.1.3
|
||||||
|
geopy==2.4.1
|
||||||
|
psutil==6.1.1
|
||||||
|
docker==7.1.0
|
||||||
|
tqdm==4.66.5
|
||||||
|
pyproj==3.7.0
|
||||||
|
rasterio==1.4.3
|
||||||
|
edt==3.0.0
|
||||||
|
opencv-python==4.11.0.86
|
||||||
|
fiona==1.9.5
|
||||||
|
pyinstaller==6.14.1
|
33
trans_orthophoto.py
Normal file
33
trans_orthophoto.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
|
||||||
|
class TransOrthophoto:
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.TransOrthophoto')
|
||||||
|
logging.basicConfig(level=logging.INFO,
|
||||||
|
format='%(asctime)s %(levelname)s %(message)s')
|
||||||
|
|
||||||
|
def trans_to_epsg4326(self, input_img, output_img):
|
||||||
|
# 目标坐标系
|
||||||
|
dst_srs = 'EPSG:4326'
|
||||||
|
# 使用 gdal.Warp 进行重投影
|
||||||
|
gdal.Warp(
|
||||||
|
destNameOrDestDS=output_img,
|
||||||
|
srcDSOrSrcDSTab=input_img,
|
||||||
|
dstSRS=dst_srs,
|
||||||
|
format='GTiff',
|
||||||
|
resampleAlg='near' # 最近邻插值
|
||||||
|
)
|
||||||
|
self.logger.info(f"文件已成功重投影: {output_img}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="正射影像投影转换为EPSG:4326 (GDAL实现)")
|
||||||
|
parser.add_argument('--input_img', required=True, help='输入影像路径')
|
||||||
|
parser.add_argument('--output_img', required=True, help='输出影像路径')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
trans = TransOrthophoto()
|
||||||
|
trans.trans_to_epsg4326(args.input_img, args.output_img)
|
81
utils/directory_manager.py
Normal file
81
utils/directory_manager.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryManager:
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
初始化目录管理器
|
||||||
|
Args:
|
||||||
|
config: 配置对象,包含输入和输出目录等信息
|
||||||
|
"""
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def clean_output_dir(self):
|
||||||
|
"""清理输出目录"""
|
||||||
|
try:
|
||||||
|
if os.path.exists(self.config.output_dir):
|
||||||
|
shutil.rmtree(self.config.output_dir)
|
||||||
|
print(f"已清理输出目录: {self.config.output_dir}")
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"清理输出目录时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def setup_output_dirs(self):
|
||||||
|
"""创建必要的输出目录结构"""
|
||||||
|
try:
|
||||||
|
# 创建主输出目录
|
||||||
|
os.makedirs(self.config.output_dir)
|
||||||
|
|
||||||
|
# 创建过滤图像保存目录
|
||||||
|
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
|
||||||
|
|
||||||
|
# 创建日志目录
|
||||||
|
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
|
||||||
|
|
||||||
|
print(f"已创建输出目录结构: {self.config.output_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"创建输出目录时发生错误: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _get_directory_size(self, path):
|
||||||
|
"""获取目录的总大小(字节)"""
|
||||||
|
total_size = 0
|
||||||
|
for dirpath, dirnames, filenames in os.walk(path):
|
||||||
|
for filename in filenames:
|
||||||
|
file_path = os.path.join(dirpath, filename)
|
||||||
|
try:
|
||||||
|
total_size += os.path.getsize(file_path)
|
||||||
|
except (OSError, FileNotFoundError):
|
||||||
|
continue
|
||||||
|
return total_size
|
||||||
|
|
||||||
|
def check_disk_space(self):
|
||||||
|
"""检查磁盘空间是否足够"""
|
||||||
|
# 获取输入目录大小
|
||||||
|
input_size = self._get_directory_size(self.config.image_dir)
|
||||||
|
|
||||||
|
# 获取输出目录所在磁盘的剩余空间
|
||||||
|
output_drive = os.path.splitdrive(
|
||||||
|
os.path.abspath(self.config.output_dir))[0]
|
||||||
|
if not output_drive: # 处理Linux/Unix路径
|
||||||
|
output_drive = '/home'
|
||||||
|
|
||||||
|
disk_usage = psutil.disk_usage(output_drive)
|
||||||
|
free_space = disk_usage.free
|
||||||
|
|
||||||
|
# 计算所需空间(输入大小的10倍)
|
||||||
|
required_space = input_size * 8
|
||||||
|
|
||||||
|
if free_space < required_space:
|
||||||
|
error_msg = (
|
||||||
|
f"磁盘空间不足!\n"
|
||||||
|
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
|
||||||
|
f"所需空间: {required_space / (1024**3):.2f} GB\n"
|
||||||
|
f"可用空间: {free_space / (1024**3):.2f} GB\n"
|
||||||
|
f"在驱动器 {output_drive}"
|
||||||
|
)
|
||||||
|
raise RuntimeError(error_msg)
|
65
utils/gps_extractor.py
Normal file
65
utils/gps_extractor.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import os
|
||||||
|
from PIL import Image
|
||||||
|
import piexif
|
||||||
|
import logging
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
|
class GPSExtractor:
|
||||||
|
"""从图像文件提取GPS坐标"""
|
||||||
|
|
||||||
|
def __init__(self, image_dir):
|
||||||
|
self.image_dir = image_dir
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _dms_to_decimal(dms):
|
||||||
|
"""将DMS格式转换为十进制度"""
|
||||||
|
return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600
|
||||||
|
|
||||||
|
def get_gps(self, image_path):
|
||||||
|
"""提取单张图片的GPS坐标"""
|
||||||
|
try:
|
||||||
|
image = Image.open(image_path)
|
||||||
|
exif_data = piexif.load(image.info['exif'])
|
||||||
|
|
||||||
|
# 提取GPS信息
|
||||||
|
gps_info = exif_data.get("GPS", {})
|
||||||
|
lat = lon = None
|
||||||
|
if gps_info:
|
||||||
|
lat = self._dms_to_decimal(gps_info.get(2, []))
|
||||||
|
lon = self._dms_to_decimal(gps_info.get(4, []))
|
||||||
|
self.logger.debug(
|
||||||
|
f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}")
|
||||||
|
|
||||||
|
if not gps_info:
|
||||||
|
self.logger.warning(f"图片无GPS信息: {image_path}")
|
||||||
|
|
||||||
|
return lat, lon
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}")
|
||||||
|
return None, None, None
|
||||||
|
|
||||||
|
def extract_all_gps(self):
|
||||||
|
"""提取所有图片的GPS坐标"""
|
||||||
|
self.logger.info(f"开始从目录提取GPS坐标: {self.image_dir}")
|
||||||
|
gps_data = []
|
||||||
|
total_images = 0
|
||||||
|
successful_extractions = 0
|
||||||
|
|
||||||
|
for image_file in os.listdir(self.image_dir):
|
||||||
|
total_images += 1
|
||||||
|
image_path = os.path.join(self.image_dir, image_file)
|
||||||
|
lat, lon = self.get_gps(image_path)
|
||||||
|
if lat and lon: # 仍然以GPS信息作为主要判断依据
|
||||||
|
successful_extractions += 1
|
||||||
|
gps_data.append({
|
||||||
|
'file': image_file,
|
||||||
|
'lat': lat,
|
||||||
|
'lon': lon,
|
||||||
|
})
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"GPS坐标提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}")
|
||||||
|
return pd.DataFrame(gps_data)
|
266
utils/grid_divider.py
Normal file
266
utils/grid_divider.py
Normal file
@ -0,0 +1,266 @@
|
|||||||
|
import logging
|
||||||
|
from geopy.distance import geodesic
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class GridDivider:
|
||||||
|
"""划分网格,并将图片分配到对应网格"""
|
||||||
|
|
||||||
|
def __init__(self, overlap=0.1, grid_size=500, output_dir=None):
|
||||||
|
self.overlap = overlap
|
||||||
|
self.grid_size = grid_size
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.GridDivider')
|
||||||
|
self.logger.info(f"初始化网格划分器,重叠率: {overlap}")
|
||||||
|
self.num_grids_width = 0 # 添加网格数量属性
|
||||||
|
self.num_grids_height = 0
|
||||||
|
|
||||||
|
def adjust_grid_size_and_overlap(self, points_df):
|
||||||
|
"""动态调整网格重叠率"""
|
||||||
|
grids = self.adjust_grid_size(points_df)
|
||||||
|
self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}")
|
||||||
|
while True:
|
||||||
|
# 使用调整好的网格大小划分网格
|
||||||
|
grids = self.divide_grids(points_df)
|
||||||
|
grid_points, multiple_grid_points = self.assign_to_grids(
|
||||||
|
points_df, grids)
|
||||||
|
|
||||||
|
if len(grids) == 1:
|
||||||
|
self.logger.info(f"网格数量为1,跳过重叠率调整")
|
||||||
|
break
|
||||||
|
elif multiple_grid_points < 0.1*len(points_df):
|
||||||
|
self.overlap += 0.02
|
||||||
|
self.logger.info(f"重叠率增加到: {self.overlap}")
|
||||||
|
else:
|
||||||
|
self.logger.info(
|
||||||
|
f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格")
|
||||||
|
break
|
||||||
|
return grids, grid_points
|
||||||
|
|
||||||
|
def adjust_grid_size(self, points_df):
|
||||||
|
"""动态调整网格大小
|
||||||
|
|
||||||
|
Args:
|
||||||
|
points_df: 包含GPS点的DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: grids
|
||||||
|
"""
|
||||||
|
self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}米")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 使用当前grid_size划分网格
|
||||||
|
grids = self.divide_grids(points_df)
|
||||||
|
grid_points, multiple_grid_points = self.assign_to_grids(
|
||||||
|
points_df, grids)
|
||||||
|
|
||||||
|
# 检查每个网格中的点数
|
||||||
|
max_points = 0
|
||||||
|
for grid_id, points in grid_points.items():
|
||||||
|
max_points = max(max_points, len(points))
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}")
|
||||||
|
|
||||||
|
# 如果最大点数超过2000,减小网格大小
|
||||||
|
if max_points > 2000:
|
||||||
|
self.grid_size -= 100
|
||||||
|
self.logger.info(f"点数超过2000,减小网格大小至: {self.grid_size}米")
|
||||||
|
if self.grid_size < 500: # 设置一个最小网格大小限制
|
||||||
|
self.logger.warning("网格大小已达到最小值500米,停止调整")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
self.logger.info(f"找到合适的网格大小: {self.grid_size}米")
|
||||||
|
break
|
||||||
|
return grids
|
||||||
|
|
||||||
|
def divide_grids(self, points_df):
|
||||||
|
"""计算边界框并划分网格
|
||||||
|
Returns:
|
||||||
|
tuple: grids 网格边界列表
|
||||||
|
"""
|
||||||
|
self.logger.info("开始划分网格")
|
||||||
|
|
||||||
|
min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max()
|
||||||
|
min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max()
|
||||||
|
|
||||||
|
# 计算区域的实际距离(米)
|
||||||
|
width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
|
||||||
|
height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters
|
||||||
|
|
||||||
|
self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}米")
|
||||||
|
|
||||||
|
# 精细调整网格的长宽,避免出现2*grid_size-1的情况的影响
|
||||||
|
grid_size_lt = [self.grid_size - 200, self.grid_size - 100,
|
||||||
|
self.grid_size, self.grid_size + 100, self.grid_size + 200]
|
||||||
|
|
||||||
|
width_modulus_lt = [width % grid_size for grid_size in grid_size_lt]
|
||||||
|
grid_width = grid_size_lt[width_modulus_lt.index(
|
||||||
|
min(width_modulus_lt))]
|
||||||
|
height_modulus_lt = [height % grid_size for grid_size in grid_size_lt]
|
||||||
|
grid_height = grid_size_lt[height_modulus_lt.index(
|
||||||
|
min(height_modulus_lt))]
|
||||||
|
self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}米")
|
||||||
|
|
||||||
|
# 计算需要划分的网格数量
|
||||||
|
self.num_grids_width = max(int(width / grid_width), 1)
|
||||||
|
self.num_grids_height = max(int(height / grid_height), 1)
|
||||||
|
|
||||||
|
# 计算每个网格对应的经纬度步长
|
||||||
|
lat_step = (max_lat - min_lat) / self.num_grids_height
|
||||||
|
lon_step = (max_lon - min_lon) / self.num_grids_width
|
||||||
|
|
||||||
|
grids = []
|
||||||
|
|
||||||
|
# 先创建所有网格
|
||||||
|
for i in range(self.num_grids_height):
|
||||||
|
for j in range(self.num_grids_width):
|
||||||
|
grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step
|
||||||
|
grid_max_lat = min_lat + \
|
||||||
|
(i + 1) * lat_step + self.overlap * lat_step
|
||||||
|
grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step
|
||||||
|
grid_max_lon = min_lon + \
|
||||||
|
(j + 1) * lon_step + self.overlap * lon_step
|
||||||
|
|
||||||
|
grid_bounds = (grid_min_lat, grid_max_lat,
|
||||||
|
grid_min_lon, grid_max_lon)
|
||||||
|
grids.append(grid_bounds)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], "
|
||||||
|
f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})")
|
||||||
|
|
||||||
|
return grids
|
||||||
|
|
||||||
|
def assign_to_grids(self, points_df, grids):
|
||||||
|
"""将点分配到对应网格"""
|
||||||
|
self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中")
|
||||||
|
|
||||||
|
grid_points = {} # 使用字典存储每个网格的点
|
||||||
|
points_assigned = 0
|
||||||
|
multiple_grid_points = 0
|
||||||
|
|
||||||
|
for i in range(self.num_grids_height):
|
||||||
|
for j in range(self.num_grids_width):
|
||||||
|
grid_points[(i, j)] = [] # 使用(i,j)元组
|
||||||
|
|
||||||
|
for _, point in points_df.iterrows():
|
||||||
|
point_assigned = False
|
||||||
|
for i in range(self.num_grids_height):
|
||||||
|
for j in range(self.num_grids_width):
|
||||||
|
grid_idx = i * self.num_grids_width + j
|
||||||
|
min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
|
||||||
|
|
||||||
|
if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon:
|
||||||
|
grid_points[(i, j)].append(point.to_dict())
|
||||||
|
if point_assigned:
|
||||||
|
multiple_grid_points += 1
|
||||||
|
else:
|
||||||
|
points_assigned += 1
|
||||||
|
point_assigned = True
|
||||||
|
|
||||||
|
# 记录每个网格的点数
|
||||||
|
for grid_id, points in grid_points.items():
|
||||||
|
self.logger.info(f"网格 {grid_id} 包含 {len(points)} 个点")
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"点分配完成: 总点数 {len(points_df)}, "
|
||||||
|
f"成功分配 {points_assigned} 个点, "
|
||||||
|
f"{multiple_grid_points} 个点被分配到多个网格"
|
||||||
|
)
|
||||||
|
|
||||||
|
return grid_points, multiple_grid_points
|
||||||
|
|
||||||
|
def visualize_grids(self, points_df, grids):
|
||||||
|
"""可视化网格划分和GPS点的分布"""
|
||||||
|
self.logger.info("开始可视化网格划分")
|
||||||
|
|
||||||
|
plt.figure(figsize=(12, 8))
|
||||||
|
|
||||||
|
# 绘制GPS点
|
||||||
|
plt.scatter(points_df['lon'], points_df['lat'],
|
||||||
|
c='blue', s=10, alpha=0.6, label='GPS points')
|
||||||
|
|
||||||
|
# 绘制网格
|
||||||
|
for i in range(self.num_grids_height):
|
||||||
|
for j in range(self.num_grids_width):
|
||||||
|
grid_idx = i * self.num_grids_width + j
|
||||||
|
min_lat, max_lat, min_lon, max_lon = grids[grid_idx]
|
||||||
|
|
||||||
|
# 计算网格的实际长度和宽度(米)
|
||||||
|
width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
|
||||||
|
height = geodesic((min_lat, min_lon),
|
||||||
|
(max_lat, min_lon)).meters
|
||||||
|
|
||||||
|
plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon],
|
||||||
|
[min_lat, min_lat, max_lat, max_lat, min_lat],
|
||||||
|
'r-', alpha=0.5)
|
||||||
|
# 在网格中心添加网格编号和尺寸信息
|
||||||
|
center_lon = (min_lon + max_lon) / 2
|
||||||
|
center_lat = (min_lat + max_lat) / 2
|
||||||
|
plt.text(center_lon, center_lat,
|
||||||
|
f"({i},{j})\n{width:.0f}m×{height:.0f}m", # 显示(i,j)和尺寸
|
||||||
|
horizontalalignment='center',
|
||||||
|
verticalalignment='center',
|
||||||
|
fontsize=8)
|
||||||
|
|
||||||
|
plt.title('Grid Division and GPS Point Distribution')
|
||||||
|
plt.xlabel('Longitude')
|
||||||
|
plt.ylabel('Latitude')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True)
|
||||||
|
|
||||||
|
# 如果提供了输出目录,保存图像
|
||||||
|
if self.output_dir:
|
||||||
|
save_path = os.path.join(
|
||||||
|
self.output_dir, 'filter_imgs', 'grid_division.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
self.logger.info(f"网格划分可视化图已保存至: {save_path}")
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def get_grid_center(self, grid_bounds) -> tuple:
|
||||||
|
"""计算网格中心点的经纬度
|
||||||
|
Args:
|
||||||
|
grid_bounds: (min_lat, max_lat, min_lon, max_lon)
|
||||||
|
Returns:
|
||||||
|
(center_lat, center_lon)
|
||||||
|
"""
|
||||||
|
min_lat, max_lat, min_lon, max_lon = grid_bounds
|
||||||
|
return ((min_lat + max_lat) / 2, (min_lon + max_lon) / 2)
|
||||||
|
|
||||||
|
def calculate_grid_translation(self, reference_grid: tuple, target_grid: tuple) -> tuple:
|
||||||
|
"""计算目标网格相对于参考网格的平移距离(米)
|
||||||
|
Args:
|
||||||
|
reference_grid: 参考网格的边界 (min_lat, max_lat, min_lon, max_lon)
|
||||||
|
target_grid: 目标网格的边界 (min_lat, max_lat, min_lon, max_lon)
|
||||||
|
Returns:
|
||||||
|
(x_translation, y_translation): 在米制单位下的平移量
|
||||||
|
"""
|
||||||
|
ref_center = self.get_grid_center(reference_grid)
|
||||||
|
target_center = self.get_grid_center(target_grid)
|
||||||
|
|
||||||
|
# 计算经度方向的距离(x轴)
|
||||||
|
x_distance = geodesic(
|
||||||
|
(ref_center[0], ref_center[1]),
|
||||||
|
(ref_center[0], target_center[1])
|
||||||
|
).meters
|
||||||
|
# 如果目标在参考点西边,距离为负
|
||||||
|
if target_center[1] < ref_center[1]:
|
||||||
|
x_distance = -x_distance
|
||||||
|
|
||||||
|
# 计算纬度方向的距离(y轴)
|
||||||
|
y_distance = geodesic(
|
||||||
|
(ref_center[0], ref_center[1]),
|
||||||
|
(target_center[0], ref_center[1])
|
||||||
|
).meters
|
||||||
|
# 如果目标在参考点南边,距离为负
|
||||||
|
if target_center[0] < ref_center[0]:
|
||||||
|
y_distance = -y_distance
|
||||||
|
|
||||||
|
return (x_distance, y_distance)
|
36
utils/logger.py
Normal file
36
utils/logger.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logger(output_dir):
|
||||||
|
# 创建logs目录
|
||||||
|
log_dir = os.path.join(output_dir, 'logs')
|
||||||
|
|
||||||
|
# 创建日志文件名(包含时间戳)
|
||||||
|
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log')
|
||||||
|
|
||||||
|
# 配置日志格式
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置文件处理器
|
||||||
|
file_handler = logging.FileHandler(log_file, encoding='utf-8')
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 配置控制台处理器
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# 获取根日志记录器
|
||||||
|
logger = logging.getLogger('UAV_Preprocess')
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
# 添加处理器
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
return logger
|
144
utils/odm_monitor.py
Normal file
144
utils/odm_monitor.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
import pandas as pd
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
|
class ODMProcessMonitor:
|
||||||
|
"""ODM处理监控器"""
|
||||||
|
|
||||||
|
def __init__(self, output_dir: str, mode: str = "三维模式", config=None):
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor')
|
||||||
|
self.mode = mode
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple) -> Tuple[bool, str]:
|
||||||
|
"""运行ODM命令"""
|
||||||
|
self.logger.info(f"开始处理网格 ({grid_id[0]},{grid_id[1]})")
|
||||||
|
success = False
|
||||||
|
error_msg = ""
|
||||||
|
max_retries = 3
|
||||||
|
current_try = 0
|
||||||
|
cpu_cores = os.cpu_count()
|
||||||
|
|
||||||
|
while current_try < max_retries:
|
||||||
|
current_try += 1
|
||||||
|
self.logger.info(
|
||||||
|
f"第 {current_try} 次尝试处理网格 ({grid_id[0]},{grid_id[1]})")
|
||||||
|
|
||||||
|
# 构建 Docker 容器运行参数
|
||||||
|
grid_dir_fixed = grid_dir[0].lower() + grid_dir[1:].replace('\\', '/')
|
||||||
|
command = (
|
||||||
|
f"--max-concurrency {cpu_cores} "
|
||||||
|
f"--force-gps "
|
||||||
|
f"--use-exif "
|
||||||
|
f"--use-hybrid-bundle-adjustment "
|
||||||
|
f"--optimize-disk-space "
|
||||||
|
f"--orthophoto-cutline "
|
||||||
|
)
|
||||||
|
|
||||||
|
command += f"--feature-type {self.config.feature_type} "
|
||||||
|
command += f"--orthophoto-resolution {self.config.orthophoto_resolution} "
|
||||||
|
|
||||||
|
if self.mode == "快拼模式":
|
||||||
|
command += (
|
||||||
|
f"--fast-orthophoto "
|
||||||
|
f"--skip-3dmodel "
|
||||||
|
)
|
||||||
|
else: # 三维模式参数
|
||||||
|
command += (
|
||||||
|
f"--dsm "
|
||||||
|
f"--dtm "
|
||||||
|
)
|
||||||
|
if current_try == 1:
|
||||||
|
command += (
|
||||||
|
f"--feature-quality low "
|
||||||
|
)
|
||||||
|
|
||||||
|
command += "--rerun-all"
|
||||||
|
|
||||||
|
docker_cmd = f'python3 /code/run.py --project-path {grid_dir} project {command}'
|
||||||
|
self.logger.info(f"Docker 命令: {docker_cmd}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 执行命令
|
||||||
|
result = subprocess.run(
|
||||||
|
docker_cmd,
|
||||||
|
shell=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
encoding="utf-8"
|
||||||
|
)
|
||||||
|
if result.returncode != 0:
|
||||||
|
self.logger.error(
|
||||||
|
f"容器运行失败,退出状态码: {result.returncode}")
|
||||||
|
error_lines = result.stderr.splitlines()
|
||||||
|
self.logger.error("容器运行失败的最后 50 行错误日志:")
|
||||||
|
for line in error_lines[-50:]:
|
||||||
|
self.logger.error(line)
|
||||||
|
error_msg = "\n".join(error_lines[-50:])
|
||||||
|
time.sleep(5)
|
||||||
|
else:
|
||||||
|
logs = result.stdout.splitlines()
|
||||||
|
self.logger.info("容器运行完成,以下是最后 50 行日志:")
|
||||||
|
for line in logs[-50:]:
|
||||||
|
self.logger.info(line)
|
||||||
|
success = True
|
||||||
|
error_msg = ""
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
self.logger.error(f"执行docker命令时发生异常: {error_msg}")
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
return success, error_msg
|
||||||
|
|
||||||
|
def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame]) -> list:
|
||||||
|
"""处理所有网格
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[tuple, pd.DataFrame]: 成功处理的网格点数据字典
|
||||||
|
"""
|
||||||
|
self.logger.info("开始执行网格处理")
|
||||||
|
successful_grid_lt = []
|
||||||
|
failed_grids = []
|
||||||
|
|
||||||
|
for grid_id, points in grid_points.items():
|
||||||
|
grid_dir = os.path.join(
|
||||||
|
self.output_dir, f'grid_{grid_id[0]}_{grid_id[1]}'
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
success, error_msg = self.run_odm_with_monitor(
|
||||||
|
grid_dir=grid_dir,
|
||||||
|
grid_id=grid_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
successful_grid_lt.append(grid_id)
|
||||||
|
else:
|
||||||
|
self.logger.error(
|
||||||
|
f"网格 ({grid_id[0]},{grid_id[1]})")
|
||||||
|
failed_grids.append(grid_id)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = str(e)
|
||||||
|
self.logger.error(
|
||||||
|
f"处理网格 ({grid_id[0]},{grid_id[1]}) 时发生异常: {error_msg}")
|
||||||
|
failed_grids.append((grid_id, error_msg))
|
||||||
|
|
||||||
|
# 汇总处理结果
|
||||||
|
total_grids = len(grid_points)
|
||||||
|
failed_count = len(failed_grids)
|
||||||
|
success_count = len(successful_grid_lt)
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"网格处理完成。总计: {total_grids}, 成功: {success_count}, 失败: {failed_count}")
|
||||||
|
|
||||||
|
if len(successful_grid_lt) == 0:
|
||||||
|
raise Exception("所有网格处理都失败,无法继续处理")
|
||||||
|
|
||||||
|
return successful_grid_lt
|
152
utils/visualizer.py
Normal file
152
utils/visualizer.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
from pyproj import Transformer
|
||||||
|
|
||||||
|
|
||||||
|
class FilterVisualizer:
|
||||||
|
"""过滤结果可视化器"""
|
||||||
|
|
||||||
|
def __init__(self, output_dir: str):
|
||||||
|
"""
|
||||||
|
初始化可视化器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_dir: 输出目录路径
|
||||||
|
"""
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
|
||||||
|
# 创建坐标转换器
|
||||||
|
self.transformer = Transformer.from_crs(
|
||||||
|
"EPSG:4326", # WGS84经纬度坐标系
|
||||||
|
"EPSG:32649", # UTM49N
|
||||||
|
always_xy=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple:
|
||||||
|
"""
|
||||||
|
将经纬度坐标转换为UTM坐标
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lon: 经度序列
|
||||||
|
lat: 纬度序列
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (x坐标, y坐标)
|
||||||
|
"""
|
||||||
|
return self.transformer.transform(lon, lat)
|
||||||
|
|
||||||
|
def visualize_filter_step(self,
|
||||||
|
current_points: pd.DataFrame,
|
||||||
|
previous_points: pd.DataFrame,
|
||||||
|
step_name: str,
|
||||||
|
save_name: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
可视化单个过滤步骤的结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_points: 当前步骤后的点
|
||||||
|
previous_points: 上一步骤的点
|
||||||
|
step_name: 步骤名称
|
||||||
|
save_name: 保存文件名,默认为step_name
|
||||||
|
"""
|
||||||
|
self.logger.info(f"开始生成{step_name}的可视化结果")
|
||||||
|
|
||||||
|
# 找出被过滤掉的点
|
||||||
|
filtered_files = set(
|
||||||
|
previous_points['file']) - set(current_points['file'])
|
||||||
|
filtered_points = previous_points[previous_points['file'].isin(
|
||||||
|
filtered_files)]
|
||||||
|
|
||||||
|
# 转换坐标到UTM
|
||||||
|
current_x, current_y = self._convert_to_utm(
|
||||||
|
current_points['lon'], current_points['lat'])
|
||||||
|
filtered_x, filtered_y = self._convert_to_utm(
|
||||||
|
filtered_points['lon'], filtered_points['lat'])
|
||||||
|
|
||||||
|
# 创建图形
|
||||||
|
plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体
|
||||||
|
plt.rcParams['axes.unicode_minus'] = False
|
||||||
|
plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
|
# 绘制保留的点
|
||||||
|
plt.scatter(current_x, current_y,
|
||||||
|
color='blue', label='保留的点',
|
||||||
|
alpha=0.6, s=5)
|
||||||
|
|
||||||
|
# 绘制被过滤的点
|
||||||
|
if not filtered_points.empty:
|
||||||
|
plt.scatter(filtered_x, filtered_y,
|
||||||
|
color='red', marker='x', label='过滤的点')
|
||||||
|
|
||||||
|
# 设置图形属性
|
||||||
|
plt.title(f"{step_name}后的GPS点\n"
|
||||||
|
f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})",
|
||||||
|
fontsize=14)
|
||||||
|
plt.xlabel("东向坐标 (米)", fontsize=12)
|
||||||
|
plt.ylabel("北向坐标 (米)", fontsize=12)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.axis('equal')
|
||||||
|
|
||||||
|
# 添加统计信息
|
||||||
|
stats_text = (
|
||||||
|
f"原始点数: {len(previous_points)}\n"
|
||||||
|
f"过滤点数: {len(filtered_points)}\n"
|
||||||
|
f"保留点数: {len(current_points)}\n"
|
||||||
|
f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%"
|
||||||
|
)
|
||||||
|
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
|
||||||
|
bbox=dict(facecolor='white', alpha=0.8))
|
||||||
|
|
||||||
|
# 添加图例
|
||||||
|
plt.legend(loc='upper right', fontsize=10)
|
||||||
|
|
||||||
|
# 调整布局
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图形
|
||||||
|
save_name = save_name or step_name.lower().replace(' ', '_')
|
||||||
|
save_path = os.path.join(
|
||||||
|
self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
self.logger.info(
|
||||||
|
f"{step_name}过滤可视化结果已保存至 {save_path}\n"
|
||||||
|
f"过滤掉 {len(filtered_points)} 个点,"
|
||||||
|
f"保留 {len(current_points)} 个点,"
|
||||||
|
f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试代码
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
np.random.seed(42)
|
||||||
|
n_points = 1000
|
||||||
|
|
||||||
|
# 生成随机点
|
||||||
|
test_data = pd.DataFrame({
|
||||||
|
'lon': np.random.uniform(120, 121, n_points),
|
||||||
|
'lat': np.random.uniform(30, 31, n_points),
|
||||||
|
'file': [f'img_{i}.jpg' for i in range(n_points)],
|
||||||
|
'date': [datetime.now() for _ in range(n_points)]
|
||||||
|
})
|
||||||
|
|
||||||
|
# 随机选择点作为过滤后的结果
|
||||||
|
filtered_data = test_data.sample(n=800)
|
||||||
|
|
||||||
|
# 测试可视化
|
||||||
|
visualizer = FilterVisualizer('test_output')
|
||||||
|
os.makedirs('test_output', exist_ok=True)
|
||||||
|
|
||||||
|
visualizer.visualize_filter_step(
|
||||||
|
filtered_data,
|
||||||
|
test_data,
|
||||||
|
"Test Filter"
|
||||||
|
)
|
Loading…
Reference in New Issue
Block a user