ODM_pro/odm_preprocess.py
2024-12-19 12:09:48 +08:00

204 lines
7.5 KiB
Python

from preprocess.gps_extractor import GPSExtractor
from preprocess.time_filter import TimeFilter
from preprocess.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider
from preprocess.logger import setup_logger
from preprocess.command_runner import CommandRunner
import os
import pandas as pd
import shutil
import matplotlib.pyplot as plt
from typing import List, Dict, Optional
from dataclasses import dataclass
from tqdm import tqdm
import subprocess
from concurrent.futures import ThreadPoolExecutor
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10
filter_distance_threshold: float = 0.001
filter_min_neighbors: int = 6
grid_overlap: float = 0.05
grid_size: float = 250
enable_filter: bool = True
enable_grid_division: bool = True
enable_visualization: bool = True
enable_copy_images: bool = True
class ImagePreprocessor:
def __init__(self, config: PreprocessConfig):
self.config = config
self.logger = setup_logger(config.output_dir)
self.gps_points = []
self.command_runner = CommandRunner(config.output_dir)
def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据"""
self.logger.info("开始提取GPS数据")
extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
return self.gps_points
def time_filter(self) -> pd.DataFrame:
"""时间过滤"""
self.logger.info("开始时间过滤")
time_filter = TimeFilter(self.config.output_dir)
self.gps_points = time_filter.filter_by_date(self.gps_points)
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
# TODO 过滤密集点算法需要改进
def filter_points(self) -> pd.DataFrame:
"""过滤GPS点"""
if not self.config.enable_filter:
return self.gps_points
self.logger.info("开始过滤GPS点")
filter = GPSFilter(self.config.output_dir)
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})")
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors
)
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
self.logger.info(
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})")
self.gps_points = filter.filter_dense_points(
self.gps_points,
grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold
)
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
def divide_grids(self) -> Dict[int, pd.DataFrame]:
"""划分网格"""
if not self.config.enable_grid_division:
return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(overlap=self.config.grid_overlap)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
return grid_points
def copy_images(self, grid_points: Dict[int, pd.DataFrame]):
"""复制图像到目标文件夹"""
if not self.config.enable_copy_images:
return
self.logger.info("开始复制图像文件")
for grid_idx, points in grid_points.items():
if self.config.enable_grid_division:
output_dir = os.path.join(
self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images')
else:
output_dir = os.path.join(
self.config.output_dir, 'project', 'images')
os.makedirs(output_dir, exist_ok=True)
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
src = os.path.join(self.config.image_dir, point['file'])
dst = os.path.join(output_dir, point['file'])
shutil.copy(src, dst)
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
def visualize_results(self):
"""可视化处理结果"""
if not self.config.enable_visualization:
return
self.logger.info("开始生成可视化结果")
extractor = GPSExtractor(self.config.image_dir)
original_points_df = extractor.extract_all_gps()
# 读取被过滤的图片列表
with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file:
filtered_files = [line.strip() for line in file if line.strip()]
# 创建一个新的图形
plt.figure(figsize=(20, 16))
# 绘制所有原始点
plt.scatter(original_points_df['lon'],
original_points_df['lat'],
color='blue',
label="Original Points",
alpha=0.6)
# 绘制被过滤的点
filtered_points_df = original_points_df[original_points_df['file'].isin(
filtered_files)]
plt.scatter(filtered_points_df['lon'],
filtered_points_df['lat'],
color="red",
label="Filtered Points",
alpha=0.6)
# 设置图形属性
plt.title("GPS Coordinates of Images", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend()
# 保存图形
plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png'))
plt.close()
self.logger.info("预处理结果图已保存")
def process(self):
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.time_filter()
self.filter_points()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.visualize_results()
self.logger.info("预处理任务完成")
self.command_runner.run_grid_commands(
grid_points,
self.config.enable_grid_division
)
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == '__main__':
# 创建配置
config = PreprocessConfig(
image_dir=r'E:\datasets\UAV\1815\images',
output_dir=r'E:\datasets\UAV\1815\output',
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_distance_threshold=0.001,
filter_min_neighbors=6,
grid_overlap=0.05,
enable_filter=True,
enable_grid_division=True,
enable_visualization=True,
enable_copy_images=True
)
# 创建处理器并执行
processor = ImagePreprocessor(config)
processor.process()