From 5e0d4382805ebf2824f1f2f71850f576ad4721cb Mon Sep 17 00:00:00 2001 From: weixin_46229132 Date: Wed, 14 May 2025 20:45:42 +0800 Subject: [PATCH] first commit --- data_preprocessing/cut_smalltif.py | 138 ++++++++++ data_preprocessing/cut_smalltif_multi.py | 169 ++++++++++++ data_preprocessing/sift_tif.py | 54 ++++ data_preprocessing/split_data.py | 21 ++ geotiff_utils.py | 296 ++++++++++++++++++++ train_LS.py | 156 +++++++++++ utils.py | 330 +++++++++++++++++++++++ 7 files changed, 1164 insertions(+) create mode 100644 data_preprocessing/cut_smalltif.py create mode 100644 data_preprocessing/cut_smalltif_multi.py create mode 100644 data_preprocessing/sift_tif.py create mode 100644 data_preprocessing/split_data.py create mode 100644 geotiff_utils.py create mode 100644 train_LS.py create mode 100644 utils.py diff --git a/data_preprocessing/cut_smalltif.py b/data_preprocessing/cut_smalltif.py new file mode 100644 index 0000000..8fb8873 --- /dev/null +++ b/data_preprocessing/cut_smalltif.py @@ -0,0 +1,138 @@ +import numpy as np +from osgeo import gdal + + +def read_tif(fileName): + dataset = gdal.Open(fileName) + + im_width = dataset.RasterXSize # 栅格矩阵的列数 + im_height = dataset.RasterYSize # 栅格矩阵的行数 + im_bands = dataset.RasterCount # 波段数 + im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据 + if len(im_data.shape) == 2: + im_data = im_data[np.newaxis, :] + im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息 + im_proj = dataset.GetProjection() # 获取投影信息 + + return im_data, im_width, im_height, im_bands, im_geotrans, im_proj + + +def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj): + if 'int8' in im_data.dtype.name: + datatype = gdal.GDT_Byte + elif 'int16' in im_data.dtype.name: + datatype = gdal.GDT_UInt16 + else: + datatype = gdal.GDT_Float32 + + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + # 创建文件 + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create(path, im_width, im_height, im_bands, datatype) + if dataset != None and im_geotrans != None and im_proj != None: + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + dataset.SetProjection(im_proj) # 写入投影 + for i in range(im_bands): + dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) + del dataset + + +fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif' +im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif( + fileName) + +fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif' +im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif( + fileName2) + +mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif' +mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif( + mask_fileName) +mask_im_data = np.int8(mask_im_data) + + +# geotiff归一化 +for i in range(im_bands): + arr = im_data[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data[i] = normalized_arr + +for i in range(im_bands2): + arr = im_data2[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data2[i] = normalized_arr + + +# 计算大图每个波段的均值和方差,train.py里transform会用到 +im_data = im_data/255 +for i in range(im_bands): + pixels = im_data[i, :, :].ravel() + print("波段{} mean: {:.4f}, std: {:.4f}".format( + i, np.mean(pixels), np.std(pixels))) +im_data = im_data*255 + +im_data2 = im_data2/255 +for i in range(im_bands2): + pixels = im_data2[i, :, :].ravel() + print("波段{} mean: {:.4f}, std: {:.4f}".format( + i, np.mean(pixels), np.std(pixels))) +im_data2 = im_data2*255 + + +# 切成小图 +a = 0 +size = 224 +for i in range(0, int(mask_im_height / size)): + for j in range(0, int(mask_im_width / size)): + im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size] + im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size] + mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size] + + # 以mask为判断基准,同时处理geotiff和mask + labelfla = np.array(mask_cut).flatten() + if np.all(labelfla == 15): # 15为NoData + print("Skip!!!") + else: + # 5m + left_h = i * size * im_geotrans[5] + im_geotrans[3] + left_w = j * size * im_geotrans[1] + im_geotrans[0] + new_geotrans = np.array(im_geotrans) + new_geotrans[0] = left_w + new_geotrans[3] = left_h + out_geotrans = tuple(new_geotrans) + + im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif' + write_tif(im_cut, size, size, im_out, out_geotrans, im_proj) + + # dan + left_h = i * size * im_geotrans2[5] + im_geotrans2[3] + left_w = j * size * im_geotrans2[1] + im_geotrans2[0] + new_geotrans = np.array(im_geotrans2) + new_geotrans[0] = left_w + new_geotrans[3] = left_h + out_geotrans = tuple(new_geotrans) + + im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif' + write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2) + + # 存mask + mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3] + mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0] + mask_new_geotrans = np.array(mask_im_geotrans) + mask_new_geotrans[0] = mask_left_w + mask_new_geotrans[3] = mask_left_h + mask_out_geotrans = tuple(mask_new_geotrans) + mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif' + write_tif(mask_cut, size, size, mask_out, + mask_out_geotrans, mask_im_proj) + + print(mask_out + 'Cut to complete') + + a = a+1 diff --git a/data_preprocessing/cut_smalltif_multi.py b/data_preprocessing/cut_smalltif_multi.py new file mode 100644 index 0000000..6a29e46 --- /dev/null +++ b/data_preprocessing/cut_smalltif_multi.py @@ -0,0 +1,169 @@ +import numpy as np +from osgeo import gdal + + +def read_tif(fileName): + dataset = gdal.Open(fileName) + + im_width = dataset.RasterXSize # 栅格矩阵的列数 + im_height = dataset.RasterYSize # 栅格矩阵的行数 + im_bands = dataset.RasterCount # 波段数 + im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据 + if len(im_data.shape) == 2: + im_data = im_data[np.newaxis, :] + im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息 + im_proj = dataset.GetProjection() # 获取投影信息 + + return im_data, im_width, im_height, im_bands, im_geotrans, im_proj + + +def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj): + if 'int8' in im_data.dtype.name: + datatype = gdal.GDT_Byte + elif 'int16' in im_data.dtype.name: + datatype = gdal.GDT_UInt16 + else: + datatype = gdal.GDT_Float32 + + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + # 创建文件 + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create(path, im_width, im_height, im_bands, datatype) + if dataset != None and im_geotrans != None and im_proj != None: + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + dataset.SetProjection(im_proj) # 写入投影 + for i in range(im_bands): + dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) + del dataset + + +fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right3.tif' +im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif( + fileName) + +fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif' +im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif( + fileName2) + +# fileName3 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj10m.tif' +# im_data3, im_width3, im_height3, im_bands3, im_geotrans3, im_proj3 = read_tif( +# fileName3) + +# fileName4 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj20m.tif' +# im_data4, im_width4, im_height4, im_bands4, im_geotrans4, im_proj4 = read_tif( +# fileName4) + +mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif' +mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif( + mask_fileName) +mask_im_data = np.int8(mask_im_data) + +# geotiff归一化 +for i in range(im_bands): + arr = im_data[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data[i] = normalized_arr + +for i in range(im_bands2): + arr = im_data2[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data2[i] = normalized_arr + +# # 计算大图每个波段的均值和方差,train.py里transform会用到 +# im_data2 = im_data2/255 +# for i in range(im_bands2): +# pixels = im_data2[i, :, :].ravel() +# print("波段{} mean: {:.4f}, std: {:.4f}".format( +# i, np.mean(pixels), np.std(pixels))) + + +# 切成小图 +a = 0 +size = 224 +for i in range(0, int(im_height / size)): + for j in range(0, int(im_width / size)): + im_cut = im_data[:, i * size*4:i * size*4 + + size*4, j * size*4:j * size*4 + size*4] + im_cut2 = im_data2[:, i * size*4:i * size*4 + + size*4, j * size*4:j * size*4 + size*4] + # im_cut3 = im_data3[:, i * size*2:i * size*2 + + # size*2, j * size*2:j * size*2 + size*2] + # im_cut4 = im_data4[:, i * size:i * size + + # size, j * size:j * size + size] + mask_cut = mask_im_data[:, i * size*4:i * + size*4 + size*4, j * size*4:j * size*4 + size*4] + + # 以20m为判断基准,同时处理geotiff和mask + labelfla_bool = np.all(np.array(mask_cut).flatten() == 15) + + if labelfla_bool: + print("False") + else: + left_h = i * size*4 * im_geotrans[5] + im_geotrans[3] + left_w = j * size*4 * im_geotrans[1] + im_geotrans[0] + new_geotrans = np.array(im_geotrans) + new_geotrans[0] = left_w + new_geotrans[3] = left_h + out_geotrans = tuple(new_geotrans) + im_out = 'E:/RSdata/mask_5m/dataset_5m/geotiff' + \ + str(a) + '.tif' + write_tif(im_cut, size*4, size*4, im_out, out_geotrans, im_proj) + print(im_out + 'Cut to complete') + + left_h2 = i * size*4 * im_geotrans2[5] + im_geotrans2[3] + left_w2 = j * size*4 * im_geotrans2[1] + im_geotrans2[0] + new_geotrans = np.array(im_geotrans2) + new_geotrans[0] = left_w2 + new_geotrans[3] = left_h2 + out_geotrans = tuple(new_geotrans) + + im_out = 'E:/RSdata/mask_5m/dataset_dan/geotiff' + \ + str(a) + '.tif' + write_tif(im_cut2, size*4, size*4, im_out, out_geotrans, im_proj2) + print(im_out + 'Cut to complete') + + # left_h3 = i * size*2 * im_geotrans3[5] + im_geotrans3[3] + # left_w3 = j * size*2 * im_geotrans3[1] + im_geotrans3[0] + # new_geotrans = np.array(im_geotrans3) + # new_geotrans[0] = left_w3 + # new_geotrans[3] = left_h3 + # out_geotrans = tuple(new_geotrans) + + # im_out = 'E:/mask_20m/all/dataset_10m/geotiff' + \ + # str(a) + '.tif' + # write_tif(im_cut3, size*2, size*2, im_out, out_geotrans, im_proj3) + # print(im_out + 'Cut to complete') + + # left_h4 = i * size * im_geotrans4[5] + im_geotrans4[3] + # left_w4 = j * size * im_geotrans4[1] + im_geotrans4[0] + # new_geotrans = np.array(im_geotrans4) + # new_geotrans[0] = left_w4 + # new_geotrans[3] = left_h4 + # out_geotrans = tuple(new_geotrans) + + # im_out = 'E:/mask_20m/all/dataset_20m/geotiff' + \ + # str(a) + '.tif' + # write_tif(im_cut4, size, size, im_out, out_geotrans, im_proj4) + # print(im_out + 'Cut to complete') + + mask_left_h = i * size*4 * \ + mask_im_geotrans[5] + mask_im_geotrans[3] + mask_left_w = j * size*4 * \ + mask_im_geotrans[1] + mask_im_geotrans[0] + mask_new_geotrans = np.array(mask_im_geotrans) + mask_new_geotrans[0] = mask_left_w + mask_new_geotrans[3] = mask_left_h + mask_out_geotrans = tuple(mask_new_geotrans) + mask_out = 'E:/RSdata/mask_5m/mask/geotiff' + str(a) + '.tif' + write_tif(mask_cut, size*4, size*4, mask_out, + mask_out_geotrans, mask_im_proj) + print(mask_out + 'Cut to complete') + + a = a+1 diff --git a/data_preprocessing/sift_tif.py b/data_preprocessing/sift_tif.py new file mode 100644 index 0000000..84392cf --- /dev/null +++ b/data_preprocessing/sift_tif.py @@ -0,0 +1,54 @@ +from osgeo import gdal +import zipfile +import os +from xml.dom.minidom import parseString +import shutil + + +# 搜索所有硬盘里所有.zip文件 +def get_all_zipfiles(root): + zip_lt = [] + for root2, dirs, files in os.walk(root): + for file in files: + if file[-4:] == '.zip': + zip_lt.append(os.path.join(root2, file)) + return zip_lt + +# 定位并读取.zip文件 +# zip_dir = 'F:/2101-2400/' # 只处理某一个文件夹 +# zip_files_list = os.listdir(zip_dir) + + +zip_dir = 'F:\\' +zip_files_list = get_all_zipfiles(zip_dir) + +# 读取.xml文件 +for file_name in zip_files_list: + # print(file_name) + try: + file = zipfile.ZipFile(file_name, "r") + except: + continue + info_list = file.infolist() + for info in info_list: + if info.filename[-4:] == ".xml": + xml_file = file.read(info.filename) + + # 解析.xml文件 + domTree = parseString(xml_file) + rootNode = domTree.documentElement + infos = rootNode.getElementsByTagName("ProductInfo")[0] + + # 筛选西北地区影像 + CenterLatitude = eval(infos.getElementsByTagName( + "CenterLatitude")[0].childNodes[0].data) + CenterLongitude = eval(infos.getElementsByTagName( + "CenterLongitude")[0].childNodes[0].data) + + # 西北地区四个边界信息: 最高纬度:50 ;最低维度:37 ;最高经度:123 ;最低经度:73 ; + if 37 < CenterLatitude < 44: + if 76 < CenterLongitude < 78: + print(file) + + # 将文件拷贝至新的文件夹中 + shutil.copy(file_name,'E:/wlk_test/') diff --git a/data_preprocessing/split_data.py b/data_preprocessing/split_data.py new file mode 100644 index 0000000..762590f --- /dev/null +++ b/data_preprocessing/split_data.py @@ -0,0 +1,21 @@ +import os +import random + + +random.seed(42) + +geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m') +num = len(geotiffs) +split_rate = 0.2 + +eval_index = random.sample(geotiffs, k=int(num*split_rate)) + +f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w') +f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w') + +# 写入文件 +for geotiff in geotiffs: + if geotiff in eval_index: + f_val.write(str(geotiff)+'\n') + else: + f_train.write(str(geotiff)+'\n') \ No newline at end of file diff --git a/geotiff_utils.py b/geotiff_utils.py new file mode 100644 index 0000000..cf7e169 --- /dev/null +++ b/geotiff_utils.py @@ -0,0 +1,296 @@ + +"""Pascal VOC Semantic Segmentation Dataset.""" +from PIL import Image, ImageOps, ImageFilter +import torchvision.transforms as transforms +import os +import torch +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image +import cv2 +# import gdal +from osgeo import gdal +import random +import torch.utils.data as data +os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000') + + +class SegmentationDataset(object): + """Segmentation Base Dataset""" + + def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): + super(SegmentationDataset, self).__init__() + self.root = root + self.transform = transform + self.split = split + self.mode = mode if mode is not None else split + self.base_size = base_size + self.crop_size = crop_size + + def _val_sync_transform(self, img, mask): + outsize = self.crop_size + short_size = outsize + w, h = img.size + if w > h: + oh = short_size + ow = int(1.0 * w * oh / h) + else: + ow = short_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - outsize) / 2.)) + y1 = int(round((h - outsize) / 2.)) + img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) + mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform(self, img, mask): + # random mirror + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + crop_size = self.crop_size + # random scale (short edge) + short_size = random.randint( + int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < crop_size: + padh = crop_size - oh if oh < crop_size else 0 + padw = crop_size - ow if ow < crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - crop_size) + y1 = random.randint(0, h - crop_size) + img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + # gaussian blur as in PSP + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif_geofeat(self, img, mask, img_feat): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + img_feat = self._img_transform(img_feat) + return img, mask, img_feat + + def _val_sync_transform_tif(self, img, mask): + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _img_transform(self, img): + return np.array(img) + + # def _mask_transform(self, mask): + # return np.array(mask).astype('int32') + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def num_class(self): + """Number of categories.""" + return self.NUM_CLASS + + @property + def pred_offset(self): + return 0 + + +class VOCYJSSegmentation(SegmentationDataset): + """Pascal VOC Semantic Segmentation Dataset. + Parameters + ---------- + root : string + Path to VOCdevkit folder. Default is './datasets/VOCdevkit' + split: string + 'train', 'val' or 'test' + transform : callable, optional + A function that transforms the image + Examples + -------- + >>> from torchvision import transforms + >>> import torch.utils.data as data + >>> # Transforms for Normalization + >>> input_transform = transforms.Compose([ + >>> transforms.ToTensor(), + >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), + >>> ]) + >>> # Create Dataset + >>> trainset = VOCSegmentation(split='train', transform=input_transform) + >>> # Create Training Loader + >>> train_data = data.DataLoader( + >>> trainset, 4, shuffle=True, + >>> num_workers=4) + """ + NUM_CLASS = 13 + + def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs): + super(VOCYJSSegmentation, self).__init__( + root, split, mode, transform, **kwargs) + _voc_root = root + txt_path = os.path.join(root, split+'.txt') + self._mask_dir = os.path.join(_voc_root, 'masks_GE') + self._image_dir = os.path.join(_voc_root, 'images_SE') + self._mask_LS_dir = os.path.join(_voc_root, 'masks_LS') + self._image_LS_dir = os.path.join(_voc_root, "images_LS") + self.image_list = read_text(txt_path) + self.transform_SE = transforms.Compose([transforms.ToTensor(), transforms.Normalize( + [.485, .456, .406, .485, .456, .406, .406], [.229, .224, .225, .229, .224, .225, .229]),]) + random.shuffle(self.image_list) + + def __getitem__(self, index): + # print( "image file path is %s "% self.images[index]) + + # 读取两种类型的图片 + img_HR = gdal.Open(os.path.join(self._image_dir, self.image_list[index])).ReadAsArray( + ).transpose(1, 2, 0).astype(np.float32) + img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray( + ).transpose(1, 2, 0).astype(np.float32) + # img_LS = cv2.resize(img_LS,(672,672),interpolation=cv2.INTER_CUBIC) + # 读取两种类型的标注 + mask_HR = gdal.Open(os.path.join( + self._mask_dir, self.image_list[index])).ReadAsArray() + mask = gdal.Open(os.path.join(self._mask_LS_dir, + self.image_list[index])).ReadAsArray() + # synchronized transform + # 只包含两种模式: train 和 val + if self.mode == 'train': + # img, mask = self._sync_transform_tif(img, mask) + img_LS, mask, img_HR = self._sync_transform_tif_geofeat( + img_LS, mask, img_HR) + elif self.mode == 'val': + # img, mask = self._val_sync_transform_tif(img, mask) + img_LS, mask, img_HR = self._sync_transform_tif_geofeat( + img_LS, mask, img_HR) + # general resize, normalize and toTensor + if self.transform is not None: + img_HR = cv2.resize(img_HR, (448, 448), + interpolation=cv2.INTER_CUBIC) + img_HR = self.transform_SE(img_HR) + img_LS = self.transform(img_LS) + # img_feat = torch.from_numpy(img_feat) + # 多返回了一个img_feat + # ,transforms.ToTensor()(img_feat), os.path.basename(self.images[index]) + return img_LS, img_HR, mask + + def __len__(self): + return len(self.image_list) + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def classes(self): + """Category names.""" + return ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12') + + +def generator_list_of_imagepath(path): + image_list = [] + for image in os.listdir(path): + # print(path) + # print(image) + if not image == '.DS_Store' and 'tif' == image.split('.')[-1]: + image_list.append(image) + return image_list + + +def read_text(textfile): + list = [] + with open(textfile, "r") as lines: + for line in lines: + list.append(line.rstrip('\n')) + return list + + +def dataset_segmentation(textpath, imagepath, train_percent): + image_list = generator_list_of_imagepath(imagepath) + num = len(image_list) + list = range(num) + train_num = int(num * train_percent) # training set num + train_list = random.sample(list, train_num) + print("train set size", train_num) + ftrain = open(os.path.join(textpath, 'train.txt'), 'w') + fval = open(os.path.join(textpath, 'val.txt'), 'w') + for i in list: + name = image_list[i] + '\n' + if i in train_list: + ftrain.write(name) + else: + fval.write(name) + ftrain.close() + fval.close() + + +if __name__ == '__main__': + # path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images' + # list=generator_list_of_imagepath(path) + # print(list) + # 切割数据集 + + textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset' + imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE' + train_percent = 0.8 + dataset_segmentation(textpath, imagepath, train_percent) + # 显示各种图片 + + # img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif' + # img = gdal.Open(img).ReadAsArray().transpose(1,2,0) + # cv2.imshow('img', img) + # img = Image.fromarray (img,'RGB') + # img.show() + # img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif' + # img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8) + # img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC) + # img2 = Image.fromarray (img2,'RGB') + # img2.show() + # img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif' + # img3 = gdal.Open(img3).ReadAsArray() + # img3 = Image.fromarray (img3) + # img3.show() + + # dataset和dataloader的测试 + + # 测试dataloader能不能用 + ''' + data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset' + input_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) + dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224) + dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224) + train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4) + test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4) + ''' diff --git a/train_LS.py b/train_LS.py new file mode 100644 index 0000000..cf9f95a --- /dev/null +++ b/train_LS.py @@ -0,0 +1,156 @@ +import os +import time +from datetime import datetime +from pathlib import Path + +import torch +from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large +from torchvision import transforms +import torch.utils.data as data +from torch import nn +import numpy as np + +from geotiff_utils import VOCYJSSegmentation +import utils +import warnings +warnings.filterwarnings("ignore") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") + + parser.add_argument( + "--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root") + parser.add_argument("--num-classes", default=13, type=int) + parser.add_argument("--device", default="cuda", help="training device") + parser.add_argument("--batch-size", default=8, type=int) + parser.add_argument("--epochs", default=50, type=int, metavar="N", + help="number of total epochs to train") + parser.add_argument('--lr', default=0.005, type=float, + help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('-out-dir', type=str, + default='DeeplabV3_LS') + + args = parser.parse_args() + + return args + + +class DeeplabV3_LS(nn.Module): + def __init__(self, n_class): + super(DeeplabV3_LS, self).__init__() + self.n_class = n_class + self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1) + + self.conv_fc = nn.Conv2d( + 21, self.n_class, kernel_size=(1, 1), stride=(1, 1)) + + self.seg = deeplabv3_resnet50(weights='DEFAULT') + + def forward(self, x_LS, x_SE): + # x = torch.cat([x, x_dan], dim=1) + # x = self.conv7_3(x) + x = self.seg(x_LS)["out"] + x = self.conv_fc(x) + return x + + +def main(args): + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.485, .456, .406], [.229, .224, .225]), + ]) + data_kwargs = {'transform': input_transform, + 'base_size': 224, 'crop_size': 224} + + # 读取geotiff数据,构建训练集、验证集 + train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train', + **data_kwargs) + val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val', + **data_kwargs) + num_workers = min( + [os.cpu_count(), args.batch_size if args.batch_size > 1 else 0]) + train_loader = data.DataLoader( + train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + val_loader = data.DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + + model = DeeplabV3_LS(n_class=args.num_classes) + model.to(device) + + criterion = torch.nn.CrossEntropyLoss(ignore_index=255) + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + + lr_scheduler = utils.create_lr_scheduler( + optimizer, len(train_loader), args.epochs, warmup=True) + + now = datetime.now() + date_time = now.strftime("%Y-%m-%d__%H-%M__") + out_dir = Path(os.path.join("./train_output", date_time + args.out_dir)) + if not out_dir.exists(): + out_dir.mkdir() + f = open(os.path.join(out_dir, "log.txt"), 'w') + start_time = time.time() + best_acc = 0 + for epoch in range(args.epochs): + print(f"Epoch {epoch+1}\n-------------------------------") + model.train() + for idx, (image, image_dan, target) in enumerate(train_loader): + image, image_dan, target = image.to( + device), image_dan.to(device), target.to(device) + output = model(image, image_dan) + + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + if idx % 100 == 0: + print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx, + len(train_loader), loss.item(), optimizer.param_groups[0]["lr"])) + + model.eval() + confmat = utils.ConfusionMatrix(args.num_classes) + with torch.no_grad(): + for image, image_dan, target in val_loader: + image, image_dan, target = image.to(device), image_dan.to( + device), target.to(device) + output = model(image, image_dan) + + confmat.update(target.flatten(), output.argmax(1).flatten()) + + info, mIoU = confmat.get_info() + print(info) + + f.write(f"Epoch {epoch+1}\n-------------------------------\n") + f.write(info+"\n\n") + f.flush() + + # # 保存准确率最好的模型 + # if mIoU > best_acc: + # print("[Save model]") + # torch.save(model, os.path.join(out_dir, "best_mIoU.pth")) + # best_acc = mIoU + torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth")) + total_time = time.time() - start_time + print("total time:", total_time) + torch.save(model, os.path.join(out_dir, "last.pth")) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..b721720 --- /dev/null +++ b/utils.py @@ -0,0 +1,330 @@ +import datetime +import errno +import os +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class ConfusionMatrix: + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.inference_mode(): + k = (a >= 0) & (a < n) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + reduce_across_processes(self.mat) + + def get_info(self): + acc_global, acc, iu = self.compute() + return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format( + acc_global.item() * 100, + [f"{i:.1f}" for i in (acc * 100).tolist()], + [f"{i:.1f}" for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ), iu.mean().item() * 100 + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", + "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) + return batched_imgs + + +def collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def create_lr_scheduler(optimizer, + num_step: int, + epochs: int, + warmup=True, + warmup_epochs=1, + warmup_factor=1e-3): + assert num_step > 0 and epochs > 0 + if warmup is False: + warmup_epochs = 0 + + def f(x): + """ + 根据step数返回一个学习率倍率因子, + 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 + """ + if warmup is True and x <= (warmup_epochs * num_step): + alpha = float(x) / (warmup_epochs * num_step) + # warmup过程中lr倍率因子从warmup_factor -> 1 + return warmup_factor * (1 - alpha) + alpha + else: + # warmup后lr倍率因子从1 -> 0 + # 参考deeplab_v2: Learning rate policy + return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) \ No newline at end of file