115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
|
import os
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import pandas as pd
|
|||
|
import logging
|
|||
|
from typing import List, Optional
|
|||
|
|
|||
|
class FilterVisualizer:
|
|||
|
"""过滤结果可视化器"""
|
|||
|
|
|||
|
def __init__(self, output_dir: str):
|
|||
|
self.output_dir = output_dir
|
|||
|
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
|
|||
|
|
|||
|
def visualize_filter_step(self,
|
|||
|
current_points: pd.DataFrame,
|
|||
|
previous_points: pd.DataFrame,
|
|||
|
step_name: str,
|
|||
|
save_name: Optional[str] = None):
|
|||
|
"""
|
|||
|
可视化单个过滤步骤的结果
|
|||
|
|
|||
|
Args:
|
|||
|
current_points: 当前步骤后的点
|
|||
|
previous_points: 上一步骤的点
|
|||
|
step_name: 步骤名称
|
|||
|
save_name: 保存文件名,默认为step_name
|
|||
|
"""
|
|||
|
self.logger.info(f"开始生成{step_name}的可视化结果")
|
|||
|
|
|||
|
# 找出被过滤掉的点
|
|||
|
filtered_files = set(previous_points['file']) - set(current_points['file'])
|
|||
|
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
|
|||
|
|
|||
|
# 创建图形
|
|||
|
plt.figure(figsize=(20, 16))
|
|||
|
|
|||
|
# 绘制保留的点
|
|||
|
plt.scatter(current_points['lon'], current_points['lat'],
|
|||
|
color='blue', label='Retained Points',
|
|||
|
alpha=0.6, s=50)
|
|||
|
|
|||
|
# 绘制被过滤的点
|
|||
|
plt.scatter(filtered_points['lon'], filtered_points['lat'],
|
|||
|
color='red', marker='x', label='Filtered Points',
|
|||
|
alpha=0.6, s=100)
|
|||
|
|
|||
|
# 设置图形属性
|
|||
|
plt.title(f"GPS Points After {step_name}", fontsize=14)
|
|||
|
plt.xlabel("Longitude", fontsize=12)
|
|||
|
plt.ylabel("Latitude", fontsize=12)
|
|||
|
plt.grid(True)
|
|||
|
|
|||
|
# 添加统计信息
|
|||
|
stats_text = (f"Total Points: {len(previous_points)}\n"
|
|||
|
f"Filtered Points: {len(filtered_points)}\n"
|
|||
|
f"Remaining Points: {len(current_points)}")
|
|||
|
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
|
|||
|
bbox=dict(facecolor='white', alpha=0.8))
|
|||
|
|
|||
|
plt.legend()
|
|||
|
|
|||
|
# 保存图形
|
|||
|
save_name = save_name or step_name.lower().replace(' ', '_')
|
|||
|
save_path = os.path.join(self.output_dir, f'filter_{save_name}.png')
|
|||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|||
|
plt.close()
|
|||
|
|
|||
|
self.logger.info(f"{step_name}过滤可视化结果已保存至 {save_path}")
|
|||
|
|
|||
|
def visualize_all_steps(self,
|
|||
|
points_history: List[pd.DataFrame],
|
|||
|
step_names: List[str]):
|
|||
|
"""
|
|||
|
可视化所有过滤步骤的结果
|
|||
|
|
|||
|
Args:
|
|||
|
points_history: 每个步骤的点数据列表
|
|||
|
step_names: 步骤名称列表
|
|||
|
"""
|
|||
|
if len(points_history) != len(step_names):
|
|||
|
raise ValueError("点数据列表和步骤名称列表长度不匹配")
|
|||
|
|
|||
|
plt.figure(figsize=(20, 16))
|
|||
|
|
|||
|
# 使用不同的颜色表示不同步骤
|
|||
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(points_history)))
|
|||
|
|
|||
|
# 绘制每个步骤的点
|
|||
|
for i, (points, color) in enumerate(zip(points_history, colors)):
|
|||
|
plt.scatter(points['lon'], points['lat'],
|
|||
|
color=color, label=f'After {step_names[i]}',
|
|||
|
alpha=0.6, s=50)
|
|||
|
|
|||
|
# 设置图形属性
|
|||
|
plt.title("GPS Points Filtering Process", fontsize=14)
|
|||
|
plt.xlabel("Longitude", fontsize=12)
|
|||
|
plt.ylabel("Latitude", fontsize=12)
|
|||
|
plt.grid(True)
|
|||
|
|
|||
|
# 添加统计<E7BB9F><E8AEA1>息
|
|||
|
stats_text = "Points Count:\n" + "\n".join(
|
|||
|
f"{step_names[i]}: {len(points)}"
|
|||
|
for i, points in enumerate(points_history)
|
|||
|
)
|
|||
|
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
|
|||
|
bbox=dict(facecolor='white', alpha=0.8))
|
|||
|
|
|||
|
plt.legend()
|
|||
|
|
|||
|
# 保存图形
|
|||
|
save_path = os.path.join(self.output_dir, 'filter_all_steps.png')
|
|||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|||
|
plt.close()
|
|||
|
|
|||
|
self.logger.info(f"所有过滤步骤的可视化结果已保存至 {save_path}")
|