Implement GPS clustering functionality and refactor odm_preprocess.py
- Added a new GPSCluster class for clustering GPS points using DBSCAN. - Introduced a clustering method in ImagePreprocessor to filter GPS points based on clustering results. - Refactored odm_preprocess.py for improved organization and readability, including reordering imports and enhancing logging messages. - Updated main execution block to include clustering in the preprocessing workflow.
This commit is contained in:
parent
e7cab3c120
commit
5cbc07ea53
@ -1,23 +1,25 @@
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from preprocess.cluster import GPSCluster
|
||||
from preprocess.command_runner import CommandRunner
|
||||
from preprocess.gps_extractor import GPSExtractor
|
||||
from preprocess.time_filter import TimeFilter
|
||||
from preprocess.gps_filter import GPSFilter
|
||||
from preprocess.grid_divider import GridDivider
|
||||
from preprocess.logger import setup_logger
|
||||
from preprocess.command_runner import CommandRunner
|
||||
import os
|
||||
import pandas as pd
|
||||
import shutil
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import List, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from tqdm import tqdm
|
||||
import subprocess
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from preprocess.time_filter import TimeFilter
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessConfig:
|
||||
"""预处理配置类"""
|
||||
|
||||
image_dir: str
|
||||
output_dir: str
|
||||
filter_grid_size: float = 0.001
|
||||
@ -55,6 +57,21 @@ class ImagePreprocessor:
|
||||
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
return self.gps_points
|
||||
|
||||
def cluster(self) -> pd.DataFrame:
|
||||
"""使用DBSCAN对GPS点进行聚类,只保留最大的类"""
|
||||
self.logger.info("开始聚类")
|
||||
# 创建聚类器并执行聚类
|
||||
clusterer = GPSCluster(self.gps_points, eps=0.01, min_samples=5)
|
||||
# 获取主要类别的点
|
||||
self.gps_points = clusterer.get_main_cluster()
|
||||
# 获取统计信息并记录
|
||||
stats = clusterer.get_cluster_stats()
|
||||
self.logger.info(
|
||||
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
|
||||
f"噪声点 {stats['noise_points']} 个"
|
||||
)
|
||||
return self.gps_points
|
||||
|
||||
# TODO 过滤密集点算法需要改进
|
||||
def filter_points(self) -> pd.DataFrame:
|
||||
"""过滤GPS点"""
|
||||
@ -65,20 +82,22 @@ class ImagePreprocessor:
|
||||
filter = GPSFilter(self.config.output_dir)
|
||||
|
||||
self.logger.info(
|
||||
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})")
|
||||
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
|
||||
)
|
||||
self.gps_points = filter.filter_isolated_points(
|
||||
self.gps_points,
|
||||
self.config.filter_distance_threshold,
|
||||
self.config.filter_min_neighbors
|
||||
self.config.filter_min_neighbors,
|
||||
)
|
||||
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
|
||||
self.logger.info(
|
||||
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})")
|
||||
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
|
||||
)
|
||||
self.gps_points = filter.filter_dense_points(
|
||||
self.gps_points,
|
||||
grid_size=self.config.filter_grid_size,
|
||||
distance_threshold=self.config.filter_dense_distance_threshold
|
||||
distance_threshold=self.config.filter_dense_distance_threshold,
|
||||
)
|
||||
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||
return self.gps_points
|
||||
@ -91,7 +110,8 @@ class ImagePreprocessor:
|
||||
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
|
||||
grid_divider = GridDivider(overlap=self.config.grid_overlap)
|
||||
grids = grid_divider.divide_grids(
|
||||
self.gps_points, grid_size=self.config.grid_size)
|
||||
self.gps_points, grid_size=self.config.grid_size
|
||||
)
|
||||
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
|
||||
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
|
||||
return grid_points
|
||||
@ -106,16 +126,16 @@ class ImagePreprocessor:
|
||||
for grid_idx, points in grid_points.items():
|
||||
if self.config.enable_grid_division:
|
||||
output_dir = os.path.join(
|
||||
self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images')
|
||||
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
|
||||
)
|
||||
else:
|
||||
output_dir = os.path.join(
|
||||
self.config.output_dir, 'project', 'images')
|
||||
output_dir = os.path.join(self.config.output_dir, "project", "images")
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
|
||||
src = os.path.join(self.config.image_dir, point['file'])
|
||||
dst = os.path.join(output_dir, point['file'])
|
||||
src = os.path.join(self.config.image_dir, point["file"])
|
||||
dst = os.path.join(output_dir, point["file"])
|
||||
shutil.copy(src, dst)
|
||||
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
|
||||
|
||||
@ -129,27 +149,34 @@ class ImagePreprocessor:
|
||||
original_points_df = extractor.extract_all_gps()
|
||||
|
||||
# 读取被过滤的图片列表
|
||||
with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file:
|
||||
with open(
|
||||
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
|
||||
) as file:
|
||||
filtered_files = [line.strip() for line in file if line.strip()]
|
||||
|
||||
# 创建一个新的图形
|
||||
plt.figure(figsize=(20, 16))
|
||||
|
||||
# 绘制所有原始点
|
||||
plt.scatter(original_points_df['lon'],
|
||||
original_points_df['lat'],
|
||||
color='blue',
|
||||
label="Original Points",
|
||||
alpha=0.6)
|
||||
plt.scatter(
|
||||
original_points_df["lon"],
|
||||
original_points_df["lat"],
|
||||
color="blue",
|
||||
label="Original Points",
|
||||
alpha=0.6,
|
||||
)
|
||||
|
||||
# 绘制被过滤的点
|
||||
filtered_points_df = original_points_df[original_points_df['file'].isin(
|
||||
filtered_files)]
|
||||
plt.scatter(filtered_points_df['lon'],
|
||||
filtered_points_df['lat'],
|
||||
color="red",
|
||||
label="Filtered Points",
|
||||
alpha=0.6)
|
||||
filtered_points_df = original_points_df[
|
||||
original_points_df["file"].isin(filtered_files)
|
||||
]
|
||||
plt.scatter(
|
||||
filtered_points_df["lon"],
|
||||
filtered_points_df["lat"],
|
||||
color="red",
|
||||
label="Filtered Points",
|
||||
alpha=0.6,
|
||||
)
|
||||
|
||||
# 设置图形属性
|
||||
plt.title("GPS Coordinates of Images", fontsize=14)
|
||||
@ -159,7 +186,7 @@ class ImagePreprocessor:
|
||||
plt.legend()
|
||||
|
||||
# 保存图形
|
||||
plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png'))
|
||||
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
|
||||
plt.close()
|
||||
self.logger.info("预处理结果图已保存")
|
||||
|
||||
@ -167,26 +194,27 @@ class ImagePreprocessor:
|
||||
"""执行完整的预处理流程"""
|
||||
try:
|
||||
self.extract_gps()
|
||||
self.time_filter()
|
||||
self.filter_points()
|
||||
grid_points = self.divide_grids()
|
||||
self.copy_images(grid_points)
|
||||
self.visualize_results()
|
||||
self.logger.info("预处理任务完成")
|
||||
self.command_runner.run_grid_commands(
|
||||
grid_points,
|
||||
self.config.enable_grid_division
|
||||
)
|
||||
self.cluster()
|
||||
# self.time_filter()
|
||||
# self.filter_points()
|
||||
# grid_points = self.divide_grids()
|
||||
# self.copy_images(grid_points)
|
||||
# self.visualize_results()
|
||||
# self.logger.info("预处理任务完成")
|
||||
# self.command_runner.run_grid_commands(
|
||||
# grid_points,
|
||||
# self.config.enable_grid_division
|
||||
# )
|
||||
except Exception as e:
|
||||
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# 创建配置
|
||||
config = PreprocessConfig(
|
||||
image_dir=r'C:\datasets\1815\images',
|
||||
output_dir=r'C:\datasets\1815\output',
|
||||
image_dir=r"../code/images",
|
||||
output_dir=r"../code/output",
|
||||
filter_grid_size=0.001,
|
||||
filter_dense_distance_threshold=10,
|
||||
filter_distance_threshold=0.001,
|
||||
@ -195,7 +223,7 @@ if __name__ == '__main__':
|
||||
enable_filter=True,
|
||||
enable_grid_division=True,
|
||||
enable_visualization=True,
|
||||
enable_copy_images=True
|
||||
enable_copy_images=True,
|
||||
)
|
||||
|
||||
# 创建处理器并执行
|
||||
|
80
preprocess/cluster.py
Normal file
80
preprocess/cluster.py
Normal file
@ -0,0 +1,80 @@
|
||||
from sklearn.cluster import DBSCAN
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
|
||||
class GPSCluster:
|
||||
def __init__(self, gps_points, eps=0.01, min_samples=5):
|
||||
"""
|
||||
初始化GPS聚类器
|
||||
|
||||
参数:
|
||||
eps: DBSCAN的邻域半径参数
|
||||
min_samples: DBSCAN的最小样本数参数
|
||||
"""
|
||||
self.eps = eps
|
||||
self.min_samples = min_samples
|
||||
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
||||
self.scaler = StandardScaler()
|
||||
self.gps_points = gps_points
|
||||
self.clustered_points = self.fit()
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
对GPS点进行聚类,只保留最大的类
|
||||
|
||||
参数:
|
||||
gps_points: 包含'lat'和'lon'列的DataFrame
|
||||
|
||||
返回:
|
||||
带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1
|
||||
"""
|
||||
# 提取经纬度数据
|
||||
X = self.gps_points[["lon", "lat"]].values
|
||||
|
||||
# # 数据标准化
|
||||
# X_scaled = self.scaler.fit_transform(X)
|
||||
|
||||
# 执行DBSCAN聚类
|
||||
labels = self.dbscan.fit_predict(X)
|
||||
|
||||
# 找出最大类的标签(排除噪声点-1)
|
||||
unique_labels = [l for l in set(labels) if l != -1]
|
||||
if unique_labels: # 如果有聚类
|
||||
label_counts = [(l, sum(labels == l)) for l in unique_labels]
|
||||
largest_label = max(label_counts, key=lambda x: x[1])[0]
|
||||
|
||||
# 将最大类标记为1,其他都标记为-1
|
||||
new_labels = (labels == largest_label).astype(int)
|
||||
new_labels[new_labels == 0] = -1
|
||||
else: # 如果没有聚类,全部标记为-1
|
||||
new_labels = labels
|
||||
|
||||
# 将聚类结果添加到原始数据中
|
||||
result_df = self.gps_points.copy()
|
||||
result_df["cluster"] = new_labels
|
||||
|
||||
return result_df
|
||||
|
||||
def get_cluster_stats(self):
|
||||
"""
|
||||
获取聚类统计信息
|
||||
|
||||
参数:
|
||||
clustered_points: 带有聚类标签的DataFrame
|
||||
|
||||
返回:
|
||||
聚类统计信息的字典
|
||||
"""
|
||||
main_cluster_points = sum(self.clustered_points["cluster"] == 1)
|
||||
stats = {
|
||||
"total_points": len(self.clustered_points),
|
||||
"main_cluster_points": main_cluster_points,
|
||||
"noise_points": sum(self.clustered_points["cluster"] == -1),
|
||||
}
|
||||
return stats
|
||||
|
||||
def get_main_cluster(self):
|
||||
return self.clustered_points[self.clustered_points["cluster"] == 1]
|
||||
|
||||
def get_noise_cluster(self):
|
||||
return self.clustered_points[self.clustered_points["cluster"] == -1]
|
Loading…
Reference in New Issue
Block a user