2024-12-19 20:48:53 +08:00
|
|
|
|
from sklearn.cluster import DBSCAN
|
|
|
|
|
from sklearn.preprocessing import StandardScaler
|
2024-12-20 20:57:01 +08:00
|
|
|
|
import os
|
2024-12-19 20:48:53 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GPSCluster:
|
2024-12-20 20:57:01 +08:00
|
|
|
|
def __init__(self, gps_points, output_dir: str, eps=0.01, min_samples=5):
|
2024-12-19 20:48:53 +08:00
|
|
|
|
"""
|
|
|
|
|
初始化GPS聚类器
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
eps: DBSCAN的邻域半径参数
|
|
|
|
|
min_samples: DBSCAN的最小样本数参数
|
|
|
|
|
"""
|
|
|
|
|
self.eps = eps
|
|
|
|
|
self.min_samples = min_samples
|
|
|
|
|
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
|
|
|
|
self.scaler = StandardScaler()
|
|
|
|
|
self.gps_points = gps_points
|
2024-12-20 20:57:01 +08:00
|
|
|
|
self.log_file = os.path.join(output_dir, 'del_imgs.txt')
|
2024-12-19 20:48:53 +08:00
|
|
|
|
|
|
|
|
|
def fit(self):
|
|
|
|
|
"""
|
|
|
|
|
对GPS点进行聚类,只保留最大的类
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
gps_points: 包含'lat'和'lon'列的DataFrame
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1
|
|
|
|
|
"""
|
|
|
|
|
# 提取经纬度数据
|
|
|
|
|
X = self.gps_points[["lon", "lat"]].values
|
|
|
|
|
|
|
|
|
|
# # 数据标准化
|
|
|
|
|
# X_scaled = self.scaler.fit_transform(X)
|
|
|
|
|
|
|
|
|
|
# 执行DBSCAN聚类
|
|
|
|
|
labels = self.dbscan.fit_predict(X)
|
|
|
|
|
|
|
|
|
|
# 找出最大类的标签(排除噪声点-1)
|
|
|
|
|
unique_labels = [l for l in set(labels) if l != -1]
|
|
|
|
|
if unique_labels: # 如果有聚类
|
|
|
|
|
label_counts = [(l, sum(labels == l)) for l in unique_labels]
|
|
|
|
|
largest_label = max(label_counts, key=lambda x: x[1])[0]
|
|
|
|
|
|
|
|
|
|
# 将最大类标记为1,其他都标记为-1
|
|
|
|
|
new_labels = (labels == largest_label).astype(int)
|
|
|
|
|
new_labels[new_labels == 0] = -1
|
|
|
|
|
else: # 如果没有聚类,全部标记为-1
|
|
|
|
|
new_labels = labels
|
|
|
|
|
|
|
|
|
|
# 将聚类结果添加到原始数据中
|
|
|
|
|
result_df = self.gps_points.copy()
|
|
|
|
|
result_df["cluster"] = new_labels
|
|
|
|
|
|
|
|
|
|
return result_df
|
|
|
|
|
|
2024-12-21 12:03:54 +08:00
|
|
|
|
def get_cluster_stats(self, clustered_points):
|
2024-12-19 20:48:53 +08:00
|
|
|
|
"""
|
|
|
|
|
获取聚类统计信息
|
|
|
|
|
|
|
|
|
|
参数:
|
|
|
|
|
clustered_points: 带有聚类标签的DataFrame
|
|
|
|
|
|
|
|
|
|
返回:
|
|
|
|
|
聚类统计信息的字典
|
|
|
|
|
"""
|
2024-12-21 12:03:54 +08:00
|
|
|
|
main_cluster_points = sum(clustered_points["cluster"] == 1)
|
2024-12-19 20:48:53 +08:00
|
|
|
|
stats = {
|
2024-12-21 12:03:54 +08:00
|
|
|
|
"total_points": len(clustered_points),
|
2024-12-19 20:48:53 +08:00
|
|
|
|
"main_cluster_points": main_cluster_points,
|
2024-12-21 12:03:54 +08:00
|
|
|
|
"noise_points": sum(clustered_points["cluster"] == -1),
|
2024-12-19 20:48:53 +08:00
|
|
|
|
}
|
2024-12-20 20:57:01 +08:00
|
|
|
|
|
2024-12-21 12:03:54 +08:00
|
|
|
|
noise_cluster = self.get_noise_cluster(clustered_points)
|
2024-12-20 20:57:01 +08:00
|
|
|
|
with open(self.log_file, 'a', encoding='utf-8') as f:
|
|
|
|
|
for i, (_, row) in enumerate(noise_cluster.iterrows()):
|
|
|
|
|
f.write(row['file']+'\n')
|
|
|
|
|
f.write('\n')
|
2024-12-19 20:48:53 +08:00
|
|
|
|
return stats
|
|
|
|
|
|
2024-12-21 12:03:54 +08:00
|
|
|
|
def get_main_cluster(self, clustered_points):
|
|
|
|
|
return clustered_points[clustered_points["cluster"] == 1]
|
2024-12-19 20:48:53 +08:00
|
|
|
|
|
2024-12-21 12:03:54 +08:00
|
|
|
|
def get_noise_cluster(self, clustered_points):
|
|
|
|
|
return clustered_points[clustered_points["cluster"] == -1]
|