更新过滤密集点算法

This commit is contained in:
龙澳 2024-12-21 10:44:25 +08:00
parent 1d3c1bd71b
commit ebe818125e
3 changed files with 150 additions and 132 deletions

View File

@ -13,7 +13,6 @@ from preprocess.gps_extractor import GPSExtractor
from preprocess.gps_filter import GPSFilter from preprocess.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider from preprocess.grid_divider import GridDivider
from preprocess.logger import setup_logger from preprocess.logger import setup_logger
from preprocess.time_filter import TimeFilter
@dataclass @dataclass
@ -51,20 +50,12 @@ class ImagePreprocessor:
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
return self.gps_points return self.gps_points
def time_filter(self) -> pd.DataFrame:
"""时间过滤"""
self.logger.info("开始时间过滤")
time_filter = TimeFilter(self.config.output_dir)
self.gps_points = time_filter.filter_by_date(self.gps_points)
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
# TODO 添加聚类参数
def cluster(self) -> pd.DataFrame: def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类""" """使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类") self.logger.info("开始聚类")
# 创建聚类器并执行聚类 # 创建聚类器并执行聚类
clusterer = GPSCluster(self.gps_points, output_dir=self.config.output_dir) clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir)
# 获取主要类别的点 # 获取主要类别的点
self.gps_points = clusterer.get_main_cluster() self.gps_points = clusterer.get_main_cluster()
# 获取统计信息并记录 # 获取统计信息并记录
@ -73,9 +64,7 @@ class ImagePreprocessor:
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}" f"噪声点 {stats['noise_points']}"
) )
return self.gps_points
# TODO 过滤密集点算法需要改进
def filter_points(self) -> pd.DataFrame: def filter_points(self) -> pd.DataFrame:
"""过滤GPS点""" """过滤GPS点"""
if not self.config.enable_filter: if not self.config.enable_filter:
@ -132,7 +121,8 @@ class ImagePreprocessor:
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
) )
else: else:
output_dir = os.path.join(self.config.output_dir, "project", "images") output_dir = os.path.join(
self.config.output_dir, "project", "images")
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)

View File

