UAV_odm_merge/odm_preprocess.py
2024-12-30 23:08:11 +08:00

347 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict, Optional
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from filter.cluster_filter import GPSCluster
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from filter.gps_filter import GPSFilter
from utils.odm_monitor import ODMProcessMonitor
from utils.gps_extractor import GPSExtractor
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from utils.visualizer import FilterVisualizer
from post_pro.merge_tif import MergeTif
from tools.test_docker_run import run_docker_command
from post_pro.merge_obj import MergeObj
from post_pro.merge_ply import MergePly
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: Optional[str] = None
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
fast_mode: bool = False
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
self.grandpa_dir = os.path.dirname(
os.path.dirname(self.config.image_dir))
self.config.output_dir = os.path.join(self.grandpa_dir, 'preprocess')
# 清理并重建输出目录
if os.path.exists(config.output_dir):
self._clean_output_dir()
self._setup_output_dirs()
# 初始化其他组件
self.logger = setup_logger(config.output_dir)
self.gps_points = None
self.odm_monitor = ODMProcessMonitor(
config.output_dir, fast_mode=config.fast_mode)
self.visualizer = FilterVisualizer(config.output_dir)
def _clean_output_dir(self):
"""清理输出目录"""
try:
shutil.rmtree(self.config.output_dir)
print(f"已清理输出目录: {self.config.output_dir}")
except Exception as e:
print(f"清理输出目录时发生错误: {str(e)}")
raise
def _setup_output_dirs(self):
"""创建必要的输出目录结构"""
try:
# 创建主输出目录
os.makedirs(self.config.output_dir)
# 创建过滤图像保存目录
os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs_visual'))
# 创建日志目录
os.makedirs(os.path.join(self.config.output_dir, 'logs'))
print(f"已创建输出目录结构: {self.config.output_dir}")
except Exception as e:
print(f"创建输出目录时发生错误: {str(e)}")
raise
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
return self.gps_points
def cluster(self, previous_points) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类")
# 创建聚类器并执行聚类
clusterer = GPSCluster(
previous_points, output_dir=self.config.output_dir,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
# 获取主要类别的点
self.clustered_points = clusterer.fit()
# 获取统计信息并记录
stats, retained_points, removed_points = clusterer.get_cluster_stats(
self.clustered_points)
self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}"
)
# 可视化聚类结果
self.visualizer.visualize_filter_step(
retained_points, removed_points, "1-Clustering")
# 移动被过滤的图片
self.move_images(removed_points, "cluster")
return retained_points
def filter_time_group_overlap(self, previous_points) -> pd.DataFrame:
"""过滤重叠的时间组"""
self.logger.info("开始过滤重叠时间组")
self.logger.info("开始过滤重叠时间组")
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
deleted_files = filter.filter_overlapping_groups(
time_threshold=self.config.time_group_interval
)
# 更新GPS点数据移除被删除的图像
retained_points = previous_points[~previous_points['file'].isin(
deleted_files)]
removed_points = previous_points[previous_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(retained_points)} 个GPS点")
# 可视化过滤结果
self.visualizer.visualize_filter_step(
retained_points, removed_points, "2-Time Group Overlap")
# 移动被过滤的图片
self.move_images(removed_points, "time_group_overlap")
return retained_points
# TODO 过滤算法还需要更新
def filter_points(self, previous_points) -> pd.DataFrame:
"""过滤GPS点"""
self.logger.info("开始过滤GPS点")
filter = GPSFilter(self.config.output_dir)
# 过滤孤立点
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, "
f"最小邻居数: {self.config.filter_min_neighbors})"
)
retained_points, removed_points = filter.filter_isolated_points(
previous_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.logger.info(f"孤立点过滤后剩余 {len(retained_points)} 个GPS点")
# 可视化孤立点过滤结果
self.visualizer.visualize_filter_step(
retained_points, removed_points, "3-Isolated Points")
# 移动被过滤的图片
self.move_images(removed_points, "isolated_points")
# # 过滤密集点
# previous_points = self.gps_points.copy()
# self.logger.info(
# f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, "
# f"距离阈值: {self.config.filter_dense_distance_threshold})"
# )
# self.gps_points = filter.filter_dense_points(
# self.gps_points,
# grid_size=self.config.filter_grid_size,
# distance_threshold=self.config.filter_dense_distance_threshold,
# time_threshold=self.config.filter_time_threshold,
# )
# self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
# # 可视化密集点过滤结果
# self.visualizer.visualize_filter_step(
# self.gps_points, previous_points, "4-Dense Points")
return retained_points
def divide_grids(self) -> Dict[int, pd.DataFrame]:
"""划分网格"""
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(
overlap=self.config.grid_overlap,
output_dir=self.config.output_dir
)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size
)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
# 生成image_groups.txt文件
try:
groups_file = os.path.join(
os.path.dirname(self.config.image_dir), "image_groups.txt"
)
self.logger.info(f"开始生成分组文件: {groups_file}")
with open(groups_file, 'w') as f:
for grid_idx, points_lt in grid_points.items():
# 使用ASCII字母作为组标识A, B, C...
group_letter = chr(65 + grid_idx) # 65是ASCII中'A'的编码
# 为每个网格中的图像写入分组信息
for point in points_lt:
f.write(f"{point['file']} {group_letter}\n")
self.logger.info(f"分组文件生成成功: {groups_file}")
except Exception as e:
self.logger.error(f"生成分组文件时发生错误: {str(e)}", exc_info=True)
raise
def move_images(self, removed_points: pd.DataFrame, step_name: str):
"""
将被过滤掉的图片移动到ret文件夹中
Args:
removed_points: 被过滤掉的GPS点对应的数据
step_name: 过滤步骤名称,用于创建子文件夹
"""
if removed_points.empty:
return
# 创建ret目录和对应步骤的子目录
ret_dir = os.path.join(self.grandpa_dir, 'ret')
os.makedirs(ret_dir, exist_ok=True)
self.logger.info(f"开始移动{step_name}步骤中被过滤的图片")
# 移动每张被过滤的图片
for _, point in removed_points.iterrows():
src_path = os.path.join(self.config.image_dir, point['file'])
dst_path = os.path.join(ret_dir, point['file'])
try:
shutil.move(src_path, dst_path)
except Exception as e:
self.logger.warning(f"移动图片 {point['file']} 时发生错误: {str(e)}")
self.logger.info(f"完成移动 {len(removed_points)} 张被{step_name}过滤的图片")
def restore_filtered_images(self):
"""将ret文件夹中的图片恢复到原始图片目录"""
try:
# 获取ret文件夹路径
ret_dir = os.path.join(self.grandpa_dir, 'ret')
if not os.path.exists(ret_dir):
self.logger.info("没有找到ret文件夹跳过恢复步骤")
return
self.logger.info("开始恢复被过滤的图片")
# 获取ret文件夹中的所有图片
filtered_images = os.listdir(ret_dir)
# 将图片移回原始目录
for img in filtered_images:
src_path = os.path.join(ret_dir, img)
dst_path = os.path.join(self.config.image_dir, img)
try:
shutil.move(src_path, dst_path)
except Exception as e:
self.logger.warning(f"恢复图片 {img} 时发生错误: {str(e)}")
self.logger.info(f"成功恢复 {len(filtered_images)} 张图片")
except Exception as e:
self.logger.error(f"恢复图片过程中发生错误: {str(e)}", exc_info=True)
raise
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.gps_points = self.cluster(self.gps_points)
# self.gps_points = self.filter_time_group_overlap(self.gps_points)
self.gps_points = self.filter_points(self.gps_points)
self.divide_grids()
self.logger.info("预处理任务完成")
self.odm_monitor.run_odm_with_monitor(
self.grandpa_dir, self.config.fast_mode)
self.restore_filtered_images()
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"G:\error_data\20241104140457\project\images",
cluster_eps=0.01,
cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_size=1000,
grid_overlap=0.05,
fast_mode=False,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()