Implement GPS clustering functionality and refactor odm_preprocess.py

- Added a new GPSCluster class for clustering GPS points using DBSCAN.
- Introduced a clustering method in ImagePreprocessor to filter GPS points based on clustering results.
- Refactored odm_preprocess.py for improved organization and readability, including reordering imports and enhancing logging messages.
- Updated main execution block to include clustering in the preprocessing workflow.
This commit is contained in:
欧阳植斌 2024-12-19 20:48:53 +08:00
parent e7cab3c120
commit 5cbc07ea53
2 changed files with 157 additions and 49 deletions

View File

@ -1,23 +1,25 @@
import os
import shutil
from dataclasses import dataclass
from typing import Dict
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from preprocess.cluster import GPSCluster
from preprocess.command_runner import CommandRunner
from preprocess.gps_extractor import GPSExtractor
from preprocess.time_filter import TimeFilter
from preprocess.gps_filter import GPSFilter
from preprocess.grid_divider import GridDivider
from preprocess.logger import setup_logger
from preprocess.command_runner import CommandRunner
import os
import pandas as pd
import shutil
import matplotlib.pyplot as plt
from typing import List, Dict, Optional
from dataclasses import dataclass
from tqdm import tqdm
import subprocess
from concurrent.futures import ThreadPoolExecutor
from preprocess.time_filter import TimeFilter
@dataclass
class PreprocessConfig:
"""预处理配置类"""
image_dir: str
output_dir: str
filter_grid_size: float = 0.001
@ -55,6 +57,21 @@ class ImagePreprocessor:
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
def cluster(self) -> pd.DataFrame:
"""使用DBSCAN对GPS点进行聚类只保留最大的类"""
self.logger.info("开始聚类")
# 创建聚类器并执行聚类
clusterer = GPSCluster(self.gps_points, eps=0.01, min_samples=5)
# 获取主要类别的点
self.gps_points = clusterer.get_main_cluster()
# 获取统计信息并记录
stats = clusterer.get_cluster_stats()
self.logger.info(
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
f"噪声点 {stats['noise_points']}"
)
return self.gps_points
# TODO 过滤密集点算法需要改进
def filter_points(self) -> pd.DataFrame:
"""过滤GPS点"""
@ -65,20 +82,22 @@ class ImagePreprocessor:
filter = GPSFilter(self.config.output_dir)
self.logger.info(
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})")
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
)
self.gps_points = filter.filter_isolated_points(
self.gps_points,
self.config.filter_distance_threshold,
self.config.filter_min_neighbors
self.config.filter_min_neighbors,
)
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
self.logger.info(
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})")
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
)
self.gps_points = filter.filter_dense_points(
self.gps_points,
grid_size=self.config.filter_grid_size,
distance_threshold=self.config.filter_dense_distance_threshold
distance_threshold=self.config.filter_dense_distance_threshold,
)
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
return self.gps_points
@ -91,7 +110,8 @@ class ImagePreprocessor:
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
grid_divider = GridDivider(overlap=self.config.grid_overlap)
grids = grid_divider.divide_grids(
self.gps_points, grid_size=self.config.grid_size)
self.gps_points, grid_size=self.config.grid_size
)
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
return grid_points
@ -106,16 +126,16 @@ class ImagePreprocessor:
for grid_idx, points in grid_points.items():
if self.config.enable_grid_division:
output_dir = os.path.join(
self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images')
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
)
else:
output_dir = os.path.join(
self.config.output_dir, 'project', 'images')
output_dir = os.path.join(self.config.output_dir, "project", "images")
os.makedirs(output_dir, exist_ok=True)
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
src = os.path.join(self.config.image_dir, point['file'])
dst = os.path.join(output_dir, point['file'])
src = os.path.join(self.config.image_dir, point["file"])
dst = os.path.join(output_dir, point["file"])
shutil.copy(src, dst)
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
@ -129,27 +149,34 @@ class ImagePreprocessor:
original_points_df = extractor.extract_all_gps()
# 读取被过滤的图片列表
with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file:
with open(
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
) as file:
filtered_files = [line.strip() for line in file if line.strip()]
# 创建一个新的图形
plt.figure(figsize=(20, 16))
# 绘制所有原始点
plt.scatter(original_points_df['lon'],
original_points_df['lat'],
color='blue',
plt.scatter(
original_points_df["lon"],
original_points_df["lat"],
color="blue",
label="Original Points",
alpha=0.6)
alpha=0.6,
)
# 绘制被过滤的点
filtered_points_df = original_points_df[original_points_df['file'].isin(
filtered_files)]
plt.scatter(filtered_points_df['lon'],
filtered_points_df['lat'],
filtered_points_df = original_points_df[
original_points_df["file"].isin(filtered_files)
]
plt.scatter(
filtered_points_df["lon"],
filtered_points_df["lat"],
color="red",
label="Filtered Points",
alpha=0.6)
alpha=0.6,
)
# 设置图形属性
plt.title("GPS Coordinates of Images", fontsize=14)
@ -159,7 +186,7 @@ class ImagePreprocessor:
plt.legend()
# 保存图形
plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png'))
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
plt.close()
self.logger.info("预处理结果图已保存")
@ -167,26 +194,27 @@ class ImagePreprocessor:
"""执行完整的预处理流程"""
try:
self.extract_gps()
self.time_filter()
self.filter_points()
grid_points = self.divide_grids()
self.copy_images(grid_points)
self.visualize_results()
self.logger.info("预处理任务完成")
self.command_runner.run_grid_commands(
grid_points,
self.config.enable_grid_division
)
self.cluster()
# self.time_filter()
# self.filter_points()
# grid_points = self.divide_grids()
# self.copy_images(grid_points)
# self.visualize_results()
# self.logger.info("预处理任务完成")
# self.command_runner.run_grid_commands(
# grid_points,
# self.config.enable_grid_division
# )
except Exception as e:
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
raise
if __name__ == '__main__':
if __name__ == "__main__":
# 创建配置
config = PreprocessConfig(
image_dir=r'C:\datasets\1815\images',
output_dir=r'C:\datasets\1815\output',
image_dir=r"../code/images",
output_dir=r"../code/output",
filter_grid_size=0.001,
filter_dense_distance_threshold=10,
filter_distance_threshold=0.001,
@ -195,7 +223,7 @@ if __name__ == '__main__':
enable_filter=True,
enable_grid_division=True,
enable_visualization=True,
enable_copy_images=True
enable_copy_images=True,
)
# 创建处理器并执行

80
preprocess/cluster.py Normal file
View File

@ -0,0 +1,80 @@
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler
class GPSCluster:
def __init__(self, gps_points, eps=0.01, min_samples=5):
"""
初始化GPS聚类器
参数:
eps: DBSCAN的邻域半径参数
min_samples: DBSCAN的最小样本数参数
"""
self.eps = eps
self.min_samples = min_samples
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
self.scaler = StandardScaler()
self.gps_points = gps_points
self.clustered_points = self.fit()
def fit(self):
"""
对GPS点进行聚类只保留最大的类
参数:
gps_points: 包含'lat''lon'列的DataFrame
返回:
带有聚类标签的DataFrame其中最大类标记为1其他点标记为-1
"""
# 提取经纬度数据
X = self.gps_points[["lon", "lat"]].values
# # 数据标准化
# X_scaled = self.scaler.fit_transform(X)
# 执行DBSCAN聚类
labels = self.dbscan.fit_predict(X)
# 找出最大类的标签(排除噪声点-1
unique_labels = [l for l in set(labels) if l != -1]
if unique_labels: # 如果有聚类
label_counts = [(l, sum(labels == l)) for l in unique_labels]
largest_label = max(label_counts, key=lambda x: x[1])[0]
# 将最大类标记为1其他都标记为-1
new_labels = (labels == largest_label).astype(int)
new_labels[new_labels == 0] = -1
else: # 如果没有聚类,全部标记为-1
new_labels = labels
# 将聚类结果添加到原始数据中
result_df = self.gps_points.copy()
result_df["cluster"] = new_labels
return result_df
def get_cluster_stats(self):
"""
获取聚类统计信息
参数:
clustered_points: 带有聚类标签的DataFrame
返回:
聚类统计信息的字典
"""
main_cluster_points = sum(self.clustered_points["cluster"] == 1)
stats = {
"total_points": len(self.clustered_points),
"main_cluster_points": main_cluster_points,
"noise_points": sum(self.clustered_points["cluster"] == -1),
}
return stats
def get_main_cluster(self):
return self.clustered_points[self.clustered_points["cluster"] == 1]
def get_noise_cluster(self):
return self.clustered_points[self.clustered_points["cluster"] == -1]