ODM_pro/odm_preprocess.py
2024-12-20 21:10:22 +08:00

236 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import shutil
from dataclasses import dataclass
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from preprocess.cluster import GPSCluster
from preprocess.command_runner import CommandRunner
from preprocess.gps_extractor import GPSExtractor
from preprocess.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider
from preprocess.logger import setup_logger
from preprocess.time_filter import TimeFilter
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
eps: float = 0.01
min_samples: int = 5
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10
filter_distance_threshold: float = 0.001
filter_min_neighbors: int = 6
grid_overlap: float = 0.05
grid_size: float = 500
enable_filter: bool = True
enable_grid_division: bool = True
enable_visualization: bool = True
enable_copy_images: bool = True
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
self.logger = setup_logger(config.output_dir)
self.gps_points = []
self.command_runner = CommandRunner(config.output_dir)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
return self.gps_points
def time_filter(self) -> pd.DataFrame:
"""时间过滤"""
self.logger.info("开始时间过滤")
time_filter = TimeFilter(self.config.output_dir)
self.gps_points = time_filter.filter_by_date(self.gps_points)
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
# TODO 添加聚类参数
def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类")
# 创建聚类器并执行聚类
clusterer = GPSCluster(self.gps_points, output_dir=self.config.output_dir)
# 获取主要类别的点
self.gps_points = clusterer.get_main_cluster()
# 获取统计信息并记录
stats = clusterer.get_cluster_stats()
self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}"
)
return self.gps_points
# TODO 过滤密集点算法需要改进
def filter_points(self) -> pd.DataFrame:
"""过滤GPS点"""
if not self.config.enable_filter:
return self.gps_points
self.logger.info("开始过滤GPS点")
filter = GPSFilter(self.config.output_dir)
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
)
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
self.logger.info(
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
)
self.gps_points = filter.filter_dense_points(
self.gps_points,
grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold,
)
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
def divide_grids(self) -> Dict[int, pd.DataFrame]:
"""划分网格"""
if not self.config.enable_grid_division:
return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(overlap=self.config.grid_overlap)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size
)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
return grid_points
def copy_images(self, grid_points: Dict[int, pd.DataFrame]):
"""复制图像到目标文件夹"""
if not self.config.enable_copy_images:
return
self.logger.info("开始复制图像文件")
for grid_idx, points in grid_points.items():
if self.config.enable_grid_division:
output_dir = os.path.join(
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
)
else:
output_dir = os.path.join(self.config.output_dir, "project", "images")
os.makedirs(output_dir, exist_ok=True)
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
def visualize_results(self):
"""可视化处理结果"""
if not self.config.enable_visualization:
return
self.logger.info("开始生成可视化结果")
extractor = GPSExtractor(self.config.image_dir)
original_points_df = extractor.extract_all_gps()
# 读取被过滤的图片列表
with open(
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
) as file:
filtered_files = [line.strip() for line in file if line.strip()]
# 创建一个新的图形
plt.figure(figsize=(20, 16))
# 绘制所有原始点
plt.scatter(
original_points_df["lon"],
original_points_df["lat"],
color="blue",
label="Original Points",
alpha=0.6,
)
# 绘制被过滤的点
filtered_points_df = original_points_df[
original_points_df["file"].isin(filtered_files)
]
plt.scatter(
filtered_points_df["lon"],
filtered_points_df["lat"],
color="red",
label="Filtered Points",
alpha=0.6,
)
# 设置图形属性
plt.title("GPS Coordinates of Images", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend()
# 保存图形
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
plt.close()
self.logger.info("预处理结果图已保存")
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.cluster()
# self.time_filter()
# self.filter_points()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.visualize_results()
# self.logger.info("预处理任务完成")
# self.command_runner.run_grid_commands(
# grid_points,
# self.config.enable_grid_division
# )
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r"E:\湖南省第二测绘院\11-06-项目移交文件(王辉给)\无人机二三维节点扩容生产影像\影像数据\199\code\images",
output_dir=r"test",
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_distance_threshold=0.001,
filter_min_neighbors=6,
grid_overlap=0.05,
grid_size=500,
enable_filter=True,
enable_grid_division=True,
enable_visualization=True,
enable_copy_images=True,
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()