commit f6a5068350dfad20b9eedf0a8dcca21c3cf8481d Author: 龙澳 Date: Mon Dec 30 17:34:21 2024 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..775aa57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +# 忽略所有__pycache__目录 +**/__pycache__/ +*.pyc +*.pyo +*.pyd + +test/ \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..257cece --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +# ODM_Pro +无人机三维重建 + +## Install + +```bash +conda install fiona +pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +``` + +Centos7安装open3d失败执行 + +```bash +conda install -c conda-forge open3d +``` + +## TODO + +- command_runner中rerun需要更新 +- grid要动态分割大小 +- 任务队列 +- 目前obj转osgb的软件windows没有装上,linux成功了,后续需要做一个docker镜像 \ No newline at end of file diff --git a/filter/cluster_filter.py b/filter/cluster_filter.py new file mode 100644 index 0000000..eadebf1 --- /dev/null +++ b/filter/cluster_filter.py @@ -0,0 +1,82 @@ +from sklearn.cluster import DBSCAN +from sklearn.preprocessing import StandardScaler +import os + + +class GPSCluster: + def __init__(self, gps_points, output_dir: str, eps=0.01, min_samples=5): + """ + 初始化GPS聚类器 + + 参数: + eps: DBSCAN的邻域半径参数 + min_samples: DBSCAN的最小样本数参数 + """ + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN(eps=eps, min_samples=min_samples) + self.scaler = StandardScaler() + self.gps_points = gps_points + + def fit(self): + """ + 对GPS点进行聚类,只保留最大的类 + + 参数: + gps_points: 包含'lat'和'lon'列的DataFrame + + 返回: + 带有聚类标签的DataFrame,其中最大类标记为1,其他点标记为-1 + """ + # 提取经纬度数据 + X = self.gps_points[["lon", "lat"]].values + + # # 数据标准化 + # X_scaled = self.scaler.fit_transform(X) + + # 执行DBSCAN聚类 + labels = self.dbscan.fit_predict(X) + + # 找出最大类的标签(排除噪声点-1) + unique_labels = [l for l in set(labels) if l != -1] + if unique_labels: # 如果有聚类 + label_counts = [(l, sum(labels == l)) for l in unique_labels] + largest_label = max(label_counts, key=lambda x: x[1])[0] + + # 将最大类标记为1,其他都标记为-1 + new_labels = (labels == largest_label).astype(int) + new_labels[new_labels == 0] = -1 + else: # 如果没有聚类,全部标记为-1 + new_labels = labels + + # 将聚类结果添加到原始数据中 + result_df = self.gps_points.copy() + result_df["cluster"] = new_labels + + return result_df + + def get_cluster_stats(self, clustered_points): + """ + 获取聚类统计信息 + + 参数: + clustered_points: 带有聚类标签的DataFrame + + 返回: + 聚类统计信息的字典 + """ + main_cluster_points = sum(clustered_points["cluster"] == 1) + stats = { + "total_points": len(clustered_points), + "main_cluster_points": main_cluster_points, + "noise_points": sum(clustered_points["cluster"] == -1), + } + + noise_cluster = self.get_noise_cluster(clustered_points) + return stats + + def get_main_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == 1] + + def get_noise_cluster(self, clustered_points): + return clustered_points[clustered_points["cluster"] == -1] diff --git a/filter/gps_filter.py b/filter/gps_filter.py new file mode 100644 index 0000000..95d1e1b --- /dev/null +++ b/filter/gps_filter.py @@ -0,0 +1,248 @@ +import os +import math +from itertools import combinations +import numpy as np +from scipy.spatial import KDTree +import logging +import pandas as pd +from datetime import datetime, timedelta + + +class GPSFilter: + """过滤密集点及孤立点""" + + def __init__(self, output_dir): + self.logger = logging.getLogger('UAV_Preprocess.GPSFilter') + + @staticmethod + def _haversine(lat1, lon1, lat2, lon2): + """计算两点之间的地理距离(单位:米)""" + R = 6371000 # 地球平均半径,单位:米 + phi1, phi2 = math.radians(lat1), math.radians(lat2) + delta_phi = math.radians(lat2 - lat1) + delta_lambda = math.radians(lon2 - lon1) + + a = math.sin(delta_phi / 2) ** 2 + math.cos(phi1) * \ + math.cos(phi2) * math.sin(delta_lambda / 2) ** 2 + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + return R * c + + @staticmethod + def _assign_to_grid(lat, lon, grid_size, min_lat, min_lon): + """根据经纬度和网格大小,将点分配到网格""" + grid_x = int((lat - min_lat) // grid_size) + grid_y = int((lon - min_lon) // grid_size) + return grid_x, grid_y + + def _get_distances(self, points_df, grid_size): + """读取图片 GPS 坐标,计算点对之间的距离并排序""" + # 确定经纬度范围 + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + self.logger.info( + f"经纬度范围:纬度[{min_lat:.6f}, {max_lat:.6f}],纬度范围[{max_lat-min_lat:.6f}]," + f"经度[{min_lon:.6f}, {max_lon:.6f}],经度范围[{max_lon-min_lon:.6f}]") + + # 分配到网格 + grid_map = {} + for _, row in points_df.iterrows(): + grid = self._assign_to_grid( + row['lat'], row['lon'], grid_size, min_lat, min_lon) + if grid not in grid_map: + grid_map[grid] = [] + grid_map[grid].append((row['file'], row['lat'], row['lon'])) + + self.logger.info(f"图像点已分配到 {len(grid_map)} 个网格中") + + # 在每个网格中计算两两距离并排序 + sorted_distances = {} + for grid, images in grid_map.items(): + distances = [] + for (img1, lat1, lon1), (img2, lat2, lon2) in combinations(images, 2): + dist = self._haversine(lat1, lon1, lat2, lon2) + distances.append((img1, img2, dist)) + distances.sort(key=lambda x: x[2]) # 按距离升序排序 + sorted_distances[grid] = distances + self.logger.debug(f"网格 {grid} 中计算了 {len(distances)} 个距离对") + + return sorted_distances + + def _group_by_time(self, points_df: pd.DataFrame, time_threshold: timedelta) -> list: + """根据拍摄时间分组图片 + + 如果相邻两张图片的拍摄时间差超过5分钟,则进行切分 + + Args: + points_df: 包含图片信息的DataFrame,必须包含'file'和'date'列 + time_threshold: 时间间隔阈值,默认5分钟 + + Returns: + list: 每个元素为时间组内的点数据 + """ + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [points_df] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + self.logger.info( + f"有效时间范围: {valid_date_points['date'].min()} 到 {valid_date_points['date'].max()}") + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + + # 记录断点信息 + break_time = valid_date_points.iloc[idx]['date'] + group_start_time = current_group.iloc[0]['date'] + group_end_time = current_group.iloc[-1]['date'] + + self.logger.info( + f"时间组 {len(time_groups)}: {len(current_group)} 个点, " + f"时间范围 [{group_start_time} - {group_end_time}]" + ) + self.logger.info( + f"在时间 {break_time} 处发现断点,时间差为 {time_diff}") + + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + self.logger.info( + f"时间组 {len(time_groups)}: {len(last_group)} 个点, " + f"时间范围 [{last_group.iloc[0]['date']} - {last_group.iloc[-1]['date']}]" + ) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + self.logger.info(f"添加无时间戳组: {len(null_date_group)} 个点") + + self.logger.info(f"共分为 {len(time_groups)} 个时间组") + return time_groups + + def filter_dense_points(self, points_df, grid_size=0.001, distance_threshold=13, time_threshold=timedelta(minutes=5)): + """ + 过滤密集点,先按时间分组,再在每个时间组内过滤。 + 空时间戳的点不进行过滤。 + + Args: + points_df: 点数据 + grid_size: 网格大小 + distance_threshold: 距离阈值(米) + time_interval: 时间间隔(秒) + """ + self.logger.info(f"开始按时间分组过滤密集点 (网格大小: {grid_size}, " + f"距离阈值: {distance_threshold}米, 分组时间间隔: {time_threshold}秒)") + + # 按时间分组 + time_groups = self._group_by_time(points_df, time_threshold) + + # 存储所有要删除的图片 + all_to_del_imgs = [] + + # 对每个时间组进行密集点过滤 + for group_idx, group_points in enumerate(time_groups): + # 检查是否为空时间戳组(最后一组) + if group_idx == len(time_groups) - 1 and group_points['date'].isna().any(): + self.logger.info(f"跳过无时间戳组 (包含 {len(group_points)} 个点)") + continue + + self.logger.info( + f"处理时间组 {group_idx + 1} (包含 {len(group_points)} 个点)") + + # 计算该组内的点间距离 + sorted_distances = self._get_distances(group_points, grid_size) + group_to_del_imgs = [] + + # 在每个网格中过滤密集点 + for grid, distances in sorted_distances.items(): + grid_del_count = 0 + while distances: + candidate_img1, candidate_img2, dist = distances[0] + if dist < distance_threshold: + distances.pop(0) + + # 获取候选图片的其他最短距离 + candidate_img1_dist = None + candidate_img2_dist = None + for distance in distances: + if candidate_img1 in distance: + candidate_img1_dist = distance[2] + break + for distance in distances: + if candidate_img2 in distance: + candidate_img2_dist = distance[2] + break + + # 选择要删除的点 + if candidate_img1_dist and candidate_img2_dist: + to_del_img = candidate_img1 if candidate_img1_dist < candidate_img2_dist else candidate_img2 + group_to_del_imgs.append(to_del_img) + grid_del_count += 1 + self.logger.debug( + f"时间组 {group_idx + 1} 网格 {grid} 删除密集点: {to_del_img} (距离: {dist:.2f}米)") + distances = [ + d for d in distances if to_del_img not in d] + else: + break + + if grid_del_count > 0: + self.logger.info( + f"时间组 {group_idx + 1} 网格 {grid} 删除了 {grid_del_count} 个密集点") + + all_to_del_imgs.extend(group_to_del_imgs) + self.logger.info( + f"时间组 {group_idx + 1} 共删除 {len(group_to_del_imgs)} 个密集点") + + + # 过滤数据 + filtered_df = points_df[~points_df['file'].isin(all_to_del_imgs)] + self.logger.info( + f"密集点过滤完成,共删除 {len(all_to_del_imgs)} 个点,剩余 {len(filtered_df)} 个点") + + return filtered_df + + def filter_isolated_points(self, points_df, threshold_distance=0.001, min_neighbors=6): + """过滤孤立点""" + self.logger.info( + f"开始过滤孤立点 (距离阈值: {threshold_distance}, 最小邻居数: {min_neighbors})") + + coords = points_df[['lat', 'lon']].values + kdtree = KDTree(coords) + neighbors_count = [len(kdtree.query_ball_point( + coord, threshold_distance)) for coord in coords] + + isolated_points = [] + for i, (_, row) in enumerate(points_df.iterrows()): + if neighbors_count[i] < min_neighbors: + isolated_points.append(row['file']) + self.logger.debug( + f"删除孤立点: {row['file']} (邻居数: {neighbors_count[i]})") + + filtered_df = points_df[~points_df['file'].isin(isolated_points)] + self.logger.info( + f"孤立点过滤完成,共删除 {len(isolated_points)} 个点,剩余 {len(filtered_df)} 个点") + return filtered_df diff --git a/filter/time_group_overlap_filter.py b/filter/time_group_overlap_filter.py new file mode 100644 index 0000000..15ece6f --- /dev/null +++ b/filter/time_group_overlap_filter.py @@ -0,0 +1,212 @@ +import shutil +import pandas as pd +from shapely.geometry import box +from utils.logger import setup_logger +from utils.gps_extractor import GPSExtractor +import numpy as np +import logging +from datetime import timedelta +import matplotlib.pyplot as plt +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class TimeGroupOverlapFilter: + """基于时间组重叠度的图像过滤器""" + + def __init__(self, image_dir: str, output_dir: str, overlap_threshold: float = 0.7): + """ + 初始化过滤器 + + Args: + image_dir: 图像目录 + output_dir: 输出目录 + overlap_threshold: 重叠阈值,默认0.7 + """ + self.image_dir = image_dir + self.output_dir = output_dir + self.overlap_threshold = overlap_threshold + self.logger = logging.getLogger('UAV_Preprocess.TimeGroupFilter') + + def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)): + """按时间间隔对点进行分组""" + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + + return time_groups + + def _get_group_bbox(self, group_df): + """获取组内点的边界框""" + min_lon = group_df['lon'].min() + max_lon = group_df['lon'].max() + min_lat = group_df['lat'].min() + max_lat = group_df['lat'].max() + return box(min_lon, min_lat, max_lon, max_lat) + + def _calculate_overlap(self, box1, box2): + """计算两个边界框的重叠率""" + if box1.intersects(box2): + intersection_area = box1.intersection(box2).area + smaller_area = min(box1.area, box2.area) + return intersection_area / smaller_area + return 0 + + def filter_overlapping_groups(self, time_threshold=timedelta(minutes=5)): + """过滤重叠的时间组""" + # 提取GPS数据 + extractor = GPSExtractor(self.image_dir) + gps_points = extractor.extract_all_gps() + + # 按时间分组 + time_groups = self._group_by_time(gps_points, time_threshold) + + # 计算每个组的边界框 + group_boxes = [] + for idx, group in enumerate(time_groups): + if not group['date'].isna().any(): # 只处理有时间戳的组 + bbox = self._get_group_bbox(group) + group_boxes.append((idx, group, bbox)) + + # 找出需要删除的组 + groups_to_delete = set() + for i in range(len(group_boxes)): + if i in groups_to_delete: + continue + + idx1, group1, box1 = group_boxes[i] + area1 = box1.area + + for j in range(i + 1, len(group_boxes)): + if j in groups_to_delete: + continue + + idx2, group2, box2 = group_boxes[j] + area2 = box2.area + + overlap_ratio = self._calculate_overlap(box1, box2) + + if overlap_ratio > self.overlap_threshold: + # 删除面积较小的组 + if area1 < area2: + group_to_delete = idx1 + smaller_area = area1 + larger_area = area2 + else: + group_to_delete = idx2 + smaller_area = area2 + larger_area = area1 + + groups_to_delete.add(group_to_delete) + self.logger.info( + f"时间组 {group_to_delete + 1} 与时间组 " + f"{idx2 + 1 if group_to_delete == idx1 else idx1 + 1} " + f"重叠率为 {overlap_ratio:.2f}," + f"面积比为 {smaller_area/larger_area:.2f}," + f"将删除较小面积的组 {group_to_delete + 1}" + ) + + # 删除重复组的图像 + deleted_files = [] + for group_idx in groups_to_delete: + group_files = time_groups[group_idx]['file'].tolist() + deleted_files.extend(group_files) + + self.logger.info(f"共删除 {len(groups_to_delete)} 个重复时间组," + f"{len(deleted_files)} 张图像") + + # 可视化结果 + self._visualize_results(time_groups, groups_to_delete) + + return deleted_files + + def _visualize_results(self, time_groups, groups_to_delete): + """可视化过滤结果""" + plt.figure(figsize=(15, 10)) + + # 生成不同的颜色 + colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups))) + + # 绘制所有组的边界框 + for idx, (group, color) in enumerate(zip(time_groups, colors)): + if not group['date'].isna().any(): # 只处理有时间戳的组 + bbox = self._get_group_bbox(group) + x, y = bbox.exterior.xy + + if idx in groups_to_delete: + # 被删除的组用虚线表示 + plt.plot(x, y, '--', color=color, alpha=0.6, + label=f'Deleted Group {idx + 1}') + else: + # 保留的组用实线表示 + plt.plot(x, y, '-', color=color, alpha=0.6, + label=f'Group {idx + 1}') + + # 绘制该组的GPS点 + plt.scatter(group['lon'], group['lat'], color=color, + s=30, alpha=0.6) + + plt.title("Time Groups and Their Bounding Boxes", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10) + plt.tight_layout() + + # 保存图片 + plt.savefig(os.path.join(self.output_dir, 'filter_imgs', 'time_groups_overlap_bbox.png'), + dpi=300, bbox_inches='tight') + plt.close() + + +if __name__ == '__main__': + # 设置路径 + DATASET = r'F:\error_data\20241108134711\3D' + output_dir = r'E:\studio2\ODM_pro\test' + os.makedirs(output_dir, exist_ok=True) + + # 设置日志 + setup_logger(os.path.dirname(output_dir)) + + # 创建过滤器并执行过滤 + filter = TimeGroupOverlapFilter(DATASET, output_dir, overlap_threshold=0.7) + deleted_files = filter.filter_overlapping_groups( + time_threshold=timedelta(minutes=5)) diff --git a/odm_preprocess.py b/odm_preprocess.py new file mode 100644 index 0000000..6034e5f --- /dev/null +++ b/odm_preprocess.py @@ -0,0 +1,319 @@ +import os +import shutil +from datetime import timedelta +from dataclasses import dataclass +from typing import Dict + +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm + +from filter.cluster_filter import GPSCluster +from filter.time_group_overlap_filter import TimeGroupOverlapFilter +from filter.gps_filter import GPSFilter +from utils.odm_monitor import ODMProcessMonitor +from utils.gps_extractor import GPSExtractor +from utils.grid_divider import GridDivider +from utils.logger import setup_logger +from utils.visualizer import FilterVisualizer +from post_pro.merge_tif import MergeTif +from tools.test_docker_run import run_docker_command +from post_pro.merge_obj import MergeObj +from post_pro.merge_ply import MergePly + + +@dataclass +class PreprocessConfig: + """预处理配置类""" + + image_dir: str + output_dir: str + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + # 时间组重叠过滤参数 + time_group_overlap_threshold: float = 0.7 + time_group_interval: timedelta = timedelta(minutes=5) + # 孤立点过滤参数 + filter_distance_threshold: float = 0.001 # 经纬度距离 + filter_min_neighbors: int = 6 + # 密集点过滤参数 + filter_grid_size: float = 0.001 + filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 + filter_time_threshold: timedelta = timedelta(minutes=5) + # 网格划分参数 + grid_overlap: float = 0.05 + grid_size: float = 500 + # 几个pipline过程是否开启 + mode: str = "快拼模式" + + +class ImagePreprocessor: + def __init__(self, config: PreprocessConfig): + self.config = config + + # 清理并重建输出目录 + if os.path.exists(config.output_dir): + self._clean_output_dir() + self._setup_output_dirs() + + # 初始化其他组件 + self.logger = setup_logger(config.output_dir) + self.gps_points = None + self.odm_monitor = ODMProcessMonitor( + config.output_dir, mode=config.mode) + self.visualizer = FilterVisualizer(config.output_dir) + + def _clean_output_dir(self): + """清理输出目录""" + try: + shutil.rmtree(self.config.output_dir) + print(f"已清理输出目录: {self.config.output_dir}") + except Exception as e: + print(f"清理输出目录时发生错误: {str(e)}") + raise + + def _setup_output_dirs(self): + """创建必要的输出目录结构""" + try: + # 创建主输出目录 + os.makedirs(self.config.output_dir) + + # 创建过滤图像保存目录 + os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) + + # 创建日志目录 + os.makedirs(os.path.join(self.config.output_dir, 'logs')) + + print(f"已创建输出目录结构: {self.config.output_dir}") + except Exception as e: + print(f"创建输出目录时发生错误: {str(e)}") + raise + + def extract_gps(self) -> pd.DataFrame: + """提取GPS数据""" + self.logger.info("开始提取GPS数据") + extractor = GPSExtractor(self.config.image_dir) + self.gps_points = extractor.extract_all_gps() + self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + return self.gps_points + + def cluster(self) -> pd.DataFrame: + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + self.logger.info("开始聚类") + previous_points = self.gps_points.copy() + + # 创建聚类器并执行聚类 + clusterer = GPSCluster( + self.gps_points, output_dir=self.config.output_dir, + eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) + + # 获取主要类别的点 + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_main_cluster(self.clustered_points) + + # 获取统计信息并记录 + stats = clusterer.get_cluster_stats(self.clustered_points) + self.logger.info( + f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," + f"噪声点 {stats['noise_points']} 个" + ) + + # 可视化聚类结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "1-Clustering") + + return self.gps_points + + def filter_time_group_overlap(self) -> pd.DataFrame: + """过滤重叠的时间组""" + self.logger.info("开始过滤重叠时间组") + + self.logger.info("开始过滤重叠时间组") + previous_points = self.gps_points.copy() + + filter = TimeGroupOverlapFilter( + self.config.image_dir, + self.config.output_dir, + overlap_threshold=self.config.time_group_overlap_threshold + ) + + deleted_files = filter.filter_overlapping_groups( + time_threshold=self.config.time_group_interval + ) + + # 更新GPS点数据,移除被删除的图像 + self.gps_points = self.gps_points[~self.gps_points['file'].isin( + deleted_files)] + self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "2-Time Group Overlap") + + return self.gps_points + + # TODO 过滤算法还需要更新 + def filter_points(self) -> pd.DataFrame: + """过滤GPS点""" + + self.logger.info("开始过滤GPS点") + filter = GPSFilter(self.config.output_dir) + + # 过滤孤立点 + previous_points = self.gps_points.copy() + self.logger.info( + f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, " + f"最小邻居数: {self.config.filter_min_neighbors})" + ) + self.gps_points = filter.filter_isolated_points( + self.gps_points, + self.config.filter_distance_threshold, + self.config.filter_min_neighbors, + ) + self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化孤立点过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "3-Isolated Points") + + # # 过滤密集点 + # previous_points = self.gps_points.copy() + # self.logger.info( + # f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, " + # f"距离阈值: {self.config.filter_dense_distance_threshold})" + # ) + # self.gps_points = filter.filter_dense_points( + # self.gps_points, + # grid_size=self.config.filter_grid_size, + # distance_threshold=self.config.filter_dense_distance_threshold, + # time_threshold=self.config.filter_time_threshold, + # ) + # self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # # 可视化密集点过滤结果 + # self.visualizer.visualize_filter_step( + # self.gps_points, previous_points, "4-Dense Points") + + return self.gps_points + + def divide_grids(self) -> Dict[int, pd.DataFrame]: + """划分网格""" + self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") + grid_divider = GridDivider( + overlap=self.config.grid_overlap, + output_dir=self.config.output_dir + ) + grids = grid_divider.divide_grids( + self.gps_points, grid_size=self.config.grid_size + ) + grid_points = grid_divider.assign_to_grids(self.gps_points, grids) + self.logger.info(f"成功划分为 {len(grid_points)} 个网格") + + # 生成image_groups.txt文件 + try: + groups_file = os.path.join(self.config.output_dir, "image_groups.txt") + self.logger.info(f"开始生成分组文件: {groups_file}") + + with open(groups_file, 'w') as f: + for grid_idx, points_lt in grid_points.items(): + # 使用ASCII字母作为组标识(A, B, C...) + group_letter = chr(65 + grid_idx) # 65是ASCII中'A'的编码 + + # 为每个网格中的图像写入分组信息 + for point in points_lt: + f.write(f"{point['file']} {group_letter}\n") + + self.logger.info(f"分组文件生成成功: {groups_file}") + except Exception as e: + self.logger.error(f"生成分组文件时发生错误: {str(e)}", exc_info=True) + raise + + return grid_points + + def copy_images(self, grid_points: Dict[int, pd.DataFrame]): + """复制图像到目标文件夹""" + self.logger.info("开始复制图像文件") + self.logger.info("开始复制图像文件") + + for grid_idx, points in grid_points.items(): + output_dir = os.path.join( + self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" + ) + + os.makedirs(output_dir, exist_ok=True) + + for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): + src = os.path.join(self.config.image_dir, point["file"]) + dst = os.path.join(output_dir, point["file"]) + shutil.copy(src, dst) + self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") + + def merge_tif(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的影像产品""" + self.logger.info("开始合并所有影像产品") + merger = MergeTif(self.config.output_dir) + merger.merge_all_tifs(grid_points) + + def merge_obj(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的OBJ模型""" + self.logger.info("开始合并OBJ模型") + merger = MergeObj(self.config.output_dir) + merger.merge_grid_obj(grid_points) + + def merge_ply(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的PLY点云""" + self.logger.info("开始合并PLY点云") + merger = MergePly(self.config.output_dir) + merger.merge_grid_ply(grid_points) + + def process(self): + """执行完整的预处理流程""" + try: + self.extract_gps() + self.cluster() + # self.filter_time_group_overlap() + self.filter_points() + grid_points = self.divide_grids() + self.copy_images(grid_points) + self.logger.info("预处理任务完成") + + self.odm_monitor.process_all_grids(grid_points) + self.merge_tif(grid_points) + self.merge_obj(grid_points) + self.merge_ply(grid_points) + except Exception as e: + self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + # 创建配置 + config = PreprocessConfig( + image_dir=r"E:\datasets\UAV\134\project\images", + output_dir=r"G:\ODM_output\134_test", + + cluster_eps=0.01, + cluster_min_samples=5, + + # 添加时间组重叠过滤参数 + time_group_overlap_threshold=0.7, + time_group_interval=timedelta(minutes=5), + + filter_distance_threshold=0.001, + filter_min_neighbors=6, + + filter_grid_size=0.001, + filter_dense_distance_threshold=10, + filter_time_threshold=timedelta(minutes=5), + + grid_size=300, + grid_overlap=0.1, + + + mode="重建模式", + ) + + # 创建处理器并执行 + processor = ImagePreprocessor(config) + processor.process() diff --git a/odm_preprocess_fast.py b/odm_preprocess_fast.py new file mode 100644 index 0000000..067b706 --- /dev/null +++ b/odm_preprocess_fast.py @@ -0,0 +1,319 @@ +import os +import shutil +from datetime import timedelta +from dataclasses import dataclass +from typing import Dict + +import matplotlib.pyplot as plt +import pandas as pd +from tqdm import tqdm + +from filter.cluster_filter import GPSCluster +from filter.time_group_overlap_filter import TimeGroupOverlapFilter +from filter.gps_filter import GPSFilter +from utils.odm_monitor import ODMProcessMonitor +from utils.gps_extractor import GPSExtractor +from utils.grid_divider import GridDivider +from utils.logger import setup_logger +from utils.visualizer import FilterVisualizer +from post_pro.merge_tif import MergeTif +from tools.test_docker_run import run_docker_command +from post_pro.merge_obj import MergeObj +from post_pro.merge_ply import MergePly + + +@dataclass +class PreprocessConfig: + """预处理配置类""" + + image_dir: str + output_dir: str + # 聚类过滤参数 + cluster_eps: float = 0.01 + cluster_min_samples: int = 5 + # 时间组重叠过滤参数 + time_group_overlap_threshold: float = 0.7 + time_group_interval: timedelta = timedelta(minutes=5) + # 孤立点过滤参数 + filter_distance_threshold: float = 0.001 # 经纬度距离 + filter_min_neighbors: int = 6 + # 密集点过滤参数 + filter_grid_size: float = 0.001 + filter_dense_distance_threshold: float = 10 # 普通距离,单位:米 + filter_time_threshold: timedelta = timedelta(minutes=5) + # 网格划分参数 + grid_overlap: float = 0.05 + grid_size: float = 500 + # 几个pipline过程是否开启 + mode: str = "快拼模式" + + +class ImagePreprocessor: + def __init__(self, config: PreprocessConfig): + self.config = config + + # # 清理并重建输出目录 + # if os.path.exists(config.output_dir): + # self._clean_output_dir() + # self._setup_output_dirs() + + # 初始化其他组件 + self.logger = setup_logger(config.output_dir) + self.gps_points = None + self.odm_monitor = ODMProcessMonitor( + config.output_dir, mode=config.mode) + self.visualizer = FilterVisualizer(config.output_dir) + + def _clean_output_dir(self): + """清理输出目录""" + try: + shutil.rmtree(self.config.output_dir) + print(f"已清理输出目录: {self.config.output_dir}") + except Exception as e: + print(f"清理输出目录时发生错误: {str(e)}") + raise + + def _setup_output_dirs(self): + """创建必要的输出目录结构""" + try: + # 创建主输出目录 + os.makedirs(self.config.output_dir) + + # 创建过滤图像保存目录 + os.makedirs(os.path.join(self.config.output_dir, 'filter_imgs')) + + # 创建日志目录 + os.makedirs(os.path.join(self.config.output_dir, 'logs')) + + print(f"已创建输出目录结构: {self.config.output_dir}") + except Exception as e: + print(f"创建输出目录时发生错误: {str(e)}") + raise + + def extract_gps(self) -> pd.DataFrame: + """提取GPS数据""" + self.logger.info("开始提取GPS数据") + extractor = GPSExtractor(self.config.image_dir) + self.gps_points = extractor.extract_all_gps() + self.logger.info(f"成功提取 {len(self.gps_points)} 个GPS点") + return self.gps_points + + def cluster(self) -> pd.DataFrame: + """使用DBSCAN对GPS点进行聚类,只保留最大的类""" + self.logger.info("开始聚类") + previous_points = self.gps_points.copy() + + # 创建聚类器并执行聚类 + clusterer = GPSCluster( + self.gps_points, output_dir=self.config.output_dir, + eps=self.config.cluster_eps, min_samples=self.config.cluster_min_samples) + + # 获取主要类别的点 + self.clustered_points = clusterer.fit() + self.gps_points = clusterer.get_main_cluster(self.clustered_points) + + # 获取统计信息并记录 + stats = clusterer.get_cluster_stats(self.clustered_points) + self.logger.info( + f"聚类完成:主要类别包含 {stats['main_cluster_points']} 个点," + f"噪声点 {stats['noise_points']} 个" + ) + + # 可视化聚类结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "1-Clustering") + + return self.gps_points + + def filter_time_group_overlap(self) -> pd.DataFrame: + """过滤重叠的时间组""" + self.logger.info("开始过滤重叠时间组") + + self.logger.info("开始过滤重叠时间组") + previous_points = self.gps_points.copy() + + filter = TimeGroupOverlapFilter( + self.config.image_dir, + self.config.output_dir, + overlap_threshold=self.config.time_group_overlap_threshold + ) + + deleted_files = filter.filter_overlapping_groups( + time_threshold=self.config.time_group_interval + ) + + # 更新GPS点数据,移除被删除的图像 + self.gps_points = self.gps_points[~self.gps_points['file'].isin( + deleted_files)] + self.logger.info(f"重叠时间组过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "2-Time Group Overlap") + + return self.gps_points + + # TODO 过滤算法还需要更新 + def filter_points(self) -> pd.DataFrame: + """过滤GPS点""" + + self.logger.info("开始过滤GPS点") + filter = GPSFilter(self.config.output_dir) + + # 过滤孤立点 + previous_points = self.gps_points.copy() + self.logger.info( + f"开始过滤孤立点(距离阈值: {self.config.filter_distance_threshold}, " + f"最小邻居数: {self.config.filter_min_neighbors})" + ) + self.gps_points = filter.filter_isolated_points( + self.gps_points, + self.config.filter_distance_threshold, + self.config.filter_min_neighbors, + ) + self.logger.info(f"孤立点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # 可视化孤立点过滤结果 + self.visualizer.visualize_filter_step( + self.gps_points, previous_points, "3-Isolated Points") + + # # 过滤密集点 + # previous_points = self.gps_points.copy() + # self.logger.info( + # f"开始过滤密集点(网格大小: {self.config.filter_grid_size}, " + # f"距离阈值: {self.config.filter_dense_distance_threshold})" + # ) + # self.gps_points = filter.filter_dense_points( + # self.gps_points, + # grid_size=self.config.filter_grid_size, + # distance_threshold=self.config.filter_dense_distance_threshold, + # time_threshold=self.config.filter_time_threshold, + # ) + # self.logger.info(f"密集点过滤后剩余 {len(self.gps_points)} 个GPS点") + + # # 可视化密集点过滤结果 + # self.visualizer.visualize_filter_step( + # self.gps_points, previous_points, "4-Dense Points") + + return self.gps_points + + def divide_grids(self) -> Dict[int, pd.DataFrame]: + """划分网格""" + self.logger.info(f"开始划分网格 (重叠率: {self.config.grid_overlap})") + grid_divider = GridDivider( + overlap=self.config.grid_overlap, + output_dir=self.config.output_dir + ) + grids = grid_divider.divide_grids( + self.gps_points, grid_size=self.config.grid_size + ) + grid_points = grid_divider.assign_to_grids(self.gps_points, grids) + self.logger.info(f"成功划分为 {len(grid_points)} 个网格") + + # 生成image_groups.txt文件 + try: + groups_file = os.path.join(self.config.output_dir, "image_groups.txt") + self.logger.info(f"开始生成分组文件: {groups_file}") + + with open(groups_file, 'w') as f: + for grid_idx, points_lt in grid_points.items(): + # 使用ASCII字母作为组标识(A, B, C...) + group_letter = chr(65 + grid_idx) # 65是ASCII中'A'的编码 + + # 为每个网格中的图像写入分组信息 + for point in points_lt: + f.write(f"{point['file']} {group_letter}\n") + + self.logger.info(f"分组文件生成成功: {groups_file}") + except Exception as e: + self.logger.error(f"生成分组文件时发生错误: {str(e)}", exc_info=True) + raise + + return grid_points + + def copy_images(self, grid_points: Dict[int, pd.DataFrame]): + """复制图像到目标文件夹""" + self.logger.info("开始复制图像文件") + self.logger.info("开始复制图像文件") + + for grid_idx, points in grid_points.items(): + output_dir = os.path.join( + self.config.output_dir, f"grid_{grid_idx + 1}", "project", "images" + ) + + os.makedirs(output_dir, exist_ok=True) + + for point in tqdm(points, desc=f"复制网格 {grid_idx + 1} 的图像"): + src = os.path.join(self.config.image_dir, point["file"]) + dst = os.path.join(output_dir, point["file"]) + shutil.copy(src, dst) + self.logger.info(f"网格 {grid_idx + 1} 包含 {len(points)} 张图像") + + def merge_tif(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的影像产品""" + self.logger.info("开始合并所有影像产品") + merger = MergeTif(self.config.output_dir) + merger.merge_all_tifs(grid_points) + + def merge_obj(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的OBJ模型""" + self.logger.info("开始合并OBJ模型") + merger = MergeObj(self.config.output_dir) + merger.merge_grid_obj(grid_points) + + def merge_ply(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的PLY点云""" + self.logger.info("开始合并PLY点云") + merger = MergePly(self.config.output_dir) + merger.merge_grid_ply(grid_points) + + def process(self): + """执行完整的预处理流程""" + try: + self.extract_gps() + self.cluster() + # self.filter_time_group_overlap() + self.filter_points() + grid_points = self.divide_grids() + # self.copy_images(grid_points) + self.logger.info("预处理任务完成") + + # self.odm_monitor.process_all_grids(grid_points) + # self.merge_tif(grid_points) + self.merge_ply(grid_points) + self.merge_obj(grid_points) + except Exception as e: + self.logger.error(f"处理过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + # 创建配置 + config = PreprocessConfig( + image_dir=r"E:\datasets\UAV\1009\project\images", + output_dir=r"G:\ODM_output\1009", + + cluster_eps=0.01, + cluster_min_samples=5, + + # 添加时间组重叠过滤参数 + time_group_overlap_threshold=0.7, + time_group_interval=timedelta(minutes=5), + + filter_distance_threshold=0.001, + filter_min_neighbors=6, + + filter_grid_size=0.001, + filter_dense_distance_threshold=10, + filter_time_threshold=timedelta(minutes=5), + + grid_size=300, + grid_overlap=0.1, + + + mode="重建模式", + ) + + # 创建处理器并执行 + processor = ImagePreprocessor(config) + processor.process() diff --git a/post_pro/merge_obj.py b/post_pro/merge_obj.py new file mode 100644 index 0000000..08f38f1 --- /dev/null +++ b/post_pro/merge_obj.py @@ -0,0 +1,142 @@ +import os +import logging +import numpy as np +from typing import Dict +import pandas as pd +import open3d as o3d + + +class MergeObj: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.MergeObj') + + def merge_two_objs(self, obj1_path: str, obj2_path: str, output_path: str): + """使用Open3D合并两个OBJ文件""" + try: + self.logger.info("开始合并OBJ模型") + self.logger.info(f"输入模型1: {obj1_path}") + self.logger.info(f"输入模型2: {obj2_path}") + self.logger.info(f"输出模型: {output_path}") + + # 检查输入文件是否存在 + if not os.path.exists(obj1_path) or not os.path.exists(obj2_path): + raise FileNotFoundError("输入模型文件不存在") + + # 读取OBJ文件 + mesh1 = o3d.io.read_triangle_mesh(obj1_path) + mesh2 = o3d.io.read_triangle_mesh(obj2_path) + + if mesh1.is_empty() or mesh2.is_empty(): + raise ValueError("无法读取OBJ文件或文件为空") + + # # 计算并对齐中心点 + # center1 = mesh1.get_center() + # center2 = mesh2.get_center() + # translation_vector = center2 - center1 + # mesh2.translate(translation_vector) + + # 不对齐,直接合并网格 + combined_mesh = mesh1 + mesh2 + + # 优化合并后的网格 + combined_mesh.remove_duplicated_vertices() + combined_mesh.remove_duplicated_triangles() + combined_mesh.compute_vertex_normals() + + # 保存合并后的模型 + if not o3d.io.write_triangle_mesh(output_path, combined_mesh): + raise RuntimeError("保存合并后的模型失败") + + self.logger.info(f"模型合并成功,已保存至: {output_path}") + + except Exception as e: + self.logger.error(f"合并OBJ模型时发生错误: {str(e)}", exc_info=True) + raise + + def merge_grid_obj(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有网格的OBJ模型""" + self.logger.info("开始合并所有网格的OBJ模型") + + if len(grid_points) < 2: + self.logger.info("只有一个网格,无需合并") + return + + input_obj1, input_obj2 = None, None + merge_count = 0 + + try: + for grid_idx, points in grid_points.items(): + grid_obj = os.path.join( + self.output_dir, + f"grid_{grid_idx + 1}", + "project", + "odm_texturing_25d", + "odm_textured_model_geo.obj" + ) + + if not os.path.exists(grid_obj): + self.logger.warning( + f"网格 {grid_idx + 1} 的OBJ文件不存在: {grid_obj}") + continue + + if input_obj1 is None: + input_obj1 = grid_obj + self.logger.info(f"设置第一个输入OBJ: {input_obj1}") + else: + input_obj2 = grid_obj + output_obj = os.path.join( + self.output_dir, "merged_model.obj") + + self.logger.info( + f"开始合并第 {merge_count + 1} 次:\n" + f"输入1: {input_obj1}\n" + f"输入2: {input_obj2}\n" + f"输出: {output_obj}" + ) + + self.merge_two_objs(input_obj1, input_obj2, output_obj) + merge_count += 1 + + input_obj1 = output_obj + input_obj2 = None + + self.logger.info( + f"OBJ模型合并完成,共执行 {merge_count} 次合并," + f"最终输出文件: {input_obj1}" + ) + + except Exception as e: + self.logger.error(f"OBJ模型合并过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + import sys + sys.path.append(os.path.dirname( + os.path.dirname(os.path.abspath(__file__)))) + from utils.logger import setup_logger + import pandas as pd + + # 设置输出目录和日志 + output_dir = r"G:\ODM_output\1009" + setup_logger(output_dir) + + # 构造测试用的grid_points字典 + # 假设我们有两个网格,每个网格包含一些GPS点的DataFrame + grid_points = { + 0: pd.DataFrame({ + 'latitude': [39.9, 39.91], + 'longitude': [116.3, 116.31], + 'altitude': [100, 101] + }), + 1: pd.DataFrame({ + 'latitude': [39.92, 39.93], + 'longitude': [116.32, 116.33], + 'altitude': [102, 103] + }) + } + + # 创建MergeObj实例并执行合并 + merge_obj = MergeObj(output_dir) + merge_obj.merge_grid_obj(grid_points) diff --git a/post_pro/merge_ply.py b/post_pro/merge_ply.py new file mode 100644 index 0000000..688bc91 --- /dev/null +++ b/post_pro/merge_ply.py @@ -0,0 +1,188 @@ +import os +import logging +import numpy as np +from typing import Dict, Tuple +import pandas as pd +import open3d as o3d + + +class MergePly: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.MergePly') + + def read_corners_file(self, grid_idx: int) -> Tuple[float, float]: + """读取角点文件并计算中心点坐标 + 角点文件格式:xmin ymin xmax ymax + """ + corners_file = os.path.join( + self.output_dir, + f"grid_{grid_idx + 1}", + "project", + "odm_orthophoto", + "odm_orthophoto_corners.txt" + ) + + try: + if not os.path.exists(corners_file): + raise FileNotFoundError(f"角点文件不存在: {corners_file}") + + # 读取角点文件 + with open(corners_file, 'r') as f: + line = f.readline().strip() + if not line: + raise ValueError(f"角点文件为空: {corners_file}") + + # 解析四个角点值:xmin ymin xmax ymax + xmin, ymin, xmax, ymax = map(float, line.split()) + + # 计算中心点坐标 + center_x = (xmin + xmax) / 2 + center_y = (ymin + ymax) / 2 + + self.logger.info( + f"网格 {grid_idx + 1} 边界坐标: \n" + f"xmin={xmin:.2f}, ymin={ymin:.2f}\n" + f"xmax={xmax:.2f}, ymax={ymax:.2f}\n" + f"中心点: x={center_x:.2f}, y={center_y:.2f}" + ) + return center_x, center_y + + except Exception as e: + self.logger.error(f"读取角点文件时发生错误: {str(e)}", exc_info=True) + raise + + def merge_two_plys(self, ply1_path: str, ply2_path: str, output_path: str, + center1: Tuple[float, float], + center2: Tuple[float, float]): + """合并两个PLY文件,使用中心点坐标进行对齐""" + try: + self.logger.info("开始合并PLY点云") + self.logger.info(f"输入点云1: {ply1_path}") + self.logger.info(f"输入点云2: {ply2_path}") + self.logger.info(f"输出点云: {output_path}") + + # 检查输入文件是否存在 + if not os.path.exists(ply1_path) or not os.path.exists(ply2_path): + raise FileNotFoundError("输入点云文件不存在") + + # 读取点云 + pcd1 = o3d.io.read_point_cloud(ply1_path) + pcd2 = o3d.io.read_point_cloud(ply2_path) + + if pcd1 is None or pcd2 is None: + raise ValueError("无法读取点云文件") + + # 计算平移向量(直接使用中心点坐标差) + translation = np.array([ + center2[0] - center1[0], # x方向的平移 + center2[1] - center1[1], # y方向的平移 + 0.0 # z方向不平移 + ]) + + # 对第二个点云进行平移 + pcd2.translate(translation*100) + + # 合并点云 + combined_pcd = pcd1 + pcd2 + + # 保存合并后的点云 + if not o3d.io.write_point_cloud(output_path, combined_pcd): + raise RuntimeError("保存合并后的点云失败") + + self.logger.info(f"点云合并成功,已保存至: {output_path}") + + except Exception as e: + self.logger.error(f"合并PLY点云时发生错误: {str(e)}", exc_info=True) + raise + + def merge_grid_ply(self, grid_points: Dict[int, list]): + """合并所有网格的PLY点云,以第一个网格为参考点""" + self.logger.info("开始合并所有网格的PLY点云") + + if len(grid_points) < 2: + self.logger.info("只有一个网格,无需合并") + return + + try: + # 获取网格索引列表并排序 + grid_indices = sorted(grid_points.keys()) + + # 读取第一个网格作为参考网格 + ref_idx = grid_indices[0] + ref_ply = os.path.join( + self.output_dir, + f"grid_{ref_idx + 1}", + "project", + "odm_filterpoints", + "point_cloud.ply" + ) + + if not os.path.exists(ref_ply): + raise FileNotFoundError(f"参考网格的PLY文件不存在: {ref_ply}") + + # 获取参考网格的中心点坐标 + ref_center = self.read_corners_file(ref_idx) + self.logger.info(f"参考网格(grid_{ref_idx + 1})中心点: x={ref_center[0]:.2f}, y={ref_center[1]:.2f}") + + # 将参考点云复制到输出位置作为初始合并结果 + output_ply = os.path.join(self.output_dir, "merged_pointcloud.ply") + import shutil + shutil.copy2(ref_ply, output_ply) + + # 依次处理其他网格 + for grid_idx in grid_indices[1:]: + current_ply = os.path.join( + self.output_dir, + f"grid_{grid_idx + 1}", + "project", + "odm_filterpoints", + "point_cloud.ply" + ) + + if not os.path.exists(current_ply): + self.logger.warning(f"网格 {grid_idx + 1} 的PLY文件不存在: {current_ply}") + continue + + # 读取当前网格的中心点坐标 + current_center = self.read_corners_file(grid_idx) + + self.logger.info( + f"处理网格 {grid_idx + 1}:\n" + f"合并点云: {current_ply}\n" + f"当前网格中心点: x={current_center[0]:.2f}, y={current_center[1]:.2f}" + ) + + # 合并点云,始终使用第一个网格的中心点作为参考点 + self.merge_two_plys( + output_ply, # 当前合并结果 + current_ply, # 要合并的新点云 + output_ply, # 覆盖原有的合并结果 + ref_center, # 参考网格中心点(始终不变) + current_center # 当前网格中心点 + ) + + self.logger.info(f"PLY点云合并完成,最终输出文件: {output_ply}") + + except Exception as e: + self.logger.error(f"PLY点云合并过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + from utils.logger import setup_logger + import open3d as o3d + + # 设置输出目录和日志 + output_dir = r"G:\ODM_output\1009" + setup_logger(output_dir) + + # 构造测试用的grid_points字典 + grid_points = { + 0: [], # 不再需要GPS点信息 + 1: [] + } + + # 创建MergePly实例并执行合并 + merge_ply = MergePly(output_dir) + merge_ply.merge_grid_ply(grid_points) diff --git a/post_pro/merge_tif.py b/post_pro/merge_tif.py new file mode 100644 index 0000000..688b5e7 --- /dev/null +++ b/post_pro/merge_tif.py @@ -0,0 +1,198 @@ +from osgeo import gdal +import logging +import os +from typing import Dict +import pandas as pd + + +class MergeTif: + def __init__(self, output_dir: str): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.MergeTif') + + def merge_two_tifs(self, input_tif1: str, input_tif2: str, output_tif: str): + """合并两张TIF影像""" + try: + self.logger.info("开始合并TIF影像") + self.logger.info(f"输入影像1: {input_tif1}") + self.logger.info(f"输入影像2: {input_tif2}") + self.logger.info(f"输出影像: {output_tif}") + + # 检查输入文件是否存在 + if not os.path.exists(input_tif1) or not os.path.exists(input_tif2): + error_msg = "输入影像文件不存在" + self.logger.error(error_msg) + raise FileNotFoundError(error_msg) + + # 打开影像,检查投影是否一致 + datasets = [gdal.Open(tif) for tif in [input_tif1, input_tif2]] + if None in datasets: + error_msg = "无法打开输入影像文件" + self.logger.error(error_msg) + raise ValueError(error_msg) + + projections = [dataset.GetProjection() for dataset in datasets] + self.logger.debug(f"影像1投影: {projections[0]}") + self.logger.debug(f"影像2投影: {projections[1]}") + + # 检查投影是否一致 + if len(set(projections)) != 1: + error_msg = "影像的投影不一致,请先进行重投影!" + self.logger.error(error_msg) + raise ValueError(error_msg) + + # 创建 GDAL Warp 选项 + warp_options = gdal.WarpOptions( + format="GTiff", + resampleAlg="average", + srcNodata=0, + dstNodata=0, + multithread=True + ) + + self.logger.info("开始执行影像拼接...") + result = gdal.Warp( + output_tif, [input_tif1, input_tif2], options=warp_options) + + if result is None: + error_msg = "影像拼接失败" + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + # 获取输出影像的基本信息 + output_dataset = gdal.Open(output_tif) + if output_dataset: + width = output_dataset.RasterXSize + height = output_dataset.RasterYSize + bands = output_dataset.RasterCount + self.logger.info(f"拼接完成,输出影像大小: {width}x{height},波段数: {bands}") + + self.logger.info(f"影像拼接成功,输出文件保存至: {output_tif}") + + except Exception as e: + self.logger.error(f"影像拼接过程中发生错误: {str(e)}", exc_info=True) + raise + + def merge_grid_tif(self, grid_points: Dict[int, pd.DataFrame], product_info: dict): + """合并指定产品的所有网格""" + product_name = product_info['name'] + product_path = product_info['path'] + filename = product_info['filename'] + + self.logger.info(f"开始合并{product_name}") + + if len(grid_points) < 2: + self.logger.info("只有一个网格,无需合并") + return + + input_tif1, input_tif2 = None, None + merge_count = 0 + + try: + for grid_idx, points in grid_points.items(): + grid_tif = os.path.join( + self.output_dir, + f"grid_{grid_idx + 1}", + "project", + product_path, + filename + ) + + if not os.path.exists(grid_tif): + self.logger.warning( + f"网格 {grid_idx + 1} 的{product_name}不存在: {grid_tif}") + continue + + if input_tif1 is None: + input_tif1 = grid_tif + self.logger.info(f"设置第一个输入{product_name}: {input_tif1}") + else: + input_tif2 = grid_tif + output_tif = os.path.join( + self.output_dir, f"merged_{product_info['output']}") + + self.logger.info( + f"开始合并{product_name}第 {merge_count + 1} 次:\n" + f"输入1: {input_tif1}\n" + f"输入2: {input_tif2}\n" + f"输出: {output_tif}" + ) + + self.merge_two_tifs(input_tif1, input_tif2, output_tif) + merge_count += 1 + + input_tif1 = output_tif + input_tif2 = None + + self.logger.info( + f"{product_name}合并完成,共执行 {merge_count} 次合并," + f"最终输出文件: {input_tif1}" + ) + + except Exception as e: + self.logger.error( + f"{product_name}合并过程中发生错误: {str(e)}", exc_info=True) + raise + + def merge_all_tifs(self, grid_points: Dict[int, pd.DataFrame]): + """合并所有产品(正射影像、DSM和DTM)""" + try: + products = [ + { + 'name': '正射影像', + 'path': 'odm_orthophoto', + 'filename': 'odm_orthophoto.original.tif', + 'output': 'orthophoto.tif' + }, + { + 'name': 'DSM', + 'path': 'odm_dem', + 'filename': 'dsm.original.tif', + 'output': 'dsm.tif' + }, + { + 'name': 'DTM', + 'path': 'odm_dem', + 'filename': 'dtm.original.tif', + 'output': 'dtm.tif' + } + ] + + for product in products: + self.merge_grid_tif(grid_points, product) + + self.logger.info("所有产品合并完成") + except Exception as e: + self.logger.error(f"产品合并过程中发生错误: {str(e)}", exc_info=True) + raise + + +if __name__ == "__main__": + import sys + sys.path.append(os.path.dirname( + os.path.dirname(os.path.abspath(__file__)))) + from utils.logger import setup_logger + import pandas as pd + + # 设置输出目录和日志 + output_dir = r"G:\ODM_output\1009" + setup_logger(output_dir) + + # 构造测试用的grid_points字典 + # 假设我们有两个网格,每个网格包含一些GPS点的DataFrame + grid_points = { + 0: pd.DataFrame({ + 'latitude': [39.9, 39.91], + 'longitude': [116.3, 116.31], + 'altitude': [100, 101] + }), + 1: pd.DataFrame({ + 'latitude': [39.92, 39.93], + 'longitude': [116.32, 116.33], + 'altitude': [102, 103] + }) + } + + # 创建MergeTif实例并执行合并 + merge_tif = MergeTif(output_dir) + merge_tif.merge_all_tifs(grid_points) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..9bc2977 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +numpy +pandas +scikit-learn +matplotlib +piexif +geopy +psutil +docker>=6.1.3 +open3d diff --git a/tools/convert_jpg.py b/tools/convert_jpg.py new file mode 100644 index 0000000..402a8cb --- /dev/null +++ b/tools/convert_jpg.py @@ -0,0 +1,14 @@ +from PIL import Image +import os + +convert_format = "tif" +img_dir = r"E:\datasets\UAV\134\project\images" +output_dir = r"E:\datasets\UAV\134_tif\project\images" + +for file_name in os.listdir(img_dir): + img = Image.open(os.path.join(img_dir, file_name)) + if convert_format == "png": + img.save(os.path.join(output_dir, file_name.replace(".jpg", ".png"))) + elif convert_format == "tif": + img.save(os.path.join(output_dir, file_name.replace(".jpg", ".tif")), "TIFF") + diff --git a/tools/merge_obj.py b/tools/merge_obj.py new file mode 100644 index 0000000..2b3e714 --- /dev/null +++ b/tools/merge_obj.py @@ -0,0 +1,34 @@ +import open3d as o3d +import numpy as np + +# 读取 .obj 文件 +def load_obj(file_path): + mesh = o3d.io.read_triangle_mesh(file_path) + if not mesh.is_empty(): + return mesh + else: + raise ValueError(f"Failed to load {file_path}") + +# 合并两个网格 +def merge_meshes(mesh1, mesh2): + # 直接合并网格 + combined_mesh = mesh1 + mesh2 + return combined_mesh + +# 保存合并后的网格 +def save_merged_mesh(mesh, output_path): + o3d.io.write_triangle_mesh(output_path, mesh) + print(f"Saved merged mesh to {output_path}") + +# 示例用法 +mesh1 = load_obj("model1.obj") +mesh2 = load_obj("model2.obj") + +# 合并两个网格 +merged_mesh = merge_meshes(mesh1, mesh2) + +# 保存合并后的网格 +save_merged_mesh(merged_mesh, "merged_model.obj") + +# 可视化合并后的网格 +o3d.visualization.draw_geometries([merged_mesh]) diff --git a/tools/merge_ply.py b/tools/merge_ply.py new file mode 100644 index 0000000..a6330ff --- /dev/null +++ b/tools/merge_ply.py @@ -0,0 +1,22 @@ +import open3d as o3d +import numpy as np + +# 读取第一个PLY文件 +pcd1 = o3d.io.read_point_cloud("path_to_first_file.ply") + +# 读取第二个PLY文件 +pcd2 = o3d.io.read_point_cloud("path_to_second_file.ply") + +# 可选:如果需要调整坐标系,可以通过平移、旋转来对齐点云 +# 例如,平移第二个点云 +offset = np.array([1000, 2000, 3000]) +pcd2.translate(offset) + +# 合并点云 +combined_pcd = pcd1 + pcd2 + +# 保存合并后的点云为PLY文件 +o3d.io.write_point_cloud("merged_output.ply", combined_pcd) + +# 可视化 +o3d.visualization.draw_geometries([combined_pcd]) diff --git a/tools/odm_pip_time.py b/tools/odm_pip_time.py new file mode 100644 index 0000000..2dba550 --- /dev/null +++ b/tools/odm_pip_time.py @@ -0,0 +1,55 @@ +from datetime import datetime +import json + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="ODM log time") + + parser.add_argument( + "--path", default=r"E:\datasets\UAV\134\project\log.json") + args = parser.parse_args() + return args + + +def main(args): + # 读取 JSON 文件 + with open(args.path, 'r') as file: + data = json.load(file) + + # 提取 "stages" 中每个步骤的开始时间和持续时间 + stage_timings = [] + for i, stage in enumerate(data.get("stages", [])): + stage_name = stage.get("name", "Unnamed Stage") + start_time = stage.get("startTime") + + # 获取当前阶段的开始时间 + if start_time: + start_dt = datetime.fromisoformat(start_time) + + # 获取阶段的结束时间:可以是下一个阶段的开始时间,或当前阶段的 `endTime`(如果存在) + if i + 1 < len(data["stages"]): + end_time = data["stages"][i + 1].get("startTime") + else: + end_time = stage.get("endTime") or data.get("endTime") + + if end_time: + end_dt = datetime.fromisoformat(end_time) + duration = (end_dt - start_dt).total_seconds() + stage_timings.append((stage_name, duration)) + + # 输出每个阶段的持续时间,调整为对齐格式 + total_time = 0 + print(f"{'Stage Name':<25} {'Duration (seconds)':>15}") + print("=" * 45) + for stage_name, duration in stage_timings: + print(f"{stage_name:<25} {duration:>15.2f}") + total_time += duration + + print('Total Time:', total_time) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/tools/show_GPS.py b/tools/show_GPS.py new file mode 100644 index 0000000..8ea9664 --- /dev/null +++ b/tools/show_GPS.py @@ -0,0 +1,51 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +from utils.gps_extractor import GPSExtractor + +DATASET = r'E:\datasets\UAV\1009\project\images' + +if __name__ == '__main__': + extractor = GPSExtractor(DATASET) + gps_points = extractor.extract_all_gps() + + # 创建两个子图 + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8)) + + # 左图:原始散点图 + ax1.scatter(gps_points['lon'], gps_points['lat'], + color='blue', marker='o', label='GPS Points') + ax1.set_title("GPS Coordinates of Images", fontsize=14) + ax1.set_xlabel("Longitude", fontsize=12) + ax1.set_ylabel("Latitude", fontsize=12) + ax1.grid(True) + ax1.legend() + + # # 右图:按时间排序的轨迹图 + # gps_points_sorted = gps_points.sort_values('date') + + # # 绘制飞行轨迹线 + # ax2.plot(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600], + # color='blue', linestyle='-', linewidth=1, alpha=0.6) + + # # 绘制GPS点 + # ax2.scatter(gps_points_sorted['lon'][300:600], gps_points_sorted['lat'][300:600], + # color='red', marker='o', s=30, label='GPS Points') + + # 标记起点和终点 + # ax2.scatter(gps_points_sorted['lon'].iloc[0], gps_points_sorted['lat'].iloc[0], + # color='green', marker='^', s=100, label='Start') + # ax2.scatter(gps_points_sorted['lon'].iloc[-1], gps_points_sorted['lat'].iloc[-1], + # color='purple', marker='s', s=100, label='End') + + ax2.set_title("UAV Flight Trajectory", fontsize=14) + ax2.set_xlabel("Longitude", fontsize=12) + ax2.set_ylabel("Latitude", fontsize=12) + ax2.grid(True) + ax2.legend() + + # 调整子图之间的间距 + plt.tight_layout() + plt.show() diff --git a/tools/show_GPS_by_time.py b/tools/show_GPS_by_time.py new file mode 100644 index 0000000..635b0a3 --- /dev/null +++ b/tools/show_GPS_by_time.py @@ -0,0 +1,138 @@ +import os +import sys +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import matplotlib.pyplot as plt +from datetime import timedelta +import logging +import numpy as np +from utils.gps_extractor import GPSExtractor +from utils.logger import setup_logger + +class GPSTimeVisualizer: + """按时间组可视化GPS点""" + + def __init__(self, image_dir: str, output_dir: str): + self.image_dir = image_dir + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.GPSVisualizer') + + def _group_by_time(self, points_df, time_threshold=timedelta(minutes=5)): + """按时间间隔对点进行分组""" + if 'date' not in points_df.columns: + self.logger.error("数据中缺少date列") + return [points_df] + + # 将date为空的行单独作为一组 + null_date_group = points_df[points_df['date'].isna()] + valid_date_points = points_df[points_df['date'].notna()] + + if not null_date_group.empty: + self.logger.info(f"发现 {len(null_date_group)} 个无时间戳的点,将作为单独分组") + + if valid_date_points.empty: + self.logger.warning("没有有效的时间戳数据") + return [null_date_group] if not null_date_group.empty else [] + + # 按时间排序 + valid_date_points = valid_date_points.sort_values('date') + + # 计算时间差 + time_diffs = valid_date_points['date'].diff() + + # 找到时间差超过阈值的位置 + time_groups = [] + current_group_start = 0 + + for idx, time_diff in enumerate(time_diffs): + if time_diff and time_diff > time_threshold: + # 添加当前组 + current_group = valid_date_points.iloc[current_group_start:idx] + time_groups.append(current_group) + current_group_start = idx + + # 添加最后一组 + last_group = valid_date_points.iloc[current_group_start:] + if not last_group.empty: + time_groups.append(last_group) + + # 如果有空时间戳的点,将其作为最后一组 + if not null_date_group.empty: + time_groups.append(null_date_group) + + return time_groups + + def visualize_time_groups(self, time_threshold=timedelta(minutes=5)): + """在同一张图上显示所有时间组,用不同颜色区分""" + # 提取GPS数据 + extractor = GPSExtractor(self.image_dir) + gps_points = extractor.extract_all_gps() + + # 按时间分组 + time_groups = self._group_by_time(gps_points, time_threshold) + + # 创建图形 + plt.figure(figsize=(15, 10)) + + # 生成不同的颜色 + colors = plt.cm.rainbow(np.linspace(0, 1, len(time_groups))) + + # 为每个时间组绘制点和轨迹 + for idx, (group, color) in enumerate(zip(time_groups, colors)): + if not group['date'].isna().any(): + # 有时间戳的组 + sorted_group = group.sort_values('date') + + # 绘制轨迹线 + plt.plot(sorted_group['lon'], sorted_group['lat'], + color=color, linestyle='-', linewidth=1.5, alpha=0.6, + label=f'Flight Path {idx + 1}') + + # 绘制GPS点 + plt.scatter(sorted_group['lon'], sorted_group['lat'], + color=color, marker='o', s=30, alpha=0.6) + + # 标记起点和终点 + plt.scatter(sorted_group['lon'].iloc[0], sorted_group['lat'].iloc[0], + color=color, marker='^', s=100, + label=f'Start {idx + 1} ({sorted_group["date"].min().strftime("%H:%M:%S")})') + plt.scatter(sorted_group['lon'].iloc[-1], sorted_group['lat'].iloc[-1], + color=color, marker='s', s=100, + label=f'End {idx + 1} ({sorted_group["date"].max().strftime("%H:%M:%S")})') + else: + # 无时间戳的组 + plt.scatter(group['lon'], group['lat'], + color=color, marker='x', s=50, alpha=0.6, + label='No Timestamp Points') + + plt.title("GPS Points by Time Groups", fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 调整图例位置和大小 + plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10) + + # 调整布局以适应图例 + plt.tight_layout() + + # 保存图片 + plt.savefig(os.path.join(self.output_dir, 'gps_time_groups_combined.png'), + dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info(f"已生成包含 {len(time_groups)} 个时间组的组合可视化图形") + + +if __name__ == '__main__': + # 设置数据集路径 + DATASET = r'F:\error_data\20241108134711\3D' + output_dir = r'E:\studio2\ODM_pro\test' + os.makedirs(output_dir, exist_ok=True) + + # 设置日志 + setup_logger(os.path.dirname(output_dir)) + + # 创建可视化器并生成图形 + visualizer = GPSTimeVisualizer(DATASET, output_dir) + visualizer.visualize_time_groups(time_threshold=timedelta(minutes=5)) \ No newline at end of file diff --git a/tools/test_docker_run.py b/tools/test_docker_run.py new file mode 100644 index 0000000..9f3a2f8 --- /dev/null +++ b/tools/test_docker_run.py @@ -0,0 +1,12 @@ +import subprocess + + +def run_docker_command(command): + result = subprocess.run(command, shell=True, + stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return result.stdout.decode('utf-8'), result.stderr.decode('utf-8') + +if __name__ == "__main__": + command = "docker run -ti --rm -v g:/ODM_output/20241024100834/grid_1:/datasets opendronemap/odm --project-path /datasets project --max-concurrency 10 --force-gps --feature-quality lowest --orthophoto-resolution 10 --fast-orthophoto --skip-3dmodel --rerun-all" + stdout, stderr = run_docker_command(command) + print(stdout) diff --git a/utils/gps_extractor.py b/utils/gps_extractor.py new file mode 100644 index 0000000..9bb2e57 --- /dev/null +++ b/utils/gps_extractor.py @@ -0,0 +1,93 @@ +import os +from PIL import Image +import piexif +import logging +import pandas as pd +from datetime import datetime + + +class GPSExtractor: + """从图像文件提取GPS坐标和拍摄日期""" + + def __init__(self, image_dir): + self.image_dir = image_dir + self.logger = logging.getLogger('UAV_Preprocess.GPSExtractor') + + @staticmethod + def _dms_to_decimal(dms): + """将DMS格式转换为十进制度""" + return dms[0][0] / dms[0][1] + (dms[1][0] / dms[1][1]) / 60 + (dms[2][0] / dms[2][1]) / 3600 + + @staticmethod + def _parse_datetime(datetime_str): + """解析EXIF中的日期时间字符串""" + try: + # EXIF日期格式通常为 'YYYY:MM:DD HH:MM:SS' + return datetime.strptime(datetime_str.decode(), '%Y:%m:%d %H:%M:%S') + except Exception: + return None + + def get_gps_and_date(self, image_path): + """提取单张图片的GPS坐标和拍摄日期""" + try: + image = Image.open(image_path) + exif_data = piexif.load(image.info['exif']) + + # 提取GPS信息 + gps_info = exif_data.get("GPS", {}) + lat = lon = None + if gps_info: + lat = self._dms_to_decimal(gps_info.get(2, [])) + lon = self._dms_to_decimal(gps_info.get(4, [])) + self.logger.debug(f"成功提取图片GPS坐标: {image_path} - 纬度: {lat}, 经度: {lon}") + + # 提取拍摄日期 + date_info = None + if "Exif" in exif_data: + # 优先使用DateTimeOriginal + date_str = exif_data["Exif"].get(36867) # DateTimeOriginal + if not date_str: + # 备选DateTime + date_str = exif_data["Exif"].get(36868) # DateTimeDigitized + if not date_str: + # 最后使用基本DateTime + date_str = exif_data["0th"].get(306) # DateTime + + if date_str: + date_info = self._parse_datetime(date_str) + self.logger.debug(f"成功提取图片拍摄日期: {image_path} - {date_info}") + + if not gps_info: + self.logger.warning(f"图片无GPS信息: {image_path}") + if not date_info: + self.logger.warning(f"图片无拍摄日期信息: {image_path}") + + return lat, lon, date_info + + except Exception as e: + self.logger.error(f"提取图片信息时发生错误: {image_path} - {str(e)}") + return None, None, None + + def extract_all_gps(self): + """提取所有图片的GPS坐标和拍摄日期""" + self.logger.info(f"开始从目录提取GPS坐标和拍摄日期: {self.image_dir}") + gps_data = [] + total_images = 0 + successful_extractions = 0 + + for image_file in os.listdir(self.image_dir): + if image_file.lower().endswith('.jpg'): + total_images += 1 + image_path = os.path.join(self.image_dir, image_file) + lat, lon, date = self.get_gps_and_date(image_path) + if lat and lon: # 仍然以GPS信息作为主要判断依据 + successful_extractions += 1 + gps_data.append({ + 'file': image_file, + 'lat': lat, + 'lon': lon, + 'date': date + }) + + self.logger.info(f"GPS坐标和拍摄日期提取完成 - 总图片数: {total_images}, 成功提取: {successful_extractions}, 失败: {total_images - successful_extractions}") + return pd.DataFrame(gps_data) diff --git a/utils/grid_divider.py b/utils/grid_divider.py new file mode 100644 index 0000000..f4fd251 --- /dev/null +++ b/utils/grid_divider.py @@ -0,0 +1,136 @@ +import logging +from geopy.distance import geodesic +import matplotlib.pyplot as plt +import os + + +class GridDivider: + """划分九宫格,并将图片分配到对应网格""" + + def __init__(self, overlap=0.1, output_dir=None): + self.overlap = overlap + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.GridDivider') + self.logger.info(f"初始化网格划分器,重叠率: {overlap}") + + def divide_grids(self, points_df, grid_size=500): + """计算边界框并划分九宫格""" + self.logger.info("开始划分九宫格") + + min_lat, max_lat = points_df['lat'].min(), points_df['lat'].max() + min_lon, max_lon = points_df['lon'].min(), points_df['lon'].max() + + # 计算区域的实际距离(米) + width = geodesic((min_lat, min_lon), (min_lat, max_lon)).meters + height = geodesic((min_lat, min_lon), (max_lat, min_lon)).meters + + self.logger.info( + f"区域宽度: {width:.2f}米, 高度: {height:.2f}米" + ) + + # 计算需要划分的网格数量 + num_grids_width = int( + width / grid_size) if int(width / grid_size) > 0 else 1 + num_grids_height = int( + height / grid_size) if int(height / grid_size) > 0 else 1 + + # 计算每个网格对应的经纬度步长 + lat_step = (max_lat - min_lat) / num_grids_height + lon_step = (max_lon - min_lon) / num_grids_width + + grids = [] + for i in range(num_grids_height): + for j in range(num_grids_width): + grid_min_lat = min_lat + i * lat_step - self.overlap * lat_step + grid_max_lat = min_lat + \ + (i + 1) * lat_step + self.overlap * lat_step + grid_min_lon = min_lon + j * lon_step - self.overlap * lon_step + grid_max_lon = min_lon + \ + (j + 1) * lon_step + self.overlap * lon_step + grids.append((grid_min_lat, grid_max_lat, + grid_min_lon, grid_max_lon)) + + self.logger.debug( + f"网格[{i},{j}]: 纬度[{grid_min_lat:.6f}, {grid_max_lat:.6f}], " + f"经度[{grid_min_lon:.6f}, {grid_max_lon:.6f}]" + ) + + self.logger.info( + f"成功划分为 {len(grids)} 个网格 ({num_grids_width}x{num_grids_height})") + + # 添加可视化调用 + self.visualize_grids(points_df, grids) + + return grids + + def assign_to_grids(self, points_df, grids): + """将点分配到对应网格""" + self.logger.info(f"开始将 {len(points_df)} 个点分配到网格中") + + grid_points = {i: [] for i in range(len(grids))} + points_assigned = 0 + multiple_grid_points = 0 + + for _, point in points_df.iterrows(): + point_assigned = False + for i, (min_lat, max_lat, min_lon, max_lon) in enumerate(grids): + if min_lat <= point['lat'] <= max_lat and min_lon <= point['lon'] <= max_lon: + grid_points[i].append(point.to_dict()) + if point_assigned: + multiple_grid_points += 1 + else: + points_assigned += 1 + point_assigned = True + + self.logger.debug( + f"点 {point['file']} (纬度: {point['lat']:.6f}, 经度: {point['lon']:.6f}) " + f"被分配到网格" + ) + + # 记录每个网格的点数 + for grid_idx, points in grid_points.items(): + self.logger.info(f"网格 {grid_idx} 包含 {len(points)} 个点") + + self.logger.info( + f"点分配完成: 总点数 {len(points_df)}, " + f"成功分配 {points_assigned} 个点, " + f"{multiple_grid_points} 个点被分配到多个网格" + ) + + return grid_points + + def visualize_grids(self, points_df, grids): + """可视化网格划分和GPS点的分布""" + self.logger.info("开始可视化网格划分") + + plt.figure(figsize=(12, 8)) + + # 绘制GPS点 + plt.scatter(points_df['lon'], points_df['lat'], + c='blue', s=10, alpha=0.6, label='GPS点') + + # 绘制网格 + for i, (min_lat, max_lat, min_lon, max_lon) in enumerate(grids): + plt.plot([min_lon, max_lon, max_lon, min_lon, min_lon], + [min_lat, min_lat, max_lat, max_lat, min_lat], + 'r-', alpha=0.5) + # 在网格中心添加网格编号 + center_lon = (min_lon + max_lon) / 2 + center_lat = (min_lat + max_lat) / 2 + plt.text(center_lon, center_lat, str(i), + horizontalalignment='center', verticalalignment='center') + + plt.title('网格划分与GPS点分布图') + plt.xlabel('经度') + plt.ylabel('纬度') + plt.legend() + plt.grid(True) + + # 如果提供了输出目录,保存图像 + if self.output_dir: + save_path = os.path.join( + self.output_dir, 'filter_imgs', 'grid_division.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + self.logger.info(f"网格划分可视化图已保存至: {save_path}") + + plt.close() diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..9aee1a6 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,35 @@ +import logging +import os +from datetime import datetime + +def setup_logger(output_dir): + # 创建logs目录 + log_dir = os.path.join(output_dir, 'logs') + + # 创建日志文件名(包含时间戳) + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + log_file = os.path.join(log_dir, f'preprocess_{timestamp}.log') + + # 配置日志格式 + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + + # 配置文件处理器 + file_handler = logging.FileHandler(log_file, encoding='utf-8') + file_handler.setFormatter(formatter) + + # 配置控制台处理器 + console_handler = logging.StreamHandler() + console_handler.setFormatter(formatter) + + # 获取根日志记录器 + logger = logging.getLogger('UAV_Preprocess') + logger.setLevel(logging.INFO) + + # 添加处理器 + logger.addHandler(file_handler) + logger.addHandler(console_handler) + + return logger \ No newline at end of file diff --git a/utils/odm_monitor.py b/utils/odm_monitor.py new file mode 100644 index 0000000..9b50c1a --- /dev/null +++ b/utils/odm_monitor.py @@ -0,0 +1,78 @@ +import os +import logging +import subprocess +from typing import Dict, Tuple +import pandas as pd + + +class ODMProcessMonitor: + """ODM处理监控器""" + + def __init__(self, output_dir: str, mode: str = "快拼模式"): + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.ODMMonitor') + self.mode = mode + + def _check_success(self, grid_dir: str) -> bool: + """检查ODM是否执行成功""" + success_markers = ['odm_orthophoto', 'odm_georeferencing'] + if self.mode != "快拼模式": + success_markers.append('odm_texturing') + return all(os.path.exists(os.path.join(grid_dir, 'project', marker)) for marker in success_markers) + + def run_odm_with_monitor(self, grid_dir: str, grid_idx: int, fast_mode: bool = True) -> Tuple[bool, str]: + """运行ODM命令""" + self.logger.info(f"开始处理网格 {grid_idx + 1}") + + # 构建Docker命令 + grid_dir = grid_dir[0].lower()+grid_dir[1:].replace('\\', '/') + docker_command = ( + f"docker run --gpus all -ti --rm " + f"-v {grid_dir}:/datasets " + f"opendronemap/odm:gpu " + f"--project-path /datasets project " + f"--max-concurrency 10 " + f"--force-gps " + f"--feature-quality lowest " + f"--orthophoto-resolution 10 " + ) + + if fast_mode: + docker_command += ( + f"--fast-orthophoto " + f"--skip-3dmodel " + ) + + docker_command += "--rerun-all" + self.logger.info(docker_command) + result = subprocess.run( + docker_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = result.stdout.decode( + 'utf-8'), result.stderr.decode('utf-8') + + self.logger.info(f"==========stdout==========: {stdout}") + self.logger.error(f"==========stderr==========: {stderr}") + # 检查执行结果 + if self._check_success(grid_dir): + self.logger.info(f"网格 {grid_idx + 1} 处理成功") + return True, "" + else: + self.logger.error(f"网格 {grid_idx + 1} 处理失败") + return False, f"网格 {grid_idx + 1} 处理失败" + + def process_all_grids(self, grid_points: Dict[int, pd.DataFrame]): + """处理所有网格""" + self.logger.info("开始执行网格处理") + for grid_idx in grid_points.keys(): + grid_dir = os.path.join( + self.output_dir, f'grid_{grid_idx + 1}' + ) + + success, error_msg = self.run_odm_with_monitor( + grid_dir=grid_dir, + grid_idx=grid_idx, + fast_mode=(self.mode == "快拼模式") + ) + + if not success: + raise Exception(f"网格 {grid_idx + 1} 处理失败: {error_msg}") diff --git a/utils/visualizer.py b/utils/visualizer.py new file mode 100644 index 0000000..f74f5a9 --- /dev/null +++ b/utils/visualizer.py @@ -0,0 +1,121 @@ +import os +import matplotlib.pyplot as plt +import pandas as pd +import logging +from typing import Optional + + +class FilterVisualizer: + """过滤结果可视化器""" + + def __init__(self, output_dir: str): + """ + 初始化可视化器 + + Args: + output_dir: 输出目录路径 + """ + self.output_dir = output_dir + self.logger = logging.getLogger('UAV_Preprocess.Visualizer') + + def visualize_filter_step(self, + current_points: pd.DataFrame, + previous_points: pd.DataFrame, + step_name: str, + save_name: Optional[str] = None): + """ + 可视化单个过滤步骤的结果 + + Args: + current_points: 当前步骤后的点 + previous_points: 上一步骤的点 + step_name: 步骤名称 + save_name: 保存文件名,默认为step_name + """ + self.logger.info(f"开始生成{step_name}的可视化结果") + + # 找出被过滤掉的点 + filtered_files = set(previous_points['file']) - set(current_points['file']) + filtered_points = previous_points[previous_points['file'].isin(filtered_files)] + + # 创建图形 + plt.figure(figsize=(20, 16)) + + # 绘制保留的点 + plt.scatter(current_points['lon'], current_points['lat'], + color='blue', label='Retained Points', + alpha=0.6, s=50) + + # 绘制被过滤的点 + if not filtered_points.empty: + plt.scatter(filtered_points['lon'], filtered_points['lat'], + color='red', marker='x', label='Filtered Points', + alpha=0.6, s=100) + + # 设置图形属性 + plt.title(f"GPS Points After {step_name}\n" + f"(Filtered: {len(filtered_points)}, Retained: {len(current_points)})", + fontsize=14) + plt.xlabel("Longitude", fontsize=12) + plt.ylabel("Latitude", fontsize=12) + plt.grid(True) + + # 添加统计信息 + stats_text = ( + f"Original Points: {len(previous_points)}\n" + f"Filtered Points: {len(filtered_points)}\n" + f"Remaining Points: {len(current_points)}\n" + f"Filter Rate: {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + plt.figtext(0.02, 0.02, stats_text, fontsize=10, + bbox=dict(facecolor='white', alpha=0.8)) + + # 添加图例 + plt.legend(loc='upper right', fontsize=10) + + # 调整布局 + plt.tight_layout() + + # 保存图形 + save_name = save_name or step_name.lower().replace(' ', '_') + save_path = os.path.join(self.output_dir, 'filter_imgs', f'filter_{save_name}.png') + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + self.logger.info( + f"{step_name}过滤可视化结果已保存至 {save_path}\n" + f"过滤掉 {len(filtered_points)} 个点," + f"保留 {len(current_points)} 个点," + f"过滤率 {len(filtered_points)/len(previous_points)*100:.1f}%" + ) + + +if __name__ == '__main__': + # 测试代码 + import numpy as np + from datetime import datetime + + # 创建测试数据 + np.random.seed(42) + n_points = 1000 + + # 生成随机点 + test_data = pd.DataFrame({ + 'lon': np.random.uniform(120, 121, n_points), + 'lat': np.random.uniform(30, 31, n_points), + 'file': [f'img_{i}.jpg' for i in range(n_points)], + 'date': [datetime.now() for _ in range(n_points)] + }) + + # 随机选择点作为过滤后的结果 + filtered_data = test_data.sample(n=800) + + # 测试可视化 + visualizer = FilterVisualizer('test_output') + os.makedirs('test_output', exist_ok=True) + + visualizer.visualize_filter_step( + filtered_data, + test_data, + "Test Filter" + ) \ No newline at end of file