加入快拼模式,修改参数

This commit is contained in:
龙澳 2024-12-21 12:03:54 +08:00
parent ebe818125e
commit 0e91d125cb
5 changed files with 105 additions and 70 deletions

View File

@ -1,5 +1,6 @@
import os import os
import shutil import shutil
from datetime import timedelta
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict from typing import Dict
@ -21,19 +22,25 @@ class PreprocessConfig:
image_dir: str image_dir: str
output_dir: str output_dir: str
eps: float = 0.01 # 聚类过滤参数
min_samples: int = 5 cluster_eps: float = 0.01
filter_grid_size: float = 0.001 cluster_min_samples: int = 5
filter_dense_distance_threshold: float = 10 # 孤立点过滤参数
filter_distance_threshold: float = 0.001 filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6 filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05 grid_overlap: float = 0.05
grid_size: float = 500 grid_size: float = 500
# 几个pipline过程是否开启
enable_filter: bool = True enable_filter: bool = True
enable_grid_division: bool = True enable_grid_division: bool = True
enable_visualization: bool = True enable_visualization: bool = True
enable_copy_images: bool = True enable_copy_images: bool = True
mode: str = "快拼模式"
class ImagePreprocessor: class ImagePreprocessor:
def __init__(self, config: PreprocessConfig): def __init__(self, config: PreprocessConfig):
@ -55,11 +62,13 @@ class ImagePreprocessor:
self.logger.info("开始聚类") self.logger.info("开始聚类")
# 创建聚类器并执行聚类 # 创建聚类器并执行聚类
clusterer = GPSCluster( clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir) self.gps_points, output_dir=self.config.output_dir,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
# 获取主要类别的点 # 获取主要类别的点
self.gps_points = clusterer.get_main_cluster() self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
# 获取统计信息并记录 # 获取统计信息并记录
stats = clusterer.get_cluster_stats() stats = clusterer.get_cluster_stats(self.clustered_points)
self.logger.info( self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}" f"噪声点 {stats['noise_points']}"
@ -90,6 +99,7 @@ class ImagePreprocessor:
self.gps_points, self.gps_points,
grid_size=self.config.filter_grid_size, grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold, distance_threshold=self.config.filter_dense_distance_threshold,
time_threshold=self.config.filter_time_threshold,
) )
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points return self.gps_points
@ -188,15 +198,15 @@ class ImagePreprocessor:
try: try:
self.extract_gps() self.extract_gps()
self.cluster() self.cluster()
# self.time_filter() self.filter_points()
# self.filter_points()
grid_points = self.divide_grids() grid_points = self.divide_grids()
self.copy_images(grid_points) self.copy_images(grid_points)
self.visualize_results() self.visualize_results()
# self.logger.info("预处理任务完成") self.logger.info("预处理任务完成")
self.command_runner.run_grid_commands( self.command_runner.run_grid_commands(
grid_points, grid_points,
self.config.enable_grid_division self.config.enable_grid_division,
self.mode
) )
except Exception as e: except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
@ -206,14 +216,22 @@ class ImagePreprocessor:
if __name__ == "__main__": if __name__ == "__main__":
# 创建配置 # 创建配置
config = PreprocessConfig( config = PreprocessConfig(
image_dir=r"E:\湖南省第二测绘院\11-06-项目移交文件(王辉给)\无人机二三维节点扩容生产影像\影像数据\199\code\images", image_dir=r"E:\datasets\UAV\1815\images",
output_dir=r"test", output_dir=r"test",
filter_grid_size=0.001,
filter_dense_distance_threshold=10, cluster_eps=0.01,
cluster_min_samples=5,
filter_distance_threshold=0.001, filter_distance_threshold=0.001,
filter_min_neighbors=6, filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_overlap=0.05, grid_overlap=0.05,
grid_size=500, grid_size=500,
enable_filter=True, enable_filter=True,
enable_grid_division=True, enable_grid_division=True,
enable_visualization=True, enable_visualization=True,

View File

@ -17,7 +17,6 @@ class GPSCluster:
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler() self.scaler = StandardScaler()
self.gps_points = gps_points self.gps_points = gps_points
self.clustered_points = self.fit()
self.log_file = os.path.join(output_dir, 'del_imgs.txt') self.log_file = os.path.join(output_dir, 'del_imgs.txt')
def fit(self): def fit(self):
@ -57,7 +56,7 @@ class GPSCluster:
return result_df return result_df
def get_cluster_stats(self): def get_cluster_stats(self, clustered_points):
""" """
获取聚类统计信息 获取聚类统计信息
@ -67,22 +66,22 @@ class GPSCluster:
返回: 返回:
聚类统计信息的字典 聚类统计信息的字典
""" """
main_cluster_points = sum(self.clustered_points["cluster"] == 1) main_cluster_points = sum(clustered_points["cluster"] == 1)
stats = { stats = {
"total_points": len(self.clustered_points), "total_points": len(clustered_points),
"main_cluster_points": main_cluster_points, "main_cluster_points": main_cluster_points,
"noise_points": sum(self.clustered_points["cluster"] == -1), "noise_points": sum(clustered_points["cluster"] == -1),
} }
noise_cluster = self.get_noise_cluster() noise_cluster = self.get_noise_cluster(clustered_points)
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, 'a', encoding='utf-8') as f:
for i, (_, row) in enumerate(noise_cluster.iterrows()): for i, (_, row) in enumerate(noise_cluster.iterrows()):
f.write(row['file']+'\n') f.write(row['file']+'\n')
f.write('\n') f.write('\n')
return stats return stats
def get_main_cluster(self): def get_main_cluster(self, clustered_points):
return self.clustered_points[self.clustered_points["cluster"] == 1] return clustered_points[clustered_points["cluster"] == 1]
def get_noise_cluster(self): def get_noise_cluster(self, clustered_points):
return self.clustered_points[self.clustered_points["cluster"] == -1] return clustered_points[clustered_points["cluster"] == -1]

View File

@ -35,11 +35,15 @@ i
""" """
try: try:
grid_dir = os.path.join(self.output_dir, f'grid_{grid_idx + 1}') grid_dir = os.path.join(self.output_dir, f'grid_{grid_idx + 1}')
command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh" if self.mode == "快拼模式":
command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh --fast-orthophoto --skip-3dmodel"
else:
command = f"docker run -ti --rm -v {grid_dir}:/datasets opendronemap/odm --project-path /datasets project --feature-quality lowest --force-gps --use-3dmesh"
self.logger.info(f"开始执行命令: {command}") self.logger.info(f"开始执行命令: {command}")
success, error_msg = self.monitor.run_odm_with_monitor(command, grid_dir, grid_idx) success, error_msg = self.monitor.run_odm_with_monitor(
command, grid_dir, grid_idx)
if not success: if not success:
raise Exception(error_msg) raise Exception(error_msg)

View File

@ -52,7 +52,7 @@ class GPSFilter:
if grid not in grid_map: if grid not in grid_map:
grid_map[grid] = [] grid_map[grid] = []
grid_map[grid].append((row['file'], row['lat'], row['lon'])) grid_map[grid].append((row['file'], row['lat'], row['lon']))
self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中") self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中")
# 在每个网格中计算两两距离并排序 # 在每个网格中计算两两距离并排序
@ -68,11 +68,11 @@ class GPSFilter:
return sorted_distances return sorted_distances
def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta = timedelta(minutes=5)) -> list: def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list:
"""根据拍摄时间分组图片 """根据拍摄时间分组图片
如果相邻两张图片的拍摄时间差超过5分钟则进行切分 如果相邻两张图片的拍摄时间差超过5分钟则进行切分
Args: Args:
points_df: 包含图片信息的DataFrame必须包含'file''date' points_df: 包含图片信息的DataFrame必须包含'file''date'
time_threshold: 时间间隔阈值默认5分钟 time_threshold: 时间间隔阈值默认5分钟
@ -97,32 +97,34 @@ class GPSFilter:
# 按时间排序 # 按时间排序
valid_date_points = valid_date_points.sort_values('date') valid_date_points = valid_date_points.sort_values('date')
self.logger.info(f"有效时间范围: {valid_date_points['date'].min()}{valid_date_points['date'].max()}") self.logger.info(
f"有效时间范围: {valid_date_points['date'].min()}{valid_date_points['date'].max()}")
# 计算时间差 # 计算时间差
time_diffs = valid_date_points['date'].diff() time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置 # 找到时间差超过阈值的位置
time_groups = [] time_groups = []
current_group_start = 0 current_group_start = 0
for idx, time_diff in enumerate(time_diffs): for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold: if time_diff and time_diff > time_threshold:
# 添加当前组 # 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx] current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group) time_groups.append(current_group)
# 记录断点信息 # 记录断点信息
break_time = valid_date_points.iloc[idx]['date'] break_time = valid_date_points.iloc[idx]['date']
group_start_time = current_group.iloc[0]['date'] group_start_time = current_group.iloc[0]['date']
group_end_time = current_group.iloc[-1]['date'] group_end_time = current_group.iloc[-1]['date']
self.logger.info( self.logger.info(
f"时间组 {len(time_groups)}: {len(current_group)} 个点, " f"时间组 {len(time_groups)}: {len(current_group)} 个点, "
f"时间范围 [{group_start_time} - {group_end_time}]" f"时间范围 [{group_start_time} - {group_end_time}]"
) )
self.logger.info(f"在时间 {break_time} 处发现断点,时间差为 {time_diff}") self.logger.info(
f"在时间 {break_time} 处发现断点,时间差为 {time_diff}")
current_group_start = idx current_group_start = idx
# 添加最后一组 # 添加最后一组
@ -142,11 +144,11 @@ class GPSFilter:
self.logger.info(f"共分为 {len(time_groups)} 个时间组") self.logger.info(f"共分为 {len(time_groups)} 个时间组")
return time_groups return time_groups
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_interval=60): def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)):
""" """
过滤密集点先按时间分组再在每个时间组内过滤 过滤密集点先按时间分组再在每个时间组内过滤
空时间戳的点不进行过滤 空时间戳的点不进行过滤
Args: Args:
points_df: 点数据 points_df: 点数据
grid_size: 网格大小 grid_size: 网格大小
@ -154,27 +156,28 @@ class GPSFilter:
time_interval: 时间间隔 time_interval: 时间间隔
""" """
self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, " self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, "
f"距离阈值: {distance_threshold}米, 时间间隔: {time_interval}秒)") f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)")
# 按时间分组 # 按时间分组
time_groups = self._group_by_time(points_df, time_interval) time_groups = self._group_by_time(points_df, time_threshold)
# 存储所有要删除的图片 # 存储所有要删除的图片
all_to_del_imgs = [] all_to_del_imgs = []
# 对每个时间组进行密集点过滤 # 对每个时间组进行密集点过滤
for group_idx, group_points in enumerate(time_groups): for group_idx, group_points in enumerate(time_groups):
# 检查是否为空时间戳组(最后一组) # 检查是否为空时间戳组(最后一组)
if group_idx == len(time_groups) - 1 and group_points['date'].isna().any(): if group_idx == len(time_groups) - 1 and group_points['date'].isna().any():
self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)") self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)")
continue continue
self.logger.info(f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)") self.logger.info(
f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)")
# 计算该组内的点间距离 # 计算该组内的点间距离
sorted_distances = self._get_distances(group_points, grid_size) sorted_distances = self._get_distances(group_points, grid_size)
group_to_del_imgs = [] group_to_del_imgs = []
# 在每个网格中过滤密集点 # 在每个网格中过滤密集点
for grid, distances in sorted_distances.items(): for grid, distances in sorted_distances.items():
grid_del_count = 0 grid_del_count = 0
@ -182,7 +185,7 @@ class GPSFilter:
candidate_img1, candidate_img2, dist = distances[0] candidate_img1, candidate_img2, dist = distances[0]
if dist < distance_threshold: if dist < distance_threshold:
distances.pop(0) distances.pop(0)
# 获取候选图片的其他最短距离 # 获取候选图片的其他最短距离
candidate_img1_dist = None candidate_img1_dist = None
candidate_img2_dist = None candidate_img2_dist = None
@ -194,52 +197,60 @@ class GPSFilter:
if candidate_img2 in distance: if candidate_img2 in distance:
candidate_img2_dist = distance[2] candidate_img2_dist = distance[2]
break break
# 选择要删除的点 # 选择要删除的点
if candidate_img1_dist and candidate_img2_dist: if candidate_img1_dist and candidate_img2_dist:
to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2 to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2
group_to_del_imgs.append(to_del_img) group_to_del_imgs.append(to_del_img)
grid_del_count += 1 grid_del_count += 1
self.logger.debug(f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)") self.logger.debug(
distances = [d for d in distances if to_del_img not in d] f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)")
distances = [
d for d in distances if to_del_img not in d]
else: else:
break break
if grid_del_count > 0: if grid_del_count > 0:
self.logger.info(f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点") self.logger.info(
f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点")
all_to_del_imgs.extend(group_to_del_imgs) all_to_del_imgs.extend(group_to_del_imgs)
self.logger.info(f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点") self.logger.info(
f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点")
# 写入删除日志 # 写入删除日志
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, 'a', encoding='utf-8') as f:
for img in all_to_del_imgs: for img in all_to_del_imgs:
f.write(img + '\n') f.write(img + '\n')
# 过滤数据 # 过滤数据
filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)] filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)]
self.logger.info(f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") self.logger.info(
f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df return filtered_df
def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6): def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6):
"""过滤孤立点""" """过滤孤立点"""
self.logger.info(f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})") self.logger.info(
f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})")
coords = points_df[['lat', 'lon']].values coords = points_df[['lat', 'lon']].values
kdtree = KDTree(coords) kdtree = KDTree(coords)
neighbors_count = [len(kdtree.query_ball_point( neighbors_count = [len(kdtree.query_ball_point(
coord, threshold_distance)) for coord in coords] coord, threshold_distance)) for coord in coords]
isolated_points = [] isolated_points = []
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, 'a', encoding='utf-8') as f:
for i, (_, row) in enumerate(points_df.iterrows()): for i, (_, row) in enumerate(points_df.iterrows()):
if neighbors_count[i] < min_neighbors: if neighbors_count[i] < min_neighbors:
isolated_points.append(row['file']) isolated_points.append(row['file'])
f.write(row['file']+'\n') f.write(row['file']+'\n')
self.logger.debug(f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})") self.logger.debug(
f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})")
f.write('\n') f.write('\n')
filtered_df = points_df[~points_df['file'].isin(isolated_points)] filtered_df = points_df[~points_df['file'].isin(isolated_points)]
self.logger.info(f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点") self.logger.info(
f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df return filtered_df

View File

@ -8,7 +8,7 @@ from typing import Optional, Tuple
class ODMProcessMonitor: class ODMProcessMonitor:
"""ODM进程监控器""" """ODM进程监控器"""
def __init__(self, max_retries: int = 3, check_interval: int = 5): def __init__(self, max_retries: int = 3, check_interval: int = 300):
""" """
初始化监控器 初始化监控器
@ -31,7 +31,10 @@ class ODMProcessMonitor:
def _check_success(self, grid_dir: str) -> bool: def _check_success(self, grid_dir: str) -> bool:
"""检查ODM是否执行成功""" """检查ODM是否执行成功"""
# ODM成功完成时会生成这些文件夹 # ODM成功完成时会生成这些文件夹
success_markers = ['odm_orthophoto', 'odm_georeferencing', 'odm_texturing'] if self.mode == "快拼模式":
success_markers = ['odm_orthophoto', 'odm_georeferencing']
else:
success_markers = ['odm_orthophoto', 'odm_georeferencing', 'odm_texturing']
return all(os.path.exists(os.path.join(grid_dir, marker)) for marker in success_markers) return all(os.path.exists(os.path.join(grid_dir, marker)) for marker in success_markers)
def run_odm_with_monitor(self, command: str, grid_dir: str, grid_idx: int) -> Tuple[bool, str]: def run_odm_with_monitor(self, command: str, grid_dir: str, grid_idx: int) -> Tuple[bool, str]: