Implement GPS clustering functionality and refactor odm_preprocess.py
- Added a new GPSCluster class for clustering GPS points using DBSCAN. - Introduced a clustering method in ImagePreprocessor to filter GPS points based on clustering results. - Refactored odm_preprocess.py for improved organization and readability, including reordering imports and enhancing logging messages. - Updated main execution block to include clustering in the preprocessing workflow.
This commit is contained in:
parent
e7cab3c120
commit
5cbc07ea53
@ -1,23 +1,25 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from preprocess.cluster import GPSCluster
|
||||||
|
from preprocess.command_runner import CommandRunner
|
||||||
from preprocess.gps_extractor import GPSExtractor
|
from preprocess.gps_extractor import GPSExtractor
|
||||||
from preprocess.time_filter import TimeFilter
|
|
||||||
from preprocess.gps_filter import GPSFilter
|
from preprocess.gps_filter import GPSFilter
|
||||||
from preprocess.grid_divider import GridDivider
|
from preprocess.grid_divider import GridDivider
|
||||||
from preprocess.logger import setup_logger
|
from preprocess.logger import setup_logger
|
||||||
from preprocess.command_runner import CommandRunner
|
from preprocess.time_filter import TimeFilter
|
||||||
import os
|
|
||||||
import pandas as pd
|
|
||||||
import shutil
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from typing import List, Dict, Optional
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from tqdm import tqdm
|
|
||||||
import subprocess
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PreprocessConfig:
|
class PreprocessConfig:
|
||||||
"""预处理配置类"""
|
"""预处理配置类"""
|
||||||
|
|
||||||
image_dir: str
|
image_dir: str
|
||||||
output_dir: str
|
output_dir: str
|
||||||
filter_grid_size: float = 0.001
|
filter_grid_size: float = 0.001
|
||||||
@ -55,6 +57,21 @@ class ImagePreprocessor:
|
|||||||
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
|
self.logger.info(f"时间过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||||
return self.gps_points
|
return self.gps_points
|
||||||
|
|
||||||
|
def cluster(self) -> pd.DataFrame:
|
||||||
|
"""使用DBSCAN对GPS点进行聚类,只保留最大的类"""
|
||||||
|
self.logger.info("开始聚类")
|
||||||
|
# 创建聚类器并执行聚类
|
||||||
|
clusterer = GPSCluster(self.gps_points, eps=0.01, min_samples=5)
|
||||||
|
# 获取主要类别的点
|
||||||
|
self.gps_points = clusterer.get_main_cluster()
|
||||||
|
# 获取统计信息并记录
|
||||||
|
stats = clusterer.get_cluster_stats()
|
||||||
|
self.logger.info(
|
||||||
|
f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点,"
|
||||||
|
f"噪声点 {stats['noise_points']} 个"
|
||||||
|
)
|
||||||
|
return self.gps_points
|
||||||
|
|
||||||
# TODO 过滤密集点算法需要改进
|
# TODO 过滤密集点算法需要改进
|
||||||
def filter_points(self) -> pd.DataFrame:
|
def filter_points(self) -> pd.DataFrame:
|
||||||
"""过滤GPS点"""
|
"""过滤GPS点"""
|
||||||
@ -65,20 +82,22 @@ class ImagePreprocessor:
|
|||||||
filter = GPSFilter(self.config.output_dir)
|
filter = GPSFilter(self.config.output_dir)
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})")
|
f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, 最小邻居数: {self.config.filter_min_neighbors})"
|
||||||
|
)
|
||||||
self.gps_points = filter.filter_isolated_points(
|
self.gps_points = filter.filter_isolated_points(
|
||||||
self.gps_points,
|
self.gps_points,
|
||||||
self.config.filter_distance_threshold,
|
self.config.filter_distance_threshold,
|
||||||
self.config.filter_min_neighbors
|
self.config.filter_min_neighbors,
|
||||||
)
|
)
|
||||||
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})")
|
f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, 距离阈值: {self.config.filter_dense_distance_threshold})"
|
||||||
|
)
|
||||||
self.gps_points = filter.filter_dense_points(
|
self.gps_points = filter.filter_dense_points(
|
||||||
self.gps_points,
|
self.gps_points,
|
||||||
grid_size=self.config.filter_grid_size,
|
grid_size=self.config.filter_grid_size,
|
||||||
distance_threshold=self.config.filter_dense_distance_threshold
|
distance_threshold=self.config.filter_dense_distance_threshold,
|
||||||
)
|
)
|
||||||
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点")
|
||||||
return self.gps_points
|
return self.gps_points
|
||||||
@ -91,7 +110,8 @@ class ImagePreprocessor:
|
|||||||
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
|
self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})")
|
||||||
grid_divider = GridDivider(overlap=self.config.grid_overlap)
|
grid_divider = GridDivider(overlap=self.config.grid_overlap)
|
||||||
grids = grid_divider.divide_grids(
|
grids = grid_divider.divide_grids(
|
||||||
self.gps_points, grid_size=self.config.grid_size)
|
self.gps_points, grid_size=self.config.grid_size
|
||||||
|
)
|
||||||
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
|
grid_points = grid_divider.assign_to_grids(self.gps_points, grids)
|
||||||
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
|
self.logger.info(f"成功划分为 {len(grid_points)} 个网格")
|
||||||
return grid_points
|
return grid_points
|
||||||
@ -106,16 +126,16 @@ class ImagePreprocessor:
|
|||||||
for grid_idx, points in grid_points.items():
|
for grid_idx, points in grid_points.items():
|
||||||
if self.config.enable_grid_division:
|
if self.config.enable_grid_division:
|
||||||
output_dir = os.path.join(
|
output_dir = os.path.join(
|
||||||
self.config.output_dir, f'grid_{grid_idx + 1}', 'project', 'images')
|
self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
output_dir = os.path.join(
|
output_dir = os.path.join(self.config.output_dir, "project", "images")
|
||||||
self.config.output_dir, 'project', 'images')
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
|
for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"):
|
||||||
src = os.path.join(self.config.image_dir, point['file'])
|
src = os.path.join(self.config.image_dir, point["file"])
|
||||||
dst = os.path.join(output_dir, point['file'])
|
dst = os.path.join(output_dir, point["file"])
|
||||||
shutil.copy(src, dst)
|
shutil.copy(src, dst)
|
||||||
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
|
self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像")
|
||||||
|
|
||||||
@ -129,27 +149,34 @@ class ImagePreprocessor:
|
|||||||
original_points_df = extractor.extract_all_gps()
|
original_points_df = extractor.extract_all_gps()
|
||||||
|
|
||||||
# 读取被过滤的图片列表
|
# 读取被过滤的图片列表
|
||||||
with open(os.path.join(self.config.output_dir, 'del_imgs.txt'), "r", encoding="utf-8") as file:
|
with open(
|
||||||
|
os.path.join(self.config.output_dir, "del_imgs.txt"), "r", encoding="utf-8"
|
||||||
|
) as file:
|
||||||
filtered_files = [line.strip() for line in file if line.strip()]
|
filtered_files = [line.strip() for line in file if line.strip()]
|
||||||
|
|
||||||
# 创建一个新的图形
|
# 创建一个新的图形
|
||||||
plt.figure(figsize=(20, 16))
|
plt.figure(figsize=(20, 16))
|
||||||
|
|
||||||
# 绘制所有原始点
|
# 绘制所有原始点
|
||||||
plt.scatter(original_points_df['lon'],
|
plt.scatter(
|
||||||
original_points_df['lat'],
|
original_points_df["lon"],
|
||||||
color='blue',
|
original_points_df["lat"],
|
||||||
|
color="blue",
|
||||||
label="Original Points",
|
label="Original Points",
|
||||||
alpha=0.6)
|
alpha=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
# 绘制被过滤的点
|
# 绘制被过滤的点
|
||||||
filtered_points_df = original_points_df[original_points_df['file'].isin(
|
filtered_points_df = original_points_df[
|
||||||
filtered_files)]
|
original_points_df["file"].isin(filtered_files)
|
||||||
plt.scatter(filtered_points_df['lon'],
|
]
|
||||||
filtered_points_df['lat'],
|
plt.scatter(
|
||||||
|
filtered_points_df["lon"],
|
||||||
|
filtered_points_df["lat"],
|
||||||
color="red",
|
color="red",
|
||||||
label="Filtered Points",
|
label="Filtered Points",
|
||||||
alpha=0.6)
|
alpha=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
# 设置图形属性
|
# 设置图形属性
|
||||||
plt.title("GPS Coordinates of Images", fontsize=14)
|
plt.title("GPS Coordinates of Images", fontsize=14)
|
||||||
@ -159,7 +186,7 @@ class ImagePreprocessor:
|
|||||||
plt.legend()
|
plt.legend()
|
||||||
|
|
||||||
# 保存图形
|
# 保存图形
|
||||||
plt.savefig(os.path.join(self.config.output_dir, 'filter_GPS.png'))
|
plt.savefig(os.path.join(self.config.output_dir, "filter_GPS.png"))
|
||||||
plt.close()
|
plt.close()
|
||||||
self.logger.info("预处理结果图已保存")
|
self.logger.info("预处理结果图已保存")
|
||||||
|
|
||||||
@ -167,26 +194,27 @@ class ImagePreprocessor:
|
|||||||
"""执行完整的预处理流程"""
|
"""执行完整的预处理流程"""
|
||||||
try:
|
try:
|
||||||
self.extract_gps()
|
self.extract_gps()
|
||||||
self.time_filter()
|
self.cluster()
|
||||||
self.filter_points()
|
# self.time_filter()
|
||||||
grid_points = self.divide_grids()
|
# self.filter_points()
|
||||||
self.copy_images(grid_points)
|
# grid_points = self.divide_grids()
|
||||||
self.visualize_results()
|
# self.copy_images(grid_points)
|
||||||
self.logger.info("预处理任务完成")
|
# self.visualize_results()
|
||||||
self.command_runner.run_grid_commands(
|
# self.logger.info("预处理任务完成")
|
||||||
grid_points,
|
# self.command_runner.run_grid_commands(
|
||||||
self.config.enable_grid_division
|
# grid_points,
|
||||||
)
|
# self.config.enable_grid_division
|
||||||
|
# )
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
# 创建配置
|
# 创建配置
|
||||||
config = PreprocessConfig(
|
config = PreprocessConfig(
|
||||||
image_dir=r'C:\datasets\1815\images',
|
image_dir=r"../code/images",
|
||||||
output_dir=r'C:\datasets\1815\output',
|
output_dir=r"../code/output",
|
||||||
filter_grid_size=0.001,
|
filter_grid_size=0.001,
|
||||||
filter_dense_distance_threshold=10,
|
filter_dense_distance_threshold=10,
|
||||||
filter_distance_threshold=0.001,
|
filter_distance_threshold=0.001,
|
||||||
@ -195,7 +223,7 @@ if __name__ == '__main__':
|
|||||||
enable_filter=True,
|
enable_filter=True,
|
||||||
enable_grid_division=True,
|
enable_grid_division=True,
|
||||||
enable_visualization=True,
|
enable_visualization=True,
|
||||||
enable_copy_images=True
|
enable_copy_images=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建处理器并执行
|
# 创建处理器并执行
|
||||||
|
80
preprocess/cluster.py
Normal file
80
preprocess/cluster.py
Normal file
@ -0,0 +1,80 @@
|
|||||||
|
from sklearn.cluster import DBSCAN
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
|
||||||
|
|
||||||
|
class GPSCluster:
|
||||||
|
def __init__(self, gps_points, eps=0.01, min_samples=5):
|
||||||
|
"""
|
||||||
|
初始化GPS聚类器
|
||||||
|
|
||||||
|
参数:
|
||||||
|
eps: DBSCAN的邻域半径参数
|
||||||
|
min_samples: DBSCAN的最小样本数参数
|
||||||
|
"""
|
||||||
|
self.eps = eps
|
||||||
|
self.min_samples = min_samples
|
||||||
|
self.dbscan = DBSCAN(eps=eps, min_samples=min_samples)
|
||||||
|
self.scaler = StandardScaler()
|
||||||
|
self.gps_points = gps_points
|
||||||
|
self.clustered_points = self.fit()
|
||||||
|
|
||||||
|
def fit(self):
|
||||||
|
"""
|
||||||
|
对GPS点进行聚类,只保留最大的类
|
||||||
|
|
||||||
|
参数:
|
||||||
|
gps_points: 包含'lat'和'lon'列的DataFrame
|
||||||
|
|
||||||
|
返回:
|
||||||
|
带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1
|
||||||
|
"""
|
||||||
|
# 提取经纬度数据
|
||||||
|
X = self.gps_points[["lon", "lat"]].values
|
||||||
|
|
||||||
|
# # 数据标准化
|
||||||
|
# X_scaled = self.scaler.fit_transform(X)
|
||||||
|
|
||||||
|
# 执行DBSCAN聚类
|
||||||
|
labels = self.dbscan.fit_predict(X)
|
||||||
|
|
||||||
|
# 找出最大类的标签(排除噪声点-1)
|
||||||
|
unique_labels = [l for l in set(labels) if l != -1]
|
||||||
|
if unique_labels: # 如果有聚类
|
||||||
|
label_counts = [(l, sum(labels == l)) for l in unique_labels]
|
||||||
|
largest_label = max(label_counts, key=lambda x: x[1])[0]
|
||||||
|
|
||||||
|
# 将最大类标记为1,其他都标记为-1
|
||||||
|
new_labels = (labels == largest_label).astype(int)
|
||||||
|
new_labels[new_labels == 0] = -1
|
||||||
|
else: # 如果没有聚类,全部标记为-1
|
||||||
|
new_labels = labels
|
||||||
|
|
||||||
|
# 将聚类结果添加到原始数据中
|
||||||
|
result_df = self.gps_points.copy()
|
||||||
|
result_df["cluster"] = new_labels
|
||||||
|
|
||||||
|
return result_df
|
||||||
|
|
||||||
|
def get_cluster_stats(self):
|
||||||
|
"""
|
||||||
|
获取聚类统计信息
|
||||||
|
|
||||||
|
参数:
|
||||||
|
clustered_points: 带有聚类标签的DataFrame
|
||||||
|
|
||||||
|
返回:
|
||||||
|
聚类统计信息的字典
|
||||||
|
"""
|
||||||
|
main_cluster_points = sum(self.clustered_points["cluster"] == 1)
|
||||||
|
stats = {
|
||||||
|
"total_points": len(self.clustered_points),
|
||||||
|
"main_cluster_points": main_cluster_points,
|
||||||
|
"noise_points": sum(self.clustered_points["cluster"] == -1),
|
||||||
|
}
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def get_main_cluster(self):
|
||||||
|
return self.clustered_points[self.clustered_points["cluster"] == 1]
|
||||||
|
|
||||||
|
def get_noise_cluster(self):
|
||||||
|
return self.clustered_points[self.clustered_points["cluster"] == -1]
|
Loading…
Reference in New Issue
Block a user