Compare commits

..

28 Commits

Author SHA1 Message Date
weixin_46229132
94a0854175 再次修改odm参数 2025-04-18 17:07:36 +08:00
weixin_46229132
aef169e916 修改odm参数 2025-04-18 14:48:57 +08:00
weixin_46229132
a3a5e5738a 可视化模块修改 2025-04-12 22:48:07 +08:00
weixin_46229132
5c382e1810 精简代码 2025-04-11 17:22:11 +08:00
weixin_46229132
e86fe196f8 删除不必要的代码 2025-04-11 15:38:15 +08:00
weixin_46229132
e57c3b6ec9 加入obj分块功能 2025-04-11 10:47:28 +08:00
weixin_46229132
a25adbcc31 修改merge_tif,使用ODM的cut,feather和merge函数 2025-04-10 18:57:34 +08:00
weixin_46229132
697660b5b3 修改轨迹图的坐标系 2025-04-06 11:09:16 +08:00
weixin_46229132
b69a610dd2 修复merge_tif bug,TODO: 加入conv_obj2代码(切分obj) 2025-03-29 14:50:38 +08:00
weixin_46229132
4828544ad6 加入3dtiles参数 2025-03-18 10:29:08 +08:00
weixin_46229132
d035d27975 快拼模式不生产高程 2025-03-17 16:34:01 +08:00
weixin_46229132
4b8038b1b2 添加高精度模式 2025-03-17 16:07:53 +08:00
weixin_46229132
b1bcdd5f5f 修改obj文件的z坐标 2025-02-21 10:41:41 +08:00
weixin_46229132
fe17a96e54 增加一个odm可选参数 2025-02-20 20:12:33 +08:00
weixin_46229132
696fe7fe7e z轴平移 2025-02-20 20:04:14 +08:00
weixin_46229132
3b6af50acc 修改网格划分,给出warning 2025-02-19 10:06:17 +08:00
weixin_46229132
ffc217fe53 修改ODM参数 2025-02-18 11:19:08 +08:00
weixin_46229132
f6d5e6cd0e 修改默认参数,针对三维模式 2025-02-18 11:16:46 +08:00
weixin_46229132
105d113e4b 修改obj的法向量,subprocess.Popen 2025-02-18 10:30:49 +08:00
weixin_46229132
30a451a19b 修改对齐方式,使用25d纹理模型 2025-02-17 19:40:45 +08:00
weixin_46229132
5ebbc664a1 不过滤孤立点 2025-02-15 15:57:27 +08:00
weixin_46229132
554bf319d5 main里不用fast 2025-02-15 15:18:37 +08:00
weixin_46229132
86940bd1b9 obj到osgb的坐标系变换 2025-02-15 14:53:02 +08:00
weixin_46229132
d05f278d79 修改odm运行参数 2025-02-15 09:49:40 +08:00
weixin_46229132
f6d6f112c9 通过UTM坐标系进行偏移 2025-02-14 20:14:06 +08:00
weixin_46229132
971517c145 修改瓦片的偏移 2025-02-06 19:01:19 +08:00
weixin_46229132
a3951c47d0 增加osgb合并功能 2025-02-06 18:14:56 +08:00
weixin_46229132
df74970ca9 增加obj格式转换功能 2025-02-06 16:54:23 +08:00
27 changed files with 1039 additions and 3582 deletions

160
app_plugin.py Normal file
View File

