每过滤一次就可视化
This commit is contained in:
parent
b1ffb863b3
commit
c730543985
@ -15,6 +15,7 @@ from filter.gps_filter import GPSFilter
|
||||
from utils.grid_divider import GridDivider
|
||||
from utils.logger import setup_logger
|
||||
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
|
||||
from utils.visualizer import FilterVisualizer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -52,9 +53,12 @@ class ImagePreprocessor:
|
||||
def __init__(self, config: PreprocessConfig):
|
||||
self.config = config
|
||||
self.logger = setup_logger(config.output_dir)
|
||||
self.gps_points = []
|
||||
self.command_runner = CommandRunner(
|
||||
config.output_dir, mode=config.mode)
|
||||
self.gps_points = None
|
||||
self.command_runner = CommandRunner(config.output_dir, mode=config.mode)
|
||||
self.visualizer = FilterVisualizer(config.output_dir)
|
||||
# 用于存储每个步骤的点数据
|
||||
self.points_history = []
|
||||
self.step_names = []
|
||||
|
||||
def extract_gps(self) -> pd.DataFrame:
|
||||
"""提取GPS数据"""
|
||||
@ -62,24 +66,32 @@ class ImagePreprocessor:
|
||||
extractor = GPSExtractor(self.config.image_dir)
|
||||
self.gps_points = extractor.extract_all_gps()
|
||||
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
|
||||
|
||||
# 记录初始状态
|
||||
self.points_history.append(self.gps_points.copy())
|
||||
self.step_names.append("Initial")
|
||||
return self.gps_points
|
||||
|
||||
def cluster(self) -> pd.DataFrame:
|
||||
"""使用DBSCAN对GPS点进行聚类,只保留最大的类"""
|
||||
"""使用DBSCAN对GPS点进行聚类"""
|
||||
self.logger.info("开始聚类")
|
||||
# 创建聚类器并执行聚类
|
||||
previous_points = self.gps_points.copy()
|
||||
|
||||
clusterer = GPSCluster(
|
||||
self.gps_points, output_dir=self.config.output_dir,
|
||||
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
|
||||
# 获取主要类别的点
|
||||
self.clustered_points = clusterer.fit()
|
||||
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
|
||||
# 获取统计信息并记录
|
||||
stats = clusterer.get_cluster_stats(self.clustered_points)
|
||||
self.logger.info(
|
||||
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
|
||||
f"噪声点 {stats['noise_points']} 个"
|
||||
)
|
||||
|
||||
# 可视化聚类结果
|
||||
if self.config.enable_visualization:
|
||||
self.visualizer.visualize_filter_step(
|
||||
self.gps_points, previous_points, "Clustering")
|
||||
|
||||
# 记录这一步的结果
|
||||
self.points_history.append(self.gps_points.copy())
|
||||
self.step_names.append("Clustering")
|
||||
return self.gps_points
|
||||
|
||||
def filter_time_group_overlap(self) -> pd.DataFrame:
|
||||
"""过滤重叠的时间组"""
|
||||
@ -87,6 +99,8 @@ class ImagePreprocessor:
|
||||
return self.gps_points
|
||||
|
||||
self.logger.info("开始过滤重叠时间组")
|
||||
previous_points = self.gps_points.copy()
|
||||
|
||||
filter = TimeGroupOverlapFilter(
|
||||
self.config.image_dir,
|
||||
self.config.output_dir,
|
||||
@ -97,150 +111,31 @@ class ImagePreprocessor:
|
||||
time_threshold=self.config.time_group_interval
|
||||
)
|
||||
|
||||
# 更新GPS点数据,移除被删除的图像
|
||||
self.gps_points = self.gps_points[~self.gps_points['file'].isin(
|
||||
deleted_files)]
|
||||
self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
self.gps_points = self.gps_points[~self.gps_points['file'].isin(deleted_files)]
|
||||
|
||||
# 可视化过滤结果
|
||||
if self.config.enable_visualization:
|
||||
self.visualizer.visualize_filter_step(
|
||||
self.gps_points, previous_points, "Time Group Overlap")
|
||||
|
||||
# 记录这一步的结果
|
||||
self.points_history.append(self.gps_points.copy())
|
||||
self.step_names.append("Time Group Overlap")
|
||||
return self.gps_points
|
||||
|
||||
# TODO 过滤算法还需要更新
|
||||
def filter_points(self) -> pd.DataFrame:
|
||||
"""过滤GPS点"""
|
||||
if not self.config.enable_filter:
|
||||
return self.gps_points
|
||||
|
||||
self.logger.info("开始过滤GPS点")
|
||||
filter = GPSFilter(self.config.output_dir)
|
||||
|
||||
self.logger.info(
|
||||
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
|
||||
)
|
||||
self.gps_points = filter.filter_isolated_points(
|
||||
self.gps_points,
|
||||
self.config.filter_distance_threshold,
|
||||
self.config.filter_min_neighbors,
|
||||
)
|
||||
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
|
||||
self.logger.info(
|
||||
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
|
||||
)
|
||||
self.gps_points = filter.filter_dense_points(
|
||||
self.gps_points,
|
||||
grid_size=self.config.filter_grid_size,
|
||||
distance_threshold=self.config.filter_dense_distance_threshold,
|
||||
time_threshold=self.config.filter_time_threshold,
|
||||
)
|
||||
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
return self.gps_points
|
||||
|
||||
def divide_grids(self) -> Dict[int, pd.DataFrame]:
|
||||
"""划分网格"""
|
||||
if not self.config.enable_grid_division:
|
||||
return {0: self.gps_points} # 不划分网格时,所有点放在一个网格中
|
||||
|
||||
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
|
||||
grid_divider = GridDivider(overlap=self.config.grid_overlap)
|
||||
grids = grid_divider.divide_grids(
|
||||
self.gps_points, grid_size=self.config.grid_size
|
||||
)
|
||||
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
|
||||
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
|
||||
return grid_points
|
||||
|
||||
def copy_images(self, grid_points: Dict[int, pd.DataFrame]):
|
||||
"""复制图像到目标文件夹"""
|
||||
if not self.config.enable_copy_images:
|
||||
return
|
||||
|
||||
self.logger.info("开始复制图像文件")
|
||||
|
||||
for grid_idx, points in grid_points.items():
|
||||
if self.config.enable_grid_division:
|
||||
output_dir = os.path.join(
|
||||
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
|
||||
)
|
||||
else:
|
||||
output_dir = os.path.join(
|
||||
self.config.output_dir, "project", "images")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
|
||||
src = os.path.join(self.config.image_dir, point["file"])
|
||||
dst = os.path.join(output_dir, point["file"])
|
||||
shutil.copy(src, dst)
|
||||
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
|
||||
|
||||
def visualize_results(self):
|
||||
"""可视化处理结果"""
|
||||
if not self.config.enable_visualization:
|
||||
return
|
||||
|
||||
self.logger.info("开始生成可视化结果")
|
||||
extractor = GPSExtractor(self.config.image_dir)
|
||||
original_points_df = extractor.extract_all_gps()
|
||||
|
||||
# 读取被过滤的图片列表
|
||||
with open(
|
||||
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
|
||||
) as file:
|
||||
filtered_files = [line.strip() for line in file if line.strip()]
|
||||
|
||||
# 创建一个新的图形
|
||||
plt.figure(figsize=(20, 16))
|
||||
|
||||
# 绘制所有原始点
|
||||
plt.scatter(
|
||||
original_points_df["lon"],
|
||||
original_points_df["lat"],
|
||||
color="blue",
|
||||
label="Original Points",
|
||||
alpha=0.6,
|
||||
)
|
||||
|
||||
# 绘制被过滤的点
|
||||
filtered_points_df = original_points_df[
|
||||
original_points_df["file"].isin(filtered_files)
|
||||
]
|
||||
plt.scatter(
|
||||
filtered_points_df["lon"],
|
||||
filtered_points_df["lat"],
|
||||
color="red",
|
||||
marker="x",
|
||||
label="Filtered Points",
|
||||
alpha=0.6,
|
||||
)
|
||||
|
||||
# 设置图形属性
|
||||
plt.title("GPS Coordinates of Images", fontsize=14)
|
||||
plt.xlabel("Longitude", fontsize=12)
|
||||
plt.ylabel("Latitude", fontsize=12)
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
|
||||
# 保存图形
|
||||
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
|
||||
plt.close()
|
||||
self.logger.info("预处理结果图已保存")
|
||||
|
||||
def process(self):
|
||||
"""执行完整的预处理流程"""
|
||||
try:
|
||||
self.extract_gps()
|
||||
self.cluster()
|
||||
self.filter_time_group_overlap()
|
||||
# self.filter_points()
|
||||
# grid_points = self.divide_grids()
|
||||
# self.copy_images(grid_points)
|
||||
# self.visualize_results()
|
||||
# self.logger.info("预处理任务完成")
|
||||
# self.command_runner.run_grid_commands(
|
||||
# grid_points,
|
||||
# self.config.enable_grid_division,
|
||||
# )
|
||||
# TODO 拼图
|
||||
|
||||
# 在处理结束时生成所有步骤的可视化
|
||||
if self.config.enable_visualization:
|
||||
self.visualizer.visualize_all_steps(
|
||||
self.points_history, self.step_names)
|
||||
|
||||
self.logger.info("预处理任务完成")
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
115
utils/visualizer.py
Normal file
115
utils/visualizer.py
Normal file
@ -0,0 +1,115 @@
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
class FilterVisualizer:
|
||||
"""过滤结果可视化器"""
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
|
||||
|
||||
def visualize_filter_step(self,
|
||||
current_points: pd.DataFrame,
|
||||
previous_points: pd.DataFrame,
|
||||
step_name: str,
|
||||
save_name: Optional[str] = None):
|
||||
"""
|
||||
可视化单个过滤步骤的结果
|
||||
|
||||
Args:
|
||||
current_points: 当前步骤后的点
|
||||
previous_points: 上一步骤的点
|
||||
step_name: 步骤名称
|
||||
save_name: 保存文件名,默认为step_name
|
||||
"""
|
||||
self.logger.info(f"开始生成{step_name}的可视化结果")
|
||||
|
||||
# 找出被过滤掉的点
|
||||
filtered_files = set(previous_points['file']) - set(current_points['file'])
|
||||
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
|
||||
|
||||
# 创建图形
|
||||
plt.figure(figsize=(20, 16))
|
||||
|
||||
# 绘制保留的点
|
||||
plt.scatter(current_points['lon'], current_points['lat'],
|
||||
color='blue', label='Retained Points',
|
||||
alpha=0.6, s=50)
|
||||
|
||||
# 绘制被过滤的点
|
||||
plt.scatter(filtered_points['lon'], filtered_points['lat'],
|
||||
color='red', marker='x', label='Filtered Points',
|
||||
alpha=0.6, s=100)
|
||||
|
||||
# 设置图形属性
|
||||
plt.title(f"GPS Points After {step_name}", fontsize=14)
|
||||
plt.xlabel("Longitude", fontsize=12)
|
||||
plt.ylabel("Latitude", fontsize=12)
|
||||
plt.grid(True)
|
||||
|
||||
# 添加统计信息
|
||||
stats_text = (f"Total Points: {len(previous_points)}\n"
|
||||
f"Filtered Points: {len(filtered_points)}\n"
|
||||
f"Remaining Points: {len(current_points)}")
|
||||
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
|
||||
bbox=dict(facecolor='white', alpha=0.8))
|
||||
|
||||
plt.legend()
|
||||
|
||||
# 保存图形
|
||||
save_name = save_name or step_name.lower().replace(' ', '_')
|
||||
save_path = os.path.join(self.output_dir, f'filter_{save_name}.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
self.logger.info(f"{step_name}过滤可视化结果已保存至 {save_path}")
|
||||
|
||||
def visualize_all_steps(self,
|
||||
points_history: List[pd.DataFrame],
|
||||
step_names: List[str]):
|
||||
"""
|
||||
可视化所有过滤步骤的结果
|
||||
|
||||
Args:
|
||||
points_history: 每个步骤的点数据列表
|
||||
step_names: 步骤名称列表
|
||||
"""
|
||||
if len(points_history) != len(step_names):
|
||||
raise ValueError("点数据列表和步骤名称列表长度不匹配")
|
||||
|
||||
plt.figure(figsize=(20, 16))
|
||||
|
||||
# 使用不同的颜色表示不同步骤
|
||||
colors = plt.cm.rainbow(np.linspace(0, 1, len(points_history)))
|
||||
|
||||
# 绘制每个步骤的点
|
||||
for i, (points, color) in enumerate(zip(points_history, colors)):
|
||||
plt.scatter(points['lon'], points['lat'],
|
||||
color=color, label=f'After {step_names[i]}',
|
||||
alpha=0.6, s=50)
|
||||
|
||||
# 设置图形属性
|
||||
plt.title("GPS Points Filtering Process", fontsize=14)
|
||||
plt.xlabel("Longitude", fontsize=12)
|
||||
plt.ylabel("Latitude", fontsize=12)
|
||||
plt.grid(True)
|
||||
|
||||
# 添加统计<E7BB9F><E8AEA1>息
|
||||
stats_text = "Points Count:\n" + "\n".join(
|
||||
f"{step_names[i]}: {len(points)}"
|
||||
for i, points in enumerate(points_history)
|
||||
)
|
||||
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
|
||||
bbox=dict(facecolor='white', alpha=0.8))
|
||||
|
||||
plt.legend()
|
||||
|
||||
# 保存图形
|
||||
save_path = os.path.join(self.output_dir, 'filter_all_steps.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
self.logger.info(f"所有过滤步骤的可视化结果已保存至 {save_path}")
|
Loading…
Reference in New Issue
Block a user