代码重构,添加面积过滤

This commit is contained in:
龙澳 2024-12-22 20:10:06 +08:00
parent 989648fcae
commit 3d7ccd815a
11 changed files with 279 additions and 30 deletions

View File

@ -0,0 +1,219 @@
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import matplotlib.pyplot as plt
from datetime import timedelta
import logging
import numpy as np
from preprocess.gps_extractor import GPSExtractor
from preprocess.logger import setup_logger
from shapely.geometry import box
import pandas as pd
import shutil
class TimeGroupOverlapFilter:
"""基于时间组重叠度的图像过滤器"""
def __init__(self, image_dir: str, output_dir: str, overlap_threshold: float = 0.7):
"""
初始化过滤器
Args:
image_dir: 图像目录
output_dir: 输出目录
overlap_threshold: 重叠阈值默认0.7
"""
self.image_dir = image_dir
self.output_dir = output_dir
self.overlap_threshold = overlap_threshold
self.logger = logging.getLogger('UAV_Preprocess.TimeGroupFilter')
def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)):
"""按时间间隔对点进行分组"""
if 'date' not in points_df.columns:
self.logger.error("数据中缺少date列")
return []
# 将date为空的行单独作为一组
null_date_group = points_df[points_df['date'].isna()]
valid_date_points = points_df[points_df['date'].notna()]
if not null_date_group.empty:
self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组")
if valid_date_points.empty:
self.logger.warning("没有有效的时间戳数据")
return [null_date_group] if not null_date_group.empty else []
# 按时间排序
valid_date_points = valid_date_points.sort_values('date')
# 计算时间差
time_diffs = valid_date_points['date'].diff()
# 找到时间差超过阈值的位置
time_groups = []
current_group_start = 0
for idx, time_diff in enumerate(time_diffs):
if time_diff and time_diff > time_threshold:
# 添加当前组
current_group = valid_date_points.iloc[current_group_start:idx]
time_groups.append(current_group)
current_group_start = idx
# 添加最后一组
last_group = valid_date_points.iloc[current_group_start:]
if not last_group.empty:
time_groups.append(last_group)
# 如果有空时间戳的点,将其作为最后一组
if not null_date_group.empty:
time_groups.append(null_date_group)
return time_groups
def _get_group_bbox(self, group_df):
"""获取组内点的边界框"""
min_lon = group_df['lon'].min()
max_lon = group_df['lon'].max()
min_lat = group_df['lat'].min()
max_lat = group_df['lat'].max()
return box(min_lon, min_lat, max_lon, max_lat)
def _calculate_overlap(self, box1, box2):
"""计算两个边界框的重叠率"""
if box1.intersects(box2):
intersection_area = box1.intersection(box2).area
smaller_area = min(box1.area, box2.area)
return intersection_area / smaller_area
return 0
def filter_overlapping_groups(self, time_threshold=timedelta(minutes=5)):
"""过滤重叠的时间组"""
# 提取GPS数据
extractor = GPSExtractor(self.image_dir)
gps_points = extractor.extract_all_gps()
# 按时间分组
time_groups = self._group_by_time(gps_points, time_threshold)
# 计算每个组的边界框
group_boxes = []
for idx, group in enumerate(time_groups):
if not group['date'].isna().any(): # 只处理有时间戳的组
bbox = self._get_group_bbox(group)
group_boxes.append((idx, group, bbox))
# 找出需要删除的组
groups_to_delete = set()
for i in range(len(group_boxes)):
if i in groups_to_delete:
continue
idx1, group1, box1 = group_boxes[i]
area1 = box1.area
for j in range(i + 1, len(group_boxes)):
if j in groups_to_delete:
continue
idx2, group2, box2 = group_boxes[j]
area2 = box2.area
overlap_ratio = self._calculate_overlap(box1, box2)
if overlap_ratio > self.overlap_threshold:
# 删除面积较小的组
if area1 < area2:
group_to_delete = idx1
smaller_area = area1
larger_area = area2
else:
group_to_delete = idx2
smaller_area = area2
larger_area = area1
groups_to_delete.add(group_to_delete)
self.logger.info(
f"时间组 {group_to_delete + 1} 与时间组 "
f"{idx2 + 1 if group_to_delete == idx1 else idx1 + 1} "
f"重叠率为 {overlap_ratio:.2f}"
f"面积比为 {smaller_area/larger_area:.2f}"
f"将删除较小面积的组 {group_to_delete + 1}"
)
# 创建删除日志文件
log_file = os.path.join(self.output_dir, 'deleted_images.txt')
# 删除重复组的图像
deleted_files = []
for group_idx in groups_to_delete:
group_files = time_groups[group_idx]['file'].tolist()
deleted_files.extend(group_files)
# 写入删除日志
with open(log_file, 'w', encoding='utf-8') as f:
for file in deleted_files:
f.write(f"{file}\n")
self.logger.info(f"共删除 {len(groups_to_delete)} 个重复时间组,"
f"{len(deleted_files)} 张图像")
# 可视化结果
self._visualize_results(time_groups, groups_to_delete)
return deleted_files
def _visualize_results(self, time_groups, groups_to_delete):
"""可视化过滤结果"""
plt.figure(figsize=(15, 10))
# 生成不同的颜色
colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups)))
# 绘制所有组的边界框
for idx, (group, color) in enumerate(zip(time_groups, colors)):
if not group['date'].isna().any(): # 只处理有时间戳的组
bbox = self._get_group_bbox(group)
x, y = bbox.exterior.xy
if idx in groups_to_delete:
# 被删除的组用虚线表示
plt.plot(x, y, '--', color=color, alpha=0.6,
label=f'Deleted Group {idx + 1}')
else:
# 保留的组用实线表示
plt.plot(x, y, '-', color=color, alpha=0.6,
label=f'Group {idx + 1}')
# 绘制该组的GPS点
plt.scatter(group['lon'], group['lat'], color=color,
s=30, alpha=0.6)
plt.title("Time Groups and Their Bounding Boxes", fontsize=14)
plt.xlabel("Longitude", fontsize=12)
plt.ylabel("Latitude", fontsize=12)
plt.grid(True)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
plt.tight_layout()
# 保存图片
plt.savefig(os.path.join(self.output_dir, 'time_groups_overlap.png'),
dpi=300, bbox_inches='tight')
plt.close()
if __name__ == '__main__':
# 设置路径
DATASET = r'F:\error_data\20241108134711\3D'
output_dir = r'E:\studio2\ODM_pro\test'
os.makedirs(output_dir, exist_ok=True)
# 设置日志
setup_logger(os.path.dirname(output_dir))
# 创建过滤器并执行过滤
filter = TimeGroupOverlapFilter(DATASET, output_dir, overlap_threshold=0.7)
deleted_files = filter.filter_overlapping_groups(time_threshold=timedelta(minutes=5))