@ -0,0 +1,160 @@
import os
import shutil
from dataclasses import dataclass
from typing import Dict, Tuple
import psutil
import pandas as pd
from filter.cluster_filter import GPSCluster
from utils.directory_manager import DirectoryManager
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from post_pro.conv_obj import ConvertOBJ
@dataclass
class ProcessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
mode: str = "三维模式"
class ODM_Plugin:
def __init__(self, config):
self.config = config
# 初始化目录管理器
self.dir_manager = DirectoryManager(config)
# 清理并重建输出目录
self.dir_manager.clean_output_dir()
self.dir_manager.setup_output_dirs()
# 检查磁盘空间
self.dir_manager.check_disk_space()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = pd.DataFrame(columns=["file", "lat", "lon"])
self.odm_monitor = ODMProcessMonitor(
config.output_dir, mode=config.mode)
self.visualizer = FilterVisualizer(config.output_dir)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points,
eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering")
def divide_grids(self) -> Dict[tuple, pd.DataFrame]:
"""划分网格
Returns:
- grid_points: 网格点数据字典
- translations: 网格平移量字典
"""
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
grid_size=self.config.grid_size,
output_dir=self.config.output_dir
)
grids, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points
)
grid_divider.visualize_grids(self.gps_points, grids)
if len(grids) >= 20:
self.logger.warning("网格数量已超过20, 需要人工调整分区")
return grid_points
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
"""复制图像到目标文件夹"""
self.logger.info("开始复制图像文件")
for grid_id, points in grid_points.items():
output_dir = os.path.join(
self.config.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"images"
)
os.makedirs(output_dir, exist_ok=True)
for point in points:
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
def merge_tif(self, grid_lt):
"""合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir)
merger.merge_orthophoto(grid_lt)
def convert_obj(self, grid_lt):
"""转换OBJ模型"""
self.logger.info("开始转换OBJ模型")
converter = ConvertOBJ(self.config.output_dir)
converter.convert_grid_obj(grid_lt)
def post_process(self, successful_grid_lt: list, grid_points: Dict[tuple, pd.DataFrame]):
"""后处理:合并或复制处理结果"""
if len(successful_grid_lt) < len(grid_points):
self.logger.warning(
f"{len(grid_points) - len(successful_grid_lt)} 个网格处理失败,"
f"将只合并成功处理的 {len(successful_grid_lt)} 个网格"
)
self.merge_tif(successful_grid_lt)
if self.config.mode == "三维模式":
self.convert_obj(successful_grid_lt)
else:
pass
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.logger.info("预处理任务完成")
successful_grid_lt = self.odm_monitor.process_all_grids(
grid_points)
self.post_process(successful_grid_lt, grid_points)
self.logger.info("重建任务完成")
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise

View File

@ -1,248 +0,0 @@
import os
import math
from itertools import combinations
import numpy as np
from scipy.spatial import KDTree
import logging
import pandas as pd
from datetime import datetime, timedelta
class GPSFilter:
"""过滤密集点及孤立点"""
def __init__(self, output_dir):
self.logger = logging.getLogger('UAV_Preprocess.GPSFilter')
@staticmethod
def _haversine(lat1, lon1, lat2, lon2):
"""计算两点之间的地理距离(单位:米)"""
R = 6371000 # 地球平均半径,单位:米
phi1, phi2 = math.radians(lat1), math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = math.sin(delta_phi / 2) ** 2 + math.cos(phi1) * \
math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return R * c
@staticmethod
def _assign_to_grid(lat, lon, grid_size, min_lat, min_lon):
"""根据经纬度和网格大小,将点分配到网格"""
grid_x = int((lat - min_lat) // grid_size)
grid_y = int((lon - min_lon) // grid_size)
return grid_x, grid_y
def _get_distances(self, points_df, grid_size):
"""读取图片 GPS 坐标,计算点对之间的距离并排序"""
# 确定经纬度范围
min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max()
min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max()
self.logger.info(
f"经纬度范围:纬度[{min_lat:.6f}, {max_lat:.6f}],纬度范围[{max_lat-min_lat:.6f}]"
f"经度[{min_lon:.6f}, {max_lon:.6f}],经度范围[{max_lon-min_lon:.6f}]")
# 分配到网格
grid_map = {}
for _, row in points_df.iterrows():
grid = self._assign_to_grid(
row['lat'], row['lon'], grid_size, min_lat, min_lon)
if grid not in grid_map:
grid_map[grid] = []
grid_map[grid].append((row['file'], row['lat'], row['lon']))
self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中")
# 在每个网格中计算两两距离并排序
sorted_distances = {}
for grid, images in grid_map.items():
distances = []
for (img1, lat1, lon1), (img2, lat2, lon2) in combinations(images, 2):
dist = self._haversine(lat1, lon1, lat2, lon2)
distances.append((img1, img2, dist))
distances.sort(key=lambda x: x[2]) # 按距离升序排序
sorted_distances[grid] = distances
self.logger.debug(f"网格 {grid} 中计算了 {len(distances)} 个距离对")
return sorted_distances
def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list:
"""根据拍摄时间分组图片
如果相邻两张图片的拍摄时间差超过5分钟则进行切分
Args:
points_df: 包含图片信息的DataFrame必须包含'file''date'
time_threshold: 时间间隔阈值默认5分钟
Returns:
list: 每个元素为时间组内的点数据
"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return [points_df]
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
self.logger.info(
f"有效时间范围: {valid_date_points['date'].min()}{valid_date_points['date'].max()}")
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
# 记录断点信息
break_time = valid_date_points.iloc[idx]['date']
group_start_time = current_group.iloc[0]['date']
group_end_time = current_group.iloc[-1]['date']
self.logger.info(
f"时间组 {len(time_groups)}: {len(current_group)} 个点, "
f"时间范围 [{group_start_time} - {group_end_time}]"
)
self.logger.info(
f"在时间 {break_time} 处发现断点,时间差为 {time_diff}")
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
self.logger.info(
f"时间组 {len(time_groups)}: {len(last_group)} 个点, "
f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]"
)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点")
self.logger.info(f"共分为 {len(time_groups)} 个时间组")
return time_groups
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)):
"""
过滤密集点先按时间分组再在每个时间组内过滤
空时间戳的点不进行过滤
Args:
points_df: 点数据
grid_size: 网格大小
distance_threshold: 距离阈值
time_interval: 时间间隔
"""
self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, "
f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)")
# 按时间分组
time_groups = self._group_by_time(points_df, time_threshold)
# 存储所有要删除的图片
all_to_del_imgs = []
# 对每个时间组进行密集点过滤
for group_idx, group_points in enumerate(time_groups):
# 检查是否为空时间戳组(最后一组)
if group_idx == len(time_groups) - 1 and group_points['date'].isna().any():
self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)")
continue
self.logger.info(
f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)")
# 计算该组内的点间距离
sorted_distances = self._get_distances(group_points, grid_size)
group_to_del_imgs = []
# 在每个网格中过滤密集点
for grid, distances in sorted_distances.items():
grid_del_count = 0
while distances:
candidate_img1, candidate_img2, dist = distances[0]
if dist < distance_threshold:
distances.pop(0)
# 获取候选图片的其他最短距离
candidate_img1_dist = None
candidate_img2_dist = None
for distance in distances:
if candidate_img1 in distance:
candidate_img1_dist = distance[2]
break
for distance in distances:
if candidate_img2 in distance:
candidate_img2_dist = distance[2]
break
# 选择要删除的点
if candidate_img1_dist and candidate_img2_dist:
to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2
group_to_del_imgs.append(to_del_img)
grid_del_count += 1
self.logger.debug(
f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)")
distances = [
d for d in distances if to_del_img not in d]
else:
break
if grid_del_count > 0:
self.logger.info(
f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点")
all_to_del_imgs.extend(group_to_del_imgs)
self.logger.info(
f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点")
# 过滤数据
filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)]
self.logger.info(
f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df
def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6):
"""过滤孤立点"""
self.logger.info(
f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})")
coords = points_df[['lat', 'lon']].values
kdtree = KDTree(coords)
neighbors_count = [len(kdtree.query_ball_point(
coord, threshold_distance)) for coord in coords]
isolated_points = []
for i, (_, row) in enumerate(points_df.iterrows()):
if neighbors_count[i] < min_neighbors:
isolated_points.append(row['file'])
self.logger.debug(
f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})")
filtered_df = points_df[~points_df['file'].isin(isolated_points)]
self.logger.info(
f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df

View File

@ -1,201 +0,0 @@
import shutil
import pandas as pd
from shapely.geometry import box
from utils.logger import setup_logger
from utils.gps_extractor import GPSExtractor
import numpy as np
import logging
from datetime import timedelta
import matplotlib.pyplot as plt
import os
import sys
class TimeGroupOverlapFilter:
"""基于时间组重叠度的图像过滤器"""
def __init__(self, image_dir: str, output_dir: str, overlap_threshold: float = 0.7):
"""
初始化过滤器
Args:
image_dir: 图像目录
output_dir: 输出目录
overlap_threshold: 重叠阈值默认0.7
"""
self.image_dir = image_dir
self.output_dir = output_dir
self.overlap_threshold = overlap_threshold
self.logger = logging.getLogger('UAV_Preprocess.TimeGroupFilter')
def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)):
"""按时间间隔对点进行分组"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return []
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
return time_groups
def _get_group_bbox(self, group_df):
"""获取组内点的边界框"""
min_lon = group_df['lon'].min()
max_lon = group_df['lon'].max()
min_lat = group_df['lat'].min()
max_lat = group_df['lat'].max()
return box(min_lon, min_lat, max_lon, max_lat)
def _calculate_overlap(self, box1, box2):
"""计算两个边界框的重叠率"""
if box1.intersects(box2):
intersection_area = box1.intersection(box2).area
smaller_area = min(box1.area, box2.area)
if smaller_area == 0:
overlap_ratio = 1
else:
overlap_ratio = intersection_area / smaller_area
else:
overlap_ratio = 0
return overlap_ratio
def filter_overlapping_groups(self, gps_points, time_threshold=timedelta(minutes=5)):
"""过滤重叠的时间组"""
# 按时间分组
self.logger.info("开始过滤重叠时间组")
time_groups = self._group_by_time(gps_points, time_threshold)
# 计算每个组的边界框
group_boxes = []
for idx, group in enumerate(time_groups):
if not group['date'].isna().any(): # 只处理有时间戳的组
bbox = self._get_group_bbox(group)
group_boxes.append((idx, group, bbox))
# 找出需要删除的组
groups_to_delete = set()
for i in range(len(group_boxes)):
if i in groups_to_delete:
continue
idx1, group1, box1 = group_boxes[i]
area1 = box1.area
for j in range(i + 1, len(group_boxes)):
if j in groups_to_delete:
continue
idx2, group2, box2 = group_boxes[j]
area2 = box2.area
overlap_ratio = self._calculate_overlap(box1, box2)
if overlap_ratio > self.overlap_threshold:
# 删除面积较小的组
if area1 < area2:
group_to_delete = idx1
smaller_area = area1
larger_area = area2
else:
group_to_delete = idx2
smaller_area = area2
larger_area = area1
groups_to_delete.add(group_to_delete)
self.logger.info(
f"时间组 {group_to_delete + 1} 与时间组 "
f"{idx2 + 1 if group_to_delete == idx1 else idx1 + 1} "
f"重叠率为 {overlap_ratio:.2f}"
f"面积比为 {smaller_area/larger_area:.2f}"
f"将删除较小面积的组 {group_to_delete + 1}"
)
# 删除重复组的图像
deleted_files = []
for group_idx in groups_to_delete:
group_files = time_groups[group_idx]['file'].tolist()
deleted_files.extend(group_files)
self.logger.info(f"共删除 {len(groups_to_delete)} 个重复时间组,"
f"{len(deleted_files)} 张图像")
# 可视化结果
self._visualize_results(time_groups, groups_to_delete)
retained_points = gps_points[~gps_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(retained_points)} 个GPS点")
return retained_points
def _visualize_results(self, time_groups, groups_to_delete):
"""可视化过滤结果"""
plt.figure(figsize=(15, 10))
# 生成不同的颜色
colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups)))
# 绘制所有组的边界框
for idx, (group, color) in enumerate(zip(time_groups, colors)):
if not group['date'].isna().any(): # 只处理有时间戳的组
bbox = self._get_group_bbox(group)
x, y = bbox.exterior.xy
if idx in groups_to_delete:
# 被删除的组用虚线表示
plt.plot(x, y, '--', color=color, alpha=0.6,
label=f'Deleted Group {idx + 1}')
else:
# 保留的组用实线表示
plt.plot(x, y, '-', color=color, alpha=0.6,
label=f'Group {idx + 1}')
# 绘制该组的GPS点
plt.scatter(group['lon'], group['lat'], color=color,
s=30, alpha=0.6)
plt.title("Time Groups and Their Bounding Boxes", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()
# 保存图片
plt.savefig(os.path.join(self.output_dir, 'filter_imgs', 'time_groups_overlap_bbox.png'),
dpi=300, bbox_inches='tight')
plt.close()

View File

@ -1,236 +0,0 @@
from utils.gps_extractor import GPSExtractor
import os
import sys
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
import pandas as pd
from matplotlib.font_manager import FontProperties
class GPSSelector:
def __init__(self, image_dir: str, output_dir: str = None):
# 移除中文字体设置
self.image_dir = image_dir
self.output_dir = output_dir
self.gps_points = None
self.selected_points = []
self.fig, self.ax = plt.subplots(figsize=(12, 8))
self.scatter = None
self.rs = None
self.setup_plot()
def extract_gps(self):
"""提取GPS数据"""
extractor = GPSExtractor(self.image_dir)
self.gps_points = extractor.extract_all_gps()
print(f"成功提取 {len(self.gps_points)} 个GPS点")
def setup_plot(self):
"""设置绘图"""
self.ax.set_title('GPS Points - Use mouse to drag and select points to delete')
self.ax.set_xlabel('Longitude')
self.ax.set_ylabel('Latitude')
self.ax.grid(True)
# 设置坐标轴使用相同的比例
self.ax.set_aspect('equal')
# 设置矩形选择器
self.rs = RectangleSelector(
self.ax, self.on_select,
interactive=True,
useblit=True,
button=[1], # 只响应左键
props=dict(facecolor='red', alpha=0.3)
)
# 添加按钮回调
self.fig.canvas.mpl_connect('key_press_event', self.on_key_press)
# 添加缩放和平移功能
self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
self.fig.canvas.mpl_connect('button_press_event', self.on_press)
self.fig.canvas.mpl_connect('button_release_event', self.on_release)
self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion)
# 用于平移功能的变量
self._pan_start = None
def plot_gps_points(self):
"""绘制GPS点"""
if self.scatter is not None:
self.scatter.remove()
# 计算经纬度的范围
lon_range = self.gps_points['lon'].max() - self.gps_points['lon'].min()
lat_range = self.gps_points['lat'].max() - self.gps_points['lat'].min()
# 设置合适的图形大小,保持经纬度的真实比例
aspect_ratio = lon_range / lat_range
fig_width = 12
fig_height = fig_width / aspect_ratio
self.fig.set_size_inches(fig_width, fig_height)
self.scatter = self.ax.scatter(
self.gps_points['lon'],
self.gps_points['lat'],
c='blue',
s=20,
alpha=0.6
)
# 设置适当的显示范围,添加一些边距
margin = 0.1
x_margin = lon_range * margin
y_margin = lat_range * margin
self.ax.set_xlim([
self.gps_points['lon'].min() - x_margin,
self.gps_points['lon'].max() + x_margin
])
self.ax.set_ylim([
self.gps_points['lat'].min() - y_margin,
self.gps_points['lat'].max() + y_margin
])
# 关闭自动缩放
self.ax.autoscale(False)
# 使用更精确的刻度
self.ax.ticklabel_format(useOffset=False, style='plain')
self.fig.canvas.draw_idle()
def on_select(self, eclick, erelease):
"""矩形选择回调"""
x1, y1 = eclick.xdata, eclick.ydata
x2, y2 = erelease.xdata, erelease.ydata
# 获取选中区域内的点
mask = (
(self.gps_points['lon'] >= min(x1, x2)) &
(self.gps_points['lon'] <= max(x1, x2)) &
(self.gps_points['lat'] >= min(y1, y2)) &
(self.gps_points['lat'] <= max(y1, y2))
)
selected = self.gps_points[mask]
self.selected_points.extend(selected['file'].tolist())
# 从数据中移除选中的点
self.gps_points = self.gps_points[~mask]
# 更新绘图
self.plot_gps_points()
print(f"选中 {len(selected)} 个点,剩余 {len(self.gps_points)} 个点")
def on_key_press(self, event):
"""键盘事件回调"""
if event.key == 'enter':
self.save_results()
plt.close()
elif event.key == 'escape':
plt.close()
def save_results(self):
"""保存结果"""
if not self.output_dir:
return
# 创建输出目录
os.makedirs(self.output_dir, exist_ok=True)
# 获取所有保留的图片文件名
remaining_files = self.gps_points['file'].tolist()
# 移动保留的图片到输出目录
for img_name in remaining_files:
src = os.path.join(self.image_dir, img_name)
dst = os.path.join(self.output_dir, img_name)
shutil.copy2(src, dst) # 使用copy2保留文件的元数据
# 保存剩余点的信息
self.gps_points.to_csv(
os.path.join(self.output_dir, "remaining_points.csv"),
index=False
)
print(f"已选择删除 {len(self.selected_points)} 张图片")
print(f"已复制 {len(remaining_files)} 张保留的图片到 {self.output_dir}")
def run(self):
"""运行选择器"""
self.extract_gps()
self.plot_gps_points()
plt.show()
def on_scroll(self, event):
"""鼠标滚轮缩放"""
if event.inaxes != self.ax:
return
# 获取当前视图范围
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
# 缩放因子
base_scale = 1.1
xdata = event.xdata
ydata = event.ydata
if event.button == 'up':
# 放大
scale_factor = 1/base_scale
else:
# 缩小
scale_factor = base_scale
# 设置新的视图范围
new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor
self.ax.set_xlim([xdata - new_width * (xdata - cur_xlim[0]) / (cur_xlim[1] - cur_xlim[0]),
xdata + new_width * (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])])
self.ax.set_ylim([ydata - new_height * (ydata - cur_ylim[0]) / (cur_ylim[1] - cur_ylim[0]),
ydata + new_height * (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])])
self.fig.canvas.draw_idle()
def on_press(self, event):
"""鼠标按下事件"""
if event.inaxes != self.ax or event.button != 3: # 只响应右键
return
self._pan_start = (event.xdata, event.ydata)
def on_release(self, event):
"""鼠标释放事件"""
self._pan_start = None
def on_motion(self, event):
"""鼠标移动事件"""
if self._pan_start is None or event.inaxes != self.ax:
return
# 计算移动距离
dx = event.xdata - self._pan_start[0]
dy = event.ydata - self._pan_start[1]
# 更新视图范围
cur_xlim = self.ax.get_xlim()
cur_ylim = self.ax.get_ylim()
self.ax.set_xlim(cur_xlim - dx)
self.ax.set_ylim(cur_ylim - dy)
self.fig.canvas.draw_idle()
if __name__ == "__main__":
# 使用示例
selector = GPSSelector(
image_dir=r"G:\error_data\20240930091614\project\images",
output_dir=r"C:\datasets\ODM_output\error1_L"
)
selector.run()

297
grid.py
View File

@ -1,297 +0,0 @@
import csv
import numpy as np
import matplotlib.pyplot as plt
import math
from shapely.geometry import box, MultiPoint
from shapely.ops import unary_union
from scipy.spatial import cKDTree
from utils.gps_extractor import GPSExtractor
# ---------------------- overlap 截断为不超过 10% ----------------------
def clamp_overlap(overlap):
if overlap < 0:
return 0.0
elif overlap > 0.1:
return 0.1
else:
return overlap
# ====================== 1) 生成可用矩形并记录其覆盖点集 ======================
def generate_rectangles_with_point_indices(points, w, h, overlap=0.1, min_points=800):
"""
bounding box (w, h) + overlap 布置网格生成所有矩形
过滤只保留"矩形内点数 >= min_points"的矩形
返回:
rect_info_list: list of (rect_polygon, covered_indices)
- rect_polygon: Shapely Polygon
- covered_indices: 一个 set表示该矩形覆盖的所有点索引
"""
overlap = clamp_overlap(overlap)
if len(points) == 0:
return []
minx, miny = np.min(points, axis=0)
maxx, maxy = np.max(points, axis=0)
# 特殊情况:只有一个点或非常小范围 -> 很难满足 800 点
if abs(maxx - minx) < 1e-15 and abs(maxy - miny) < 1e-15:
return []
# 步长
step_x = w * (1 - overlap)
step_y = h * (1 - overlap)
x_coords = np.arange(minx, maxx + step_x, step_x)
y_coords = np.arange(miny, maxy + step_y, step_y)
# 建立 KDTree加速查找
tree = cKDTree(points)
rect_info_list = []
for x in x_coords:
for y in y_coords:
rect_poly = box(x, y, x + w, y + h)
rx_min, ry_min, rx_max, ry_max = rect_poly.bounds
cx = (rx_min + rx_max) / 2
cy = (ry_min + ry_max) / 2
r = math.sqrt((rx_max - rx_min) ** 2 + (ry_max - ry_min) ** 2) / 2
candidate_ids = tree.query_ball_point([cx, cy], r)
if not candidate_ids:
continue
covered_set = set()
for idx_pt in candidate_ids:
px, py = points[idx_pt]
if rx_min <= px <= rx_max and ry_min <= py <= ry_max:
covered_set.add(idx_pt)
# 如果覆盖的点数不足 min_points就不保留
if len(covered_set) < min_points:
continue
rect_info_list.append((rect_poly, covered_set))
return rect_info_list
# ====================== 2) 贪心算法选取子集覆盖所有点 ======================
def cover_all_points_greedy(points, rect_info_list):
"""
给定所有点 points 以及 "可用矩形+覆盖点集合" rect_info_list
要求:
- 选出若干矩形使得所有点都被覆盖 (每个点至少属于1个选中矩形)
- 最终并集面积最小 (做近似贪心)
返回:
chosen_rects: 最终选出的矩形列表 (每个是 shapely Polygon)
"""
n = len(points)
all_indices = set(range(n)) # 所有点的索引
uncovered = set(all_indices) # 尚未被覆盖的点索引
chosen_rects = []
union_polygon = None # 当前已选矩形的并集
# 如果没有任何矩形可用,就直接失败
if not rect_info_list:
return []
# 为了在贪心过程中快速评估"新矩形带来的额外并集面积"
# 我们每次选择矩形后更新 union_polygon然后比较 union_polygon.union(new_rect).area - union_polygon.area
# 但 union_polygon 初始为 None
while uncovered:
best_gain = 0
best_new_area = float('inf')
best_rect = None
best_covered_new = set()
for rect_poly, covered_set in rect_info_list:
# 计算能覆盖多少"尚未覆盖"的点
newly_covered = uncovered.intersection(covered_set)
if not newly_covered:
continue
# 计算额外增加的并集面积
if union_polygon is None:
# 第一次选union_polygon 为空 => new_area = rect_poly.area
new_area = rect_poly.area
area_increase = new_area
else:
# 计算 union_polygon rect_poly 的面积
test_union = union_polygon.union(rect_poly)
new_area = test_union.area
area_increase = new_area - union_polygon.area
# 贪心策略:最大化 (覆盖点数量) / (面积增量)
# 或者 equivalently, (覆盖点数量) 多、(面积增量) 小 都是好
ratio = len(newly_covered) / max(area_increase, 1e-12)
# 我们要找到 ratio 最大的那个
if ratio > best_gain:
best_gain = ratio
best_new_area = area_increase
best_rect = rect_poly
best_covered_new = newly_covered
if best_rect is None:
# 没有可选的矩形能覆盖任何剩余点 => 失败 (无法覆盖所有点)
return []
# 选中 best_rect
chosen_rects.append(best_rect)
uncovered -= best_covered_new
# 更新并集
if union_polygon is None:
union_polygon = best_rect
else:
union_polygon = union_polygon.union(best_rect)
return chosen_rects
# ====================== 3) 主流程: 离散搜索 (w,h) + 贪心覆盖 ======================
def find_optimal_rectangles_cover_all_points(
points,
base_w,
base_h,
overlap=0.1,
steps=5,
min_points=800
):
"""
[0.5*base_w,1.5*base_w] x [0.5*base_h,1.5*base_h] 的离散区间枚举 (w,h)
- 生成可用矩形(800 )的列表
- 用贪心算法选出子集来覆盖所有点
- 计算选中矩形的并集面积
选出面积最小的方案并返回
"""
overlap = clamp_overlap(overlap)
n = len(points)
if n == 0:
return [], (base_w, base_h), 0.0 # 没有点就不用覆盖了
w_candidates = np.linspace(0.3 * base_w, 2 * base_w, steps)
h_candidates = np.linspace(0.3 * base_h, 2 * base_h, steps)
best_rects = []
best_area = float('inf')
best_w, best_h = base_w, base_h
for w in w_candidates:
for h in h_candidates:
rect_info_list = generate_rectangles_with_point_indices(points, w, h, overlap, min_points)
if not rect_info_list:
# 说明没有任何矩形能达到 "≥800点"
continue
# 用贪心覆盖所有点
chosen_rects = cover_all_points_greedy(points, rect_info_list)
if not chosen_rects:
# 无法覆盖所有点
continue
# 计算并集面积
union_poly = unary_union(chosen_rects)
area_covered = union_poly.area
if area_covered < best_area:
best_area = area_covered
best_rects = chosen_rects
best_w, best_h = w, h
return best_rects, (best_w, best_h), best_area
# ====================== 4) 读取 CSV + 可视化 ======================
def plot_image_points_cover_all_min_area(
image_dir, # 新参数:图片文件夹路径
base_rect_width=0.001,
base_rect_height=0.001,
overlap=0.1,
steps=5,
min_points=800
):
"""
从图片文件夹读取GPS坐标:
1) 使用 GPSExtractor 从图片中提取GPS坐标
2) [0.5*base_w,1.5*base_w] x [0.5*base_h,1.5*base_h] 离散搜索 (w,h)
3) 对每个 (w,h), 先生成所有"含≥800点"的矩形 => 再用贪心覆盖所有点 => 计算并集面积
4) 最小并集面积的方案即近似最优解
5) 最终用该方案可视化
"""
overlap = clamp_overlap(overlap)
# 使用 GPSExtractor 读取图片GPS坐标
extractor = GPSExtractor(image_dir)
gps_df = extractor.extract_all_gps()
if gps_df.empty:
print("未能从图片中提取到GPS坐标。")
return
points = np.column_stack((gps_df['lon'], gps_df['lat'])) # (N, 2), [x=lon, y=lat]
n = len(points)
if n == 0:
print("No points extracted from images.")
return
# 贪心 + 离散搜索
chosen_rects, (best_w, best_h), best_area = find_optimal_rectangles_cover_all_points(
points,
base_w=base_rect_width,
base_h=base_rect_height,
overlap=overlap,
steps=steps,
min_points=min_points
)
if not chosen_rects:
print(f"无法找到满足 '每个矩形≥{min_points}' 且覆盖所有点 的方案,试着调大尺寸/步数/overlap。")
return
# 可视化
plt.figure(figsize=(10, 8))
# 画点
plt.scatter(points[:, 0], points[:, 1], c='red', s=10, label='Points')
# 画矩形
for i, rect in enumerate(chosen_rects):
if rect.is_empty:
continue
x, y = rect.exterior.xy
plt.fill(x, y, edgecolor='green', fill=False, alpha=0.3,
label='Chosen Rectangles' if i == 0 else "")
plt.title(
f"Cover All Points, Each Rect≥{min_points} pts, Minimal Union Area\n"
f"base=({base_rect_width:.6f} x {base_rect_height:.6f}), overlap≤{overlap}, steps={steps}\n"
f"best (w,h)=({best_w:.6f},{best_h:.6f}), union area={best_area:.6f}, #rect={len(chosen_rects)}"
)
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.legend()
plt.grid(True)
plt.show()
# ------------------ 测试入口 ------------------
if __name__ == "__main__":
image_dir = r"C:\datasets\134\code\images" # 替换为你的图片文件夹路径
plot_image_points_cover_all_min_area(
image_dir,
base_rect_width=0.01,
base_rect_height=0.01,
overlap=0.05, # 会被截断到 0.1
steps=40,
min_points=100
)

27
main.py
View File

@ -1,6 +1,6 @@
import argparse
from datetime import timedelta
from odm_preprocess import PreprocessConfig, ImagePreprocessor
from app_plugin import ProcessConfig, ODM_Plugin
def parse_args():
parser = argparse.ArgumentParser(description='ODM预处理工具')
@ -10,41 +10,36 @@ def parse_args():
parser.add_argument('--output_dir', required=True, help='输出目录路径')
# 可选参数
parser.add_argument('--mode', default='重建模式', choices=['快拼模式', '三维模式', '重建模式'], help='处理模式')
parser.add_argument('--mode', default='三维模式',
choices=['快拼模式', '三维模式'], help='处理模式')
parser.add_argument('--grid_size', type=float, default=800, help='网格大小(米)')
parser.add_argument('--grid_overlap', type=float, default=0.05, help='网格重叠率')
parser.add_argument('--produce_dem', action='store_true', help='是否生成DEM')
parser.add_argument('--grid_overlap', type=float,
default=0.05, help='网格重叠率')
args = parser.parse_args()
return args
def main():
args = parse_args()
# 创建配置
config = PreprocessConfig(
config = ProcessConfig(
image_dir=args.image_dir,
output_dir=args.output_dir,
mode=args.mode,
grid_size=args.grid_size,
grid_overlap=args.grid_overlap,
produce_dem=args.produce_dem,
# 其他参数使用默认值
grid_overlap=0.05,
cluster_eps=0.01,
cluster_min_samples=5,
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor = ODM_Plugin(config)
processor.process()
if __name__ == '__main__':
main()

View File

@ -1,341 +0,0 @@
import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict, Tuple
import psutil
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from filter.cluster_filter import GPSCluster
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from filter.gps_filter import GPSFilter
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from post_pro.merge_obj import MergeObj
from post_pro.obj_post_pro import ObjPostProcessor
from post_pro.merge_laz import MergePly
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
mode: str = "快拼模式"
produce_dem: bool = False
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
# 检查磁盘空间
self._check_disk_space()
# 清理并重建输出目录
if os.path.exists(config.output_dir):
self._clean_output_dir()
self._setup_output_dirs()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = None
self.odm_monitor = ODMProcessMonitor(
config.output_dir, mode=config.mode)
self.visualizer = FilterVisualizer(config.output_dir)
def _clean_output_dir(self):
"""清理输出目录"""
try:
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def _setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def _get_directory_size(self, path):
"""获取目录的总大小(字节)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
try:
total_size += os.path.getsize(file_path)
except (OSError, FileNotFoundError):
continue
return total_size
def _check_disk_space(self):
"""检查磁盘空间是否足够"""
# 获取输入目录大小
input_size = self._get_directory_size(self.config.image_dir)
# 获取输出目录所在磁盘的剩余空间
output_drive = os.path.splitdrive(
os.path.abspath(self.config.output_dir))[0]
if not output_drive: # 处理Linux/Unix路径
output_drive = '/home'
disk_usage = psutil.disk_usage(output_drive)
free_space = disk_usage.free
# 计算所需空间输入大小的1.5倍)
required_space = input_size * 12
if free_space < required_space:
error_msg = (
f"磁盘空间不足!\n"
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
f"所需空间: {required_space / (1024**3):.2f} GB\n"
f"可用空间: {free_space / (1024**3):.2f} GB\n"
f"在驱动器 {output_drive}"
)
raise RuntimeError(error_msg)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points,
eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering")
def filter_isolated_points(self):
"""过滤孤立点"""
filter = GPSFilter(self.config.output_dir)
previous_points = self.gps_points.copy()
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "2-Isolated Points")
def filter_time_group_overlap(self):
"""过滤重叠的时间组"""
previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
self.gps_points = filter.filter_overlapping_groups(
self.gps_points,
time_threshold=self.config.time_group_interval
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "3-Time Group Overlap")
def filter_alternate_images(self):
"""按时间顺序隔一个删一个图像来降低密度"""
previous_points = self.gps_points.copy()
# 按时间戳排序
self.gps_points = self.gps_points.sort_values('date')
# 保留索引为偶数的行(即隔一个保留一个)
self.gps_points = self.gps_points.iloc[::2].reset_index(drop=True)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "4-Alternate Images")
self.logger.info(f"交替过滤后剩余 {len(self.gps_points)} 个点")
def divide_grids(self) -> Tuple[Dict[tuple, pd.DataFrame], Dict[tuple, tuple]]:
"""划分网格
Returns:
tuple: (grid_points, translations)
- grid_points: 网格点数据字典
- translations: 网格平移量字典
"""
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
grid_size=self.config.grid_size,
output_dir=self.config.output_dir
)
grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points
)
grid_divider.visualize_grids(self.gps_points, grids)
return grid_points, translations
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
"""复制图像到目标文件夹"""
self.logger.info("开始复制图像文件")
for grid_id, points in grid_points.items():
output_dir = os.path.join(
self.config.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"images"
)
os.makedirs(output_dir, exist_ok=True)
for point in points:
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool):
"""合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir)
merger.merge_all_tifs(grid_points, produce_dem)
def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并所有网格的PLY点云"""
self.logger.info("开始合并PLY点云")
merger = MergePly(self.config.output_dir)
merger.merge_grid_laz(grid_points)
def merge_obj(self, grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""合并所有网格的OBJ模型并转换为OSGB格式"""
self.logger.info("开始合并OBJ模型")
merger = MergeObj(self.config.output_dir)
center_lon, center_lat, bounding_box = merger.merge_grid_obj(grid_points)
# 转换为OSGB格式
self.logger.info("开始转换为OSGB格式")
processor = ObjPostProcessor(self.config.output_dir)
if not processor.convert_to_osgb(center_lon, center_lat, bounding_box):
self.logger.error("OSGB转换失败")
def post_process(self, successful_grid_points: Dict[tuple, pd.DataFrame], grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""后处理:合并或复制处理结果"""
if len(successful_grid_points) < len(grid_points):
self.logger.warning(
f"{len(grid_points) - len(successful_grid_points)} 个网格处理失败,"
f"将只合并成功处理的 {len(successful_grid_points)} 个网格"
)
if self.config.mode == "快拼模式":
self.merge_tif(successful_grid_points, self.config.produce_dem)
elif self.config.mode == "三维模式":
# self.merge_ply(successful_grid_points)
self.merge_obj(successful_grid_points, translations)
else:
self.merge_tif(successful_grid_points, self.config.produce_dem)
# self.merge_ply(successful_grid_points)
self.merge_obj(successful_grid_points, translations)
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
self.filter_isolated_points()
# self.filter_time_group_overlap()
# self.filter_alternate_images()
grid_points, translations = self.divide_grids()
self.copy_images(grid_points)
self.logger.info("预处理任务完成")
successful_grid_points = self.odm_monitor.process_all_grids(
grid_points, self.config.produce_dem)
self.post_process(successful_grid_points,
grid_points, translations)
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"E:\datasets\UAV\134\project\images",
output_dir=r"G:\ODM_output\134",
cluster_eps=0.01,
cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_size=800,
grid_overlap=0.05,
mode="重建模式",
produce_dem=False,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()

View File

@ -1,341 +0,0 @@
import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict, Tuple
import psutil
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from filter.cluster_filter import GPSCluster
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from filter.gps_filter import GPSFilter
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from post_pro.merge_obj import MergeObj
from post_pro.obj_post_pro import ObjPostProcessor
from post_pro.merge_laz import MergePly
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
mode: str = "快拼模式"
produce_dem: bool = False
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
# 检查磁盘空间
self._check_disk_space()
# # 清理并重建输出目录
# if os.path.exists(config.output_dir):
# self._clean_output_dir()
# self._setup_output_dirs()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = None
self.odm_monitor = ODMProcessMonitor(
config.output_dir, mode=config.mode)
self.visualizer = FilterVisualizer(config.output_dir)
def _clean_output_dir(self):
"""清理输出目录"""
try:
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def _setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def _get_directory_size(self, path):
"""获取目录的总大小(字节)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
try:
total_size += os.path.getsize(file_path)
except (OSError, FileNotFoundError):
continue
return total_size
def _check_disk_space(self):
"""检查磁盘空间是否足够"""
# 获取输入目录大小
input_size = self._get_directory_size(self.config.image_dir)
# 获取输出目录所在磁盘的剩余空间
output_drive = os.path.splitdrive(
os.path.abspath(self.config.output_dir))[0]
if not output_drive: # 处理Linux/Unix路径
output_drive = '/home'
disk_usage = psutil.disk_usage(output_drive)
free_space = disk_usage.free
# 计算所需空间输入大小的1.5倍)
required_space = input_size * 12
if free_space < required_space:
error_msg = (
f"磁盘空间不足!\n"
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
f"所需空间: {required_space / (1024**3):.2f} GB\n"
f"可用空间: {free_space / (1024**3):.2f} GB\n"
f"在驱动器 {output_drive}"
)
raise RuntimeError(error_msg)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points,
eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering")
def filter_isolated_points(self):
"""过滤孤立点"""
filter = GPSFilter(self.config.output_dir)
previous_points = self.gps_points.copy()
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "2-Isolated Points")
def filter_time_group_overlap(self):
"""过滤重叠的时间组"""
previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
self.gps_points = filter.filter_overlapping_groups(
self.gps_points,
time_threshold=self.config.time_group_interval
)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "3-Time Group Overlap")
def filter_alternate_images(self):
"""按时间顺序隔一个删一个图像来降低密度"""
previous_points = self.gps_points.copy()
# 按时间戳排序
self.gps_points = self.gps_points.sort_values('date')
# 保留索引为偶数的行(即隔一个保留一个)
self.gps_points = self.gps_points.iloc[::2].reset_index(drop=True)
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "4-Alternate Images")
self.logger.info(f"交替过滤后剩余 {len(self.gps_points)} 个点")
def divide_grids(self) -> Tuple[Dict[tuple, pd.DataFrame], Dict[tuple, tuple]]:
"""划分网格
Returns:
tuple: (grid_points, translations)
- grid_points: 网格点数据字典
- translations: 网格平移量字典
"""
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
grid_size=self.config.grid_size,
output_dir=self.config.output_dir
)
grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points
)
grid_divider.visualize_grids(self.gps_points, grids)
return grid_points, translations
def copy_images(self, grid_points: Dict[tuple, pd.DataFrame]):
"""复制图像到目标文件夹"""
self.logger.info("开始复制图像文件")
for grid_id, points in grid_points.items():
output_dir = os.path.join(
self.config.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"images"
)
os.makedirs(output_dir, exist_ok=True)
for point in points:
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 包含 {len(points)} 张图像")
def merge_tif(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool):
"""合并所有网格的影像产品"""
self.logger.info("开始合并所有影像产品")
merger = MergeTif(self.config.output_dir)
merger.merge_all_tifs(grid_points, produce_dem)
def merge_ply(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并所有网格的PLY点云"""
self.logger.info("开始合并PLY点云")
merger = MergePly(self.config.output_dir)
merger.merge_grid_laz(grid_points)
def merge_obj(self, grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""合并所有网格的OBJ模型并转换为OSGB格式"""
self.logger.info("开始合并OBJ模型")
merger = MergeObj(self.config.output_dir)
center_lon, center_lat, bounding_box = merger.merge_grid_obj(grid_points)
# 转换为OSGB格式
self.logger.info("开始转换为OSGB格式")
processor = ObjPostProcessor(self.config.output_dir)
if not processor.convert_to_osgb(center_lon, center_lat, bounding_box):
self.logger.error("OSGB转换失败")
def post_process(self, successful_grid_points: Dict[tuple, pd.DataFrame], grid_points: Dict[tuple, pd.DataFrame], translations: Dict[tuple, tuple]):
"""后处理:合并或复制处理结果"""
if len(successful_grid_points) < len(grid_points):
self.logger.warning(
f"{len(grid_points) - len(successful_grid_points)} 个网格处理失败,"
f"将只合并成功处理的 {len(successful_grid_points)} 个网格"
)
if self.config.mode == "快拼模式":
self.merge_tif(successful_grid_points, self.config.produce_dem)
elif self.config.mode == "三维模式":
# self.merge_ply(successful_grid_points)
self.merge_obj(successful_grid_points, translations)
else:
self.merge_tif(successful_grid_points, self.config.produce_dem)
# self.merge_ply(successful_grid_points)
self.merge_obj(successful_grid_points, translations)
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
self.filter_isolated_points()
# self.filter_time_group_overlap()
# self.filter_alternate_images()
grid_points, translations = self.divide_grids()
# self.copy_images(grid_points)
self.logger.info("预处理任务完成")
# successful_grid_points = self.odm_monitor.process_all_grids(
# grid_points, self.config.produce_dem)
successful_grid_points = grid_points
self.post_process(successful_grid_points,
grid_points, translations)
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"E:\datasets\UAV\134\project\images",
output_dir=r"G:\ODM_output\134",
cluster_eps=0.01,
cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_size=800,
grid_overlap=0.05,
mode="重建模式",
produce_dem=False,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()

266
post_pro/conv_obj.py Normal file
View File

@ -0,0 +1,266 @@
import os
import subprocess
import json
import shutil
import logging
from pyproj import Transformer
import cv2
class ConvertOBJ:
def __init__(self, output_dir: str):
self.output_dir = output_dir
# 用于存储所有grid的UTM范围
self.ref_east = float('inf')
self.ref_north = float('inf')
# 初始化UTM到WGS84的转换器
self.transformer = Transformer.from_crs(
"EPSG:32649", "EPSG:4326", always_xy=True)
self.logger = logging.getLogger('UAV_Preprocess.ConvertOBJ')
def convert_grid_obj(self, grid_lt):
"""转换每个网格的OBJ文件为OSGB格式"""
os.makedirs(os.path.join(self.output_dir,
"osgb", "Data"), exist_ok=True)
# 以第一个grid的UTM坐标作为参照系
first_grid_id = grid_lt[0]
first_grid_dir = os.path.join(
self.output_dir,
f"grid_{first_grid_id[0]}_{first_grid_id[1]}",
"project"
)
log_file = os.path.join(
first_grid_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
self.ref_east, self.ref_north = self.read_utm_offset(log_file)
for grid_id in grid_lt:
try:
self._convert_single_grid(grid_id)
except Exception as e:
self.logger.error(f"网格 {grid_id} 转换失败: {str(e)}")
self._create_merged_metadata()
def _convert_single_grid(self, grid_id):
"""转换单个网格的OBJ文件"""
# 构建相关路径
grid_name = f"grid_{grid_id[0]}_{grid_id[1]}"
project_dir = os.path.join(self.output_dir, grid_name, "project")
texturing_dir = os.path.join(project_dir, "odm_texturing")
texturing_dst_dir = os.path.join(project_dir, "odm_texturing_dst")
split_obj_dir = os.path.join(texturing_dst_dir, "split_obj")
log_file = os.path.join(
project_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
os.makedirs(texturing_dst_dir, exist_ok=True)
# 修改obj文件z坐标的值
min_25d_z = self.get_min_z_from_obj(os.path.join(
project_dir, 'odm_texturing_25d', 'odm_textured_model_geo.obj'))
self.modify_z_in_obj(texturing_dir, min_25d_z)
# 在新文件夹下利用UTM偏移量修改obj文件顶点坐标纹理文件下采样
utm_offset = self.read_utm_offset(log_file)
modified_obj = self.modify_obj_coordinates(
texturing_dir, texturing_dst_dir, utm_offset)
self.downsample_texture(texturing_dir, texturing_dst_dir)
# 将obj文件进行切片
self.logger.info(f"开始切片网格 {grid_id} 的OBJ文件")
os.makedirs(split_obj_dir)
cmd = (
f"D:\\software\\Obj2Tiles\\Obj2Tiles.exe --stage Splitting --lods 1 --divisions 3 "
f"{modified_obj} {split_obj_dir}"
)
subprocess.run(cmd, check=True)
# 执行格式转换Linux下osgconv有问题记得注释掉
self.logger.info(f"开始转换网格 {grid_id} 的OBJ文件")
# 先获取split_obj_dir下的所有obj文件
obj_lod_dir = os.path.join(split_obj_dir, "LOD-0")
obj_files = [f for f in os.listdir(
obj_lod_dir) if f.endswith('.obj')]
for obj_file in obj_files:
obj_path = os.path.join(obj_lod_dir, obj_file)
osgb_file = os.path.splitext(obj_file)[0] + '.osgb'
osgb_path = os.path.join(split_obj_dir, osgb_file)
# 执行 osgconv 命令
subprocess.run(['osgconv', obj_path, osgb_path, '--compressed',
'--smooth', '--fix-transparency'], check=True)
# 创建OSGB目录结构复制文件
osgb_base_dir = os.path.join(self.output_dir, "osgb")
data_dir = os.path.join(osgb_base_dir, "Data")
for obj_file in obj_files:
obj_file_name = os.path.splitext(obj_file)[0]
tile_dirs = os.path.join(
data_dir, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}")
os.makedirs(tile_dirs, exist_ok=True)
shutil.copy2(os.path.join(
split_obj_dir, obj_file_name+".osgb"), tile_dirs)
os.rename(os.path.join(tile_dirs, obj_file_name+".osgb"),
os.path.join(tile_dirs, f"grid_{grid_id[0]}_{grid_id[1]}_{obj_file_name}.osgb"))
def _create_merged_metadata(self):
"""创建合并后的metadata.xml文件"""
# 转换为WGS84经纬度
center_lon, center_lat = self.transformer.transform(
self.ref_east, self.ref_north)
metadata_content = f"""<?xml version="1.0" encoding="utf-8"?>
<ModelMetadata version="1">
<SRS>EPSG:4326</SRS>
<SRSOrigin>{center_lon},{center_lat},0</SRSOrigin>
<Texture>
<ColorSource>Visible</ColorSource>
</Texture>
</ModelMetadata>"""
metadata_file = os.path.join(self.output_dir, "osgb", "metadata.xml")
with open(metadata_file, 'w', encoding='utf-8') as f:
f.write(metadata_content)
def read_utm_offset(self, log_file: str) -> tuple:
"""读取UTM偏移量"""
try:
east_offset = None
north_offset = None
with open(log_file, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
if 'utm_north_offset' in line and i + 1 < len(lines):
north_offset = float(lines[i + 1].strip())
elif 'utm_east_offset' in line and i + 1 < len(lines):
east_offset = float(lines[i + 1].strip())
if east_offset is None or north_offset is None:
raise ValueError("未找到UTM偏移量")
return east_offset, north_offset
except Exception as e:
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
raise
def modify_obj_coordinates(self, texturing_dir: str, texturing_dst_dir: str, utm_offset: tuple) -> str:
"""修改obj文件中的顶点坐标使用相对坐标系"""
obj_file = os.path.join(
texturing_dir, "odm_textured_model_modified.obj")
obj_dst_file = os.path.join(
texturing_dst_dir, "odm_textured_model_geo_utm.obj")
if not os.path.exists(obj_file):
raise FileNotFoundError(f"找不到OBJ文件: {obj_file}")
shutil.copy2(os.path.join(texturing_dir, "odm_textured_model_geo.mtl"),
os.path.join(texturing_dst_dir, "odm_textured_model_geo.mtl"))
east_offset, north_offset = utm_offset
self.logger.info(
f"UTM坐标偏移{east_offset - self.ref_east}, {north_offset - self.ref_north}")
try:
with open(obj_file, 'r') as f_in, open(obj_dst_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '):
# 处理顶点坐标行
parts = line.strip().split()
# 使用相对于整体最小UTM坐标的偏移
x = float(parts[1]) + (east_offset - self.ref_east)
y = float(parts[2]) + (north_offset - self.ref_north)
z = float(parts[3])
f_out.write(f'v {x:.6f} {z:.6f} {-y:.6f}\n')
elif line.startswith('vn '): # 处理法线向量
parts = line.split()
nx = float(parts[1])
ny = float(parts[2])
nz = float(parts[3])
# 同步反转法线的 Y 轴
new_line = f"vn {nx} {nz} {-ny}\n"
f_out.write(new_line)
else:
# 其他行直接写入
f_out.write(line)
return obj_dst_file
except Exception as e:
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
raise
def downsample_texture(self, src_dir: str, dst_dir: str):
"""复制并重命名纹理文件对大于100MB的文件进行多次下采样直到文件小于100MB
Args:
src_dir: 源纹理目录
dst_dir: 目标纹理目录
"""
for file in os.listdir(src_dir):
if file.lower().endswith(('.png')):
src_path = os.path.join(src_dir, file)
dst_path = os.path.join(dst_dir, file)
# 检查文件大小(以字节为单位)
file_size = os.path.getsize(src_path)
if file_size <= 100 * 1024 * 1024: # 如果文件小于等于100MB直接复制
shutil.copy2(src_path, dst_path)
else:
# 文件大于100MB进行下采样
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
if_first_ds = True
while file_size > 100 * 1024 * 1024: # 大于100MB
self.logger.info(f"纹理文件 {file} 大于100MB进行下采样")
if if_first_ds:
# 计算新的尺寸长宽各变为1/4
new_size = (img.shape[1] // 4,
img.shape[0] // 4) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
if_first_ds = False
else:
# 计算新的尺寸长宽各变为1/2
new_size = (img.shape[1] // 2,
img.shape[0] // 2) # 逐步减小尺寸
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
# 更新文件路径为下采样后的路径
cv2.imwrite(dst_path, resized_img, [
cv2.IMWRITE_PNG_COMPRESSION, 9])
# 更新文件大小和图像
file_size = os.path.getsize(dst_path)
img = cv2.imread(dst_path, cv2.IMREAD_UNCHANGED)
self.logger.info(
f"下采样后文件大小: {file_size / (1024 * 1024):.2f} MB")
def get_min_z_from_obj(self, file_path):
min_z = float('inf') # 初始值设为无穷大
with open(file_path, 'r') as obj_file:
for line in obj_file:
# 检查每一行是否是顶点定义(以 'v ' 开头)
if line.startswith('v '):
# 获取顶点坐标
parts = line.split()
# 将z值转换为浮动数字
z = float(parts[3])
# 更新最小z值
if z < min_z:
min_z = z
return min_z
def modify_z_in_obj(self, texturing_dir, min_25d_z):
obj_file = os.path.join(texturing_dir, 'odm_textured_model_geo.obj')
output_file = os.path.join(
texturing_dir, 'odm_textured_model_modified.obj')
with open(obj_file, 'r') as f_in, open(output_file, 'w') as f_out:
for line in f_in:
if line.startswith('v '): # 顶点坐标行
parts = line.strip().split()
x = float(parts[1])
y = float(parts[2])
z = float(parts[3])
if z < min_25d_z:
z = min_25d_z
f_out.write(f"v {x} {y} {z}\n")
else:
f_out.write(line)

View File

@ -1,62 +0,0 @@
import os
import logging
import numpy as np
from typing import Dict, Tuple
import pandas as pd
import subprocess
import shutil
class MergePly:
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergePly')
def merge_grid_laz(self, grid_points: Dict[tuple, pd.DataFrame]):
"""合并所有网格的点云数据"""
try:
# 获取所有点云文件路径
laz_files = []
for grid_id, points in grid_points.items():
laz_path = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"odm_georeferencing",
"odm_georeferenced_model.laz"
)
if os.path.exists(laz_path):
laz_files.append(laz_path)
else:
self.logger.warning(
f"网格 ({grid_id[0]},{grid_id[1]}) 的点云文件不存在")
kwargs = {
'all_inputs': " ".join(laz_files),
'output': os.path.join(self.output_dir, 'pointcloud.laz')
}
subprocess.run(
'D:\\software\\LAStools\\bin\\lasmerge64.exe -i {all_inputs} -o "{output}"'.format(**kwargs))
except Exception as e:
self.logger.error(f"PLY点云合并过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
from utils.logger import setup_logger
# 设置输出目录和日志
output_dir = r"G:\ODM_output\1009"
setup_logger(output_dir)
# 构造测试用的grid_points字典
grid_points = {
(0, 0): [], # 不再需要GPS点信息
(0, 1): []
}
# 创建MergePly实例并执行合并
merge_ply = MergePly(output_dir)
merge_ply.merge_grid_laz(grid_points)

View File

@ -1,463 +0,0 @@
import os
import logging
import pandas as pd
from typing import Dict, List, Tuple
import numpy as np
import shutil
import time
import cv2
import subprocess
from pyproj import Transformer
class MergeObj:
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergeObj')
# 用于存储所有grid的UTM范围
self.min_east = float('inf')
self.min_north = float('inf')
self.max_east = float('-inf')
self.max_north = float('-inf')
# 初始化UTM到WGS84的转换器
self.transformer = Transformer.from_crs(
"EPSG:32649", "EPSG:4326", always_xy=True)
def read_obj(self, file_path):
"""读取.obj文件返回顶点、纹理坐标、法线、面的列表和MTL文件名"""
vertices = [] # v
tex_coords = [] # vt
normals = [] # vn
faces = [] # f
face_materials = [] # 每个面对应的材质名称
mtl_file = None # mtl文件名
current_material = None # 当前使用的材质
with open(file_path, 'r') as file:
for line in file:
if line.startswith('#') or not line.strip():
continue
parts = line.strip().split()
if not parts:
continue
if parts[0] == 'mtllib': # MTL文件引用
mtl_file = parts[1]
elif parts[0] == 'usemtl': # 材质使用
current_material = parts[1]
elif parts[0] == 'v': # 顶点
vertices.append(
[float(parts[1]), float(parts[2]), float(parts[3])])
elif parts[0] == 'vt': # 纹理坐标
tex_coords.append([float(parts[1]), float(parts[2])])
elif parts[0] == 'vn': # 法线
normals.append(
[float(parts[1]), float(parts[2]), float(parts[3])])
elif parts[0] == 'f': # 面
# 处理面的索引 (v/vt/vn)
face_v = []
face_vt = []
face_vn = []
for p in parts[1:]:
indices = p.split('/')
face_v.append(int(indices[0]))
if len(indices) > 1 and indices[1]:
face_vt.append(int(indices[1]))
if len(indices) > 2:
face_vn.append(int(indices[2]))
faces.append((face_v, face_vt, face_vn))
face_materials.append(current_material) # 记录这个面使用的材质
return vertices, tex_coords, normals, faces, face_materials, mtl_file
def write_obj(self, file_path, vertices, tex_coords, normals, faces, face_materials, mtl_file=None):
"""将顶点、纹理坐标、法线和面写入到.obj文件"""
with open(file_path, 'w') as file:
# 写入MTL文件引用
if mtl_file:
file.write(f"mtllib {mtl_file}\n")
# 写入顶点
for v in vertices:
file.write(f"v {v[0]} {v[1]} {v[2]}\n")
# 写入纹理坐标
for vt in tex_coords:
file.write(f"vt {vt[0]} {vt[1]}\n")
# 写入法线
for vn in normals:
file.write(f"vn {vn[0]} {vn[1]} {vn[2]}\n")
# 写入面(按材质分组)
current_material = None
for face, material in zip(faces, face_materials):
# 如果材质发生变化写入新的usemtl
if material != current_material:
file.write(f"usemtl {material}\n")
current_material = material
face_str = "f"
for i in range(len(face[0])):
face_str += " "
face_str += str(face[0][i])
if face[1]:
face_str += f"/{face[1][i]}"
else:
face_str += "/"
if face[2]:
face_str += f"/{face[2][i]}"
else:
face_str += "/"
file.write(face_str + "\n")
def merge_grid_obj(self, grid_points: Dict[tuple, pd.DataFrame]) -> Tuple[float, float]:
"""合并所有网格的OBJ模型
Args:
grid_points: 网格点数据字典
Returns:
Tuple[float, float]: (longitude, latitude)中心点经纬度坐标
"""
try:
# 创建输出目录
output_model_dir = os.path.join(self.output_dir, "texturing")
os.makedirs(output_model_dir, exist_ok=True)
# 初始化全局边界框坐标
global_min_lon = float('inf')
global_min_lat = float('inf')
global_max_lon = float('-inf')
global_max_lat = float('-inf')
# 第一次遍历获取所有grid的UTM范围
for grid_id, points in grid_points.items():
base_dir = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project"
)
log_file = os.path.join(
base_dir, "odm_orthophoto", "odm_orthophoto_log.txt")
east_offset, north_offset = self.read_utm_offset(log_file)
# 更新UTM范围
self.min_east = min(self.min_east, east_offset)
self.min_north = min(self.min_north, north_offset)
self.max_east = max(self.max_east, east_offset)
self.max_north = max(self.max_north, north_offset)
# 收集所有grid的数据
all_vertices = [] # 所有顶点
all_tex_coords = [] # 所有纹理坐标
all_normals = [] # 所有法线
all_faces = [] # 所有面
all_face_materials = [] # 所有面的材质
all_materials = {} # 所有材质信息
grid_centers = [] # 所有grid的中心点
# 处理每个grid
for grid_id, points in grid_points.items():
base_dir = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
"odm_texturing"
)
obj_path = os.path.join(base_dir, "odm_textured_model_geo.obj")
mtl_path = os.path.join(base_dir, "odm_textured_model_geo.mtl")
if not os.path.exists(obj_path) or not os.path.exists(mtl_path):
self.logger.warning(
f"网格 ({grid_id[0]},{grid_id[1]}) 的文件不存在")
continue
# 读取UTM偏移量并修改obj文件的顶点坐标
log_file = os.path.join(
base_dir, "..", "odm_orthophoto", "odm_orthophoto_log.txt")
utm_offset = self.read_utm_offset(log_file)
modified_obj = self.modify_obj_coordinates(
obj_path, utm_offset)
# 读取obj文件内容
vertices, tex_coords, normals, faces, face_materials, _ = self.read_obj(
modified_obj)
# 计算当前grid的中心点
grid_center_lon, grid_center_lat, grid_bounding_box = self.get_center_coordinates(
vertices)
grid_centers.append((grid_center_lon, grid_center_lat))
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 中心点经纬度: ({grid_center_lon}, {grid_center_lat})")
# 更新全局边界框坐标
global_min_lon = min(
global_min_lon, grid_bounding_box['LB_lon'])
global_min_lat = min(
global_min_lat, grid_bounding_box['LB_lat'])
global_max_lon = max(
global_max_lon, grid_bounding_box['RU_lon'])
global_max_lat = max(
global_max_lat, grid_bounding_box['RU_lat'])
# 复制并重命名纹理文件
texture_map = self.copy_and_rename_texture(
base_dir,
output_model_dir,
grid_id
)
# 读取并更新材质内容
materials = self.read_mtl(mtl_path)
updated_materials = self.update_mtl_content(
materials,
texture_map,
grid_id
)
all_materials.update(updated_materials)
# 计算顶点偏移量
v_offset = len(all_vertices)
vt_offset = len(all_tex_coords)
vn_offset = len(all_normals)
# 添加顶点、纹理坐标和法线
all_vertices.extend(vertices)
all_tex_coords.extend(tex_coords)
all_normals.extend(normals)
# 添加面和材质
for face, material in zip(faces, face_materials):
# 调整面的索引
new_face_v = [f + v_offset for f in face[0]]
new_face_vt = [
f + vt_offset for f in face[1]] if face[1] else []
new_face_vn = [
f + vn_offset for f in face[2]] if face[2] else []
all_faces.append((new_face_v, new_face_vt, new_face_vn))
# 添加材质前缀
if material:
all_face_materials.append(
f"material_{grid_id[0]}_{grid_id[1]}_{material}")
else:
all_face_materials.append(material)
if not all_vertices:
self.logger.error("没有找到有效的文件")
return
# 写入合并后的MTL文件
final_mtl = os.path.join(output_model_dir, "textured_model.mtl")
with open(final_mtl, 'w') as f:
for mat_name, content in all_materials.items():
f.write(f"newmtl {mat_name}\n")
for line in content:
f.write(f"{line}\n")
f.write("\n")
# 写入合并后的OBJ文件
final_obj = os.path.join(output_model_dir, "textured_model.obj")
self.write_obj(final_obj, all_vertices, all_tex_coords, all_normals,
all_faces, all_face_materials, "textured_model.mtl")
# 计算整体中心点
center_lon = sum(center[0]
for center in grid_centers) / len(grid_centers)
center_lat = sum(center[1]
for center in grid_centers) / len(grid_centers)
self.logger.info(f"模型整体中心点经纬度: ({center_lon}, {center_lat})")
# 计算整个区域的边界框
bounding_box = [global_min_lon, global_min_lat, global_max_lon, global_max_lat]
self.logger.info(
f"模型整体边界框: ({bounding_box[0]}, {bounding_box[1]}) - ({bounding_box[2]}, {bounding_box[3]})")
return center_lon, center_lat, bounding_box
except Exception as e:
self.logger.error(f"合并过程中发生错误: {str(e)}", exc_info=True)
raise
def get_center_coordinates(self, vertices: List[List[float]]) -> Tuple[float, float, Dict[str, float]]:
"""计算顶点的中心点UTM坐标并转换为WGS84经纬度。
注意顶点坐标是相对于整体最小UTM坐标的偏移值需要加回最小UTM坐标
Args:
vertices: 顶点列表每个顶点是[x, y, z]格式x和y是相对于最小UTM坐标的偏移
Returns:
Tuple[float, float, Dict[str, float]]: (longitude, latitude, bounding_box)
"""
# 计算相对坐标的边界框
x_coords = [v[0] for v in vertices]
y_coords = [v[1] for v in vertices]
# 计算中心点相对坐标
center_x_relative = (min(x_coords) + max(x_coords)) / 2
center_y_relative = (min(y_coords) + max(y_coords)) / 2
# 加回最小UTM坐标得到实际的UTM坐标
center_x_utm = center_x_relative + self.min_east
center_y_utm = center_y_relative + self.min_north
# 转换为WGS84经纬度
lon, lat = self.transformer.transform(center_x_utm, center_y_utm)
# 计算边界框并转换为经纬度
bounding_box = {
'LB_lon': self.transformer.transform(min(x_coords) + self.min_east, min(y_coords) + self.min_north)[0],
'LB_lat': self.transformer.transform(min(x_coords) + self.min_east, min(y_coords) + self.min_north)[1],
'RU_lon': self.transformer.transform(max(x_coords) + self.min_east, max(y_coords) + self.min_north)[0],
'RU_lat': self.transformer.transform(max(x_coords) + self.min_east, max(y_coords) + self.min_north)[1]
}
self.logger.info(f"模型UTM中心点: ({center_x_utm}, {center_y_utm})")
return lon, lat, bounding_box
def read_mtl(self, mtl_path: str) -> dict:
"""读取MTL文件内容
Returns:
dict: 材质名称到材质信息的映射
"""
materials = {}
current_material = None
with open(mtl_path, 'r') as f:
content = f.read()
for line in content.strip().split('\n'):
if not line:
continue
parts = line.split()
if not parts:
continue
if parts[0] == 'newmtl':
current_material = parts[1]
materials[current_material] = []
elif current_material:
materials[current_material].append(line)
return materials
def copy_and_rename_texture(self, src_dir: str, dst_dir: str, grid_id: tuple) -> dict:
"""复制并重命名纹理文件对大于100MB的文件进行下采样
Args:
src_dir: 源纹理目录
dst_dir: 目标纹理目录
grid_id: 网格ID
Returns:
dict: 原始文件名到新文件名的映射
"""
texture_map = {}
os.makedirs(dst_dir, exist_ok=True)
for file in os.listdir(src_dir):
if file.lower().endswith(('.png', '.jpg', '.jpeg')):
# 生成新的文件名
new_name = f"grid_{grid_id[0]}_{grid_id[1]}_{file}"
src_path = os.path.join(src_dir, file)
dst_path = os.path.join(dst_dir, new_name)
# 检查文件大小(以字节为单位)
file_size = os.path.getsize(src_path)
if file_size > 100 * 1024 * 1024: # 大于100MB
self.logger.info(f"纹理文件 {file} 大于100MB进行4倍下采样")
# 读取图像
img = cv2.imread(src_path, cv2.IMREAD_UNCHANGED)
if img is not None:
# 计算新的尺寸长宽各变为1/4
new_size = (img.shape[1] // 4, img.shape[0] // 4)
# 使用双三次插值进行下采样
resized_img = cv2.resize(
img, new_size, interpolation=cv2.INTER_CUBIC)
# 保存压缩后的图像
if file.lower().endswith('.png'):
cv2.imwrite(dst_path, resized_img, [
cv2.IMWRITE_PNG_COMPRESSION, 9])
else:
cv2.imwrite(dst_path, resized_img, [
cv2.IMWRITE_JPEG_QUALITY, 95])
else:
self.logger.warning(f"无法读取图像文件: {src_path}")
shutil.copy2(src_path, dst_path)
else:
# 文件大小未超过100MB直接复制
shutil.copy2(src_path, dst_path)
texture_map[file] = new_name
self.logger.debug(f"处理纹理文件: {file} -> {new_name}")
return texture_map
def update_mtl_content(self, materials: dict, texture_map: dict, grid_id: tuple) -> dict:
"""更新材质内容,修改材质名称和纹理路径
Args:
materials: 原始材质信息
texture_map: 纹理文件映射
grid_id: 网格ID
Returns:
dict: 更新后的材质信息
"""
updated_materials = {}
for mat_name, content in materials.items():
# 为材质名称添加网格ID前缀与OBJ文件中的usemtl保持一致
new_mat_name = f"material_{grid_id[0]}_{grid_id[1]}_{mat_name}"
updated_content = []
for line in content:
if line.startswith('map_'): # 更新纹理文件路径
parts = line.split()
old_texture = parts[-1]
if old_texture in texture_map:
parts[-1] = texture_map[old_texture]
line = ' '.join(parts)
updated_content.append(line)
updated_materials[new_mat_name] = updated_content
return updated_materials
def read_utm_offset(self, log_file: str) -> tuple:
"""读取UTM偏移量"""
try:
east_offset = None
north_offset = None
with open(log_file, 'r') as f:
lines = f.readlines()
for i, line in enumerate(lines):
if 'utm_north_offset' in line and i + 1 < len(lines):
north_offset = float(lines[i + 1].strip())
elif 'utm_east_offset' in line and i + 1 < len(lines):
east_offset = float(lines[i + 1].strip())
if east_offset is None or north_offset is None:
raise ValueError("未找到UTM偏移量")
return east_offset, north_offset
except Exception as e:
self.logger.error(f"读取UTM偏移量时发生错误: {str(e)}")
raise
def modify_obj_coordinates(self, obj_file: str, utm_offset: tuple) -> str:
"""修改obj文件中的顶点坐标使用相对坐标系"""
east_offset, north_offset = utm_offset
output_obj = obj_file.replace('.obj', '_utm.obj')
try:
with open(obj_file, 'r') as f_in, open(output_obj, 'w') as f_out:
for line in f_in:
if line.startswith('v '):
# 处理顶点坐标行
parts = line.strip().split()
# 使用相对于整体最小UTM坐标的偏移
x = float(parts[1]) + (east_offset - self.min_east)
y = float(parts[2]) + (north_offset - self.min_north)
z = float(parts[3])
f_out.write(f'v {x:.6f} {y:.6f} {z:.6f}\n')
else:
# 其他行直接写入
f_out.write(line)
return output_obj
except Exception as e:
self.logger.error(f"修改obj坐标时发生错误: {str(e)}")
raise

View File

@ -5,6 +5,13 @@ from typing import Dict
import pandas as pd
import time
import shutil
import rasterio
from rasterio.mask import mask
from rasterio.transform import Affine, rowcol
import fiona
from edt import edt
import numpy as np
import math
class MergeTif:
@ -12,251 +19,271 @@ class MergeTif:
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.MergeTif')
def merge_two_tifs(self, input_tif1: str, input_tif2: str, output_tif: str):
"""合并两张TIF影像"""
def merge_orthophoto(self, grid_lt):
"""合并网格的正射影像"""
try:
self.logger.info("开始合并TIF影像")
self.logger.info(f"输入影像1: {input_tif1}")
self.logger.info(f"输入影像2: {input_tif2}")
self.logger.info(f"输出影像: {output_tif}")
# 检查输入文件是否存在
if not os.path.exists(input_tif1) or not os.path.exists(input_tif2):
error_msg = "输入影像文件不存在"
self.logger.error(error_msg)
raise FileNotFoundError(error_msg)
# 打开影像,检查投影是否一致
datasets = []
try:
for tif in [input_tif1, input_tif2]:
ds = gdal.Open(tif)
if ds is None:
error_msg = f"无法打开影像文件: {tif}"
self.logger.error(error_msg)
raise ValueError(error_msg)
datasets.append(ds)
projections = [ds.GetProjection() for ds in datasets]
self.logger.debug(f"影像1投影: {projections[0]}")
self.logger.debug(f"影像2投影: {projections[1]}")
# 检查投影是否一致
if len(set(projections)) != 1:
error_msg = "影像的投影不一致,请先进行重投影!"
self.logger.error(error_msg)
raise ValueError(error_msg)
# 如果输出文件已存在,先删除
if os.path.exists(output_tif):
try:
os.remove(output_tif)
except Exception as e:
self.logger.warning(f"删除已存在的输出文件失败: {str(e)}")
# 生成一个新的输出文件名
base, ext = os.path.splitext(output_tif)
output_tif = f"{base}_{int(time.time())}{ext}"
self.logger.info(f"使用新的输出文件名: {output_tif}")
# 创建 GDAL Warp 选项
warp_options = gdal.WarpOptions(
format="GTiff",
resampleAlg="average",
srcNodata=0,
dstNodata=0,
multithread=True
)
self.logger.info("开始执行影像拼接...")
result = gdal.Warp(output_tif, datasets, options=warp_options)
if result is None:
error_msg = "影像拼接失败"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# 获取输出影像的基本信息
output_dataset = gdal.Open(output_tif)
if output_dataset:
width = output_dataset.RasterXSize
height = output_dataset.RasterYSize
bands = output_dataset.RasterCount
self.logger.info(
f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}")
output_dataset = None # 显式关闭数据集
self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}")
finally:
# 确保所有数据集都被正确关闭
for ds in datasets:
if ds:
ds = None
result = None
except Exception as e:
self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_grid_tif(self, grid_points: Dict[tuple, pd.DataFrame], product_info: dict):
"""合并指定产品的所有网格"""
product_name = product_info['name']
product_path = product_info['path']
filename_original = product_info['filename']
filename = filename_original.replace(".original", "")
self.logger.info(f"开始合并{product_name}")
input_tif1, input_tif2 = None, None
merge_count = 0
temp_files = []
try:
for grid_id, points in grid_points.items():
grid_tif_original = os.path.join(
all_orthos_and_ortho_cuts = []
for grid_id in grid_lt:
grid_ortho_dir = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
product_path,
filename_original
"odm_orthophoto",
)
grid_tif = os.path.join(
self.output_dir,
f"grid_{grid_id[0]}_{grid_id[1]}",
"project",
product_path,
filename
)
if os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}存在: {grid_tif_original, grid_tif}")
# 如果文件大于600MB则不使用original文件
file_size_mb_original = os.path.getsize(
grid_tif_original) / (1024 * 1024) # 转换为MB
if file_size_mb_original > 600:
to_merge_tif = grid_tif
else:
to_merge_tif = grid_tif_original
elif os.path.exists(grid_tif_original) and not os.path.exists(grid_tif):
to_merge_tif = grid_tif_original
elif not os.path.exists(grid_tif_original) and os.path.exists(grid_tif):
to_merge_tif = grid_tif
else:
self.logger.warning(
f"网格 ({grid_id[0]},{grid_id[1]}) 的{product_name}不存在: {grid_tif_original, grid_tif}")
continue
tif_path = os.path.join(grid_ortho_dir, "odm_orthophoto.tif")
tif_mask = os.path.join(grid_ortho_dir, "cutline.gpkg")
output_cut_tif = os.path.join(
grid_ortho_dir, "odm_orthophoto_cut.tif")
output_feathered_tif = os.path.join(
grid_ortho_dir, "odm_orthophoto_feathered.tif")
if input_tif1 is None:
input_tif1 = to_merge_tif
self.logger.info(f"设置第一个输入{product_name}: {input_tif1}")
else:
input_tif2 = to_merge_tif
# 生成带时间戳的临时输出文件名
temp_output = os.path.join(
self.output_dir,
f"temp_merged_{int(time.time())}_{product_info['output']}"
)
self.compute_mask_raster(
tif_path, tif_mask, output_cut_tif, blend_distance=20)
self.feather_raster(
tif_path, output_feathered_tif, blend_distance=20)
all_orthos_and_ortho_cuts.append(
[output_feathered_tif, output_cut_tif])
self.logger.info(
f"开始合并{product_name}{merge_count + 1} 次:\n"
f"输入1: {input_tif1}\n"
f"输入2: {input_tif2}\n"
f"输出: {temp_output}"
)
self.merge_two_tifs(input_tif1, input_tif2, temp_output)
merge_count += 1
input_tif1 = temp_output
input_tif2 = None
temp_files.append(temp_output)
final_output = os.path.join(
self.output_dir, product_info['output'])
shutil.copy2(input_tif1, final_output)
# 清理所有临时文件
for temp_file in temp_files:
try:
os.remove(temp_file)
except Exception as e:
self.logger.warning(f"删除临时文件失败: {str(e)}")
self.logger.info(
f"{product_name}合并完成,共执行 {merge_count} 次合并,"
f"最终输出文件: {final_output}"
)
except Exception as e:
self.logger.error(
f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True)
raise
def merge_all_tifs(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool):
"""合并所有产品正射影像、DSM和DTM"""
try:
products = [
{
'name': '正射影像',
'path': 'odm_orthophoto',
'filename': 'odm_orthophoto.original.tif',
'output': 'orthophoto.tif'
},
]
if produce_dem:
products.append(
{
'name': 'DSM',
'path': 'odm_dem',
'filename': 'dsm.original.tif',
'output': 'dsm.tif'
orthophoto_vars = {
'TILED': 'NO',
'COMPRESS': False,
'PREDICTOR': '1',
'BIGTIFF': 'IF_SAFER',
'BLOCKXSIZE': 512,
'BLOCKYSIZE': 512,
'NUM_THREADS': 15
}
)
products.append(
{
'name': 'DTM',
'path': 'odm_dem',
'filename': 'dtm.original.tif',
'output': 'dtm.tif'
}
)
for product in products:
self.merge_grid_tif(grid_points, product)
self.merge(all_orthos_and_ortho_cuts, os.path.join(
self.output_dir, "orthophoto.tif"), orthophoto_vars)
self.logger.info("所有产品合并完成")
except Exception as e:
self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True)
raise
def compute_mask_raster(self, input_raster, vector_mask, output_raster, blend_distance=20, only_max_coords_feature=False):
if not os.path.exists(input_raster):
print("Cannot mask raster, %s does not exist" % input_raster)
return
if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(
os.path.dirname(os.path.abspath(__file__))))
from utils.logger import setup_logger
import pandas as pd
if not os.path.exists(vector_mask):
print("Cannot mask raster, %s does not exist" % vector_mask)
return
# 设置输出目录和日志
output_dir = r"G:\ODM_output\1009"
setup_logger(output_dir)
print("Computing mask raster: %s" % output_raster)
# 构造测试用的grid_points字典
# 假设我们有两个网格每个网格包含一些GPS点的DataFrame
grid_points = {
(0, 0): pd.DataFrame({
'latitude': [39.9, 39.91],
'longitude': [116.3, 116.31],
'altitude': [100, 101]
}),
(0, 1): pd.DataFrame({
'latitude': [39.92, 39.93],
'longitude': [116.32, 116.33],
'altitude': [102, 103]
})
}
with rasterio.open(input_raster, 'r') as rast:
with fiona.open(vector_mask) as src:
burn_features = src
# 创建MergeTif实例并执行合并
merge_tif = MergeTif(output_dir)
merge_tif.merge_all_tifs(grid_points)
if only_max_coords_feature:
max_coords_count = 0
max_coords_feature = None
for feature in src:
if feature is not None:
# No complex shapes
if len(feature['geometry']['coordinates'][0]) > max_coords_count:
max_coords_count = len(
feature['geometry']['coordinates'][0])
max_coords_feature = feature
if max_coords_feature is not None:
burn_features = [max_coords_feature]
shapes = [feature["geometry"] for feature in burn_features]
out_image, out_transform = mask(rast, shapes, nodata=0)
if blend_distance > 0:
if out_image.shape[0] >= 4:
# alpha_band = rast.dataset_mask()
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot blend cutline!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def feather_raster(self, input_raster, output_raster, blend_distance=20):
if not os.path.exists(input_raster):
print("Cannot feather raster, %s does not exist" % input_raster)
return
print("Computing feather raster: %s" % output_raster)
with rasterio.open(input_raster, 'r') as rast:
out_image = rast.read()
if blend_distance > 0:
if out_image.shape[0] >= 4:
alpha_band = out_image[-1]
dist_t = edt(alpha_band, black_border=True, parallel=0)
dist_t[dist_t <= blend_distance] /= blend_distance
dist_t[dist_t > blend_distance] = 1
np.multiply(alpha_band, dist_t,
out=alpha_band, casting="unsafe")
else:
print(
"%s does not have an alpha band, cannot feather raster!" % input_raster)
with rasterio.open(output_raster, 'w', BIGTIFF="IF_SAFER", **rast.profile) as dst:
dst.colorinterp = rast.colorinterp
dst.write(out_image)
return output_raster
def merge(self, input_ortho_and_ortho_cuts, output_orthophoto, orthophoto_vars={}):
"""
Based on https://github.com/mapbox/rio-merge-rgba/
Merge orthophotos around cutlines using a blend buffer.
"""
inputs = []
bounds = None
precision = 7
for o, c in input_ortho_and_ortho_cuts:
inputs.append((o, c))
with rasterio.open(inputs[0][0]) as first:
res = first.res
dtype = first.dtypes[0]
profile = first.profile
num_bands = first.meta['count'] - 1 # minus alpha
colorinterp = first.colorinterp
print("%s valid orthophoto rasters to merge" % len(inputs))
sources = [(rasterio.open(o), rasterio.open(c)) for o, c in inputs]
# scan input files.
# while we're at it, validate assumptions about inputs
xs = []
ys = []
for src, _ in sources:
left, bottom, right, top = src.bounds
xs.extend([left, right])
ys.extend([bottom, top])
if src.profile["count"] < 2:
raise ValueError("Inputs must be at least 2-band rasters")
dst_w, dst_s, dst_e, dst_n = min(xs), min(ys), max(xs), max(ys)
print("Output bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
output_transform = Affine.translation(dst_w, dst_n)
output_transform *= Affine.scale(res[0], -res[1])
# Compute output array shape. We guarantee it will cover the output
# bounds completely.
output_width = int(math.ceil((dst_e - dst_w) / res[0]))
output_height = int(math.ceil((dst_n - dst_s) / res[1]))
# Adjust bounds to fit.
dst_e, dst_s = output_transform * (output_width, output_height)
print("Output width: %d, height: %d" %
(output_width, output_height))
print("Adjusted bounds: %r %r %r %r" % (dst_w, dst_s, dst_e, dst_n))
profile["transform"] = output_transform
profile["height"] = output_height
profile["width"] = output_width
profile["tiled"] = orthophoto_vars.get('TILED', 'YES') == 'YES'
profile["blockxsize"] = orthophoto_vars.get('BLOCKXSIZE', 512)
profile["blockysize"] = orthophoto_vars.get('BLOCKYSIZE', 512)
profile["compress"] = orthophoto_vars.get('COMPRESS', 'LZW')
profile["predictor"] = orthophoto_vars.get('PREDICTOR', '2')
profile["bigtiff"] = orthophoto_vars.get('BIGTIFF', 'IF_SAFER')
profile.update()
# create destination file
with rasterio.open(output_orthophoto, "w", **profile) as dstrast:
dstrast.colorinterp = colorinterp
for idx, dst_window in dstrast.block_windows():
left, bottom, right, top = dstrast.window_bounds(dst_window)
blocksize = dst_window.width
dst_rows, dst_cols = (dst_window.height, dst_window.width)
# initialize array destined for the block
dst_count = first.count
dst_shape = (dst_count, dst_rows, dst_cols)
dstarr = np.zeros(dst_shape, dtype=dtype)
# First pass, write all rasters naively without blending
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
# pixels without data yet are available to write
write_region = np.logical_and(
(dstarr[-1] == 0), (temp[-1] != 0) # 0 is nodata
)
np.copyto(dstarr, temp, where=write_region)
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Second pass, write all feathered rasters
# blending the edges
for src, _ in sources:
src_window = tuple(zip(rowcol(
src.transform, left, top, op=round, precision=precision
), rowcol(
src.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = src.read(
out=temp, window=src_window, boundless=True, masked=False
)
where = temp[-1] != 0
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=where)
dstarr[-1][where] = 255.0
# check if dest has any nodata pixels available
if np.count_nonzero(dstarr[-1]) == blocksize:
break
# Third pass, write cut rasters
# blending the cutlines
for _, cut in sources:
src_window = tuple(zip(rowcol(
cut.transform, left, top, op=round, precision=precision
), rowcol(
cut.transform, right, bottom, op=round, precision=precision
)))
temp = np.zeros(dst_shape, dtype=dtype)
temp = cut.read(
out=temp, window=src_window, boundless=True, masked=False
)
# For each band, average alpha values between
# destination raster and cut raster
for b in range(0, num_bands):
blended = temp[-1] / 255.0 * temp[b] + \
(1 - temp[-1] / 255.0) * dstarr[b]
np.copyto(dstarr[b], blended,
casting='unsafe', where=temp[-1] != 0)
dstrast.write(dstarr, window=dst_window)
return output_orthophoto

View File

@ -1,102 +0,0 @@
import os
import logging
import subprocess
from typing import Tuple, Dict
class ObjPostProcessor:
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.ObjPostProcessor')
def create_metadata_xml(self, osgb_dir: str, lon: float, lat: float, bounding_box):
"""创建metadata.xml文件包含地理参考信息
Args:
osgb_dir: osgb输出目录
lon: 中心点经度
lat: 中心点纬度
"""
try:
metadata_content = f'''<?xml version="1.0" encoding="utf-8"?>
<ModelMetadata version="1">
<!-- Spatial Reference System -->
<SRS>EPSG:4326</SRS>
<!-- Center point in Spatial Reference System (in Longitude, Latitude, Height) -->
<SRSOrigin>{lon},{lat},0.000000</SRSOrigin>
<!-- Bounding Box with Two Points -->
<BoundingBox>
<!-- Left-Bottom Corner (Longitude, Latitude) -->
<LB_lon>{bounding_box[0]}</LB_lon>
<LB_lat>{bounding_box[1]}</LB_lat>
<!-- Right-Top Corner (Longitude, Latitude) -->
<RU_lon>{bounding_box[2]}</RU_lon>
<RU_lat>{bounding_box[3]}</RU_lat>
</BoundingBox>
<Texture>
<ColorSource>Visible</ColorSource>
</Texture>
</ModelMetadata>
'''
# metadata.xml 放在根目录
metadata_path = os.path.join(osgb_dir, 'metadata.xml')
with open(metadata_path, 'w', encoding='utf-8') as f:
f.write(metadata_content)
self.logger.info(f"已创建metadata.xml: {metadata_path}")
except Exception as e:
self.logger.error(f"创建metadata.xml时发生错误: {str(e)}")
raise
def convert_to_osgb(self, center_lon, center_lat, bounding_box):
"""将obj转换为osgb并创建metadata.xml
Args:
center_coords: (longitude, latitude)中心点经纬度坐标
"""
try:
# 获取合并后的obj文件路径
obj_dir = os.path.join(self.output_dir, 'texturing')
obj_file = os.path.join(obj_dir, 'textured_model.obj')
if not os.path.exists(obj_file):
raise Exception(f"未找到obj文件: {obj_file}")
# 创建osgb目录结构
osgb_dir = os.path.join(self.output_dir, 'osgb')
osgb_data_dir = os.path.join(osgb_dir, 'Data', 'textured_model')
os.makedirs(osgb_data_dir, exist_ok=True)
# 输出文件路径
output_osgb = os.path.join(osgb_data_dir, 'textured_model.osgb')
# 构建osgconv命令
cmd = [
'osgconv',
'--compressed',
'--smooth',
'--fix-transparency',
'-o', '0,1,0-0,0,-1',
obj_file,
output_osgb
]
# 执行命令
self.logger.info(f"执行osgconv命令{' '.join(cmd)}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"osgb格式转换失败: {result.stderr}")
self.logger.info(f"转换完成: {output_osgb}")
# 创建metadata.xml
self.create_metadata_xml(osgb_dir, center_lon, center_lat, bounding_box)
return True
except Exception as e:
self.logger.error(f"转换osgb时发生错误: {str(e)}")
return False

View File

@ -1,6 +0,0 @@
@echo off
set IMAGE_DIR=E:\datasets\UAV\134\project\images\
set OUTPUT_DIR=G:\ODM_output\134
python main.py --image_dir %IMAGE_DIR% --output_dir %OUTPUT_DIR% --mode 三维模式
pause

View File

@ -1,41 +0,0 @@
from PIL import Image
import os
import shutil
from multiprocessing import Pool
from functools import partial
def convert_image(file_name, img_dir, output_dir, convert_format):
input_path = os.path.join(img_dir, file_name)
output_path = os.path.join(output_dir, file_name.replace(".jpg", f".{convert_format}"))
# 打开并转换图像
img = Image.open(input_path)
img.save(output_path)
def main():
convert_format = "png"
img_dir = r"E:\datasets\UAV\134\project\images"
output_dir = r"E:\datasets\UAV\134_png\project\images"
# 如果输出目录存在,先删除
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# 创建输出目录
os.makedirs(output_dir)
# 获取所有文件名
file_names = os.listdir(img_dir)
# 创建部分函数,固定除文件名外的其他参数
convert_partial = partial(convert_image,
img_dir=img_dir,
output_dir=output_dir,
convert_format=convert_format)
# 使用进程池并行处理
with Pool() as pool:
pool.map(convert_partial, file_names)
if __name__ == '__main__':
main()

View File

@ -1,45 +0,0 @@
import cv2
import numpy as np
def resize_image(image, max_size=1200):
# 获取原始尺寸
height, width = image.shape[:2]
# 计算缩放比例
scale = min(max_size/width, max_size/height)
if scale < 1:
# 只有当图像过大时才进行缩放
new_width = int(width * scale)
new_height = int(height * scale)
resized = cv2.resize(image, (new_width, new_height))
return resized, scale
return image, 1.0
def mouse_callback(event, x, y, flags, param):
if event == cv2.EVENT_LBUTTONDOWN:
# 计算原始图像上的坐标
original_x = int(x / scale)
original_y = int(y / scale)
print(f'原始图像坐标 (x, y): ({original_x}, {original_y})')
# 在缩放后的图像上标记点击位置
cv2.circle(displayed_img, (x, y), 3, (0, 255, 0), -1)
cv2.imshow('image', displayed_img)
# 读取图像
img = cv2.imread(r"E:\datasets\UAV\134\project\images\20240312_093841_W_W.jpg")
if img is None:
print('错误:无法读取图像')
exit()
# 缩放图像
displayed_img, scale = resize_image(img)
# 创建窗口并设置鼠标回调函数
cv2.imshow('image', displayed_img)
cv2.setMouseCallback('image', mouse_callback)
# 等待按键,按 'q' 退出
while True:
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cv2.destroyAllWindows()

View File

@ -1,245 +0,0 @@
import os
import shutil
def read_obj(file_path):
"""读取OBJ文件返回顶点、纹理坐标、法线、面的列表和MTL文件名"""
vertices = [] # v
tex_coords = [] # vt
normals = [] # vn
faces = [] # f
face_materials = [] # 每个面对应的材质名称
mtl_file = None # mtl文件名
current_material = None # 当前使用的材质
with open(file_path, 'r') as file:
for line in file:
if line.startswith('#') or not line.strip():
continue
parts = line.strip().split()
if not parts:
continue
if parts[0] == 'mtllib': # MTL文件引用
mtl_file = parts[1]
elif parts[0] == 'usemtl': # 材质使用
current_material = parts[1]
elif parts[0] == 'v': # 顶点
vertices.append([float(parts[1]), float(parts[2]), float(parts[3])])
elif parts[0] == 'vt': # 纹理坐标
tex_coords.append([float(parts[1]), float(parts[2])])
elif parts[0] == 'vn': # 法线
normals.append([float(parts[1]), float(parts[2]), float(parts[3])])
elif parts[0] == 'f': # 面
# 处理面的索引 (v/vt/vn)
face_v = []
face_vt = []
face_vn = []
for p in parts[1:]:
indices = p.split('/')
face_v.append(int(indices[0]))
if len(indices) > 1 and indices[1]:
face_vt.append(int(indices[1]))
if len(indices) > 2:
face_vn.append(int(indices[2]))
faces.append((face_v, face_vt, face_vn))
face_materials.append(current_material) # 记录这个面使用的材质
return vertices, tex_coords, normals, faces, face_materials, mtl_file
def read_mtl(mtl_path: str) -> tuple:
"""读取MTL文件内容
Returns:
tuple: (文件内容列表, 材质名称列表)
"""
content = []
material_names = []
with open(mtl_path, 'r') as f:
content = f.readlines()
for line in content:
if line.startswith('newmtl'):
material_names.append(line.strip().split()[1])
return content, material_names
def update_mtl_content(content: list, model_id: int) -> tuple:
"""更新MTL文件内容修改材质名称和纹理文件路径
Returns:
tuple: (更新后的内容, 更新后的材质名称列表)
"""
updated_lines = []
updated_material_names = []
current_material = None
for line in content:
if line.startswith('newmtl'):
# 为材质名称添加前缀
parts = line.strip().split()
material_name = parts[1]
current_material = f"grid_{model_id}_{material_name}"
updated_material_names.append(current_material)
updated_lines.append(f"newmtl {current_material}\n")
elif line.startswith('map_'):
# 更新纹理文件路径
parts = line.strip().split()
texture_file = os.path.basename(parts[-1])
parts[-1] = f"grid_{model_id}_{texture_file}"
updated_lines.append(' '.join(parts) + '\n')
else:
updated_lines.append(line)
return updated_lines, updated_material_names
def merge_mtl_files(mtl1_path: str, mtl2_path: str, output_path: str) -> tuple:
"""合并两个MTL文件
Returns:
tuple: (第一个模型的材质名称列表, 第二个模型的材质名称列表)
"""
# 读取两个MTL文件
content1, materials1 = read_mtl(mtl1_path)
content2, materials2 = read_mtl(mtl2_path)
# 更新两个MTL的内容
updated_content1, updated_materials1 = update_mtl_content(content1, 0)
updated_content2, updated_materials2 = update_mtl_content(content2, 1)
# 合并并写入新的MTL文件
with open(output_path, 'w') as f:
f.writelines(updated_content1)
f.write('\n') # 添加分隔行
f.writelines(updated_content2)
return updated_materials1, updated_materials2
def write_obj(file_path, vertices, tex_coords, normals, faces, face_materials, mtl_file=None):
"""将顶点、纹理坐标、法线和面写入到OBJ文件"""
with open(file_path, 'w') as file:
# 写入MTL文件引用
if mtl_file:
file.write(f"mtllib {mtl_file}\n")
# 写入顶点
for v in vertices:
file.write(f"v {v[0]} {v[1]} {v[2]}\n")
# 写入纹理坐标
for vt in tex_coords:
file.write(f"vt {vt[0]} {vt[1]}\n")
# 写入法线
for vn in normals:
file.write(f"vn {vn[0]} {vn[1]} {vn[2]}\n")
# 写入面(按材质分组)
current_material = None
for face, material in zip(faces, face_materials):
# 如果材质发生变化写入新的usemtl
if material != current_material:
file.write(f"usemtl {material}\n")
current_material = material
face_str = "f"
for j in range(len(face[0])):
face_str += " "
face_str += str(face[0][j])
if face[1]:
face_str += f"/{face[1][j]}"
else:
face_str += "/"
if face[2]:
face_str += f"/{face[2][j]}"
else:
face_str += "/"
file.write(face_str + "\n")
def translate_vertices(vertices, translation):
"""平移顶点"""
return [[v[0] + translation[0], v[1] + translation[1], v[2] + translation[2]] for v in vertices]
def copy_mtl_and_textures(src_dir: str, dst_dir: str, model_id: int):
"""复制MTL文件和相关的纹理文件并重命名避免冲突
Args:
src_dir: 源目录包含MTL和纹理文件
dst_dir: 目标目录
model_id: 模型ID用于重命名
"""
# 复制并重命名纹理文件
for file in os.listdir(src_dir):
if file.lower().endswith('.png'):
src_file = os.path.join(src_dir, file)
new_name = f"grid_{model_id}_{file}"
dst_file = os.path.join(dst_dir, new_name)
shutil.copy2(src_file, dst_file)
print(f"复制纹理文件: {file} -> {new_name}")
def merge_objs(obj1_path, obj2_path, output_path):
"""合并两个OBJ文件"""
print(f"开始合并OBJ模型:\n输入1: {obj1_path}\n输入2: {obj2_path}")
# 读取两个obj文件
vertices1, tex_coords1, normals1, faces1, face_materials1, mtl1 = read_obj(obj1_path)
vertices2, tex_coords2, normals2, faces2, face_materials2, mtl2 = read_obj(obj2_path)
# 固定平移量(0, 1000, 0)
translation = (0, 1000, 0)
# 平移第二个模型的顶点
vertices2_translated = translate_vertices(vertices2, translation)
# 计算偏移量
v_offset = len(vertices1)
vt_offset = len(tex_coords1)
vn_offset = len(normals1)
# 合并顶点、纹理坐标和法线
all_vertices = vertices1 + vertices2_translated
all_tex_coords = tex_coords1 + tex_coords2
all_normals = normals1 + normals2
# 调整第二个模型的面索引和材质名称
all_faces = faces1.copy()
all_face_materials = face_materials1.copy()
for face, material in zip(faces2, face_materials2):
new_face_v = [f + v_offset for f in face[0]]
new_face_vt = [f + vt_offset for f in face[1]] if face[1] else []
new_face_vn = [f + vn_offset for f in face[2]] if face[2] else []
all_faces.append((new_face_v, new_face_vt, new_face_vn))
# 为第二个模型的材质名称添加前缀
all_face_materials.append(f"grid_1_{material}")
# 为第一个模型的材质添加前缀
all_face_materials[:len(faces1)] = [f"grid_0_{mat}" for mat in face_materials1]
# 创建输出子目录
output_dir = os.path.dirname(output_path)
os.makedirs(output_dir, exist_ok=True)
# 复制并重命名两个模型的纹理文件
src_dir1 = os.path.dirname(obj1_path)
src_dir2 = os.path.dirname(obj2_path)
copy_mtl_and_textures(src_dir1, output_dir, 0)
copy_mtl_and_textures(src_dir2, output_dir, 1)
# 合并MTL文件并获取材质名称
src_mtl1 = os.path.join(src_dir1, mtl1)
src_mtl2 = os.path.join(src_dir2, mtl2)
dst_mtl = os.path.join(output_dir, "merged_model.mtl")
merge_mtl_files(src_mtl1, src_mtl2, dst_mtl)
# 写入合并后的obj文件
write_obj(output_path, all_vertices, all_tex_coords, all_normals,
all_faces, all_face_materials, "merged_model.mtl")
print(f"模型合并成功,已保存至: {output_path}")
if __name__ == "__main__":
# 测试参数
obj1_path = r"G:\ODM_output\1009\grid_0_0\project\odm_texturing\odm_textured_model_geo.obj"
obj2_path = r"G:\ODM_output\1009\grid_0_1\project\odm_texturing\odm_textured_model_geo.obj"
output_dir = r"G:\ODM_output\1009\merge_test"
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "merged_test.obj")
# 执行合并
merge_objs(obj1_path, obj2_path, output_path)

View File

@ -1,55 +0,0 @@
from datetime import datetime
import json
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="ODM log time")
parser.add_argument(
"--path", default=r"E:\datasets\UAV\134\project\log.json")
args = parser.parse_args()
return args
def main(args):
# 读取 JSON 文件
with open(args.path, 'r') as file:
data = json.load(file)
# 提取 "stages" 中每个步骤的开始时间和持续时间
stage_timings = []
for i, stage in enumerate(data.get("stages", [])):
stage_name = stage.get("name", "Unnamed Stage")
start_time = stage.get("startTime")
# 获取当前阶段的开始时间
if start_time:
start_dt = datetime.fromisoformat(start_time)
# 获取阶段的结束时间:可以是下一个阶段的开始时间,或当前阶段的 `endTime`(如果存在)
if i + 1 < len(data["stages"]):
end_time = data["stages"][i + 1].get("startTime")
else:
end_time = stage.get("endTime") or data.get("endTime")
if end_time:
end_dt = datetime.fromisoformat(end_time)
duration = (end_dt - start_dt).total_seconds()
stage_timings.append((stage_name, duration))
# 输出每个阶段的持续时间,调整为对齐格式
total_time = 0
print(f"{'Stage Name':<25} {'Duration (seconds)':>15}")
print("=" * 45)
for stage_name, duration in stage_timings:
print(f"{stage_name:<25} {duration:>15.2f}")
total_time += duration
print('Total Time:', total_time)
if __name__ == '__main__':
args = parse_args()
main(args)

View File

@ -1,63 +0,0 @@
import os
import piexif
from PIL import Image
def dms_to_decimal(dms):
"""将DMS格式转换为十进制度"""
if not dms:
return None
degrees = dms[0][0] / dms[0][1]
minutes = dms[1][0] / dms[1][1] / 60
seconds = dms[2][0] / dms[2][1] / 3600
return degrees + minutes + seconds
def get_gps_info(image_path):
"""获取图片的GPS信息"""
try:
image = Image.open(image_path)
exif_data = piexif.load(image.info['exif'])
gps_info = exif_data.get("GPS", {})
if not gps_info:
return None, None, None
# 获取纬度
lat = dms_to_decimal(gps_info.get(2))
if lat and gps_info.get(1) and gps_info[1] == b'S':
lat = -lat
# 获取经度
lon = dms_to_decimal(gps_info.get(4))
if lon and gps_info.get(3) and gps_info[3] == b'W':
lon = -lon
# 获取高度
alt = None
if 6 in gps_info:
alt = gps_info[6][0] / gps_info[6][1]
return lat, lon, alt
except Exception as e:
print(f"读取文件 {image_path} 时出错: {str(e)}")
return None, None, None
def main():
# 设置输入输出路径
image_dir = r"E:\datasets\UAV\134\project\images"
output_path = r"E:\datasets\UAV\134\project\gps.txt"
with open(output_path, 'w', encoding='utf-8') as f:
for filename in os.listdir(image_dir):
if filename.lower().endswith(('.jpg', '.jpeg')):
image_path = os.path.join(image_dir, filename)
lat, lon, alt = get_gps_info(image_path)
if lat is not None and lon is not None:
# 如果没有高度信息使用0
alt = alt if alt is not None else 0
# filename = filename.replace(".jpg", ".tif")
f.write(f"{filename} {lat} {lon} {alt}\n")
if __name__ == '__main__':
main()

View File

@ -1,51 +0,0 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import matplotlib.pyplot as plt
from utils.gps_extractor import GPSExtractor
DATASET = r'E:\datasets\UAV\1009\project\images'
if __name__ == '__main__':
extractor = GPSExtractor(DATASET)
gps_points = extractor.extract_all_gps()
# 创建两个子图
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
# 左图:原始散点图
ax1.scatter(gps_points['lon'], gps_points['lat'],
color='blue', marker='o', label='GPS Points')
ax1.set_title("GPS Coordinates of Images", fontsize=14)
ax1.set_xlabel("Longitude", fontsize=12)
ax1.set_ylabel("Latitude", fontsize=12)
ax1.grid(True)
ax1.legend()
# # 右图:按时间排序的轨迹图
# gps_points_sorted = gps_points.sort_values('date')
# # 绘制飞行轨迹线
# ax2.plot(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600],
# color='blue', linestyle='-', linewidth=1, alpha=0.6)
# # 绘制GPS点
# ax2.scatter(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600],
# color='red', marker='o', s=30, label='GPS Points')
# 标记起点和终点
# ax2.scatter(gps_points_sorted['lon'].iloc[0], gps_points_sorted['lat'].iloc[0],
# color='green', marker='^', s=100, label='Start')
# ax2.scatter(gps_points_sorted['lon'].iloc[-1], gps_points_sorted['lat'].iloc[-1],
# color='purple', marker='s', s=100, label='End')
ax2.set_title("UAV Flight Trajectory", fontsize=14)
ax2.set_xlabel("Longitude", fontsize=12)
ax2.set_ylabel("Latitude", fontsize=12)
ax2.grid(True)
ax2.legend()
# 调整子图之间的间距
plt.tight_layout()
plt.show()

View File

@ -1,138 +0,0 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import matplotlib.pyplot as plt
from datetime import timedelta
import logging
import numpy as np
from utils.gps_extractor import GPSExtractor
from utils.logger import setup_logger
class GPSTimeVisualizer:
"""按时间组可视化GPS点"""
def __init__(self, image_dir: str, output_dir: str):
self.image_dir = image_dir
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.GPSVisualizer')
def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)):
"""按时间间隔对点进行分组"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return [points_df]
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
return time_groups
def visualize_time_groups(self, time_threshold=timedelta(minutes=5)):
"""在同一张图上显示所有时间组,用不同颜色区分"""
# 提取GPS数据
extractor = GPSExtractor(self.image_dir)
gps_points = extractor.extract_all_gps()
# 按时间分组
time_groups = self._group_by_time(gps_points, time_threshold)
# 创建图形
plt.figure(figsize=(15, 10))
# 生成不同的颜色
colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups)))
# 为每个时间组绘制点和轨迹
for idx, (group, color) in enumerate(zip(time_groups, colors)):
if not group['date'].isna().any():
# 有时间戳的组
sorted_group = group.sort_values('date')
# 绘制轨迹线
plt.plot(sorted_group['lon'], sorted_group['lat'],
color=color, linestyle='-', linewidth=1.5, alpha=0.6,
label=f'Flight Path {idx + 1}')
# 绘制GPS点
plt.scatter(sorted_group['lon'], sorted_group['lat'],
color=color, marker='o', s=30, alpha=0.6)
# 标记起点和终点
plt.scatter(sorted_group['lon'].iloc[0], sorted_group['lat'].iloc[0],
color=color, marker='^', s=100,
label=f'Start {idx + 1} ({sorted_group["date"].min().strftime("%H:%M:%S")})')
plt.scatter(sorted_group['lon'].iloc[-1], sorted_group['lat'].iloc[-1],
color=color, marker='s', s=100,
label=f'End {idx + 1} ({sorted_group["date"].max().strftime("%H:%M:%S")})')
else:
# 无时间戳的组
plt.scatter(group['lon'], group['lat'],
color=color, marker='x', s=50, alpha=0.6,
label='No Timestamp Points')
plt.title("GPS Points by Time Groups", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 调整图例位置和大小
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
# 调整布局以适应图例
plt.tight_layout()
# 保存图片
plt.savefig(os.path.join(self.output_dir, 'gps_time_groups_combined.png'),
dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(f"已生成包含 {len(time_groups)} 个时间组的组合可视化图形")
if __name__ == '__main__':
# 设置数据集路径
DATASET = r'F:\error_data\20241108134711\3D'
output_dir = r'E:\studio2\ODM_pro\test'
os.makedirs(output_dir, exist_ok=True)
# 设置日志
setup_logger(os.path.dirname(output_dir))
# 创建可视化器并生成图形
visualizer = GPSTimeVisualizer(DATASET, output_dir)
visualizer.visualize_time_groups(time_threshold=timedelta(minutes=5))

View File

@ -0,0 +1,81 @@
import os
import shutil
import psutil
class DirectoryManager:
def __init__(self, config):
"""
初始化目录管理器
Args:
config: 配置对象包含输入和输出目录等信息
"""
self.config = config
def clean_output_dir(self):
"""清理输出目录"""
try:
if os.path.exists(self.config.output_dir):
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
else:
pass
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def _get_directory_size(self, path):
"""获取目录的总大小(字节)"""
total_size = 0
for dirpath, dirnames, filenames in os.walk(path):
for filename in filenames:
file_path = os.path.join(dirpath, filename)
try:
total_size += os.path.getsize(file_path)
except (OSError, FileNotFoundError):
continue
return total_size
def check_disk_space(self):
"""检查磁盘空间是否足够"""
# 获取输入目录大小
input_size = self._get_directory_size(self.config.image_dir)
# 获取输出目录所在磁盘的剩余空间
output_drive = os.path.splitdrive(
os.path.abspath(self.config.output_dir))[0]
if not output_drive: # 处理Linux/Unix路径
output_drive = '/home'
disk_usage = psutil.disk_usage(output_drive)
free_space = disk_usage.free
# 计算所需空间输入大小的10倍
required_space = input_size * 8
if free_space < required_space:
error_msg = (
f"磁盘空间不足!\n"
f"输入目录大小: {input_size / (1024**3):.2f} GB\n"
f"所需空间: {required_space / (1024**3):.2f} GB\n"
f"可用空间: {free_space / (1024**3):.2f} GB\n"
f"在驱动器 {output_drive}"
)
raise RuntimeError(error_msg)

View File

@ -3,11 +3,10 @@ from PIL import Image
import piexif
import logging
import pandas as pd
from datetime import datetime
class GPSExtractor:
"""从图像文件提取GPS坐标和拍摄日期"""
"""从图像文件提取GPS坐标"""
def __init__(self, image_dir):
self.image_dir = image_dir
@ -18,17 +17,8 @@ class GPSExtractor:
"""将DMS格式转换为十进制度"""
return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600
@staticmethod
def _parse_datetime(datetime_str):
"""解析EXIF中的日期时间字符串"""
try:
# EXIF日期格式通常为 'YYYY:MM:DD HH:MM:SS'
return datetime.strptime(datetime_str.decode(), '%Y:%m:%d %H:%M:%S')
except Exception:
return None
def get_gps_and_date(self, image_path):
"""提取单张图片的GPS坐标和拍摄日期"""
def get_gps(self, image_path):
"""提取单张图片的GPS坐标"""
try:
image = Image.open(image_path)
exif_data = piexif.load(image.info['exif'])
@ -39,38 +29,21 @@ class GPSExtractor:
if gps_info:
lat = self._dms_to_decimal(gps_info.get(2, []))
lon = self._dms_to_decimal(gps_info.get(4, []))
self.logger.debug(f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}")
# 提取拍摄日期
date_info = None
if "Exif" in exif_data:
# 优先使用DateTimeOriginal
date_str = exif_data["Exif"].get(36867) # DateTimeOriginal
if not date_str:
# 备选DateTime
date_str = exif_data["Exif"].get(36868) # DateTimeDigitized
if not date_str:
# 最后使用基本DateTime
date_str = exif_data["0th"].get(306) # DateTime
if date_str:
date_info = self._parse_datetime(date_str)
self.logger.debug(f"成功提取图片拍摄日期: {image_path} - {date_info}")
self.logger.debug(
f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}")
if not gps_info:
self.logger.warning(f"图片无GPS信息: {image_path}")
if not date_info:
self.logger.warning(f"图片无拍摄日期信息: {image_path}")
return lat, lon, date_info
return lat, lon
except Exception as e:
self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}")
return None, None, None
def extract_all_gps(self):
"""提取所有图片的GPS坐标和拍摄日期"""
self.logger.info(f"开始从目录提取GPS坐标和拍摄日期: {self.image_dir}")
"""提取所有图片的GPS坐标"""
self.logger.info(f"开始从目录提取GPS坐标: {self.image_dir}")
gps_data = []
total_images = 0
successful_extractions = 0
@ -78,15 +51,15 @@ class GPSExtractor:
for image_file in os.listdir(self.image_dir):
total_images += 1
image_path = os.path.join(self.image_dir, image_file)
lat, lon, date = self.get_gps_and_date(image_path)
lat, lon = self.get_gps(image_path)
if lat and lon: # 仍然以GPS信息作为主要判断依据
successful_extractions += 1
gps_data.append({
'file': image_file,
'lat': lat,
'lon': lon,
'date': date
})
self.logger.info(f"GPS坐标和拍摄日期提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}")
self.logger.info(
f"GPS坐标提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}")
return pd.DataFrame(gps_data)

View File

@ -16,50 +16,15 @@ class GridDivider:
self.num_grids_width = 0 # 添加网格数量属性
self.num_grids_height = 0
def adjust_grid_size(self, points_df):
"""动态调整网格大小
Args:
points_df: 包含GPS点的DataFrame
Returns:
tuple: (grids, translations, grid_points, final_grid_size)
"""
self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}")
while True:
# 使用当前grid_size划分网格
grids, translations = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(points_df, grids)
# 检查每个网格中的点数
max_points = 0
for grid_id, points in grid_points.items():
max_points = max(max_points, len(points))
self.logger.info(f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}")
# 如果最大点数超过1500减小网格大小
if max_points > 1500:
self.grid_size -= 100
self.logger.info(f"点数超过1500减小网格大小至: {self.grid_size}")
if self.grid_size < 500: # 设置一个最小网格大小限制
self.logger.warning("网格大小已达到最小值500米停止调整")
break
else:
self.logger.info(f"找到合适的网格大小: {self.grid_size}")
break
return grids
def adjust_grid_size_and_overlap(self, points_df):
"""动态调整网格重叠率"""
grids = self.adjust_grid_size(points_df)
self.logger.info(f"开始动态调整网格重叠率,初始重叠率: {self.overlap}")
while True:
# 使用调整好的网格大小划分网格
grids, translations = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(points_df, grids)
grids = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(
points_df, grids)
if len(grids) == 1:
self.logger.info(f"网格数量为1跳过重叠率调整")
@ -68,16 +33,52 @@ class GridDivider:
self.overlap += 0.02
self.logger.info(f"重叠率增加到: {self.overlap}")
else:
self.logger.info(f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格")
self.logger.info(
f"找到合适的重叠率: {self.overlap}, 有{multiple_grid_points}个点被分配到多个网格")
break
return grids, translations, grid_points
return grids, grid_points
def adjust_grid_size(self, points_df):
"""动态调整网格大小
Args:
points_df: 包含GPS点的DataFrame
Returns:
tuple: grids
"""
self.logger.info(f"开始动态调整网格大小,初始大小: {self.grid_size}")
while True:
# 使用当前grid_size划分网格
grids = self.divide_grids(points_df)
grid_points, multiple_grid_points = self.assign_to_grids(
points_df, grids)
# 检查每个网格中的点数
max_points = 0
for grid_id, points in grid_points.items():
max_points = max(max_points, len(points))
self.logger.info(
f"当前网格大小: {self.grid_size}米, 单个网格最大点数: {max_points}")
# 如果最大点数超过2000减小网格大小
if max_points > 2000:
self.grid_size -= 100
self.logger.info(f"点数超过2000减小网格大小至: {self.grid_size}")
if self.grid_size < 500: # 设置一个最小网格大小限制
self.logger.warning("网格大小已达到最小值500米停止调整")
break
else:
self.logger.info(f"找到合适的网格大小: {self.grid_size}")
break
return grids
def divide_grids(self, points_df):
"""计算边界框并划分网格
Returns:
tuple: (grids, translations)
- grids: 网格边界列表
- translations: 网格平移量字典
tuple: grids 网格边界列表
"""
self.logger.info("开始划分网格")
@ -91,12 +92,15 @@ class GridDivider:
self.logger.info(f"区域宽度: {width:.2f}米, 高度: {height:.2f}")
# 精细调整网格的长宽避免出现2*grid_size-1的情况的影响
grid_size_lt = [self.grid_size -200, self.grid_size -100, self.grid_size , self.grid_size +100, self.grid_size +200]
grid_size_lt = [self.grid_size - 200, self.grid_size - 100,
self.grid_size, self.grid_size + 100, self.grid_size + 200]
width_modulus_lt = [width % grid_size for grid_size in grid_size_lt]
grid_width = grid_size_lt[width_modulus_lt.index(min(width_modulus_lt))]
grid_width = grid_size_lt[width_modulus_lt.index(
min(width_modulus_lt))]
height_modulus_lt = [height % grid_size for grid_size in grid_size_lt]
grid_height = grid_size_lt[height_modulus_lt.index(min(height_modulus_lt))]
grid_height = grid_size_lt[height_modulus_lt.index(
min(height_modulus_lt))]
self.logger.info(f"网格宽度: {grid_width:.2f}米, 网格高度: {grid_height:.2f}")
# 计算需要划分的网格数量
@ -108,18 +112,19 @@ class GridDivider:
lon_step = (max_lon - min_lon) / self.num_grids_width
grids = []
grid_translations = {} # 存储每个网格相对于第一个网格的平移量
# 先创建所有网格
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step
grid_max_lat = min_lat + (i + 1) * lat_step + self.overlap * lat_step
grid_max_lat = min_lat + \
(i + 1) * lat_step + self.overlap * lat_step
grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step
grid_max_lon = min_lon + (j + 1) * lon_step + self.overlap * lon_step
grid_max_lon = min_lon + \
(j + 1) * lon_step + self.overlap * lon_step
grid_id = (i, j) # 使用(i,j)作为网格标识i代表行j代表列
grid_bounds = (grid_min_lat, grid_max_lat, grid_min_lon, grid_max_lon)
grid_bounds = (grid_min_lat, grid_max_lat,
grid_min_lon, grid_max_lon)
grids.append(grid_bounds)
self.logger.debug(
@ -127,26 +132,10 @@ class GridDivider:
f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]"
)
# 计算每个网格相对于第一个网格的平移量
reference_grid = grids[0]
for i in range(self.num_grids_height):
for j in range(self.num_grids_width):
grid_id = (i, j)
grid_idx = i * self.num_grids_width + j
if grid_idx == 0: # 参考网格
grid_translations[grid_id] = (0, 0)
else:
translation = self.calculate_grid_translation(reference_grid, grids[grid_idx])
grid_translations[grid_id] = translation
self.logger.debug(
f"网格[{i},{j}]相对于参考网格的平移量: x={translation[0]:.2f}m, y={translation[1]:.2f}m"
)
self.logger.info(
f"成功划分为 {len(grids)} 个网格 ({self.num_grids_width}x{self.num_grids_height})")
return grids, grid_translations
return grids
def assign_to_grids(self, points_df, grids):
"""将点分配到对应网格"""
@ -205,7 +194,8 @@ class GridDivider:
# 计算网格的实际长度和宽度(米)
width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters
height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters
height = geodesic((min_lat, min_lon),
(max_lat, min_lon)).meters
plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon],
[min_lat, min_lat, max_lat, max_lat, min_lat],