@ -5,6 +5,7 @@ import numpy as np
from scipy.spatial import KDTree from scipy.spatial import KDTree
import logging import logging
import pandas as pd import pandas as pd
from datetime import datetime, timedelta
class GPSFilter: class GPSFilter:
@ -67,23 +68,122 @@ class GPSFilter:
return sorted_distances return sorted_distances
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13): def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta = timedelta(minutes=5)) -> list:
"""过滤密集点,根据提供的距离阈值""" """根据拍摄时间分组图片
self.logger.info(f"开始过滤密集点 (网格大小: {grid_size}, 距离阈值: {distance_threshold}米)")
sorted_distances = self._get_distances(points_df, grid_size) 如果相邻两张图片的拍摄时间差超过5分钟则进行切分
to_del_imgs = []
"""遍历每个网格,删除网格中距离小于阈值的点""" Args:
points_df: 包含图片信息的DataFrame必须包含'file''date'
time_threshold: 时间间隔阈值默认5分钟
Returns:
list: 每个元素为时间组内的点数据
"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return [points_df]
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
self.logger.info(f"有效时间范围: {valid_date_points['date'].min()}{valid_date_points['date'].max()}")
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
# 记录断点信息
break_time = valid_date_points.iloc[idx]['date']
group_start_time = current_group.iloc[0]['date']
group_end_time = current_group.iloc[-1]['date']
self.logger.info(
f"时间组 {len(time_groups)}: {len(current_group)} 个点, "
f"时间范围 [{group_start_time} - {group_end_time}]"
)
self.logger.info(f"在时间 {break_time} 处发现断点,时间差为 {time_diff}")
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
self.logger.info(
f"时间组 {len(time_groups)}: {len(last_group)} 个点, "
f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]"
)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点")
self.logger.info(f"共分为 {len(time_groups)} 个时间组")
return time_groups
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_interval=60):
"""
过滤密集点先按时间分组再在每个时间组内过滤
空时间戳的点不进行过滤
Args:
points_df: 点数据
grid_size: 网格大小
distance_threshold: 距离阈值
time_interval: 时间间隔
"""
self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, "
f"距离阈值: {distance_threshold}米, 时间间隔: {time_interval}秒)")
# 按时间分组
time_groups = self._group_by_time(points_df, time_interval)
# 存储所有要删除的图片
all_to_del_imgs = []
# 对每个时间组进行密集点过滤
for group_idx, group_points in enumerate(time_groups):
# 检查是否为空时间戳组(最后一组)
if group_idx == len(time_groups) - 1 and group_points['date'].isna().any():
self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)")
continue
self.logger.info(f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)")
# 计算该组内的点间距离
sorted_distances = self._get_distances(group_points, grid_size)
group_to_del_imgs = []
# 在每个网格中过滤密集点
for grid, distances in sorted_distances.items(): for grid, distances in sorted_distances.items():
grid_del_count = 0 grid_del_count = 0
while distances: while distances:
candidate_img1, candidate_img2, dist = distances[0] candidate_img1, candidate_img2, dist = distances[0]
if dist < distance_threshold: if dist < distance_threshold:
# 有小于阈值的距离,先删除最近的距离,这两个点作为候选点,要删掉一个
distances.pop(0) distances.pop(0)
# 获取候选图片1和图片2倒数第二短的距离 # 获取候选图片的其他最短距离
candidate_img1_dist = None candidate_img1_dist = None
candidate_img2_dist = None candidate_img2_dist = None
for distance in distances: for distance in distances:
@ -95,30 +195,31 @@ class GPSFilter:
candidate_img2_dist = distance[2] candidate_img2_dist = distance[2]
break break
# 谁短删掉谁 # 选择要删除的点
if candidate_img1_dist and candidate_img2_dist: if candidate_img1_dist and candidate_img2_dist:
if candidate_img1_dist < candidate_img2_dist: to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2
to_del_img = candidate_img1 group_to_del_imgs.append(to_del_img)
else:
to_del_img = candidate_img2
to_del_imgs.append(to_del_img)
grid_del_count += 1 grid_del_count += 1
self.logger.debug(f"在网格 {grid} 中删除密集点: {to_del_img} (距离: {dist:.2f}米)") self.logger.debug(f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)")
# 从距离列表中删除与被删除图片相关的记录 distances = [d for d in distances if to_del_img not in d]
distances = [
distance for distance in distances if to_del_img not in distance]
else: else:
break break
if grid_del_count > 0:
self.logger.info(f"网格 {grid} 中删除了 {grid_del_count} 个密集点")
# 过滤掉删除的图片,写入日志 if grid_del_count > 0:
self.logger.info(f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点")
all_to_del_imgs.extend(group_to_del_imgs)
self.logger.info(f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点")
# 写入删除日志
with open(self.log_file, 'a', encoding='utf-8') as f: with open(self.log_file, 'a', encoding='utf-8') as f:
for img in to_del_imgs: for img in all_to_del_imgs:
f.write(img + '\n') f.write(img + '\n')
filtered_df = points_df[~points_df['file'].isin(to_del_imgs)] # 过滤数据
self.logger.info(f"密集点过滤完成,共删除 {len(to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)]
self.logger.info(f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
return filtered_df return filtered_df
def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6): def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6):

View File

@ -1,73 +0,0 @@
import os
import logging
import pandas as pd
from datetime import datetime, timedelta
class TimeFilter:
"""基于拍摄时间过滤图片"""
def __init__(self, output_dir):
self.log_file = os.path.join(output_dir, 'del_imgs.txt')
self.logger = logging.getLogger('UAV_Preprocess.TimeFilter')
self.time_threshold = timedelta(minutes=5) # 5分钟阈值
def filter_by_date(self, points_df: pd.DataFrame) -> pd.DataFrame:
"""根据拍摄时间过滤图片
如果相邻两张图片的拍摄时间差超过5分钟过滤掉后续所有图片
Args:
points_df: 包含图片信息的DataFrame必须包含'file''date'
Returns:
过滤后的DataFrame
"""
self.logger.info("开始基于拍摄时间进行过滤")
# 确保date列存在且不为空
if 'date' not in points_df.columns:
self.logger.error("输入数据中缺少date列")
return points_df
# 删除date为空的行
points_df = points_df.dropna(subset=['date'])
if len(points_df) == 0:
self.logger.warning("没有有效的拍摄时间数据")
return points_df
# 按时间排序
points_df = points_df.sort_values('date')
self.logger.info(f"排序后的时间范围: {points_df['date'].min()}{points_df['date'].max()}")
# 计算时间差
time_diffs = points_df['date'].diff()
# 找到第一个时间差超过阈值的位置
break_idx = None
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > self.time_threshold:
break_idx = idx
break_time = points_df.iloc[idx]['date']
self.logger.info(f"在索引 {idx} 处发现时间断点,时间差为 {time_diff}")
self.logger.info(f"断点时间: {break_time}")
break
# 如果找到断点,过滤掉后续图片
if break_idx is not None:
to_delete = points_df.iloc[break_idx:]['file'].tolist()
self.logger.info(f"将删除 {len(to_delete)} 张断点后的图片")
# 记录被删除的图片
with open(self.log_file, 'a', encoding='utf-8') as f:
for img in to_delete:
f.write(img + '\n')
f.write('\n')
# 保留断点之前的图片
filtered_df = points_df.iloc[:break_idx]
self.logger.info(f"时间过滤完成,保留了 {len(filtered_df)} 张图片")
return filtered_df
self.logger.info("未发现时间断点,保留所有图片")
return points_df