View File

@ -1,3 +0,0 @@
2024-12-22 18:58:05 - UAV_Preprocess.GPSExtractor - INFO - 开始从目录提取GPS坐标和拍摄日期: F:\error_data\20241104140457\code\images
2024-12-22 18:58:22 - UAV_Preprocess.GPSExtractor - INFO - GPS坐标和拍摄日期提取完成 - 总图片数: 2708, 成功提取: 2708, 失败: 0
2024-12-22 18:58:22 - UAV_Preprocess.GPSVisualizer - INFO - 已生成包含 14 个时间组的组合可视化图形

View File

@ -1,3 +0,0 @@
2024-12-22 18:59:26 - UAV_Preprocess.GPSExtractor - INFO - 开始从目录提取GPS坐标和拍摄日期: F:\error_data\20241108134711\3D
2024-12-22 19:00:09 - UAV_Preprocess.GPSExtractor - INFO - GPS坐标和拍摄日期提取完成 - 总图片数: 6615, 成功提取: 6615, 失败: 0
2024-12-22 19:00:10 - UAV_Preprocess.GPSVisualizer - INFO - 已生成包含 8 个时间组的组合可视化图形

View File

@ -8,12 +8,13 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from tqdm import tqdm from tqdm import tqdm
from preprocess.cluster import GPSCluster from filter.cluster_filter import GPSCluster
from preprocess.command_runner import CommandRunner from utils.command_runner import CommandRunner
from preprocess.gps_extractor import GPSExtractor from utils.gps_extractor import GPSExtractor
from preprocess.gps_filter import GPSFilter from filter.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider from utils.grid_divider import GridDivider
from preprocess.logger import setup_logger from utils.logger import setup_logger
from filter.time_group_overlap_filter import TimeGroupOverlapFilter
@dataclass @dataclass
@ -25,12 +26,16 @@ class PreprocessConfig:
# 聚类过滤参数 # 聚类过滤参数
cluster_eps: float = 0.01 cluster_eps: float = 0.01
cluster_min_samples: int = 5 cluster_min_samples: int = 5
# 时间组重叠过滤参数
time_group_overlap_threshold: float = 0.7
time_group_interval: timedelta = timedelta(minutes=5)
enable_time_group_filter: bool = True
# 孤立点过滤参数 # 孤立点过滤参数
filter_distance_threshold: float = 0.001 # 经纬度距离 filter_distance_threshold: float = 0.001 # 经纬度距离
filter_min_neighbors: int = 6 filter_min_neighbors: int = 6
# 密集点过滤参数 # 密集点过滤参数
filter_grid_size: float = 0.001 filter_grid_size: float = 0.001
filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 filter_dense_distance_threshold: float = 10 # 普通距离,单位:米
filter_time_threshold: timedelta = timedelta(minutes=5) filter_time_threshold: timedelta = timedelta(minutes=5)
# 网格划分参数 # 网格划分参数
grid_overlap: float = 0.05 grid_overlap: float = 0.05
@ -42,12 +47,14 @@ class PreprocessConfig:
enable_copy_images: bool = True enable_copy_images: bool = True
mode: str = "快拼模式" mode: str = "快拼模式"
class ImagePreprocessor: class ImagePreprocessor:
def __init__(self, config: PreprocessConfig): def __init__(self, config: PreprocessConfig):
self.config = config self.config = config
self.logger = setup_logger(config.output_dir) self.logger = setup_logger(config.output_dir)
self.gps_points = [] self.gps_points = []
self.command_runner = CommandRunner(config.output_dir, mode=config.mode) self.command_runner = CommandRunner(
config.output_dir, mode=config.mode)
def extract_gps(self) -> pd.DataFrame: def extract_gps(self) -> pd.DataFrame:
"""提取GPS数据""" """提取GPS数据"""
@ -74,6 +81,29 @@ class ImagePreprocessor:
f"噪声点 {stats['noise_points']}" f"噪声点 {stats['noise_points']}"
) )
def filter_time_group_overlap(self) -> pd.DataFrame:
"""过滤重叠的时间组"""
if not self.config.enable_time_group_filter:
return self.gps_points
self.logger.info("开始过滤重叠时间组")
filter = TimeGroupOverlapFilter(
self.config.image_dir,
self.config.output_dir,
overlap_threshold=self.config.time_group_overlap_threshold
)
deleted_files = filter.filter_overlapping_groups(
time_threshold=self.config.time_group_interval
)
# 更新GPS点数据移除被删除的图像
self.gps_points = self.gps_points[~self.gps_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
# TODO 过滤算法还需要更新 # TODO 过滤算法还需要更新
def filter_points(self) -> pd.DataFrame: def filter_points(self) -> pd.DataFrame:
"""过滤GPS点""" """过滤GPS点"""
@ -200,15 +230,16 @@ class ImagePreprocessor:
try: try:
self.extract_gps() self.extract_gps()
self.cluster() self.cluster()
self.filter_points() self.filter_time_group_overlap()
grid_points = self.divide_grids() # self.filter_points()
self.copy_images(grid_points) # grid_points = self.divide_grids()
self.visualize_results() # self.copy_images(grid_points)
self.logger.info("预处理任务完成") # self.visualize_results()
self.command_runner.run_grid_commands( # self.logger.info("预处理任务完成")
grid_points, # self.command_runner.run_grid_commands(
self.config.enable_grid_division, # grid_points,
) # self.config.enable_grid_division,
# )
# TODO 拼图 # TODO 拼图
except Exception as e: except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
@ -218,12 +249,17 @@ class ImagePreprocessor:
if __name__ == "__main__": if __name__ == "__main__":
# 创建配置 # 创建配置
config = PreprocessConfig( config = PreprocessConfig(
image_dir=r"F:\error_data\20240930091614\project\images", image_dir=r"F:\error_data\20241016140912\code\images",
output_dir=r"F:\error_data\20240930091614\output", output_dir=r"G:\output",
cluster_eps=0.01, cluster_eps=0.01,
cluster_min_samples=5, cluster_min_samples=5,
# 添加时间组重叠过滤参数
time_group_overlap_threshold=0.7,
time_group_interval=timedelta(minutes=5),
enable_time_group_filter=True,
filter_distance_threshold=0.001, filter_distance_threshold=0.001,
filter_min_neighbors=6, filter_min_neighbors=6,
@ -239,7 +275,7 @@ if __name__ == "__main__":
enable_visualization=True, enable_visualization=True,
enable_copy_images=True, enable_copy_images=True,
mode="sadf模式", mode="快拼模式",
) )
# 创建处理器并执行 # 创建处理器并执行