更新过滤密集点算法
This commit is contained in:
parent
1d3c1bd71b
commit
ebe818125e
@ -13,7 +13,6 @@ from preprocess.gps_extractor import GPSExtractor
|
||||
from preprocess.gps_filter import GPSFilter
|
||||
from preprocess.grid_divider import GridDivider
|
||||
from preprocess.logger import setup_logger
|
||||
from preprocess.time_filter import TimeFilter
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -51,20 +50,12 @@ class ImagePreprocessor:
|
||||
self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点")
|
||||
return self.gps_points
|
||||
|
||||
def time_filter(self) -> pd.DataFrame:
|
||||
"""时间过滤"""
|
||||
self.logger.info("开始时间过滤")
|
||||
time_filter = TimeFilter(self.config.output_dir)
|
||||
self.gps_points = time_filter.filter_by_date(self.gps_points)
|
||||
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
return self.gps_points
|
||||
|
||||
# TODO 添加聚类参数
|
||||
def cluster(self) -> pd.DataFrame:
|
||||
"""使用DBSCAN对GPS点进行聚类,只保留最大的类"""
|
||||
self.logger.info("开始聚类")
|
||||
# 创建聚类器并执行聚类
|
||||
clusterer = GPSCluster(self.gps_points, output_dir=self.config.output_dir)
|
||||
clusterer = GPSCluster(
|
||||
self.gps_points, output_dir=self.config.output_dir)
|
||||
# 获取主要类别的点
|
||||
self.gps_points = clusterer.get_main_cluster()
|
||||
# 获取统计信息并记录
|
||||
@ -73,9 +64,7 @@ class ImagePreprocessor:
|
||||
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
|
||||
f"噪声点 {stats['noise_points']} 个"
|
||||
)
|
||||
return self.gps_points
|
||||
|
||||
# TODO 过滤密集点算法需要改进
|
||||
def filter_points(self) -> pd.DataFrame:
|
||||
"""过滤GPS点"""
|
||||
if not self.config.enable_filter:
|
||||
@ -132,7 +121,8 @@ class ImagePreprocessor:
|
||||
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
|
||||
)
|
||||
else:
|
||||
output_dir = os.path.join(self.config.output_dir, "project", "images")
|
||||
output_dir = os.path.join(
|
||||
self.config.output_dir, "project", "images")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
|
@ -5,6 +5,7 @@ import numpy as np
|
||||
from scipy.spatial import KDTree
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class GPSFilter:
|
||||
@ -67,58 +68,158 @@ class GPSFilter:
|
||||
|
||||
return sorted_distances
|
||||
|
||||
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13):
|
||||
"""过滤密集点,根据提供的距离阈值"""
|
||||
self.logger.info(f"开始过滤密集点 (网格大小: {grid_size}, 距离阈值: {distance_threshold}米)")
|
||||
def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta = timedelta(minutes=5)) -> list:
|
||||
"""根据拍摄时间分组图片
|
||||
|
||||
sorted_distances = self._get_distances(points_df, grid_size)
|
||||
to_del_imgs = []
|
||||
如果相邻两张图片的拍摄时间差超过5分钟,则进行切分
|
||||
|
||||
"""遍历每个网格,删除网格中距离小于阈值的点"""
|
||||
for grid, distances in sorted_distances.items():
|
||||
grid_del_count = 0
|
||||
while distances:
|
||||
candidate_img1, candidate_img2, dist = distances[0]
|
||||
if dist < distance_threshold:
|
||||
# 有小于阈值的距离,先删除最近的距离,这两个点作为候选点,要删掉一个
|
||||
distances.pop(0)
|
||||
Args:
|
||||
points_df: 包含图片信息的DataFrame,必须包含'file'和'date'列
|
||||
time_threshold: 时间间隔阈值,默认5分钟
|
||||
|
||||
# 获取候选图片1和图片2倒数第二短的距离
|
||||
candidate_img1_dist = None
|
||||
candidate_img2_dist = None
|
||||
for distance in distances:
|
||||
if candidate_img1 in distance:
|
||||
candidate_img1_dist = distance[2]
|
||||
break
|
||||
for distance in distances:
|
||||
if candidate_img2 in distance:
|
||||
candidate_img2_dist = distance[2]
|
||||
break
|
||||
Returns:
|
||||
list: 每个元素为时间组内的点数据
|
||||
"""
|
||||
if 'date' not in points_df.columns:
|
||||
self.logger.error("数据中缺少date列")
|
||||
return [points_df]
|
||||
|
||||
# 谁短删掉谁
|
||||
if candidate_img1_dist and candidate_img2_dist:
|
||||
if candidate_img1_dist < candidate_img2_dist:
|
||||
to_del_img = candidate_img1
|
||||
else:
|
||||
to_del_img = candidate_img2
|
||||
to_del_imgs.append(to_del_img)
|
||||
grid_del_count += 1
|
||||
self.logger.debug(f"在网格 {grid} 中删除密集点: {to_del_img} (距离: {dist:.2f}米)")
|
||||
# 从距离列表中删除与被删除图片相关的记录
|
||||
distances = [
|
||||
distance for distance in distances if to_del_img not in distance]
|
||||
else:
|
||||
break
|
||||
if grid_del_count > 0:
|
||||
self.logger.info(f"网格 {grid} 中删除了 {grid_del_count} 个密集点")
|
||||
# 将date为空的行单独作为一组
|
||||
null_date_group = points_df[points_df['date'].isna()]
|
||||
valid_date_points = points_df[points_df['date'].notna()]
|
||||
|
||||
# 过滤掉删除的图片,写入日志
|
||||
if not null_date_group.empty:
|
||||
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
|
||||
|
||||
if valid_date_points.empty:
|
||||
self.logger.warning("没有有效的时间戳数据")
|
||||
return [null_date_group] if not null_date_group.empty else []
|
||||
|
||||
# 按时间排序
|
||||
valid_date_points = valid_date_points.sort_values('date')
|
||||
self.logger.info(f"有效时间范围: {valid_date_points['date'].min()} 到 {valid_date_points['date'].max()}")
|
||||
|
||||
# 计算时间差
|
||||
time_diffs = valid_date_points['date'].diff()
|
||||
|
||||
# 找到时间差超过阈值的位置
|
||||
time_groups = []
|
||||
current_group_start = 0
|
||||
|
||||
for idx, time_diff in enumerate(time_diffs):
|
||||
if time_diff and time_diff > time_threshold:
|
||||
# 添加当前组
|
||||
current_group = valid_date_points.iloc[current_group_start:idx]
|
||||
time_groups.append(current_group)
|
||||
|
||||
# 记录断点信息
|
||||
break_time = valid_date_points.iloc[idx]['date']
|
||||
group_start_time = current_group.iloc[0]['date']
|
||||
group_end_time = current_group.iloc[-1]['date']
|
||||
|
||||
self.logger.info(
|
||||
f"时间组 {len(time_groups)}: {len(current_group)} 个点, "
|
||||
f"时间范围 [{group_start_time} - {group_end_time}]"
|
||||
)
|
||||
self.logger.info(f"在时间 {break_time} 处发现断点,时间差为 {time_diff}")
|
||||
|
||||
current_group_start = idx
|
||||
|
||||
# 添加最后一组
|
||||
last_group = valid_date_points.iloc[current_group_start:]
|
||||
if not last_group.empty:
|
||||
time_groups.append(last_group)
|
||||
self.logger.info(
|
||||
f"时间组 {len(time_groups)}: {len(last_group)} 个点, "
|
||||
f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]"
|
||||
)
|
||||
|
||||
# 如果有空时间戳的点,将其作为最后一组
|
||||
if not null_date_group.empty:
|
||||
time_groups.append(null_date_group)
|
||||
self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点")
|
||||
|
||||
self.logger.info(f"共分为 {len(time_groups)} 个时间组")
|
||||
return time_groups
|
||||
|
||||
def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_interval=60):
|
||||
"""
|
||||
过滤密集点,先按时间分组,再在每个时间组内过滤。
|
||||
空时间戳的点不进行过滤。
|
||||
|
||||
Args:
|
||||
points_df: 点数据
|
||||
grid_size: 网格大小
|
||||
distance_threshold: 距离阈值(米)
|
||||
time_interval: 时间间隔(秒)
|
||||
"""
|
||||
self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, "
|
||||
f"距离阈值: {distance_threshold}米, 时间间隔: {time_interval}秒)")
|
||||
|
||||
# 按时间分组
|
||||
time_groups = self._group_by_time(points_df, time_interval)
|
||||
|
||||
# 存储所有要删除的图片
|
||||
all_to_del_imgs = []
|
||||
|
||||
# 对每个时间组进行密集点过滤
|
||||
for group_idx, group_points in enumerate(time_groups):
|
||||
# 检查是否为空时间戳组(最后一组)
|
||||
if group_idx == len(time_groups) - 1 and group_points['date'].isna().any():
|
||||
self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)")
|
||||
continue
|
||||
|
||||
self.logger.info(f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)")
|
||||
|
||||
# 计算该组内的点间距离
|
||||
sorted_distances = self._get_distances(group_points, grid_size)
|
||||
group_to_del_imgs = []
|
||||
|
||||
# 在每个网格中过滤密集点
|
||||
for grid, distances in sorted_distances.items():
|
||||
grid_del_count = 0
|
||||
while distances:
|
||||
candidate_img1, candidate_img2, dist = distances[0]
|
||||
if dist < distance_threshold:
|
||||
distances.pop(0)
|
||||
|
||||
# 获取候选图片的其他最短距离
|
||||
candidate_img1_dist = None
|
||||
candidate_img2_dist = None
|
||||
for distance in distances:
|
||||
if candidate_img1 in distance:
|
||||
candidate_img1_dist = distance[2]
|
||||
break
|
||||
for distance in distances:
|
||||
if candidate_img2 in distance:
|
||||
candidate_img2_dist = distance[2]
|
||||
break
|
||||
|
||||
# 选择要删除的点
|
||||
if candidate_img1_dist and candidate_img2_dist:
|
||||
to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2
|
||||
group_to_del_imgs.append(to_del_img)
|
||||
grid_del_count += 1
|
||||
self.logger.debug(f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)")
|
||||
distances = [d for d in distances if to_del_img not in d]
|
||||
else:
|
||||
break
|
||||
|
||||
if grid_del_count > 0:
|
||||
self.logger.info(f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点")
|
||||
|
||||
all_to_del_imgs.extend(group_to_del_imgs)
|
||||
self.logger.info(f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点")
|
||||
|
||||
# 写入删除日志
|
||||
with open(self.log_file, 'a', encoding='utf-8') as f:
|
||||
for img in to_del_imgs:
|
||||
f.write(img+'\n')
|
||||
for img in all_to_del_imgs:
|
||||
f.write(img + '\n')
|
||||
|
||||
# 过滤数据
|
||||
filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)]
|
||||
self.logger.info(f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
|
||||
|
||||
filtered_df = points_df[~points_df['file'].isin(to_del_imgs)]
|
||||
self.logger.info(f"密集点过滤完成,共删除 {len(to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点")
|
||||
return filtered_df
|
||||
|
||||
def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6):
|
||||
|
@ -1,73 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
class TimeFilter:
|
||||
"""基于拍摄时间过滤图片"""
|
||||
|
||||
def __init__(self, output_dir):
|
||||
self.log_file = os.path.join(output_dir, 'del_imgs.txt')
|
||||
self.logger = logging.getLogger('UAV_Preprocess.TimeFilter')
|
||||
self.time_threshold = timedelta(minutes=5) # 5分钟阈值
|
||||
|
||||
def filter_by_date(self, points_df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""根据拍摄时间过滤图片
|
||||
|
||||
如果相邻两张图片的拍摄时间差超过5分钟,过滤掉后续所有图片
|
||||
|
||||
Args:
|
||||
points_df: 包含图片信息的DataFrame,必须包含'file'和'date'列
|
||||
|
||||
Returns:
|
||||
过滤后的DataFrame
|
||||
"""
|
||||
self.logger.info("开始基于拍摄时间进行过滤")
|
||||
|
||||
# 确保date列存在且不为空
|
||||
if 'date' not in points_df.columns:
|
||||
self.logger.error("输入数据中缺少date列")
|
||||
return points_df
|
||||
|
||||
# 删除date为空的行
|
||||
points_df = points_df.dropna(subset=['date'])
|
||||
|
||||
if len(points_df) == 0:
|
||||
self.logger.warning("没有有效的拍摄时间数据")
|
||||
return points_df
|
||||
|
||||
# 按时间排序
|
||||
points_df = points_df.sort_values('date')
|
||||
self.logger.info(f"排序后的时间范围: {points_df['date'].min()} 到 {points_df['date'].max()}")
|
||||
|
||||
# 计算时间差
|
||||
time_diffs = points_df['date'].diff()
|
||||
|
||||
# 找到第一个时间差超过阈值的位置
|
||||
break_idx = None
|
||||
for idx, time_diff in enumerate(time_diffs):
|
||||
if time_diff and time_diff > self.time_threshold:
|
||||
break_idx = idx
|
||||
break_time = points_df.iloc[idx]['date']
|
||||
self.logger.info(f"在索引 {idx} 处发现时间断点,时间差为 {time_diff}")
|
||||
self.logger.info(f"断点时间: {break_time}")
|
||||
break
|
||||
|
||||
# 如果找到断点,过滤掉后续图片
|
||||
if break_idx is not None:
|
||||
to_delete = points_df.iloc[break_idx:]['file'].tolist()
|
||||
self.logger.info(f"将删除 {len(to_delete)} 张断点后的图片")
|
||||
|
||||
# 记录被删除的图片
|
||||
with open(self.log_file, 'a', encoding='utf-8') as f:
|
||||
for img in to_delete:
|
||||
f.write(img + '\n')
|
||||
f.write('\n')
|
||||
|
||||
# 保留断点之前的图片
|
||||
filtered_df = points_df.iloc[:break_idx]
|
||||
self.logger.info(f"时间过滤完成,保留了 {len(filtered_df)} 张图片")
|
||||
return filtered_df
|
||||
|
||||
self.logger.info("未发现时间断点,保留所有图片")
|
||||
return points_df
|
Loading…
Reference in New Issue
Block a user