UAV/filter/cluster_filter.py

78 lines
2.4 KiB
Python
Raw Normal View History

2024-12-23 11:31:20 +08:00
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import os
2025-01-06 17:18:03 +08:00
import logging
2024-12-23 11:31:20 +08:00
class GPSCluster:
2025-01-06 17:18:03 +08:00
def __init__(self, gps_points, eps=0.01, min_samples=3):
2024-12-23 11:31:20 +08:00
"""
初始化GPS聚类器
参数:
eps: DBSCAN的邻域半径参数
min_samples: DBSCAN的最小样本数参数
"""
self.eps = eps
self.min_samples = min_samples
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler()
self.gps_points = gps_points
2025-01-06 17:18:03 +08:00
self.logger = logging.getLogger('UAV_Preprocess.GPSCluster')
2024-12-23 11:31:20 +08:00
def fit(self):
"""
对GPS点进行聚类只保留最大的类
参数:
gps_points: 包含'lat''lon'列的DataFrame
返回:
带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1
"""
2025-01-06 17:18:03 +08:00
self.logger.info("开始聚类")
2024-12-23 11:31:20 +08:00
# 提取经纬度数据
X = self.gps_points[["lon", "lat"]].values
# # 数据标准化
# X_scaled = self.scaler.fit_transform(X)
# 执行DBSCAN聚类
labels = self.dbscan.fit_predict(X)
# 找出最大类的标签(排除噪声点-1
unique_labels = [l for l in set(labels) if l != -1]
if unique_labels: # 如果有聚类
label_counts = [(l, sum(labels == l)) for l in unique_labels]
largest_label = max(label_counts, key=lambda x: x[1])[0]
# 将最大类标记为1其他都标记为-1
new_labels = (labels == largest_label).astype(int)
new_labels[new_labels == 0] = -1
else: # 如果没有聚类,全部标记为-1
new_labels = labels
# 将聚类结果添加到原始数据中
result_df = self.gps_points.copy()
result_df["cluster"] = new_labels
return result_df
def get_cluster_stats(self, clustered_points):
"""
获取聚类统计信息
参数:
clustered_points: 带有聚类标签的DataFrame
返回:
聚类统计信息的字典
"""
2025-01-06 17:18:03 +08:00
main_cluster = clustered_points[clustered_points["cluster"] == 1]
noise_cluster = clustered_points[clustered_points["cluster"] == -1]
2024-12-23 11:31:20 +08:00
2025-01-06 17:18:03 +08:00
self.logger.info(f"聚类完成:主要类别包含 {len(main_cluster)} 个点,"
f"噪声点 {len(noise_cluster)}")
2024-12-23 11:31:20 +08:00
2025-01-06 17:18:03 +08:00
return main_cluster