ODM_pro/utils/visualizer.py
2024-12-22 20:19:12 +08:00

115 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import List, Optional
class FilterVisualizer:
"""过滤结果可视化器"""
def __init__(self, output_dir: str):
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
def visualize_filter_step(self,
current_points: pd.DataFrame,
previous_points: pd.DataFrame,
step_name: str,
save_name: Optional[str] = None):
"""
可视化单个过滤步骤的结果
Args:
current_points: 当前步骤后的点
previous_points: 上一步骤的点
step_name: 步骤名称
save_name: 保存文件名默认为step_name
"""
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
# 创建图形
plt.figure(figsize=(20, 16))
# 绘制保留的点
plt.scatter(current_points['lon'], current_points['lat'],
color='blue', label='Retained Points',
alpha=0.6, s=50)
# 绘制被过滤的点
plt.scatter(filtered_points['lon'], filtered_points['lat'],
color='red', marker='x', label='Filtered Points',
alpha=0.6, s=100)
# 设置图形属性
plt.title(f"GPS Points After {step_name}", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 添加统计信息
stats_text = (f"Total Points: {len(previous_points)}\n"
f"Filtered Points: {len(filtered_points)}\n"
f"Remaining Points: {len(current_points)}")
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
plt.legend()
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(self.output_dir, f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(f"{step_name}过滤可视化结果已保存至 {save_path}")
def visualize_all_steps(self,
points_history: List[pd.DataFrame],
step_names: List[str]):
"""
可视化所有过滤步骤的结果
Args:
points_history: 每个步骤的点数据列表
step_names: 步骤名称列表
"""
if len(points_history) != len(step_names):
raise ValueError("点数据列表和步骤名称列表长度不匹配")
plt.figure(figsize=(20, 16))
# 使用不同的颜色表示不同步骤
colors = plt.cm.rainbow(np.linspace(0, 1, len(points_history)))
# 绘制每个步骤的点
for i, (points, color) in enumerate(zip(points_history, colors)):
plt.scatter(points['lon'], points['lat'],
color=color, label=f'After {step_names[i]}',
alpha=0.6, s=50)
# 设置图形属性
plt.title("GPS Points Filtering Process", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 添加统计<E7BB9F><E8AEA1>
stats_text = "Points Count:\n" + "\n".join(
f"{step_names[i]}: {len(points)}"
for i, points in enumerate(points_history)
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
plt.legend()
# 保存图形
save_path = os.path.join(self.output_dir, 'filter_all_steps.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(f"所有过滤步骤的可视化结果已保存至 {save_path}")