UAV/filter/gps_filter.py
2024-12-23 11:31:20 +08:00

249 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import math
from itertools import combinations
import numpy as np
from scipy.spatial import KDTree
import logging
import pandas as pd
from datetime import datetime, timedelta
class GPSFilter:
"""过滤密集点及孤立点"""
def __init__(self, output_dir):
self.logger = logging.getLogger('UAV_Preprocess.GPSFilter')
@staticmethod
def _haversine(lat1, lon1, lat2, lon2):
"""计算两点之间的地理距离(单位:米)"""
R = 6371000 # 地球平均半径,单位:米
phi1, phi2 = math.radians(lat1), math.radians(lat2)
delta_phi = math.radians(lat2 - lat1)
delta_lambda = math.radians(lon2 - lon1)
a = math.sin(delta_phi / 2) ** 2 + math.cos(phi1) * \
math.cos(phi2) * math.sin(delta_lambda / 2) ** 2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return R * c
@staticmethod
def _assign_to_grid(lat, lon, grid_size, min_lat, min_lon):
"""根据经纬度和网格大小,将点分配到网格"""
grid_x = int((lat - min_lat) // grid_size)
grid_y = int((lon - min_lon) // grid_size)
return grid_x, grid_y
def _get_distances(self, points_df, grid_size):
"""读取图片 GPS 坐标,计算点对之间的距离并排序"""
# 确定经纬度范围
min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max()
min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max()
self.logger.info(
f"经纬度范围:纬度[{min_lat:.6f}, {max_lat:.6f}],纬度范围[{max_lat-min_lat:.6f}]"
f"经度[{min_lon:.6f}, {max_lon:.6f}],经度范围[{max_lon-min_lon:.6f}]")
# 分配到网格
grid_map = {}
for _, row in points_df.iterrows():
grid = self._assign_to_grid(
row['lat'], row['lon'], grid_size, min_lat, min_lon)
if grid not in grid_map:
grid_map[grid] = []
grid_map[grid].append((row['file'], row['lat'], row['lon']))
self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中")
# 在每个网格中计算两两距离并排序
sorted_distances = {}
for grid, images in grid_map.items():
distances = []
for (img1, lat1, lon1), (img2, lat2, lon2) in combinations(images, 2):
dist = self._haversine(lat1, lon1, lat2, lon2)
distances.append((img1, img2, dist))
distances.sort(key=lambda x: x[2]) # 按距离升序排序
sorted_distances[grid] = distances
self.logger.debug(f"网格 {grid} 中计算了 {len(distances)} 个距离对")
return sorted_distances
def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list:
"""根据拍摄时间分组图片
如果相邻两张图片的拍摄时间差超过5分钟则进行切分
Args:
points_df: 包含图片信息的DataFrame必须包含'file''date'
time_threshold: 时间间隔阈值默认5分钟
Returns:
list: 每个元素为时间组内的点数据
"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return [points_df]
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
self.logger.info(
f"有效时间范围: {valid_date_points['date'].min()}{valid_date_points['date'].max()}")
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
# 记录断点信息
break_time = valid_date_points.iloc[idx]['date']
group_start_time = current_group.iloc[0]['date']
group_end_time = current_group.iloc[-1]['date']
self.logger.info(
f"时间组 {len(time_groups)}: {len(current_group)} 个点, "
f"时间范围 [{group_start_time} - {group_end_time}]"
)
self.logger.info(
f"在时间 {break_time} 处发现断点,时间差为 {time_diff}")
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
self.logger.info(
f"时间组 {len(time_groups)}: {len(last_group)} 个点, "
f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]"
)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点")
self.logger.info(f"共分为 {len(time_groups)} 个时间组")
return time_groups
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)):
"""
过滤密集点,先按时间分组,再在每个时间组内过滤。
空时间戳的点不进行过滤。
Args:
points_df: 点数据
grid_size: 网格大小
distance_threshold: 距离阈值(米)
time_interval: 时间间隔(秒)
"""
self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, "
f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)")
# 按时间分组
time_groups = self._group_by_time(points_df, time_threshold)
# 存储所有要删除的图片
all_to_del_imgs = []
# 对每个时间组进行密集点过滤
for group_idx, group_points in enumerate(time_groups):
# 检查是否为空时间戳组(最后一组)
if group_idx == len(time_groups) - 1 and group_points['date'].isna().any():
self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)")
continue
self.logger.info(
f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)")
# 计算该组内的点间距离
sorted_distances = self._get_distances(group_points, grid_size)
group_to_del_imgs = []
# 在每个网格中过滤密集点
for grid, distances in sorted_distances.items():
grid_del_count = 0
while distances:
candidate_img1, candidate_img2, dist = distances[0]
if dist < distance_threshold:
distances.pop(0)
# 获取候选图片的其他最短距离
candidate_img1_dist = None
candidate_img2_dist = None
for distance in distances:
if candidate_img1 in distance:
candidate_img1_dist = distance[2]
break
for distance in distances:
if candidate_img2 in distance:
candidate_img2_dist = distance[2]
break
# 选择要删除的点
if candidate_img1_dist and candidate_img2_dist:
to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2
group_to_del_imgs.append(to_del_img)
grid_del_count += 1
self.logger.debug(
f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)")
distances = [
d for d in distances if to_del_img not in d]
else:
break
if grid_del_count > 0:
self.logger.info(
f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点")
all_to_del_imgs.extend(group_to_del_imgs)
self.logger.info(
f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点")
# 过滤数据
filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)]
self.logger.info(
f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df
def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6):
"""过滤孤立点"""
self.logger.info(
f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})")
coords = points_df[['lat', 'lon']].values
kdtree = KDTree(coords)
neighbors_count = [len(kdtree.query_ball_point(
coord, threshold_distance)) for coord in coords]
isolated_points = []
for i, (_, row) in enumerate(points_df.iterrows()):
if neighbors_count[i] < min_neighbors:
isolated_points.append(row['file'])
self.logger.debug(
f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})")
filtered_df = points_df[~points_df['file'].isin(isolated_points)]
self.logger.info(
f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df