ODM_pro/preprocess/cluster.py
2024-12-21 12:03:54 +08:00

88 lines
2.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
import os
class GPSCluster:
def __init__(self, gps_points, output_dir: str, eps=0.01, min_samples=5):
"""
初始化GPS聚类器
参数:
eps: DBSCAN的邻域半径参数
min_samples: DBSCAN的最小样本数参数
"""
self.eps = eps
self.min_samples = min_samples
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler()
self.gps_points = gps_points
self.log_file = os.path.join(output_dir, 'del_imgs.txt')
def fit(self):
"""
对GPS点进行聚类只保留最大的类
参数:
gps_points: 包含'lat''lon'列的DataFrame
返回:
带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1
"""
# 提取经纬度数据
X = self.gps_points[["lon", "lat"]].values
# # 数据标准化
# X_scaled = self.scaler.fit_transform(X)
# 执行DBSCAN聚类
labels = self.dbscan.fit_predict(X)
# 找出最大类的标签(排除噪声点-1
unique_labels = [l for l in set(labels) if l != -1]
if unique_labels: # 如果有聚类
label_counts = [(l, sum(labels == l)) for l in unique_labels]
largest_label = max(label_counts, key=lambda x: x[1])[0]
# 将最大类标记为1其他都标记为-1
new_labels = (labels == largest_label).astype(int)
new_labels[new_labels == 0] = -1
else: # 如果没有聚类,全部标记为-1
new_labels = labels
# 将聚类结果添加到原始数据中
result_df = self.gps_points.copy()
result_df["cluster"] = new_labels
return result_df
def get_cluster_stats(self, clustered_points):
"""
获取聚类统计信息
参数:
clustered_points: 带有聚类标签的DataFrame
返回:
聚类统计信息的字典
"""
main_cluster_points = sum(clustered_points["cluster"] == 1)
stats = {
"total_points": len(clustered_points),
"main_cluster_points": main_cluster_points,
"noise_points": sum(clustered_points["cluster"] == -1),
}
noise_cluster = self.get_noise_cluster(clustered_points)
with open(self.log_file, 'a', encoding='utf-8') as f:
for i, (_, row) in enumerate(noise_cluster.iterrows()):
f.write(row['file']+'\n')
f.write('\n')
return stats
def get_main_cluster(self, clustered_points):
return clustered_points[clustered_points["cluster"] == 1]
def get_noise_cluster(self, clustered_points):
return clustered_points[clustered_points["cluster"] == -1]