UAV/utils/visualizer.py
2024-12-23 11:31:20 +08:00

121 lines
4.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import Optional
class FilterVisualizer:
"""过滤结果可视化器"""
def __init__(self, output_dir: str):
"""
初始化可视化器
Args:
output_dir: 输出目录路径
"""
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
def visualize_filter_step(self,
current_points: pd.DataFrame,
previous_points: pd.DataFrame,
step_name: str,
save_name: Optional[str] = None):
"""
可视化单个过滤步骤的结果
Args:
current_points: 当前步骤后的点
previous_points: 上一步骤的点
step_name: 步骤名称
save_name: 保存文件名默认为step_name
"""
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(filtered_files)]
# 创建图形
plt.figure(figsize=(20, 16))
# 绘制保留的点
plt.scatter(current_points['lon'], current_points['lat'],
color='blue', label='Retained Points',
alpha=0.6, s=50)
# 绘制被过滤的点
if not filtered_points.empty:
plt.scatter(filtered_points['lon'], filtered_points['lat'],
color='red', marker='x', label='Filtered Points',
alpha=0.6, s=100)
# 设置图形属性
plt.title(f"GPS Points After {step_name}\n"
f"(Filtered: {len(filtered_points)}, Retained: {len(current_points)})",
fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
# 添加统计信息
stats_text = (
f"Original Points: {len(previous_points)}\n"
f"Filtered Points: {len(filtered_points)}\n"
f"Remaining Points: {len(current_points)}\n"
f"Filter Rate: {len(filtered_points)/len(previous_points)*100:.1f}%"
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
# 添加图例
plt.legend(loc='upper right', fontsize=10)
# 调整布局
plt.tight_layout()
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(
f"{step_name}过滤可视化结果已保存至 {save_path}\n"
f"过滤掉 {len(filtered_points)} 个点,"
f"保留 {len(current_points)} 个点,"
f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%"
)
if __name__ == '__main__':
# 测试代码
import numpy as np
from datetime import datetime
# 创建测试数据
np.random.seed(42)
n_points = 1000
# 生成随机点
test_data = pd.DataFrame({
'lon': np.random.uniform(120, 121, n_points),
'lat': np.random.uniform(30, 31, n_points),
'file': [f'img_{i}.jpg' for i in range(n_points)],
'date': [datetime.now() for _ in range(n_points)]
})
# 随机选择点作为过滤后的结果
filtered_data = test_data.sample(n=800)
# 测试可视化
visualizer = FilterVisualizer('test_output')
os.makedirs('test_output', exist_ok=True)
visualizer.visualize_filter_step(
filtered_data,
test_data,
"Test Filter"
)