ODM_pro/odm_preprocess.py
2024-12-21 12:03:54 +08:00

244 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
from datetime import timedelta
from dataclasses import dataclass
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from preprocess.cluster import GPSCluster
from preprocess.command_runner import CommandRunner
from preprocess.gps_extractor import GPSExtractor
from preprocess.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider
from preprocess.logger import setup_logger
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
# 聚类过滤参数
cluster_eps: float = 0.01
cluster_min_samples: int = 5
# 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6
# 密集点过滤参数
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数
grid_overlap: float = 0.05
grid_size: float = 500
# 几个pipline过程是否开启
enable_filter: bool = True
enable_grid_division: bool = True
enable_visualization: bool = True
enable_copy_images: bool = True
mode: str = "快拼模式"
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
self.logger = setup_logger(config.output_dir)
self.gps_points = []
self.command_runner = CommandRunner(config.output_dir)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
return self.gps_points
def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类")
# 创建聚类器并执行聚类
clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
# 获取主要类别的点
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
# 获取统计信息并记录
stats = clusterer.get_cluster_stats(self.clustered_points)
self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}"
)
def filter_points(self) -> pd.DataFrame:
"""过滤GPS点"""
if not self.config.enable_filter:
return self.gps_points
self.logger.info("开始过滤GPS点")
filter = GPSFilter(self.config.output_dir)
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
)
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
self.logger.info(
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
)
self.gps_points = filter.filter_dense_points(
self.gps_points,
grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold,
time_threshold=self.config.filter_time_threshold,
)
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
def divide_grids(self) -> Dict[int, pd.DataFrame]:
"""划分网格"""
if not self.config.enable_grid_division:
return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(overlap=self.config.grid_overlap)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size
)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
return grid_points
def copy_images(self, grid_points: Dict[int, pd.DataFrame]):
"""复制图像到目标文件夹"""
if not self.config.enable_copy_images:
return
self.logger.info("开始复制图像文件")
for grid_idx, points in grid_points.items():
if self.config.enable_grid_division:
output_dir = os.path.join(
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
)
else:
output_dir = os.path.join(
self.config.output_dir, "project", "images")
os.makedirs(output_dir, exist_ok=True)
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
def visualize_results(self):
"""可视化处理结果"""
if not self.config.enable_visualization:
return
self.logger.info("开始生成可视化结果")
extractor = GPSExtractor(self.config.image_dir)
original_points_df = extractor.extract_all_gps()
# 读取被过滤的图片列表
with open(
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
) as file:
filtered_files = [line.strip() for line in file if line.strip()]
# 创建一个新的图形
plt.figure(figsize=(20, 16))
# 绘制所有原始点
plt.scatter(
original_points_df["lon"],
original_points_df["lat"],
color="blue",
label="Original Points",
alpha=0.6,
)
# 绘制被过滤的点
filtered_points_df = original_points_df[
original_points_df["file"].isin(filtered_files)
]
plt.scatter(
filtered_points_df["lon"],
filtered_points_df["lat"],
color="red",
label="Filtered Points",
alpha=0.6,
)
# 设置图形属性
plt.title("GPS Coordinates of Images", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend()
# 保存图形
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
plt.close()
self.logger.info("预处理结果图已保存")
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
self.filter_points()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.visualize_results()
self.logger.info("预处理任务完成")
self.command_runner.run_grid_commands(
grid_points,
self.config.enable_grid_division,
self.mode
)
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"E:\datasets\UAV\1815\images",
output_dir=r"test",
cluster_eps=0.01,
cluster_min_samples=5,
filter_distance_threshold=0.001,
filter_min_neighbors=6,
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_time_threshold=timedelta(minutes=5),
grid_overlap=0.05,
grid_size=500,
enable_filter=True,
enable_grid_division=True,
enable_visualization=True,
enable_copy_images=True,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()