ODM_pro/odm_preprocess.py
2024-12-22 20:19:12 +08:00

179 lines
5.9 KiB
Python

import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from filter.cluster_filter import GPSCluster
from utils.command_runner import CommandRunner
from utils.gps_extractor import GPSExtractor
from filter.gps_filter import GPSFilter
from utils.grid_divider import GridDivider
from utils.logger import setup_logger
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from utils.visualizer import FilterVisualizer
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
enable_time_group_filter: bool = True
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
enable_filter: bool = True
enable_grid_division: bool = True
enable_visualization: bool = True
enable_copy_images: bool = True
mode: str = "快拼模式"
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
self.logger = setup_logger(config.output_dir)
self.gps_points = None
self.command_runner = CommandRunner(config.output_dir, mode=config.mode)
self.visualizer = FilterVisualizer(config.output_dir)
# 用于存储每个步骤的点数据
self.points_history = []
self.step_names = []
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
# 记录初始状态
self.points_history.append(self.gps_points.copy())
self.step_names.append("Initial")
return self.gps_points
def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类"""
self.logger.info("开始聚类")
previous_points = self.gps_points.copy()
clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
# 可视化聚类结果
if self.config.enable_visualization:
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "Clustering")
# 记录这一步的结果
self.points_history.append(self.gps_points.copy())
self.step_names.append("Clustering")
return self.gps_points
def filter_time_group_overlap(self) -> pd.DataFrame:
"""过滤重叠的时间组"""
if not self.config.enable_time_group_filter:
return self.gps_points
self.logger.info("开始过滤重叠时间组")
previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
deleted_files = filter.filter_overlapping_groups(
time_threshold=self.config.time_group_interval
)
self.gps_points = self.gps_points[~self.gps_points['file'].isin(deleted_files)]
# 可视化过滤结果
if self.config.enable_visualization:
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "Time Group Overlap")
# 记录这一步的结果
self.points_history.append(self.gps_points.copy())
self.step_names.append("Time Group Overlap")
return self.gps_points
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
self.filter_time_group_overlap()
# 在处理结束时生成所有步骤的可视化
if self.config.enable_visualization:
self.visualizer.visualize_all_steps(
self.points_history, self.step_names)
self.logger.info("预处理任务完成")
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"F:\error_data\20241016140912\code\images",
output_dir=r"G:\output",
cluster_eps=0.01,
cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
enable_time_group_filter=True,
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_overlap=0.03,
grid_size=1000,
enable_filter=True,
enable_grid_division=True,
enable_visualization=True,
enable_copy_images=True,
mode="快拼模式",
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()