from utils.gps_extractor import GPSExtractor
import os
import sys
import shutil
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.widgets import RectangleSelector
import pandas as pd
from matplotlib.font_manager import FontProperties


class GPSSelector:
    def __init__(self, image_dir: str, output_dir: str = None):
        # 移除中文字体设置
        self.image_dir = image_dir
        self.output_dir = output_dir
        self.gps_points = None
        self.selected_points = []
        self.fig, self.ax = plt.subplots(figsize=(12, 8))
        self.scatter = None
        self.rs = None
        self.setup_plot()

    def extract_gps(self):
        """提取GPS数据"""
        extractor = GPSExtractor(self.image_dir)
        self.gps_points = extractor.extract_all_gps()
        print(f"成功提取 {len(self.gps_points)} 个GPS点")

    def setup_plot(self):
        """设置绘图"""
        self.ax.set_title('GPS Points - Use mouse to drag and select points to delete')
        self.ax.set_xlabel('Longitude')
        self.ax.set_ylabel('Latitude')
        self.ax.grid(True)

        # 设置坐标轴使用相同的比例
        self.ax.set_aspect('equal')

        # 设置矩形选择器
        self.rs = RectangleSelector(
            self.ax, self.on_select,
            interactive=True,
            useblit=True,
            button=[1],  # 只响应左键
            props=dict(facecolor='red', alpha=0.3)
        )

        # 添加按钮回调
        self.fig.canvas.mpl_connect('key_press_event', self.on_key_press)
        
        # 添加缩放和平移功能
        self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
        self.fig.canvas.mpl_connect('button_press_event', self.on_press)
        self.fig.canvas.mpl_connect('button_release_event', self.on_release)
        self.fig.canvas.mpl_connect('motion_notify_event', self.on_motion)
        
        # 用于平移功能的变量
        self._pan_start = None

    def plot_gps_points(self):
        """绘制GPS点"""
        if self.scatter is not None:
            self.scatter.remove()

        # 计算经纬度的范围
        lon_range = self.gps_points['lon'].max() - self.gps_points['lon'].min()
        lat_range = self.gps_points['lat'].max() - self.gps_points['lat'].min()
        
        # 设置合适的图形大小,保持经纬度的真实比例
        aspect_ratio = lon_range / lat_range
        fig_width = 12
        fig_height = fig_width / aspect_ratio
        self.fig.set_size_inches(fig_width, fig_height)

        self.scatter = self.ax.scatter(
            self.gps_points['lon'],
            self.gps_points['lat'],
            c='blue',
            s=20,
            alpha=0.6
        )
        
        # 设置适当的显示范围,添加一些边距
        margin = 0.1
        x_margin = lon_range * margin
        y_margin = lat_range * margin
        
        self.ax.set_xlim([
            self.gps_points['lon'].min() - x_margin,
            self.gps_points['lon'].max() + x_margin
        ])
        self.ax.set_ylim([
            self.gps_points['lat'].min() - y_margin,
            self.gps_points['lat'].max() + y_margin
        ])
        
        # 关闭自动缩放
        self.ax.autoscale(False)
        
        # 使用更精确的刻度
        self.ax.ticklabel_format(useOffset=False, style='plain')
        
        self.fig.canvas.draw_idle()

    def on_select(self, eclick, erelease):
        """矩形选择回调"""
        x1, y1 = eclick.xdata, eclick.ydata
        x2, y2 = erelease.xdata, erelease.ydata

        # 获取选中区域内的点
        mask = (
            (self.gps_points['lon'] >= min(x1, x2)) &
            (self.gps_points['lon'] <= max(x1, x2)) &
            (self.gps_points['lat'] >= min(y1, y2)) &
            (self.gps_points['lat'] <= max(y1, y2))
        )

        selected = self.gps_points[mask]
        self.selected_points.extend(selected['file'].tolist())

        # 从数据中移除选中的点
        self.gps_points = self.gps_points[~mask]

        # 更新绘图
        self.plot_gps_points()
        print(f"选中 {len(selected)} 个点,剩余 {len(self.gps_points)} 个点")

    def on_key_press(self, event):
        """键盘事件回调"""
        if event.key == 'enter':
            self.save_results()
            plt.close()
        elif event.key == 'escape':
            plt.close()

    def save_results(self):
        """保存结果"""
        if not self.output_dir:
            return

        # 创建输出目录
        os.makedirs(self.output_dir, exist_ok=True)

        # 获取所有保留的图片文件名
        remaining_files = self.gps_points['file'].tolist()

        # 移动保留的图片到输出目录
        for img_name in remaining_files:
            src = os.path.join(self.image_dir, img_name)
            dst = os.path.join(self.output_dir, img_name)
            shutil.copy2(src, dst)  # 使用copy2保留文件的元数据

        # 保存剩余点的信息
        self.gps_points.to_csv(
            os.path.join(self.output_dir, "remaining_points.csv"),
            index=False
        )

        print(f"已选择删除 {len(self.selected_points)} 张图片")
        print(f"已复制 {len(remaining_files)} 张保留的图片到 {self.output_dir}")

    def run(self):
        """运行选择器"""
        self.extract_gps()
        self.plot_gps_points()
        plt.show()

    def on_scroll(self, event):
        """鼠标滚轮缩放"""
        if event.inaxes != self.ax:
            return
        
        # 获取当前视图范围
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()
        
        # 缩放因子
        base_scale = 1.1
        xdata = event.xdata
        ydata = event.ydata
        
        if event.button == 'up':
            # 放大
            scale_factor = 1/base_scale
        else:
            # 缩小
            scale_factor = base_scale
        
        # 设置新的视图范围
        new_width = (cur_xlim[1] - cur_xlim[0]) * scale_factor
        new_height = (cur_ylim[1] - cur_ylim[0]) * scale_factor
        
        self.ax.set_xlim([xdata - new_width * (xdata - cur_xlim[0]) / (cur_xlim[1] - cur_xlim[0]),
                          xdata + new_width * (cur_xlim[1] - xdata) / (cur_xlim[1] - cur_xlim[0])])
        self.ax.set_ylim([ydata - new_height * (ydata - cur_ylim[0]) / (cur_ylim[1] - cur_ylim[0]),
                          ydata + new_height * (cur_ylim[1] - ydata) / (cur_ylim[1] - cur_ylim[0])])
        
        self.fig.canvas.draw_idle()

    def on_press(self, event):
        """鼠标按下事件"""
        if event.inaxes != self.ax or event.button != 3:  # 只响应右键
            return
        self._pan_start = (event.xdata, event.ydata)

    def on_release(self, event):
        """鼠标释放事件"""
        self._pan_start = None

    def on_motion(self, event):
        """鼠标移动事件"""
        if self._pan_start is None or event.inaxes != self.ax:
            return
        
        # 计算移动距离
        dx = event.xdata - self._pan_start[0]
        dy = event.ydata - self._pan_start[1]
        
        # 更新视图范围
        cur_xlim = self.ax.get_xlim()
        cur_ylim = self.ax.get_ylim()
        
        self.ax.set_xlim(cur_xlim - dx)
        self.ax.set_ylim(cur_ylim - dy)
        
        self.fig.canvas.draw_idle()


if __name__ == "__main__":
    # 使用示例
    selector = GPSSelector(
        image_dir=r"G:\error_data\20240930091614\project\images",
        output_dir=r"C:\datasets\ODM_output\error1_L"
    )
    selector.run()