代码重构,添加面积过滤
This commit is contained in:
parent
989648fcae
commit
3d7ccd815a
219
filter/time_group_overlap_filter.py
Normal file
219
filter/time_group_overlap_filter.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from datetime import timedelta
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
from preprocess.gps_extractor import GPSExtractor
|
||||||
|
from preprocess.logger import setup_logger
|
||||||
|
from shapely.geometry import box
|
||||||
|
import pandas as pd
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
class TimeGroupOverlapFilter:
|
||||||
|
"""基于时间组重叠度的图像过滤器"""
|
||||||
|
|
||||||
|
def __init__(self, image_dir: str, output_dir: str, overlap_threshold: float = 0.7):
|
||||||
|
"""
|
||||||
|
初始化过滤器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_dir: 图像目录
|
||||||
|
output_dir: 输出目录
|
||||||
|
overlap_threshold: 重叠阈值,默认0.7
|
||||||
|
"""
|
||||||
|
self.image_dir = image_dir
|
||||||
|
self.output_dir = output_dir
|
||||||
|
self.overlap_threshold = overlap_threshold
|
||||||
|
self.logger = logging.getLogger('UAV_Preprocess.TimeGroupFilter')
|
||||||
|
|
||||||
|
def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)):
|
||||||
|
"""按时间间隔对点进行分组"""
|
||||||
|
if 'date' not in points_df.columns:
|
||||||
|
self.logger.error("数据中缺少date列")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 将date为空的行单独作为一组
|
||||||
|
null_date_group = points_df[points_df['date'].isna()]
|
||||||
|
valid_date_points = points_df[points_df['date'].notna()]
|
||||||
|
|
||||||
|
if not null_date_group.empty:
|
||||||
|
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
|
||||||
|
|
||||||
|
if valid_date_points.empty:
|
||||||
|
self.logger.warning("没有有效的时间戳数据")
|
||||||
|
return [null_date_group] if not null_date_group.empty else []
|
||||||
|
|
||||||
|
# 按时间排序
|
||||||
|
valid_date_points = valid_date_points.sort_values('date')
|
||||||
|
|
||||||
|
# 计算时间差
|
||||||
|
time_diffs = valid_date_points['date'].diff()
|
||||||
|
|
||||||
|
# 找到时间差超过阈值的位置
|
||||||
|
time_groups = []
|
||||||
|
current_group_start = 0
|
||||||
|
|
||||||
|
for idx, time_diff in enumerate(time_diffs):
|
||||||
|
if time_diff and time_diff > time_threshold:
|
||||||
|
# 添加当前组
|
||||||
|
current_group = valid_date_points.iloc[current_group_start:idx]
|
||||||
|
time_groups.append(current_group)
|
||||||
|
current_group_start = idx
|
||||||
|
|
||||||
|
# 添加最后一组
|
||||||
|
last_group = valid_date_points.iloc[current_group_start:]
|
||||||
|
if not last_group.empty:
|
||||||
|
time_groups.append(last_group)
|
||||||
|
|
||||||
|
# 如果有空时间戳的点,将其作为最后一组
|
||||||
|
if not null_date_group.empty:
|
||||||
|
time_groups.append(null_date_group)
|
||||||
|
|
||||||
|
return time_groups
|
||||||
|
|
||||||
|
def _get_group_bbox(self, group_df):
|
||||||
|
"""获取组内点的边界框"""
|
||||||
|
min_lon = group_df['lon'].min()
|
||||||
|
max_lon = group_df['lon'].max()
|
||||||
|
min_lat = group_df['lat'].min()
|
||||||
|
max_lat = group_df['lat'].max()
|
||||||
|
return box(min_lon, min_lat, max_lon, max_lat)
|
||||||
|
|
||||||
|
def _calculate_overlap(self, box1, box2):
|
||||||
|
"""计算两个边界框的重叠率"""
|
||||||
|
if box1.intersects(box2):
|
||||||
|
intersection_area = box1.intersection(box2).area
|
||||||
|
smaller_area = min(box1.area, box2.area)
|
||||||
|
return intersection_area / smaller_area
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def filter_overlapping_groups(self, time_threshold=timedelta(minutes=5)):
|
||||||
|
"""过滤重叠的时间组"""
|
||||||
|
# 提取GPS数据
|
||||||
|
extractor = GPSExtractor(self.image_dir)
|
||||||
|
gps_points = extractor.extract_all_gps()
|
||||||
|
|
||||||
|
# 按时间分组
|
||||||
|
time_groups = self._group_by_time(gps_points, time_threshold)
|
||||||
|
|
||||||
|
# 计算每个组的边界框
|
||||||
|
group_boxes = []
|
||||||
|
for idx, group in enumerate(time_groups):
|
||||||
|
if not group['date'].isna().any(): # 只处理有时间戳的组
|
||||||
|
bbox = self._get_group_bbox(group)
|
||||||
|
group_boxes.append((idx, group, bbox))
|
||||||
|
|
||||||
|
# 找出需要删除的组
|
||||||
|
groups_to_delete = set()
|
||||||
|
for i in range(len(group_boxes)):
|
||||||
|
if i in groups_to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
idx1, group1, box1 = group_boxes[i]
|
||||||
|
area1 = box1.area
|
||||||
|
|
||||||
|
for j in range(i + 1, len(group_boxes)):
|
||||||
|
if j in groups_to_delete:
|
||||||
|
continue
|
||||||
|
|
||||||
|
idx2, group2, box2 = group_boxes[j]
|
||||||
|
area2 = box2.area
|
||||||
|
|
||||||
|
overlap_ratio = self._calculate_overlap(box1, box2)
|
||||||
|
|
||||||
|
if overlap_ratio > self.overlap_threshold:
|
||||||
|
# 删除面积较小的组
|
||||||
|
if area1 < area2:
|
||||||
|
group_to_delete = idx1
|
||||||
|
smaller_area = area1
|
||||||
|
larger_area = area2
|
||||||
|
else:
|
||||||
|
group_to_delete = idx2
|
||||||
|
smaller_area = area2
|
||||||
|
larger_area = area1
|
||||||
|
|
||||||
|
groups_to_delete.add(group_to_delete)
|
||||||
|
self.logger.info(
|
||||||
|
f"时间组 {group_to_delete + 1} 与时间组 "
|
||||||
|
f"{idx2 + 1 if group_to_delete == idx1 else idx1 + 1} "
|
||||||
|
f"重叠率为 {overlap_ratio:.2f},"
|
||||||
|
f"面积比为 {smaller_area/larger_area:.2f},"
|
||||||
|
f"将删除较小面积的组 {group_to_delete + 1}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建删除日志文件
|
||||||
|
log_file = os.path.join(self.output_dir, 'deleted_images.txt')
|
||||||
|
|
||||||
|
# 删除重复组的图像
|
||||||
|
deleted_files = []
|
||||||
|
for group_idx in groups_to_delete:
|
||||||
|
group_files = time_groups[group_idx]['file'].tolist()
|
||||||
|
deleted_files.extend(group_files)
|
||||||
|
|
||||||
|
# 写入删除日志
|
||||||
|
with open(log_file, 'w', encoding='utf-8') as f:
|
||||||
|
for file in deleted_files:
|
||||||
|
f.write(f"{file}\n")
|
||||||
|
|
||||||
|
self.logger.info(f"共删除 {len(groups_to_delete)} 个重复时间组,"
|
||||||
|
f"{len(deleted_files)} 张图像")
|
||||||
|
|
||||||
|
# 可视化结果
|
||||||
|
self._visualize_results(time_groups, groups_to_delete)
|
||||||
|
|
||||||
|
return deleted_files
|
||||||
|
|
||||||
|
def _visualize_results(self, time_groups, groups_to_delete):
|
||||||
|
"""可视化过滤结果"""
|
||||||
|
plt.figure(figsize=(15, 10))
|
||||||
|
|
||||||
|
# 生成不同的颜色
|
||||||
|
colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups)))
|
||||||
|
|
||||||
|
# 绘制所有组的边界框
|
||||||
|
for idx, (group, color) in enumerate(zip(time_groups, colors)):
|
||||||
|
if not group['date'].isna().any(): # 只处理有时间戳的组
|
||||||
|
bbox = self._get_group_bbox(group)
|
||||||
|
x, y = bbox.exterior.xy
|
||||||
|
|
||||||
|
if idx in groups_to_delete:
|
||||||
|
# 被删除的组用虚线表示
|
||||||
|
plt.plot(x, y, '--', color=color, alpha=0.6,
|
||||||
|
label=f'Deleted Group {idx + 1}')
|
||||||
|
else:
|
||||||
|
# 保留的组用实线表示
|
||||||
|
plt.plot(x, y, '-', color=color, alpha=0.6,
|
||||||
|
label=f'Group {idx + 1}')
|
||||||
|
|
||||||
|
# 绘制该组的GPS点
|
||||||
|
plt.scatter(group['lon'], group['lat'], color=color,
|
||||||
|
s=30, alpha=0.6)
|
||||||
|
|
||||||
|
plt.title("Time Groups and Their Bounding Boxes", fontsize=14)
|
||||||
|
plt.xlabel("Longitude", fontsize=12)
|
||||||
|
plt.ylabel("Latitude", fontsize=12)
|
||||||
|
plt.grid(True)
|
||||||
|
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图片
|
||||||
|
plt.savefig(os.path.join(self.output_dir, 'time_groups_overlap.png'),
|
||||||
|
dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 设置路径
|
||||||
|
DATASET = r'F:\error_data\20241108134711\3D'
|
||||||
|
output_dir = r'E:\studio2\ODM_pro\test'
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 设置日志
|
||||||
|
setup_logger(os.path.dirname(output_dir))
|
||||||
|
|
||||||
|
# 创建过滤器并执行过滤
|
||||||
|
filter = TimeGroupOverlapFilter(DATASET, output_dir, overlap_threshold=0.7)
|
||||||
|
deleted_files = filter.filter_overlapping_groups(time_threshold=timedelta(minutes=5))
|
@ -1,3 +0,0 @@
|
|||||||
2024-12-22 18:58:05 - UAV_Preprocess.GPSExtractor - INFO - 开始从目录提取GPS坐标和拍摄日期: F:\error_data\20241104140457\code\images
|
|
||||||
2024-12-22 18:58:22 - UAV_Preprocess.GPSExtractor - INFO - GPS坐标和拍摄日期提取完成 - 总图片数: 2708, 成功提取: 2708, 失败: 0
|
|
||||||
2024-12-22 18:58:22 - UAV_Preprocess.GPSVisualizer - INFO - 已生成包含 14 个时间组的组合可视化图形
|
|
@ -1,3 +0,0 @@
|
|||||||
2024-12-22 18:59:26 - UAV_Preprocess.GPSExtractor - INFO - 开始从目录提取GPS坐标和拍摄日期: F:\error_data\20241108134711\3D
|
|
||||||
2024-12-22 19:00:09 - UAV_Preprocess.GPSExtractor - INFO - GPS坐标和拍摄日期提取完成 - 总图片数: 6615, 成功提取: 6615, 失败: 0
|
|
||||||
2024-12-22 19:00:10 - UAV_Preprocess.GPSVisualizer - INFO - 已生成包含 8 个时间组的组合可视化图形
|
|
@ -8,12 +8,13 @@ import matplotlib.pyplot as plt
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from preprocess.cluster import GPSCluster
|
from filter.cluster_filter import GPSCluster
|
||||||
from preprocess.command_runner import CommandRunner
|
from utils.command_runner import CommandRunner
|
||||||
from preprocess.gps_extractor import GPSExtractor
|
from utils.gps_extractor import GPSExtractor
|
||||||
from preprocess.gps_filter import GPSFilter
|
from filter.gps_filter import GPSFilter
|
||||||
from preprocess.grid_divider import GridDivider
|
from utils.grid_divider import GridDivider
|
||||||
from preprocess.logger import setup_logger
|
from utils.logger import setup_logger
|
||||||
|
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,12 +26,16 @@ class PreprocessConfig:
|
|||||||
# 聚类过滤参数
|
# 聚类过滤参数
|
||||||
cluster_eps: float = 0.01
|
cluster_eps: float = 0.01
|
||||||
cluster_min_samples: int = 5
|
cluster_min_samples: int = 5
|
||||||
|
# 时间组重叠过滤参数
|
||||||
|
time_group_overlap_threshold: float = 0.7
|
||||||
|
time_group_interval: timedelta = timedelta(minutes=5)
|
||||||
|
enable_time_group_filter: bool = True
|
||||||
# 孤立点过滤参数
|
# 孤立点过滤参数
|
||||||
filter_distance_threshold: float = 0.001 # 经纬度距离
|
filter_distance_threshold: float = 0.001 # 经纬度距离
|
||||||
filter_min_neighbors: int = 6
|
filter_min_neighbors: int = 6
|
||||||
# 密集点过滤参数
|
# 密集点过滤参数
|
||||||
filter_grid_size: float = 0.001
|
filter_grid_size: float = 0.001
|
||||||
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
|
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
|
||||||
filter_time_threshold: timedelta = timedelta(minutes=5)
|
filter_time_threshold: timedelta = timedelta(minutes=5)
|
||||||
# 网格划分参数
|
# 网格划分参数
|
||||||
grid_overlap: float = 0.05
|
grid_overlap: float = 0.05
|
||||||
@ -42,12 +47,14 @@ class PreprocessConfig:
|
|||||||
enable_copy_images: bool = True
|
enable_copy_images: bool = True
|
||||||
mode: str = "快拼模式"
|
mode: str = "快拼模式"
|
||||||
|
|
||||||
|
|
||||||
class ImagePreprocessor:
|
class ImagePreprocessor:
|
||||||
def __init__(self, config: PreprocessConfig):
|
def __init__(self, config: PreprocessConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.logger = setup_logger(config.output_dir)
|
self.logger = setup_logger(config.output_dir)
|
||||||
self.gps_points = []
|
self.gps_points = []
|
||||||
self.command_runner = CommandRunner(config.output_dir, mode=config.mode)
|
self.command_runner = CommandRunner(
|
||||||
|
config.output_dir, mode=config.mode)
|
||||||
|
|
||||||
def extract_gps(self) -> pd.DataFrame:
|
def extract_gps(self) -> pd.DataFrame:
|
||||||
"""提取GPS数据"""
|
"""提取GPS数据"""
|
||||||
@ -65,7 +72,7 @@ class ImagePreprocessor:
|
|||||||
self.gps_points, output_dir=self.config.output_dir,
|
self.gps_points, output_dir=self.config.output_dir,
|
||||||
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
|
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples)
|
||||||
# 获取主要类别的点
|
# 获取主要类别的点
|
||||||
self.clustered_points = clusterer.fit()
|
self.clustered_points = clusterer.fit()
|
||||||
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
|
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
|
||||||
# 获取统计信息并记录
|
# 获取统计信息并记录
|
||||||
stats = clusterer.get_cluster_stats(self.clustered_points)
|
stats = clusterer.get_cluster_stats(self.clustered_points)
|
||||||
@ -74,6 +81,29 @@ class ImagePreprocessor:
|
|||||||
f"噪声点 {stats['noise_points']} 个"
|
f"噪声点 {stats['noise_points']} 个"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def filter_time_group_overlap(self) -> pd.DataFrame:
|
||||||
|
"""过滤重叠的时间组"""
|
||||||
|
if not self.config.enable_time_group_filter:
|
||||||
|
return self.gps_points
|
||||||
|
|
||||||
|
self.logger.info("开始过滤重叠时间组")
|
||||||
|
filter = TimeGroupOverlapFilter(
|
||||||
|
self.config.image_dir,
|
||||||
|
self.config.output_dir,
|
||||||
|
overlap_threshold=self.config.time_group_overlap_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_files = filter.filter_overlapping_groups(
|
||||||
|
time_threshold=self.config.time_group_interval
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新GPS点数据,移除被删除的图像
|
||||||
|
self.gps_points = self.gps_points[~self.gps_points['file'].isin(
|
||||||
|
deleted_files)]
|
||||||
|
self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||||
|
|
||||||
|
return self.gps_points
|
||||||
|
|
||||||
# TODO 过滤算法还需要更新
|
# TODO 过滤算法还需要更新
|
||||||
def filter_points(self) -> pd.DataFrame:
|
def filter_points(self) -> pd.DataFrame:
|
||||||
"""过滤GPS点"""
|
"""过滤GPS点"""
|
||||||
@ -200,15 +230,16 @@ class ImagePreprocessor:
|
|||||||
try:
|
try:
|
||||||
self.extract_gps()
|
self.extract_gps()
|
||||||
self.cluster()
|
self.cluster()
|
||||||
self.filter_points()
|
self.filter_time_group_overlap()
|
||||||
grid_points = self.divide_grids()
|
# self.filter_points()
|
||||||
self.copy_images(grid_points)
|
# grid_points = self.divide_grids()
|
||||||
self.visualize_results()
|
# self.copy_images(grid_points)
|
||||||
self.logger.info("预处理任务完成")
|
# self.visualize_results()
|
||||||
self.command_runner.run_grid_commands(
|
# self.logger.info("预处理任务完成")
|
||||||
grid_points,
|
# self.command_runner.run_grid_commands(
|
||||||
self.config.enable_grid_division,
|
# grid_points,
|
||||||
)
|
# self.config.enable_grid_division,
|
||||||
|
# )
|
||||||
# TODO 拼图
|
# TODO 拼图
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||||
@ -218,12 +249,17 @@ class ImagePreprocessor:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 创建配置
|
# 创建配置
|
||||||
config = PreprocessConfig(
|
config = PreprocessConfig(
|
||||||
image_dir=r"F:\error_data\20240930091614\project\images",
|
image_dir=r"F:\error_data\20241016140912\code\images",
|
||||||
output_dir=r"F:\error_data\20240930091614\output",
|
output_dir=r"G:\output",
|
||||||
|
|
||||||
cluster_eps=0.01,
|
cluster_eps=0.01,
|
||||||
cluster_min_samples=5,
|
cluster_min_samples=5,
|
||||||
|
|
||||||
|
# 添加时间组重叠过滤参数
|
||||||
|
time_group_overlap_threshold=0.7,
|
||||||
|
time_group_interval=timedelta(minutes=5),
|
||||||
|
enable_time_group_filter=True,
|
||||||
|
|
||||||
filter_distance_threshold=0.001,
|
filter_distance_threshold=0.001,
|
||||||
filter_min_neighbors=6,
|
filter_min_neighbors=6,
|
||||||
|
|
||||||
@ -238,8 +274,8 @@ if __name__ == "__main__":
|
|||||||
enable_grid_division=True,
|
enable_grid_division=True,
|
||||||
enable_visualization=True,
|
enable_visualization=True,
|
||||||
enable_copy_images=True,
|
enable_copy_images=True,
|
||||||
|
|
||||||
mode="sadf模式",
|
mode="快拼模式",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建处理器并执行
|
# 创建处理器并执行
|
||||||
|
Loading…
Reference in New Issue
Block a user