UAV/utils/visualizer.py
2025-04-12 22:48:07 +08:00

153 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import matplotlib.pyplot as plt
import pandas as pd
import logging
from typing import Optional
from pyproj import Transformer
class FilterVisualizer:
"""过滤结果可视化器"""
def __init__(self, output_dir: str):
"""
初始化可视化器
Args:
output_dir: 输出目录路径
"""
self.output_dir = output_dir
self.logger = logging.getLogger('UAV_Preprocess.Visualizer')
# 创建坐标转换器
self.transformer = Transformer.from_crs(
"EPSG:4326", # WGS84经纬度坐标系
"EPSG:32649", # UTM49N
always_xy=True
)
def _convert_to_utm(self, lon: pd.Series, lat: pd.Series) -> tuple:
"""
将经纬度坐标转换为UTM坐标
Args:
lon: 经度序列
lat: 纬度序列
Returns:
tuple: (x坐标, y坐标)
"""
return self.transformer.transform(lon, lat)
def visualize_filter_step(self,
current_points: pd.DataFrame,
previous_points: pd.DataFrame,
step_name: str,
save_name: Optional[str] = None):
"""
可视化单个过滤步骤的结果
Args:
current_points: 当前步骤后的点
previous_points: 上一步骤的点
step_name: 步骤名称
save_name: 保存文件名默认为step_name
"""
self.logger.info(f"开始生成{step_name}的可视化结果")
# 找出被过滤掉的点
filtered_files = set(
previous_points['file']) - set(current_points['file'])
filtered_points = previous_points[previous_points['file'].isin(
filtered_files)]
# 转换坐标到UTM
current_x, current_y = self._convert_to_utm(
current_points['lon'], current_points['lat'])
filtered_x, filtered_y = self._convert_to_utm(
filtered_points['lon'], filtered_points['lat'])
# 创建图形
plt.rcParams['font.sans-serif'] = ['SimHei'] # 黑体
plt.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(20, 20))
# 绘制保留的点
plt.scatter(current_x, current_y,
color='blue', label='保留的点',
alpha=0.6, s=5)
# 绘制被过滤的点
if not filtered_points.empty:
plt.scatter(filtered_x, filtered_y,
color='red', marker='x', label='过滤的点')
# 设置图形属性
plt.title(f"{step_name}后的GPS点\n"
f"(过滤: {len(filtered_points)}, 保留: {len(current_points)})",
fontsize=14)
plt.xlabel("东向坐标 (米)", fontsize=12)
plt.ylabel("北向坐标 (米)", fontsize=12)
plt.grid(True)
plt.axis('equal')
# 添加统计信息
stats_text = (
f"原始点数: {len(previous_points)}\n"
f"过滤点数: {len(filtered_points)}\n"
f"保留点数: {len(current_points)}\n"
f"过滤率: {len(filtered_points)/len(previous_points)*100:.1f}%"
)
plt.figtext(0.02, 0.02, stats_text, fontsize=10,
bbox=dict(facecolor='white', alpha=0.8))
# 添加图例
plt.legend(loc='upper right', fontsize=10)
# 调整布局
plt.tight_layout()
# 保存图形
save_name = save_name or step_name.lower().replace(' ', '_')
save_path = os.path.join(
self.output_dir, 'filter_imgs', f'filter_{save_name}.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
self.logger.info(
f"{step_name}过滤可视化结果已保存至 {save_path}\n"
f"过滤掉 {len(filtered_points)} 个点,"
f"保留 {len(current_points)} 个点,"
f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%"
)
if __name__ == '__main__':
# 测试代码
import numpy as np
from datetime import datetime
# 创建测试数据
np.random.seed(42)
n_points = 1000
# 生成随机点
test_data = pd.DataFrame({
'lon': np.random.uniform(120, 121, n_points),
'lat': np.random.uniform(30, 31, n_points),
'file': [f'img_{i}.jpg' for i in range(n_points)],
'date': [datetime.now() for _ in range(n_points)]
})
# 随机选择点作为过滤后的结果
filtered_data = test_data.sample(n=800)
# 测试可视化
visualizer = FilterVisualizer('test_output')
os.makedirs('test_output', exist_ok=True)
visualizer.visualize_filter_step(
filtered_data,
test_data,
"Test Filter"
)