View File

@ -2,6 +2,7 @@ import logging
import os
from datetime import datetime
def setup_logger(output_dir):
# 创建logs目录
log_dir = os.path.join(output_dir, 'logs')

View File

@ -1,252 +1,121 @@
import os
import time
import logging
import subprocess
from typing import Dict, Tuple
import pandas as pd
import numpy as np
from osgeo import gdal
class NotOverlapError(Exception):
"""图像重叠度不足异常"""
pass
import docker
class ODMProcessMonitor:
"""ODM处理监控器"""
def __init__(self, output_dir: str, mode: str = "快拼模式"):
def __init__(self, output_dir: str, mode: str = "三维模式"):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor')
self.mode = mode
def _check_success(self, grid_dir: str) -> bool:
"""检查ODM是否执行成功
检查项目:
1. 必要的文件夹和文件是否存在
2. 产品文件是否有效
"""
project_dir = os.path.join(grid_dir, 'project')
# 根据不同模式检查不同的产品
if self.mode == "快拼模式":
# 只检查正射影像
# if not self._check_orthophoto(project_dir):
# return False
pass
elif self.mode == "三维模式":
# 检查点云和实景三维
if not all([
os.path.exists(os.path.join(project_dir, 'odm_georeferencing', 'odm_georeferenced_model.laz')),
os.path.exists(os.path.join(project_dir, 'odm_texturing', 'odm_textured_model_geo.obj'))
]):
self.logger.error("点云或实景三维文件夹未生成")
return False
# TODO: 添加点云和实景三维的质量检查
elif self.mode == "重建模式":
# 检查所有产品
if not all([
os.path.exists(os.path.join(project_dir, 'odm_georeferencing', 'odm_georeferenced_model.laz')),
os.path.exists(os.path.join(project_dir, 'odm_texturing', 'odm_textured_model_geo.obj'))
]):
self.logger.error("部分必要的文件夹未生成")
return False
# 检查正射影像
# if not self._check_orthophoto(project_dir):
# return False
# TODO: 添加点云和实景三维的质量检查
return True
# TODO 正射影像怎么检查最好
def _check_orthophoto(self, project_dir: str) -> bool:
"""检查正射影像的质量"""
ortho_path = os.path.join(project_dir, 'odm_orthophoto', 'odm_orthophoto.original.tif')
if not os.path.exists(ortho_path):
self.logger.error("正射影像文件未生成")
return False
# 检查文件大小
file_size_mb = os.path.getsize(ortho_path) / (1024 * 1024) # 转换为MB
if file_size_mb < 1:
self.logger.error(f"正射影像文件过小: {file_size_mb:.2f}MB")
return False
try:
# 打开影像文件
ds = gdal.Open(ortho_path)
if ds is None:
self.logger.error("无法打开正射影像文件")
return False
# 读取第一个波段
band = ds.GetRasterBand(1)
# 获取统计信息
stats = band.GetStatistics(False, True)
if stats is None:
self.logger.error("无法获取影像统计信息")
return False
min_val, max_val, mean, std = stats
# 计算空值比例
no_data_value = band.GetNoDataValue()
array = band.ReadAsArray()
if no_data_value is not None:
no_data_ratio = np.sum(array == no_data_value) / array.size
else:
no_data_ratio = 0
# 检查空值比例是否过高超过50%
if no_data_ratio > 0.5:
self.logger.error(f"正射影像空值比例过高: {no_data_ratio:.2%}")
return False
# 检查影像是否全黑或全白
if max_val - min_val < 1:
self.logger.error("正射影像可能无效:像素值范围过小")
return False
ds = None # 关闭数据集
return True
except Exception as e:
self.logger.error(f"检查正射影像时发生错误: {str(e)}")
return False
def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple, produce_dem: bool = False) -> Tuple[bool, str]:
def run_odm_with_monitor(self, grid_dir: str, grid_id: tuple) -> Tuple[bool, str]:
"""运行ODM命令"""
self.logger.info(f"开始处理网格 ({grid_id[0]},{grid_id[1]})")
success = False
error_msg = ""
max_retries = 3
current_try = 0
cpu_cores = os.cpu_count()
# 根据模式设置是否使用lowest quality
use_lowest_quality = self.mode == "快拼模式"
# 初始化 Docker 客户端
client = docker.from_env()
while current_try < max_retries:
current_try += 1
self.logger.info(
f"{current_try} 次尝试处理网格 ({grid_id[0]},{grid_id[1]})")
try:
# 构建Docker命令
grid_dir = grid_dir[0].lower()+grid_dir[1:].replace('\\', '/')
docker_command = (
f"docker run --gpus all -ti --rm "
f"-v {grid_dir}:/datasets "
f"opendronemap/odm:gpu "
# 构建 Docker 容器运行参数
grid_dir = grid_dir[0].lower(
) + grid_dir[1:].replace('\\', '/')
volumes = {
grid_dir: {'bind': '/datasets', 'mode': 'rw'}
}
command = (
f"--project-path /datasets project "
f"--max-concurrency 15 "
f"--max-concurrency {cpu_cores} "
f"--force-gps "
f"--use-exif "
f"--use-hybrid-bundle-adjustment "
f"--optimize-disk-space "
# f"--feature-quality ultra "
)
# 根据是否使用lowest quality添加参数
if use_lowest_quality:
docker_command += f"--feature-quality lowest "
docker_command += f"--orthophoto-resolution 8 "
if produce_dem:
docker_command += (
f"--dsm "
f"--dtm "
f"--orthophoto-cutline "
f"--feature-type sift "
f"--orthophoto-resolution 8 "
# f"--mesh-size 5000000 "
# f"--mesh-octree-depth 13 "
)
if self.mode == "快拼模式":
docker_command += (
#f"--fast-orthophoto "
command += (
f"--fast-orthophoto "
f"--skip-3dmodel "
)
# elif self.mode == "三维模式":
# docker_command += (
# f"--skip-orthophoto "
# )
else: # 三维模式参数
command += (
f"--dsm "
f"--dtm "
)
if current_try == 1:
command += (
f"--feature-quality low "
)
docker_command += "--rerun-all"
self.logger.info(docker_command)
command += "--rerun-all"
result = subprocess.run(
docker_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = result.stdout.decode(
'utf-8'), result.stderr.decode('utf-8')
self.logger.info(f"Docker 命令: {command}")
stdout_lines = stdout.strip().split('\n')
last_lines = '\n'.join(
stdout_lines[-50:] if len(stdout_lines) > 10 else stdout_lines)
self.logger.info(f"==========stdout==========: {last_lines}")
if stderr:
self.logger.error(f"docker run指令执行失败")
self.logger.error(f"==========stderr==========: {stderr}")
if "error during connect" in stderr or "The system cannot find the file specified" in stderr:
error_msg = "Docker没有启动请启动Docker"
elif "user declined directory sharing" in stderr:
error_msg = "Docker无法访问目录请检查目录权限和共享设置"
# 运行 Docker 容器
container = client.containers.run(
image="opendronemap/odm:gpu",
command=command,
volumes=volumes,
detach=True,
remove=False,
runtime="nvidia", # 使用 GPU
)
# 等待容器运行完成
exit_status = container.wait()
if exit_status["StatusCode"] != 0:
self.logger.error(
f"容器运行失败,退出状态码: {exit_status['StatusCode']}")
# 获取容器的错误日志
error_msg = container.logs(
stderr=True).decode("utf-8").splitlines()
self.logger.error("容器运行失败的详细错误日志:")
for line in error_msg:
self.logger.error(line)
container.remove()
time.sleep(5)
else:
error_msg = "Docker运行失败需要人工排查错误"
break
else:
self.logger.info("docker run指令执行成功")
if "ODM app finished" in last_lines:
self.logger.info("ODM处理完成")
if self._check_success(grid_dir):
self.logger.info(
f"网格 ({grid_id[0]},{grid_id[1]}) 处理成功")
# 获取所有日志
logs = container.logs().decode("utf-8").splitlines()
# 输出最后 50 行日志
self.logger.info("容器运行完成,以下是最后 50 行日志:")
for line in logs[-50:]:
self.logger.info(line)
success = True
error_msg = ""
break
else:
self.logger.error(
f"虽然ODM处理完成但是生产产品质量可能不合格需要人工检查")
raise NotOverlapError
# TODO 先写成这样,后面这三种情况可能处理不一样
elif "enough overlap" in last_lines:
raise NotOverlapError
elif "out of memory" in last_lines:
raise NotOverlapError
elif "strange values" in last_lines:
raise NotOverlapError
else:
raise NotOverlapError
except NotOverlapError:
if use_lowest_quality and self.mode == "快拼模式":
self.logger.warning(
"检测到not overlap错误移除lowest quality参数后重试")
use_lowest_quality = False
time.sleep(10)
continue
else:
self.logger.error(
"出现错误,需要人工检查数据集")
error_msg = "图像重叠度不足,需要人工检查数据集的采样间隔情况"
container.remove()
break
return success, error_msg
def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame], produce_dem: bool) -> Dict[tuple, pd.DataFrame]:
def process_all_grids(self, grid_points: Dict[tuple, pd.DataFrame]) -> list:
"""处理所有网格
Returns:
Dict[tuple, pd.DataFrame]: 成功处理的网格点数据字典
"""
self.logger.info("开始执行网格处理")
successful_grid_points = {}
successful_grid_lt = []
failed_grids = []
for grid_id, points in grid_points.items():
@ -258,11 +127,10 @@ class ODMProcessMonitor:
success, error_msg = self.run_odm_with_monitor(
grid_dir=grid_dir,
grid_id=grid_id,
produce_dem=produce_dem
)
if success:
successful_grid_points[grid_id] = points
successful_grid_lt.append(grid_id)
else:
self.logger.error(
f"网格 ({grid_id[0]},{grid_id[1]}) 处理失败: {error_msg}")
@ -277,7 +145,7 @@ class ODMProcessMonitor:
# 汇总处理结果
total_grids = len(grid_points)
failed_count = len(failed_grids)
success_count = len(successful_grid_points)
success_count = len(successful_grid_lt)
self.logger.info(
f"网格处理完成。总计: {total_grids}, 成功: {success_count}, 失败: {failed_count}")
@ -287,7 +155,7 @@ class ODMProcessMonitor:
self.logger.error(
f"网格 ({grid_id[0]},{grid_id[1]}): {error_msg}")
if len(successful_grid_points) == 0:
if len(successful_grid_lt) == 0:
raise Exception("所有网格处理都失败,无法继续处理")
return successful_grid_points
return successful_grid_lt

View File

@ -3,6 +3,7 @@ import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import Optional
from pyproj import Transformer
class FilterVisualizer:
@ -17,6 +18,25 @@ class FilterVisualizer:
"""
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
# 创建坐标转换器
self.transformer = Transformer.from_crs(
"EPSG:4326", # WGS84经纬度坐标系
"EPSG:32649", # UTM49N
always_xy=True
)
def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple:
"""
将经纬度坐标转换为UTM坐标
Args:
lon: 经度序列
lat: 纬度序列
Returns:
tuple: (x坐标, y坐标)
"""
return self.transformer.transform(lon, lat)
def visualize_filter_step(self,
current_points: pd.DataFrame,
@ -35,37 +55,47 @@ class FilterVisualizer:
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
filtered_files = set(
previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(
filtered_files)]
# 转换坐标到UTM
current_x, current_y = self._convert_to_utm(
current_points['lon'], current_points['lat'])
filtered_x, filtered_y = self._convert_to_utm(
filtered_points['lon'], filtered_points['lat'])
# 创建图形
plt.figure(figsize=(20, 16))
plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(20, 20))
# 绘制保留的点
plt.scatter(current_points['lon'], current_points['lat'],
color='blue', label='Retained Points',
alpha=0.6, s=50)
plt.scatter(current_x, current_y,
color='blue', label='保留的点',
alpha=0.6, s=5)
# 绘制被过滤的点
if not filtered_points.empty:
plt.scatter(filtered_points['lon'], filtered_points['lat'],
color='red', marker='x', label='Filtered Points',
alpha=0.6, s=100)
plt.scatter(filtered_x, filtered_y,
color='red', marker='x', label='过滤的点')
# 设置图形属性
plt.title(f"GPS Points After {step_name}\n"
f"(Filtered: {len(filtered_points)}, Retained: {len(current_points)})",
plt.title(f"{step_name}后的GPS点\n"
f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})",
fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.xlabel("东向坐标 (米)", fontsize=12)
plt.ylabel("北向坐标 (米)", fontsize=12)
plt.grid(True)
plt.axis('equal')
# 添加统计信息
stats_text = (
f"Original Points: {len(previous_points)}\n"
f"Filtered Points: {len(filtered_points)}\n"
f"Remaining Points: {len(current_points)}\n"
f"Filter Rate: {len(filtered_points)/len(previous_points)*100:.1f}%"
f"原始点数: {len(previous_points)}\n"
f"过滤点数: {len(filtered_points)}\n"
f"保留点数: {len(current_points)}\n"
f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%"
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
@ -78,7 +108,8 @@ class FilterVisualizer:
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
save_path = os.path.join(
self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()