每过滤一次就可视化

This commit is contained in:
龙澳 2024-12-22 20:19:12 +08:00
parent b1ffb863b3
commit c730543985
2 changed files with 158 additions and 148 deletions

View File

@ -15,6 +15,7 @@ from filter.gps_filter import GPSFilter
from utils.grid_divider import GridDivider from utils.grid_divider import GridDivider
from utils.logger import setup_logger from utils.logger import setup_logger
from filter.time_group_overlap_filter import TimeGroupOverlapFilter from filter.time_group_overlap_filter import TimeGroupOverlapFilter
from utils.visualizer import FilterVisualizer
@dataclass @dataclass
@ -52,9 +53,12 @@ class ImagePreprocessor:
def __init__(self, config: PreprocessConfig): def __init__(self, config: PreprocessConfig):
self.config = config self.config = config
self.logger = setup_logger(config.output_dir) self.logger = setup_logger(config.output_dir)
self.gps_points = [] self.gps_points = None
self.command_runner = CommandRunner( self.command_runner = CommandRunner(config.output_dir, mode=config.mode)
config.output_dir, mode=config.mode) self.visualizer = FilterVisualizer(config.output_dir)
# 用于存储每个步骤的点数据
self.points_history = []
self.step_names = []
def extract_gps(self) -> pd.DataFrame: def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据""" """提取GPS数据"""
@ -62,24 +66,32 @@ class ImagePreprocessor:
extractor = GPSExtractor(self.config.image_dir) extractor = GPSExtractor(self.config.image_dir)
self.gps_points = extractor.extract_all_gps() self.gps_points = extractor.extract_all_gps()
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
# 记录初始状态
self.points_history.append(self.gps_points.copy())
self.step_names.append("Initial")
return self.gps_points return self.gps_points
def cluster(self) -> pd.DataFrame: def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类""" """使用DBSCAN对GPS点进行聚类"""
self.logger.info("开始聚类") self.logger.info("开始聚类")
# 创建聚类器并执行聚类 previous_points = self.gps_points.copy()
clusterer = GPSCluster( clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir, self.gps_points, output_dir=self.config.output_dir,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
# 获取主要类别的点
self.clustered_points = clusterer.fit() self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_main_cluster(self.clustered_points) self.gps_points = clusterer.get_main_cluster(self.clustered_points)
# 获取统计信息并记录
stats = clusterer.get_cluster_stats(self.clustered_points) # 可视化聚类结果
self.logger.info( if self.config.enable_visualization:
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," self.visualizer.visualize_filter_step(
f"噪声点 {stats['noise_points']}" self.gps_points, previous_points, "Clustering")
)
# 记录这一步的结果
self.points_history.append(self.gps_points.copy())
self.step_names.append("Clustering")
return self.gps_points
def filter_time_group_overlap(self) -> pd.DataFrame: def filter_time_group_overlap(self) -> pd.DataFrame:
"""过滤重叠的时间组""" """过滤重叠的时间组"""
@ -87,6 +99,8 @@ class ImagePreprocessor:
return self.gps_points return self.gps_points
self.logger.info("开始过滤重叠时间组") self.logger.info("开始过滤重叠时间组")
previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter( filter = TimeGroupOverlapFilter(
self.config.image_dir, self.config.image_dir,
self.config.output_dir, self.config.output_dir,
@ -97,150 +111,31 @@ class ImagePreprocessor:
time_threshold=self.config.time_group_interval time_threshold=self.config.time_group_interval
) )
# 更新GPS点数据移除被删除的图像 self.gps_points = self.gps_points[~self.gps_points['file'].isin(deleted_files)]
self.gps_points = self.gps_points[~self.gps_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点")
# 可视化过滤结果
if self.config.enable_visualization:
self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "Time Group Overlap")
# 记录这一步的结果
self.points_history.append(self.gps_points.copy())
self.step_names.append("Time Group Overlap")
return self.gps_points return self.gps_points
# TODO 过滤算法还需要更新
def filter_points(self) -> pd.DataFrame:
"""过滤GPS点"""
if not self.config.enable_filter:
return self.gps_points
self.logger.info("开始过滤GPS点")
filter = GPSFilter(self.config.output_dir)
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
)
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors,
)
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
self.logger.info(
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
)
self.gps_points = filter.filter_dense_points(
self.gps_points,
grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold,
time_threshold=self.config.filter_time_threshold,
)
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
def divide_grids(self) -> Dict[int, pd.DataFrame]:
"""划分网格"""
if not self.config.enable_grid_division:
return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(overlap=self.config.grid_overlap)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size
)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
return grid_points
def copy_images(self, grid_points: Dict[int, pd.DataFrame]):
"""复制图像到目标文件夹"""
if not self.config.enable_copy_images:
return
self.logger.info("开始复制图像文件")
for grid_idx, points in grid_points.items():
if self.config.enable_grid_division:
output_dir = os.path.join(
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
)
else:
output_dir = os.path.join(
self.config.output_dir, "project", "images")
os.makedirs(output_dir, exist_ok=True)
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
def visualize_results(self):
"""可视化处理结果"""
if not self.config.enable_visualization:
return
self.logger.info("开始生成可视化结果")
extractor = GPSExtractor(self.config.image_dir)
original_points_df = extractor.extract_all_gps()
# 读取被过滤的图片列表
with open(
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
) as file:
filtered_files = [line.strip() for line in file if line.strip()]
# 创建一个新的图形
plt.figure(figsize=(20, 16))
# 绘制所有原始点
plt.scatter(
original_points_df["lon"],
original_points_df["lat"],
color="blue",
label="Original Points",
alpha=0.6,
)
# 绘制被过滤的点
filtered_points_df = original_points_df[
original_points_df["file"].isin(filtered_files)
]
plt.scatter(
filtered_points_df["lon"],
filtered_points_df["lat"],
color="red",
marker="x",
label="Filtered Points",
alpha=0.6,
)
# 设置图形属性
plt.title("GPS Coordinates of Images", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend()
# 保存图形
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
plt.close()
self.logger.info("预处理结果图已保存")
def process(self): def process(self):
"""执行完整的预处理流程""" """执行完整的预处理流程"""
try: try:
self.extract_gps() self.extract_gps()
self.cluster() self.cluster()
self.filter_time_group_overlap() self.filter_time_group_overlap()
# self.filter_points()
# grid_points = self.divide_grids() # 在处理结束时生成所有步骤的可视化
# self.copy_images(grid_points) if self.config.enable_visualization:
# self.visualize_results() self.visualizer.visualize_all_steps(
# self.logger.info("预处理任务完成") self.points_history, self.step_names)
# self.command_runner.run_grid_commands(
# grid_points, self.logger.info("预处理任务完成")
# self.config.enable_grid_division,
# )
# TODO 拼图
except Exception as e: except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise raise

115
utils/visualizer.py Normal file
View File

@ -0,0 +1,115 @@
import os
import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import List, Optional
class FilterVisualizer:
"""过滤结果可视化器"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
def visualize_filter_step(self,
current_points: pd.DataFrame,
previous_points: pd.DataFrame,
step_name: str,
save_name: Optional[str] = None):
"""
可视化单个过滤步骤的结果
Args:
current_points: 当前步骤后的点
previous_points: 上一步骤的点
step_name: 步骤名称
save_name: 保存文件名默认为step_name
"""
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
# 创建图形
plt.figure(figsize=(20, 16))
# 绘制保留的点
plt.scatter(current_points['lon'], current_points['lat'],
color='blue', label='Retained Points',
alpha=0.6, s=50)
# 绘制被过滤的点
plt.scatter(filtered_points['lon'], filtered_points['lat'],
color='red', marker='x', label='Filtered Points',
alpha=0.6, s=100)
# 设置图形属性
plt.title(f"GPS Points After {step_name}", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 添加统计信息
stats_text = (f"Total Points: {len(previous_points)}\n"
f"Filtered Points: {len(filtered_points)}\n"
f"Remaining Points: {len(current_points)}")
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
plt.legend()
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(self.output_dir, f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(f"{step_name}过滤可视化结果已保存至 {save_path}")
def visualize_all_steps(self,
points_history: List[pd.DataFrame],
step_names: List[str]):
"""
可视化所有过滤步骤的结果
Args:
points_history: 每个步骤的点数据列表
step_names: 步骤名称列表
"""
if len(points_history) != len(step_names):
raise ValueError("点数据列表和步骤名称列表长度不匹配")
plt.figure(figsize=(20, 16))
# 使用不同的颜色表示不同步骤
colors = plt.cm.rainbow(np.linspace(0, 1, len(points_history)))
# 绘制每个步骤的点
for i, (points, color) in enumerate(zip(points_history, colors)):
plt.scatter(points['lon'], points['lat'],
color=color, label=f'After {step_names[i]}',
alpha=0.6, s=50)
# 设置图形属性
plt.title("GPS Points Filtering Process", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 添加统计<E7BB9F><E8AEA1>
stats_text = "Points Count:\n" + "\n".join(
f"{step_names[i]}: {len(points)}"
for i, points in enumerate(points_history)
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
plt.legend()
# 保存图形
save_path = os.path.join(self.output_dir, 'filter_all_steps.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(f"所有过滤步骤的可视化结果已保存至 {save_path}")