format code

This commit is contained in:
龙澳 2025-01-06 17:18:03 +08:00
parent 8540565f0e
commit 9935d8f789
3 changed files with 23 additions and 49 deletions

View File

@ -1,10 +1,11 @@
from sklearn.cluster import DBSCAN from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler from sklearn.preprocessing import StandardScaler
import os import os
import logging
class GPSCluster: class GPSCluster:
def __init__(self, gps_points, output_dir: str, eps=0.01, min_samples=5): def __init__(self, gps_points, eps=0.01, min_samples=3):
""" """
初始化GPS聚类器 初始化GPS聚类器
@ -17,6 +18,7 @@ class GPSCluster:
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler() self.scaler = StandardScaler()
self.gps_points = gps_points self.gps_points = gps_points
self.logger = logging.getLogger('UAV_Preprocess.GPSCluster')
def fit(self): def fit(self):
""" """
@ -28,6 +30,7 @@ class GPSCluster:
返回: 返回:
带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1 带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1
""" """
self.logger.info("开始聚类")
# 提取经纬度数据 # 提取经纬度数据
X = self.gps_points[["lon", "lat"]].values X = self.gps_points[["lon", "lat"]].values
@ -65,18 +68,10 @@ class GPSCluster:
返回: 返回:
聚类统计信息的字典 聚类统计信息的字典
""" """
main_cluster_points = sum(clustered_points["cluster"] == 1) main_cluster = clustered_points[clustered_points["cluster"] == 1]
stats = { noise_cluster = clustered_points[clustered_points["cluster"] == -1]
"total_points": len(clustered_points),
"main_cluster_points": main_cluster_points,
"noise_points": sum(clustered_points["cluster"] == -1),
}
noise_cluster = self.get_noise_cluster(clustered_points) self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点,"
return stats f"噪声点 {len(noise_cluster)}")
def get_main_cluster(self, clustered_points): return main_cluster
return clustered_points[clustered_points["cluster"] == 1]
def get_noise_cluster(self, clustered_points):
return clustered_points[clustered_points["cluster"] == -1]

View File

@ -97,6 +97,7 @@ class TimeGroupOverlapFilter:
def filter_overlapping_groups(self, gps_points, time_threshold=timedelta(minutes=5)): def filter_overlapping_groups(self, gps_points, time_threshold=timedelta(minutes=5)):
"""过滤重叠的时间组""" """过滤重叠的时间组"""
# 按时间分组 # 按时间分组
self.logger.info("开始过滤重叠时间组")
time_groups = self._group_by_time(gps_points, time_threshold) time_groups = self._group_by_time(gps_points, time_threshold)
# 计算每个组的边界框 # 计算每个组的边界框
@ -156,7 +157,10 @@ class TimeGroupOverlapFilter:
# 可视化结果 # 可视化结果
self._visualize_results(time_groups, groups_to_delete) self._visualize_results(time_groups, groups_to_delete)
return deleted_files retained_points = gps_points[~gps_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(retained_points)} 个GPS点")
return retained_points
def _visualize_results(self, time_groups, groups_to_delete): def _visualize_results(self, time_groups, groups_to_delete):
"""可视化过滤结果""" """可视化过滤结果"""

View File

@ -3,7 +3,7 @@ import shutil
from datetime import timedelta from datetime import timedelta
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple from typing import Dict, Tuple
import psutil # 需要添加到 requirements.txt import psutil
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
@ -143,52 +143,35 @@ class ImagePreprocessor:
def cluster(self): def cluster(self):
"""使用DBSCAN对GPS点进行聚类只保留最大的类""" """使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类")
previous_points = self.gps_points.copy() previous_points = self.gps_points.copy()
# 创建聚类器并执行聚类
clusterer = GPSCluster( clusterer = GPSCluster(
self.gps_points, output_dir=self.config.output_dir, self.gps_points,
eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) eps=self.config.cluster_eps,
min_samples=self.config.cluster_min_samples
# 获取主要类别的点
self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_main_cluster(self.clustered_points)
# 获取统计信息并记录
stats = clusterer.get_cluster_stats(self.clustered_points)
self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}"
) )
# 可视化聚类结果 self.clustered_points = clusterer.fit()
self.gps_points = clusterer.get_cluster_stats(self.clustered_points)
self.visualizer.visualize_filter_step( self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "1-Clustering") self.gps_points, previous_points, "1-Clustering")
def filter_isolated_points(self): def filter_isolated_points(self):
"""过滤孤立点""" """过滤孤立点"""
self.logger.info("开始过滤孤立点")
filter = GPSFilter(self.config.output_dir) filter = GPSFilter(self.config.output_dir)
previous_points = self.gps_points.copy() previous_points = self.gps_points.copy()
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, "
f"最小邻居数: {self.config.filter_min_neighbors})"
)
self.gps_points = filter.filter_isolated_points( self.gps_points = filter.filter_isolated_points(
self.gps_points, self.gps_points,
self.config.filter_distance_threshold, self.config.filter_distance_threshold,
self.config.filter_min_neighbors, self.config.filter_min_neighbors,
) )
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
# 可视化孤立点过滤结果
self.visualizer.visualize_filter_step( self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "2-Isolated Points") self.gps_points, previous_points, "2-Isolated Points")
def filter_time_group_overlap(self): def filter_time_group_overlap(self):
"""过滤重叠的时间组""" """过滤重叠的时间组"""
self.logger.info("开始过滤重叠时间组")
previous_points = self.gps_points.copy() previous_points = self.gps_points.copy()
filter = TimeGroupOverlapFilter( filter = TimeGroupOverlapFilter(
@ -197,17 +180,11 @@ class ImagePreprocessor:
overlap_threshold=self.config.time_group_overlap_threshold overlap_threshold=self.config.time_group_overlap_threshold
) )
deleted_files = filter.filter_overlapping_groups( self.gps_points = filter.filter_overlapping_groups(
self.gps_points, self.gps_points,
time_threshold=self.config.time_group_interval time_threshold=self.config.time_group_interval
) )
# 更新GPS点数据移除被删除的图像
self.gps_points = self.gps_points[~self.gps_points['file'].isin(
deleted_files)]
self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点")
# 可视化过滤结果
self.visualizer.visualize_filter_step( self.visualizer.visualize_filter_step(
self.gps_points, previous_points, "3-Time Group Overlap") self.gps_points, previous_points, "3-Time Group Overlap")
@ -218,7 +195,6 @@ class ImagePreprocessor:
- grid_points: 网格点数据字典 - grid_points: 网格点数据字典
- translations: 网格平移量字典 - translations: 网格平移量字典
""" """
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider( grid_divider = GridDivider(
overlap=self.config.grid_overlap, overlap=self.config.grid_overlap,
grid_size=self.config.grid_size, grid_size=self.config.grid_size,
@ -227,7 +203,6 @@ class ImagePreprocessor:
grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap( grids, translations, grid_points = grid_divider.adjust_grid_size_and_overlap(
self.gps_points self.gps_points
) )
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
grid_divider.visualize_grids(self.gps_points, grids) grid_divider.visualize_grids(self.gps_points, grids)
return grid_points, translations return grid_points, translations