diff --git a/.gitignore b/.gitignore index 5d381cc..0d56ad3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +/train_output \ No newline at end of file diff --git a/data_preprocessing/cut_smalltif.bak b/data_preprocessing/cut_smalltif.bak new file mode 100644 index 0000000..8fb8873 --- /dev/null +++ b/data_preprocessing/cut_smalltif.bak @@ -0,0 +1,138 @@ +import numpy as np +from osgeo import gdal + + +def read_tif(fileName): + dataset = gdal.Open(fileName) + + im_width = dataset.RasterXSize # 栅格矩阵的列数 + im_height = dataset.RasterYSize # 栅格矩阵的行数 + im_bands = dataset.RasterCount # 波段数 + im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据 + if len(im_data.shape) == 2: + im_data = im_data[np.newaxis, :] + im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息 + im_proj = dataset.GetProjection() # 获取投影信息 + + return im_data, im_width, im_height, im_bands, im_geotrans, im_proj + + +def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj): + if 'int8' in im_data.dtype.name: + datatype = gdal.GDT_Byte + elif 'int16' in im_data.dtype.name: + datatype = gdal.GDT_UInt16 + else: + datatype = gdal.GDT_Float32 + + if len(im_data.shape) == 3: + im_bands, im_height, im_width = im_data.shape + else: + im_bands, (im_height, im_width) = 1, im_data.shape + # 创建文件 + driver = gdal.GetDriverByName("GTiff") + dataset = driver.Create(path, im_width, im_height, im_bands, datatype) + if dataset != None and im_geotrans != None and im_proj != None: + dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数 + dataset.SetProjection(im_proj) # 写入投影 + for i in range(im_bands): + dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) + del dataset + + +fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif' +im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif( + fileName) + +fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif' +im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif( + fileName2) + +mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif' +mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif( + mask_fileName) +mask_im_data = np.int8(mask_im_data) + + +# geotiff归一化 +for i in range(im_bands): + arr = im_data[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data[i] = normalized_arr + +for i in range(im_bands2): + arr = im_data2[i, :, :] + Min = arr.min() + Max = arr.max() + normalized_arr = (arr-Min)/(Max-Min)*255 + im_data2[i] = normalized_arr + + +# 计算大图每个波段的均值和方差,train.py里transform会用到 +im_data = im_data/255 +for i in range(im_bands): + pixels = im_data[i, :, :].ravel() + print("波段{} mean: {:.4f}, std: {:.4f}".format( + i, np.mean(pixels), np.std(pixels))) +im_data = im_data*255 + +im_data2 = im_data2/255 +for i in range(im_bands2): + pixels = im_data2[i, :, :].ravel() + print("波段{} mean: {:.4f}, std: {:.4f}".format( + i, np.mean(pixels), np.std(pixels))) +im_data2 = im_data2*255 + + +# 切成小图 +a = 0 +size = 224 +for i in range(0, int(mask_im_height / size)): + for j in range(0, int(mask_im_width / size)): + im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size] + im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size] + mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size] + + # 以mask为判断基准,同时处理geotiff和mask + labelfla = np.array(mask_cut).flatten() + if np.all(labelfla == 15): # 15为NoData + print("Skip!!!") + else: + # 5m + left_h = i * size * im_geotrans[5] + im_geotrans[3] + left_w = j * size * im_geotrans[1] + im_geotrans[0] + new_geotrans = np.array(im_geotrans) + new_geotrans[0] = left_w + new_geotrans[3] = left_h + out_geotrans = tuple(new_geotrans) + + im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif' + write_tif(im_cut, size, size, im_out, out_geotrans, im_proj) + + # dan + left_h = i * size * im_geotrans2[5] + im_geotrans2[3] + left_w = j * size * im_geotrans2[1] + im_geotrans2[0] + new_geotrans = np.array(im_geotrans2) + new_geotrans[0] = left_w + new_geotrans[3] = left_h + out_geotrans = tuple(new_geotrans) + + im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif' + write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2) + + # 存mask + mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3] + mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0] + mask_new_geotrans = np.array(mask_im_geotrans) + mask_new_geotrans[0] = mask_left_w + mask_new_geotrans[3] = mask_left_h + mask_out_geotrans = tuple(mask_new_geotrans) + mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif' + write_tif(mask_cut, size, size, mask_out, + mask_out_geotrans, mask_im_proj) + + print(mask_out + 'Cut to complete') + + a = a+1 diff --git a/data_preprocessing/cut_smalltif.py b/data_preprocessing/cut_smalltif.py index 8fb8873..5290d2b 100644 --- a/data_preprocessing/cut_smalltif.py +++ b/data_preprocessing/cut_smalltif.py @@ -44,9 +44,6 @@ fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif' im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif( fileName) -fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif' -im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif( - fileName2) mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif' mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif( @@ -55,19 +52,14 @@ mask_im_data = np.int8(mask_im_data) # geotiff归一化 +lower_percent = 2 +upper_percent = 98 for i in range(im_bands): arr = im_data[i, :, :] - Min = arr.min() - Max = arr.max() - normalized_arr = (arr-Min)/(Max-Min)*255 - im_data[i] = normalized_arr - -for i in range(im_bands2): - arr = im_data2[i, :, :] - Min = arr.min() - Max = arr.max() - normalized_arr = (arr-Min)/(Max-Min)*255 - im_data2[i] = normalized_arr + lower = np.percentile(arr, lower_percent) + upper = np.percentile(arr, upper_percent) + stretched = np.clip((arr - lower) / (upper - lower), 0, 1) + im_data[i] = (stretched * 255).astype(np.uint8) # 计算大图每个波段的均值和方差,train.py里transform会用到 @@ -78,61 +70,42 @@ for i in range(im_bands): i, np.mean(pixels), np.std(pixels))) im_data = im_data*255 -im_data2 = im_data2/255 -for i in range(im_bands2): - pixels = im_data2[i, :, :].ravel() - print("波段{} mean: {:.4f}, std: {:.4f}".format( - i, np.mean(pixels), np.std(pixels))) -im_data2 = im_data2*255 - # 切成小图 a = 0 -size = 224 +size = 448 for i in range(0, int(mask_im_height / size)): for j in range(0, int(mask_im_width / size)): im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size] - im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size] - mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size] + mask_cut = mask_im_data[:, i * size:i * + size + size, j * size:j * size + size] - # 以mask为判断基准,同时处理geotiff和mask labelfla = np.array(mask_cut).flatten() if np.all(labelfla == 15): # 15为NoData print("Skip!!!") else: - # 5m - left_h = i * size * im_geotrans[5] + im_geotrans[3] - left_w = j * size * im_geotrans[1] + im_geotrans[0] - new_geotrans = np.array(im_geotrans) - new_geotrans[0] = left_w - new_geotrans[3] = left_h - out_geotrans = tuple(new_geotrans) + # 取5、4、3波段,注意顺序 + rgb_cut = np.stack([ + im_cut[4], # 第5波段 + im_cut[3], # 第4波段 + im_cut[2], # 第3波段 + ], axis=0) # shape: (3, H, W) + # 转为 (H, W, 3) + rgb_cut = np.transpose(rgb_cut, (1, 2, 0)) + # 归一化到0-255并转uint8 + rgb_cut = np.clip(rgb_cut, 0, 255) + rgb_cut = rgb_cut.astype(np.uint8) + # 保存为jpg + from PIL import Image + rgb_img = Image.fromarray(rgb_cut) + rgb_img.save( + f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg') - im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif' - write_tif(im_cut, size, size, im_out, out_geotrans, im_proj) - - # dan - left_h = i * size * im_geotrans2[5] + im_geotrans2[3] - left_w = j * size * im_geotrans2[1] + im_geotrans2[0] - new_geotrans = np.array(im_geotrans2) - new_geotrans[0] = left_w - new_geotrans[3] = left_h - out_geotrans = tuple(new_geotrans) - - im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif' - write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2) - - # 存mask - mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3] - mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0] - mask_new_geotrans = np.array(mask_im_geotrans) - mask_new_geotrans[0] = mask_left_w - mask_new_geotrans[3] = mask_left_h - mask_out_geotrans = tuple(mask_new_geotrans) - mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif' - write_tif(mask_cut, size, size, mask_out, - mask_out_geotrans, mask_im_proj) - - print(mask_out + 'Cut to complete') + # mask只取第一个波段(如果是单通道),转uint8保存为png + mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut + mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8) + mask_img = Image.fromarray(mask_arr) + mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png') + print(f'img_{a}.jpg and mask_{a}.png saved') a = a+1 diff --git a/data_preprocessing/split_data.py b/data_preprocessing/split_data.py index 762590f..551bfcc 100644 --- a/data_preprocessing/split_data.py +++ b/data_preprocessing/split_data.py @@ -4,18 +4,18 @@ import random random.seed(42) -geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m') +geotiffs = os.listdir(r'E:\RSdata\wlk_right_448\dataset_5m_jpg') num = len(geotiffs) split_rate = 0.2 eval_index = random.sample(geotiffs, k=int(num*split_rate)) -f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w') -f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w') +f_train = open(r'E:\RSdata\wlk_right_448/train.txt', 'w') +f_val = open(r'E:\RSdata\wlk_right_448/val.txt', 'w') # 写入文件 for geotiff in geotiffs: if geotiff in eval_index: - f_val.write(str(geotiff)+'\n') + f_train.write(str(geotiff)+'\n') else: - f_train.write(str(geotiff)+'\n') \ No newline at end of file + f_val.write(str(geotiff)+'\n') diff --git a/datasets_pro_code/datasets_pro.py b/datasets_pro_code/datasets_pro.py new file mode 100644 index 0000000..92efbc3 --- /dev/null +++ b/datasets_pro_code/datasets_pro.py @@ -0,0 +1,62 @@ +import os +import shutil + +# 定义文件夹路径 +base_dir = r"e:\datasets\wlk_right_448" +jpeg_images_dir = os.path.join(base_dir, "JPEGImages") +segmentation_class_dir = os.path.join(base_dir, "SegmentationClass") +annotations_dir = os.path.join(base_dir, "stare_stuct", "annotations") +images_dir = os.path.join(base_dir, "stare_stuct", "images") + +# 定义目标文件夹 +annotations_training_dir = os.path.join(annotations_dir, "training") +annotations_validation_dir = os.path.join(annotations_dir, "validation") +images_training_dir = os.path.join(images_dir, "training") +images_validation_dir = os.path.join(images_dir, "validation") + +# 创建目标文件夹 +os.makedirs(annotations_training_dir, exist_ok=True) +os.makedirs(annotations_validation_dir, exist_ok=True) +os.makedirs(images_training_dir, exist_ok=True) +os.makedirs(images_validation_dir, exist_ok=True) + +# 读取 train.txt 和 val.txt +train_file = os.path.join(base_dir, "train.txt") +val_file = os.path.join(base_dir, "val.txt") + + +def read_file_list(file_path): + with open(file_path, "r") as f: + return [line.strip() for line in f.readlines()] + + +train_list = read_file_list(train_file) +val_list = read_file_list(val_file) + +# 移动文件函数 + + +def move_files(file_list, src_images_dir, src_labels_dir, dst_images_dir, dst_labels_dir): + for file_name in file_list: + # 图片文件 + image_src = os.path.join(src_images_dir, file_name) + image_dst = os.path.join(dst_images_dir, file_name) + if os.path.exists(image_src): + shutil.copy(image_src, image_dst) + + # 标签文件 + label_src = os.path.join(src_labels_dir, file_name) + label_dst = os.path.join(dst_labels_dir, file_name) + if os.path.exists(label_src): + shutil.copy(label_src, label_dst) + + +# 移动训练集文件 +move_files(train_list, jpeg_images_dir, segmentation_class_dir, + images_training_dir, annotations_training_dir) + +# 移动验证集文件 +move_files(val_list, jpeg_images_dir, segmentation_class_dir, + images_validation_dir, annotations_validation_dir) + +print("文件组织完成!") diff --git a/datasets_pro_code/read_mask.py b/datasets_pro_code/read_mask.py new file mode 100644 index 0000000..1e8fa2a --- /dev/null +++ b/datasets_pro_code/read_mask.py @@ -0,0 +1,22 @@ +from PIL import Image +import numpy as np +import os +from osgeo import gdal + + +mask_dir = r"E:\datasets\wlk_right_448\mask" # 修改为你的mask文件夹路径 + +all_labels = set() + +for file in os.listdir(mask_dir): + if file.lower().endswith('.tif'): + tif_path = os.path.join(mask_dir, file) + dataset = gdal.Open(tif_path) + if dataset is None: + print(f"无法打开: {tif_path}") + continue + band = dataset.ReadAsArray() + unique = np.unique(band) + all_labels.update(unique) + +print("所有mask中出现过的标签数字:", sorted(all_labels)) diff --git a/datasets_pro_code/remove_tif.py b/datasets_pro_code/remove_tif.py new file mode 100644 index 0000000..63563d1 --- /dev/null +++ b/datasets_pro_code/remove_tif.py @@ -0,0 +1,6 @@ +input_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val.txt" +output_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val_no.txt" + +with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout: + for line in fin: + fout.write(line.strip().replace(".tif", "") + "\n") \ No newline at end of file diff --git a/datasets_pro_code/tif_mask_to_png.py b/datasets_pro_code/tif_mask_to_png.py new file mode 100644 index 0000000..bb9b63f --- /dev/null +++ b/datasets_pro_code/tif_mask_to_png.py @@ -0,0 +1,30 @@ +import os +import numpy as np +from PIL import Image +from osgeo import gdal + +src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS" +dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS_png" +os.makedirs(dst_dir, exist_ok=True) + +for file in os.listdir(src_dir): + if file.lower().endswith('.tif'): + tif_path = os.path.join(src_dir, file) + dataset = gdal.Open(tif_path) + if dataset is None: + print(f"无法打开: {tif_path}") + continue + mask = dataset.ReadAsArray() + if mask.ndim != 2: + print(f"{file} 不是单波段,跳过") + continue + + # 替换像素值 + mask = mask.copy() + mask[mask == 15] = 255 + + png_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".png") + Image.fromarray(mask.astype(np.uint8)).save(png_path) + print(f"已保存: {png_path}") + +print("全部转换完成!") \ No newline at end of file diff --git a/datasets_pro_code/tif_to_jpg.py b/datasets_pro_code/tif_to_jpg.py new file mode 100644 index 0000000..9f1e3b4 --- /dev/null +++ b/datasets_pro_code/tif_to_jpg.py @@ -0,0 +1,33 @@ +import os +import numpy as np +from PIL import Image +from osgeo import gdal + +# 输入和输出文件夹 +src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS" +dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS_jpg" +os.makedirs(dst_dir, exist_ok=True) + +for file in os.listdir(src_dir): + if file.lower().endswith('.tif'): + tif_path = os.path.join(src_dir, file) + dataset = gdal.Open(tif_path) + if dataset is None: + print(f"无法打开: {tif_path}") + continue + + # 读取所有波段 + bands = [] + for i in range(1, dataset.RasterCount + 1): + band = dataset.GetRasterBand(i).ReadAsArray() + bands.append(band) + img = np.stack(bands, axis=-1) if len(bands) > 1 else bands[0] + + # 转换为uint8 + img = img.astype(np.uint8) + + jpg_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".jpg") + Image.fromarray(img).save(jpg_path, quality=95) + print(f"已保存: {jpg_path}") + +print("全部转换完成!") diff --git a/train_JL/geotiff_utils.py b/train_JL/geotiff_utils.py new file mode 100644 index 0000000..9494658 --- /dev/null +++ b/train_JL/geotiff_utils.py @@ -0,0 +1,274 @@ + +"""Pascal VOC Semantic Segmentation Dataset.""" +from PIL import Image, ImageOps, ImageFilter +import torchvision.transforms as transforms +import os +import torch +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image +import cv2 +# import gdal +from osgeo import gdal +import random +import torch.utils.data as data +os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000') + + +class SegmentationDataset(object): + """Segmentation Base Dataset""" + + def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): + super(SegmentationDataset, self).__init__() + self.root = root + self.transform = transform + self.split = split + self.mode = mode if mode is not None else split + self.base_size = base_size + self.crop_size = crop_size + + def _val_sync_transform(self, img, mask): + outsize = self.crop_size + short_size = outsize + w, h = img.size + if w > h: + oh = short_size + ow = int(1.0 * w * oh / h) + else: + ow = short_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - outsize) / 2.)) + y1 = int(round((h - outsize) / 2.)) + img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) + mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform(self, img, mask): + # random mirror + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + crop_size = self.crop_size + # random scale (short edge) + short_size = random.randint( + int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < crop_size: + padh = crop_size - oh if oh < crop_size else 0 + padw = crop_size - ow if ow < crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - crop_size) + y1 = random.randint(0, h - crop_size) + img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + # gaussian blur as in PSP + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif_geofeat(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _val_sync_transform_tif(self, img, mask): + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _img_transform(self, img): + return np.array(img) + + # def _mask_transform(self, mask): + # return np.array(mask).astype('int32') + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def num_class(self): + """Number of categories.""" + return self.NUM_CLASS + + @property + def pred_offset(self): + return 0 + + +class VOCYJSSegmentation(SegmentationDataset): + """Pascal VOC Semantic Segmentation Dataset. + Parameters + ---------- + root : string + Path to VOCdevkit folder. Default is './datasets/VOCdevkit' + split: string + 'train', 'val' or 'test' + transform : callable, optional + A function that transforms the image + Examples + -------- + >>> from torchvision import transforms + >>> import torch.utils.data as data + >>> # Transforms for Normalization + >>> input_transform = transforms.Compose([ + >>> transforms.ToTensor(), + >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), + >>> ]) + >>> # Create Dataset + >>> trainset = VOCSegmentation(split='train', transform=input_transform) + >>> # Create Training Loader + >>> train_data = data.DataLoader( + >>> trainset, 4, shuffle=True, + >>> num_workers=4) + """ + NUM_CLASS = 13 + + def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs): + super(VOCYJSSegmentation, self).__init__( + root, split, mode, transform, **kwargs) + _voc_root = root + txt_path = os.path.join(root, split+'.txt') + self._mask_LS_dir = os.path.join(_voc_root, 'mask') + self._image_LS_dir = os.path.join(_voc_root, "dataset_5m") + self.image_list = read_text(txt_path) + random.shuffle(self.image_list) + + def __getitem__(self, index): + img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray( + ).transpose(1, 2, 0).astype(np.float32) + mask = gdal.Open(os.path.join(self._mask_LS_dir, + self.image_list[index])).ReadAsArray() + # synchronized transform + # 只包含两种模式: train 和 val + if self.mode == 'train': + img_LS, mask = self._sync_transform_tif_geofeat( + img_LS, mask) + elif self.mode == 'val': + img_LS, mask = self._sync_transform_tif_geofeat( + img_LS, mask) + # general resize, normalize and toTensor + if self.transform is not None: + img_LS = self.transform(img_LS) + return img_LS, mask + + def __len__(self): + return len(self.image_list) + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def classes(self): + """Category names.""" + return ('0', '1', '2', '3', '4', '5', '6') + + +def generator_list_of_imagepath(path): + image_list = [] + for image in os.listdir(path): + # print(path) + # print(image) + if not image == '.DS_Store' and 'tif' == image.split('.')[-1]: + image_list.append(image) + return image_list + + +def read_text(textfile): + list = [] + with open(textfile, "r") as lines: + for line in lines: + list.append(line.rstrip('\n')) + return list + + +def dataset_segmentation(textpath, imagepath, train_percent): + image_list = generator_list_of_imagepath(imagepath) + num = len(image_list) + list = range(num) + train_num = int(num * train_percent) # training set num + train_list = random.sample(list, train_num) + print("train set size", train_num) + ftrain = open(os.path.join(textpath, 'train.txt'), 'w') + fval = open(os.path.join(textpath, 'val.txt'), 'w') + for i in list: + name = image_list[i] + '\n' + if i in train_list: + ftrain.write(name) + else: + fval.write(name) + ftrain.close() + fval.close() + + +if __name__ == '__main__': + # path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images' + # list=generator_list_of_imagepath(path) + # print(list) + # 切割数据集 + + textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset' + imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE' + train_percent = 0.8 + dataset_segmentation(textpath, imagepath, train_percent) + # 显示各种图片 + + # img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif' + # img = gdal.Open(img).ReadAsArray().transpose(1,2,0) + # cv2.imshow('img', img) + # img = Image.fromarray (img,'RGB') + # img.show() + # img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif' + # img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8) + # img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC) + # img2 = Image.fromarray (img2,'RGB') + # img2.show() + # img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif' + # img3 = gdal.Open(img3).ReadAsArray() + # img3 = Image.fromarray (img3) + # img3.show() + + # dataset和dataloader的测试 + + # 测试dataloader能不能用 + ''' + data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset' + input_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) + dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224) + dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224) + train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4) + test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4) + ''' diff --git a/train_JL/train_JL.py b/train_JL/train_JL.py new file mode 100644 index 0000000..f2d178d --- /dev/null +++ b/train_JL/train_JL.py @@ -0,0 +1,156 @@ +import os +import time +from datetime import datetime +from pathlib import Path + +import torch +from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large +from torchvision import transforms +import torch.utils.data as data +from torch import nn +import numpy as np + +from geotiff_utils import VOCYJSSegmentation +import utils as utils +import warnings +warnings.filterwarnings("ignore") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") + + parser.add_argument( + "--data-path", default=r"E:\datasets\wlk_right_448", help="VOCdevkit root") + parser.add_argument("--num-classes", default=7, type=int) + parser.add_argument("--device", default="cuda", help="training device") + parser.add_argument("--batch-size", default=4, type=int) + parser.add_argument("--epochs", default=50, type=int, metavar="N", + help="number of total epochs to train") + parser.add_argument('--lr', default=0.005, type=float, + help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('-out-dir', type=str, + default='DeeplabV3_JL') + + args = parser.parse_args() + + return args + + +class DeeplabV3_JL(nn.Module): + def __init__(self, n_class): + super(DeeplabV3_JL, self).__init__() + self.n_class = n_class + self.conv6_3 = nn.Conv2d(6, 3, kernel_size=1, stride=1) + + self.conv_fc = nn.Conv2d( + 21, self.n_class, kernel_size=(1, 1), stride=(1, 1)) + + self.seg = deeplabv3_resnet50(weights='DEFAULT') + + def forward(self, x): + # x = torch.cat([x, x_dan], dim=1) + x = self.conv6_3(x) + x = self.seg(x)["out"] + x = self.conv_fc(x) + return x + + +def main(args): + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.535, .767, .732, .561, .494, .564], + [.0132, .0188, .0181, .0173, .0183, .0259]), + ]) + data_kwargs = {'transform': input_transform, + 'base_size': 448, 'crop_size': 448} + + # 读取geotiff数据,构建训练集、验证集 + train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train', + **data_kwargs) + val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val', + **data_kwargs) + num_workers = min( + [os.cpu_count(), args.batch_size if args.batch_size > 1 else 0]) + train_loader = data.DataLoader( + train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + val_loader = data.DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + + model = DeeplabV3_JL(n_class=args.num_classes) + model.to(device) + + criterion = torch.nn.CrossEntropyLoss(ignore_index=255) + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + + lr_scheduler = utils.create_lr_scheduler( + optimizer, len(train_loader), args.epochs, warmup=True) + + now = datetime.now() + date_time = now.strftime("%Y-%m-%d__%H-%M__") + out_dir = Path(os.path.join("./train_output", date_time + args.out_dir)) + if not out_dir.exists(): + out_dir.mkdir() + f = open(os.path.join(out_dir, "log.txt"), 'w') + start_time = time.time() + best_acc = 0 + for epoch in range(args.epochs): + print(f"Epoch {epoch+1}\n-------------------------------") + model.train() + for idx, (image, target) in enumerate(train_loader): + image, target = image.to( + device), target.to(device) + output = model(image) + + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + if idx % 100 == 0: + print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx, + len(train_loader), loss.item(), optimizer.param_groups[0]["lr"])) + + model.eval() + confmat = utils.ConfusionMatrix(args.num_classes) + with torch.no_grad(): + for image, target in val_loader: + image, target = image.to(device), target.to(device) + output = model(image) + + confmat.update(target.flatten(), output.argmax(1).flatten()) + + info, mIoU = confmat.get_info() + print(info) + + f.write(f"Epoch {epoch+1}\n-------------------------------\n") + f.write(info+"\n\n") + f.flush() + + # # 保存准确率最好的模型 + # if mIoU > best_acc: + # print("[Save model]") + # torch.save(model, os.path.join(out_dir, "best_mIoU.pth")) + # best_acc = mIoU + torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth")) + total_time = time.time() - start_time + print("total time:", total_time) + torch.save(model, os.path.join(out_dir, "last.pth")) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/utils.py b/train_JL/utils.py similarity index 100% rename from utils.py rename to train_JL/utils.py diff --git a/train_JL_jpg/geotiff_utils.py b/train_JL_jpg/geotiff_utils.py new file mode 100644 index 0000000..dbbc433 --- /dev/null +++ b/train_JL_jpg/geotiff_utils.py @@ -0,0 +1,279 @@ + +"""Pascal VOC Semantic Segmentation Dataset.""" +from PIL import Image, ImageOps, ImageFilter +import torchvision.transforms as transforms +import os +import torch +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image +import cv2 +# import gdal +from osgeo import gdal +import random +import torch.utils.data as data +os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000') + + +class SegmentationDataset(object): + """Segmentation Base Dataset""" + + def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): + super(SegmentationDataset, self).__init__() + self.root = root + self.transform = transform + self.split = split + self.mode = mode if mode is not None else split + self.base_size = base_size + self.crop_size = crop_size + + def _val_sync_transform(self, img, mask): + outsize = self.crop_size + short_size = outsize + w, h = img.size + if w > h: + oh = short_size + ow = int(1.0 * w * oh / h) + else: + ow = short_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - outsize) / 2.)) + y1 = int(round((h - outsize) / 2.)) + img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) + mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform(self, img, mask): + # random mirror + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + crop_size = self.crop_size + # random scale (short edge) + short_size = random.randint( + int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < crop_size: + padh = crop_size - oh if oh < crop_size else 0 + padw = crop_size - ow if ow < crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - crop_size) + y1 = random.randint(0, h - crop_size) + img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + # gaussian blur as in PSP + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif_geofeat(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _val_sync_transform_tif(self, img, mask): + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _img_transform(self, img): + return np.array(img) + + # def _mask_transform(self, mask): + # return np.array(mask).astype('int32') + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def num_class(self): + """Number of categories.""" + return self.NUM_CLASS + + @property + def pred_offset(self): + return 0 + + +class VOCYJSSegmentation(SegmentationDataset): + """Pascal VOC Semantic Segmentation Dataset. + Parameters + ---------- + root : string + Path to VOCdevkit folder. Default is './datasets/VOCdevkit' + split: string + 'train', 'val' or 'test' + transform : callable, optional + A function that transforms the image + Examples + -------- + >>> from torchvision import transforms + >>> import torch.utils.data as data + >>> # Transforms for Normalization + >>> input_transform = transforms.Compose([ + >>> transforms.ToTensor(), + >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), + >>> ]) + >>> # Create Dataset + >>> trainset = VOCSegmentation(split='train', transform=input_transform) + >>> # Create Training Loader + >>> train_data = data.DataLoader( + >>> trainset, 4, shuffle=True, + >>> num_workers=4) + """ + NUM_CLASS = 13 + + def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs): + super(VOCYJSSegmentation, self).__init__( + root, split, mode, transform, **kwargs) + _voc_root = root + txt_path = os.path.join(root, split+'.txt') + self._mask_LS_dir = os.path.join(_voc_root, 'mask_png') + self._image_LS_dir = os.path.join(_voc_root, "dataset_5m_jpg") + self.image_list = read_text(txt_path) + random.shuffle(self.image_list) + + def __getitem__(self, index): + img_name = self.image_list[index].split('.')[0]+'.jpg' + mask_name = self.image_list[index].split('.')[0]+'.png' + mask_name = mask_name.replace('img', 'mask') + img_LS = np.array(Image.open(os.path.join( + self._image_LS_dir, img_name))).astype(np.float32) + mask = np.array(Image.open(os.path.join( + self._mask_LS_dir, mask_name))).astype(np.int32) + mask = torch.from_numpy(mask).long() + + # synchronized transform + # 只包含两种模式: train 和 val + if self.mode == 'train': + img_LS, mask = self._sync_transform_tif_geofeat( + img_LS, mask) + elif self.mode == 'val': + img_LS, mask = self._sync_transform_tif_geofeat( + img_LS, mask) + # general resize, normalize and toTensor + if self.transform is not None: + img_LS = self.transform(img_LS) + return img_LS, mask + + def __len__(self): + return len(self.image_list) + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def classes(self): + """Category names.""" + return ('0', '1', '2', '3', '4', '5', '6') + + +def generator_list_of_imagepath(path): + image_list = [] + for image in os.listdir(path): + # print(path) + # print(image) + if not image == '.DS_Store' and 'tif' == image.split('.')[-1]: + image_list.append(image) + return image_list + + +def read_text(textfile): + list = [] + with open(textfile, "r") as lines: + for line in lines: + list.append(line.rstrip('\n')) + return list + + +def dataset_segmentation(textpath, imagepath, train_percent): + image_list = generator_list_of_imagepath(imagepath) + num = len(image_list) + list = range(num) + train_num = int(num * train_percent) # training set num + train_list = random.sample(list, train_num) + print("train set size", train_num) + ftrain = open(os.path.join(textpath, 'train.txt'), 'w') + fval = open(os.path.join(textpath, 'val.txt'), 'w') + for i in list: + name = image_list[i] + '\n' + if i in train_list: + ftrain.write(name) + else: + fval.write(name) + ftrain.close() + fval.close() + + +if __name__ == '__main__': + # path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images' + # list=generator_list_of_imagepath(path) + # print(list) + # 切割数据集 + + textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset' + imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE' + train_percent = 0.8 + dataset_segmentation(textpath, imagepath, train_percent) + # 显示各种图片 + + # img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif' + # img = gdal.Open(img).ReadAsArray().transpose(1,2,0) + # cv2.imshow('img', img) + # img = Image.fromarray (img,'RGB') + # img.show() + # img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif' + # img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8) + # img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC) + # img2 = Image.fromarray (img2,'RGB') + # img2.show() + # img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif' + # img3 = gdal.Open(img3).ReadAsArray() + # img3 = Image.fromarray (img3) + # img3.show() + + # dataset和dataloader的测试 + + # 测试dataloader能不能用 + ''' + data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset' + input_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) + dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224) + dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224) + train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4) + test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4) + ''' diff --git a/train_JL_jpg/train_JL.py b/train_JL_jpg/train_JL.py new file mode 100644 index 0000000..97952ad --- /dev/null +++ b/train_JL_jpg/train_JL.py @@ -0,0 +1,153 @@ +import os +import time +from datetime import datetime +from pathlib import Path + +import torch +from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large +from torchvision import transforms +import torch.utils.data as data +from torch import nn +import numpy as np + +from geotiff_utils import VOCYJSSegmentation +import utils as utils +import warnings +warnings.filterwarnings("ignore") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") + + parser.add_argument( + "--data-path", default=r"E:\RSdata\wlk_right_448", help="VOCdevkit root") + parser.add_argument("--num-classes", default=7, type=int) + parser.add_argument("--device", default="cuda", help="training device") + parser.add_argument("--batch-size", default=4, type=int) + parser.add_argument("--epochs", default=50, type=int, metavar="N", + help="number of total epochs to train") + parser.add_argument('--lr', default=0.005, type=float, + help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('-out-dir', type=str, + default='DeeplabV3_JL') + + args = parser.parse_args() + + return args + + +class DeeplabV3_JL_3(nn.Module): + def __init__(self, n_class): + super(DeeplabV3_JL_3, self).__init__() + self.n_class = n_class + + self.conv_fc = nn.Conv2d( + 21, self.n_class, kernel_size=(1, 1), stride=(1, 1)) + + self.seg = deeplabv3_resnet50(weights='DEFAULT') + + def forward(self, x): + x = self.seg(x)["out"] + x = self.conv_fc(x) + return x + + +def main(args): + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.485, .456, .406], + [.229, .224, .225]), + ]) + data_kwargs = {'transform': input_transform, + 'base_size': 448, 'crop_size': 448} + + # 读取geotiff数据,构建训练集、验证集 + train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train', + **data_kwargs) + val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val', + **data_kwargs) + num_workers = min( + [os.cpu_count(), args.batch_size if args.batch_size > 1 else 0]) + train_loader = data.DataLoader( + train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + val_loader = data.DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + + model = DeeplabV3_JL_3(n_class=args.num_classes) + model.to(device) + + criterion = torch.nn.CrossEntropyLoss(ignore_index=255) + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + + lr_scheduler = utils.create_lr_scheduler( + optimizer, len(train_loader), args.epochs, warmup=True) + + now = datetime.now() + date_time = now.strftime("%Y-%m-%d__%H-%M__") + out_dir = Path(os.path.join("./train_output", date_time + args.out_dir)) + if not out_dir.exists(): + out_dir.mkdir() + f = open(os.path.join(out_dir, "log.txt"), 'w') + start_time = time.time() + best_acc = 0 + for epoch in range(args.epochs): + print(f"Epoch {epoch+1}\n-------------------------------") + model.train() + for idx, (image, target) in enumerate(train_loader): + image, target = image.to( + device), target.to(device) + output = model(image) + + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + if idx % 100 == 0: + print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx, + len(train_loader), loss.item(), optimizer.param_groups[0]["lr"])) + + model.eval() + confmat = utils.ConfusionMatrix(args.num_classes) + with torch.no_grad(): + for image, target in val_loader: + image, target = image.to(device), target.to(device) + output = model(image) + + confmat.update(target.flatten(), output.argmax(1).flatten()) + + info, mIoU = confmat.get_info() + print(info) + + f.write(f"Epoch {epoch+1}\n-------------------------------\n") + f.write(info+"\n\n") + f.flush() + + # # 保存准确率最好的模型 + # if mIoU > best_acc: + # print("[Save model]") + # torch.save(model, os.path.join(out_dir, "best_mIoU.pth")) + # best_acc = mIoU + torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth")) + total_time = time.time() - start_time + print("total time:", total_time) + torch.save(model, os.path.join(out_dir, "last.pth")) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/train_JL_jpg/utils.py b/train_JL_jpg/utils.py new file mode 100644 index 0000000..b721720 --- /dev/null +++ b/train_JL_jpg/utils.py @@ -0,0 +1,330 @@ +import datetime +import errno +import os +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class ConfusionMatrix: + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.inference_mode(): + k = (a >= 0) & (a < n) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + reduce_across_processes(self.mat) + + def get_info(self): + acc_global, acc, iu = self.compute() + return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format( + acc_global.item() * 100, + [f"{i:.1f}" for i in (acc * 100).tolist()], + [f"{i:.1f}" for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ), iu.mean().item() * 100 + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", + "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) + return batched_imgs + + +def collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def create_lr_scheduler(optimizer, + num_step: int, + epochs: int, + warmup=True, + warmup_epochs=1, + warmup_factor=1e-3): + assert num_step > 0 and epochs > 0 + if warmup is False: + warmup_epochs = 0 + + def f(x): + """ + 根据step数返回一个学习率倍率因子, + 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 + """ + if warmup is True and x <= (warmup_epochs * num_step): + alpha = float(x) / (warmup_epochs * num_step) + # warmup过程中lr倍率因子从warmup_factor -> 1 + return warmup_factor * (1 - alpha) + alpha + else: + # warmup后lr倍率因子从1 -> 0 + # 参考deeplab_v2: Learning rate policy + return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) \ No newline at end of file diff --git a/geotiff_utils.py b/train_LS/geotiff_utils.py similarity index 100% rename from geotiff_utils.py rename to train_LS/geotiff_utils.py diff --git a/train_LS.py b/train_LS/train_LS.py similarity index 97% rename from train_LS.py rename to train_LS/train_LS.py index cf9f95a..b66137e 100644 --- a/train_LS.py +++ b/train_LS/train_LS.py @@ -11,7 +11,7 @@ from torch import nn import numpy as np from geotiff_utils import VOCYJSSegmentation -import utils +import utils as utils import warnings warnings.filterwarnings("ignore") @@ -21,7 +21,7 @@ def parse_args(): parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") parser.add_argument( - "--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root") + "--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root") parser.add_argument("--num-classes", default=13, type=int) parser.add_argument("--device", default="cuda", help="training device") parser.add_argument("--batch-size", default=8, type=int) diff --git a/train_LS/utils.py b/train_LS/utils.py new file mode 100644 index 0000000..b721720 --- /dev/null +++ b/train_LS/utils.py @@ -0,0 +1,330 @@ +import datetime +import errno +import os +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class ConfusionMatrix: + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.inference_mode(): + k = (a >= 0) & (a < n) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + reduce_across_processes(self.mat) + + def get_info(self): + acc_global, acc, iu = self.compute() + return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format( + acc_global.item() * 100, + [f"{i:.1f}" for i in (acc * 100).tolist()], + [f"{i:.1f}" for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ), iu.mean().item() * 100 + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", + "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) + return batched_imgs + + +def collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def create_lr_scheduler(optimizer, + num_step: int, + epochs: int, + warmup=True, + warmup_epochs=1, + warmup_factor=1e-3): + assert num_step > 0 and epochs > 0 + if warmup is False: + warmup_epochs = 0 + + def f(x): + """ + 根据step数返回一个学习率倍率因子, + 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 + """ + if warmup is True and x <= (warmup_epochs * num_step): + alpha = float(x) / (warmup_epochs * num_step) + # warmup过程中lr倍率因子从warmup_factor -> 1 + return warmup_factor * (1 - alpha) + alpha + else: + # warmup后lr倍率因子从1 -> 0 + # 参考deeplab_v2: Learning rate policy + return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) \ No newline at end of file diff --git a/train_LS_jpg/geotiff_utils.py b/train_LS_jpg/geotiff_utils.py new file mode 100644 index 0000000..26a6fe9 --- /dev/null +++ b/train_LS_jpg/geotiff_utils.py @@ -0,0 +1,270 @@ + +"""Pascal VOC Semantic Segmentation Dataset.""" +from PIL import Image, ImageOps, ImageFilter +import torchvision.transforms as transforms +import os +import torch +import numpy as np +from matplotlib import pyplot as plt +from PIL import Image +import cv2 +# import gdal +from osgeo import gdal +import random +import torch.utils.data as data +os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000') + + +class SegmentationDataset(object): + """Segmentation Base Dataset""" + + def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): + super(SegmentationDataset, self).__init__() + self.root = root + self.transform = transform + self.split = split + self.mode = mode if mode is not None else split + self.base_size = base_size + self.crop_size = crop_size + + def _val_sync_transform(self, img, mask): + outsize = self.crop_size + short_size = outsize + w, h = img.size + if w > h: + oh = short_size + ow = int(1.0 * w * oh / h) + else: + ow = short_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - outsize) / 2.)) + y1 = int(round((h - outsize) / 2.)) + img = img.crop((x1, y1, x1 + outsize, y1 + outsize)) + mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize)) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform(self, img, mask): + # random mirror + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + crop_size = self.crop_size + # random scale (short edge) + short_size = random.randint( + int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < crop_size: + padh = crop_size - oh if oh < crop_size else 0 + padw = crop_size - ow if ow < crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - crop_size) + y1 = random.randint(0, h - crop_size) + img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size)) + # gaussian blur as in PSP + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur(radius=random.random())) + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _sync_transform_tif_geofeat(self, img, mask): + # random mirror + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _val_sync_transform_tif(self, img, mask): + # final transform + img, mask = self._img_transform(img), self._mask_transform(mask) + return img, mask + + def _img_transform(self, img): + return np.array(img) + + # def _mask_transform(self, mask): + # return np.array(mask).astype('int32') + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def num_class(self): + """Number of categories.""" + return self.NUM_CLASS + + @property + def pred_offset(self): + return 0 + + +class VOCYJSSegmentation(SegmentationDataset): + """Pascal VOC Semantic Segmentation Dataset. + Parameters + ---------- + root : string + Path to VOCdevkit folder. Default is './datasets/VOCdevkit' + split: string + 'train', 'val' or 'test' + transform : callable, optional + A function that transforms the image + Examples + -------- + >>> from torchvision import transforms + >>> import torch.utils.data as data + >>> # Transforms for Normalization + >>> input_transform = transforms.Compose([ + >>> transforms.ToTensor(), + >>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]), + >>> ]) + >>> # Create Dataset + >>> trainset = VOCSegmentation(split='train', transform=input_transform) + >>> # Create Training Loader + >>> train_data = data.DataLoader( + >>> trainset, 4, shuffle=True, + >>> num_workers=4) + """ + NUM_CLASS = 13 + + def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs): + super(VOCYJSSegmentation, self).__init__( + root, split, mode, transform, **kwargs) + _voc_root = root + txt_path = os.path.join(root, split+'.txt') + self._mask_LS_dir = os.path.join(_voc_root, "masks_LS_png") + self._image_LS_dir = os.path.join(_voc_root, "images_LS_jpg") + self.image_list = read_text(txt_path) + random.shuffle(self.image_list) + + def __getitem__(self, index): + img_name = self.image_list[index].split('.')[0]+'.jpg' + mask_name = self.image_list[index].split('.')[0]+'.png' + img_LS = np.array(Image.open(os.path.join( + self._image_LS_dir, img_name))).astype(np.float32) + mask = np.array(Image.open(os.path.join( + self._mask_LS_dir, mask_name))).astype(np.int32) + mask = torch.from_numpy(mask).long() + + # general resize, normalize and toTensor + if self.transform is not None: + img_LS = self.transform(img_LS) + return img_LS, mask + + def __len__(self): + return len(self.image_list) + + def _mask_transform(self, mask): + target = np.array(mask).astype('int32') + # target = target[np.newaxis, :] + target[target > 12] = 255 + return torch.from_numpy(target).long() + + @property + def classes(self): + """Category names.""" + return ('0', '1', '2', '3', '4', '5', '6', '7' '8', '9', '10', '11', '12') + + +def generator_list_of_imagepath(path): + image_list = [] + for image in os.listdir(path): + # print(path) + # print(image) + if not image == '.DS_Store' and 'tif' == image.split('.')[-1]: + image_list.append(image) + return image_list + + +def read_text(textfile): + list = [] + with open(textfile, "r") as lines: + for line in lines: + list.append(line.rstrip('\n')) + return list + + +def dataset_segmentation(textpath, imagepath, train_percent): + image_list = generator_list_of_imagepath(imagepath) + num = len(image_list) + list = range(num) + train_num = int(num * train_percent) # training set num + train_list = random.sample(list, train_num) + print("train set size", train_num) + ftrain = open(os.path.join(textpath, 'train.txt'), 'w') + fval = open(os.path.join(textpath, 'val.txt'), 'w') + for i in list: + name = image_list[i] + '\n' + if i in train_list: + ftrain.write(name) + else: + fval.write(name) + ftrain.close() + fval.close() + + +if __name__ == '__main__': + # path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images' + # list=generator_list_of_imagepath(path) + # print(list) + # 切割数据集 + + textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset' + imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE' + train_percent = 0.8 + dataset_segmentation(textpath, imagepath, train_percent) + # 显示各种图片 + + # img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif' + # img = gdal.Open(img).ReadAsArray().transpose(1,2,0) + # cv2.imshow('img', img) + # img = Image.fromarray (img,'RGB') + # img.show() + # img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif' + # img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8) + # img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC) + # img2 = Image.fromarray (img2,'RGB') + # img2.show() + # img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif' + # img3 = gdal.Open(img3).ReadAsArray() + # img3 = Image.fromarray (img3) + # img3.show() + + # dataset和dataloader的测试 + + # 测试dataloader能不能用 + ''' + data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset' + input_transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])]) + dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224) + dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224) + train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4) + test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4) + ''' diff --git a/train_LS_jpg/train.py b/train_LS_jpg/train.py new file mode 100644 index 0000000..5e5cf82 --- /dev/null +++ b/train_LS_jpg/train.py @@ -0,0 +1,152 @@ +import os +import time +from datetime import datetime +from pathlib import Path + +import torch +from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large +from torchvision import transforms +import torch.utils.data as data +from torch import nn +import numpy as np + +from geotiff_utils import VOCYJSSegmentation +import utils +import warnings +warnings.filterwarnings("ignore") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") + + parser.add_argument( + "--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root") + parser.add_argument("--num-classes", default=13, type=int) + parser.add_argument("--device", default="cuda", help="training device") + parser.add_argument("--batch-size", default=8, type=int) + parser.add_argument("--epochs", default=50, type=int, metavar="N", + help="number of total epochs to train") + parser.add_argument('--lr', default=0.005, type=float, + help='initial learning rate') + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', + help='momentum') + parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, + metavar='W', help='weight decay (default: 1e-4)', + dest='weight_decay') + parser.add_argument('-out-dir', type=str, + default='DeeplabV3_LS_3') + + args = parser.parse_args() + + return args + + +class DeeplabV3_LS_3(nn.Module): + def __init__(self, n_class): + super(DeeplabV3_LS_3, self).__init__() + self.n_class = n_class + self.conv_fc = nn.Conv2d( + 21, self.n_class, kernel_size=(1, 1), stride=(1, 1)) + + self.seg = deeplabv3_resnet50(weights='DEFAULT') + + def forward(self, x): + x = self.seg(x)["out"] + x = self.conv_fc(x) + return x + + +def main(args): + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + + input_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([.485, .456, .406], + [.229, .224, .225]), + ]) + data_kwargs = {'transform': input_transform, + 'base_size': 224, 'crop_size': 224} + + # 读取geotiff数据,构建训练集、验证集 + train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train', + **data_kwargs) + val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val', + **data_kwargs) + num_workers = min( + [os.cpu_count(), args.batch_size if args.batch_size > 1 else 0]) + train_loader = data.DataLoader( + train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + val_loader = data.DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) + + model = DeeplabV3_LS_3(n_class=args.num_classes) + model.to(device) + + criterion = torch.nn.CrossEntropyLoss(ignore_index=255) + optimizer = torch.optim.SGD( + model.parameters(), + lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay + ) + + lr_scheduler = utils.create_lr_scheduler( + optimizer, len(train_loader), args.epochs, warmup=True) + + now = datetime.now() + date_time = now.strftime("%Y-%m-%d__%H-%M__") + out_dir = Path(os.path.join("./train_output", date_time + args.out_dir)) + if not out_dir.exists(): + out_dir.mkdir() + f = open(os.path.join(out_dir, "log.txt"), 'w') + start_time = time.time() + best_acc = 0 + for epoch in range(args.epochs): + print(f"Epoch {epoch+1}\n-------------------------------") + model.train() + for idx, (image, target) in enumerate(train_loader): + image = image.to(device) + target = target.to(device) + output = model(image) + + loss = criterion(output, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + lr_scheduler.step() + + if idx % 100 == 0: + print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx, + len(train_loader), loss.item(), optimizer.param_groups[0]["lr"])) + + model.eval() + confmat = utils.ConfusionMatrix(args.num_classes) + with torch.no_grad(): + for image, target in val_loader: + image, target = image.to(device), target.to(device) + output = model(image) + + confmat.update(target.flatten(), output.argmax(1).flatten()) + + info, mIoU = confmat.get_info() + print(info) + + f.write(f"Epoch {epoch+1}\n-------------------------------\n") + f.write(info+"\n\n") + f.flush() + + # # 保存准确率最好的模型 + # if mIoU > best_acc: + # print("[Save model]") + # torch.save(model, os.path.join(out_dir, "best_mIoU.pth")) + # best_acc = mIoU + torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth")) + total_time = time.time() - start_time + print("total time:", total_time) + torch.save(model, os.path.join(out_dir, "last.pth")) + + +if __name__ == '__main__': + args = parse_args() + + main(args) diff --git a/train_LS_jpg/utils.py b/train_LS_jpg/utils.py new file mode 100644 index 0000000..b721720 --- /dev/null +++ b/train_LS_jpg/utils.py @@ -0,0 +1,330 @@ +import datetime +import errno +import os +import time +from collections import defaultdict, deque + +import torch +import torch.distributed as dist + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + t = reduce_across_processes([self.count, self.total]) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value + ) + + +class ConfusionMatrix: + def __init__(self, num_classes): + self.num_classes = num_classes + self.mat = None + + def update(self, a, b): + n = self.num_classes + if self.mat is None: + self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device) + with torch.inference_mode(): + k = (a >= 0) & (a < n) + inds = n * a[k].to(torch.int64) + b[k] + self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n) + + def reset(self): + self.mat.zero_() + + def compute(self): + h = self.mat.float() + acc_global = torch.diag(h).sum() / h.sum() + acc = torch.diag(h) / h.sum(1) + iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h)) + return acc_global, acc, iu + + def reduce_from_all_processes(self): + reduce_across_processes(self.mat) + + def get_info(self): + acc_global, acc, iu = self.compute() + return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format( + acc_global.item() * 100, + [f"{i:.1f}" for i in (acc * 100).tolist()], + [f"{i:.1f}" for i in (iu * 100).tolist()], + iu.mean().item() * 100, + ), iu.mean().item() * 100 + + +class MetricLogger: + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{attr}'") + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append(f"{name}: {str(meter)}") + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.4f}") + data_time = SmoothedValue(fmt="{avg:.4f}") + space_fmt = ":" + str(len(str(len(iterable)))) + "d" + if torch.cuda.is_available(): + log_msg = self.delimiter.join( + [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + "max mem: {memory:.0f}", + ] + ) + else: + log_msg = self.delimiter.join( + [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", + "{meters}", "time: {time}", "data: {data}"] + ) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print( + log_msg.format( + i, + len(iterable), + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + print( + log_msg.format( + i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time) + ) + ) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print(f"{header} Total time: {total_time_str}") + + +def cat_list(images, fill_value=0): + max_size = tuple(max(s) for s in zip(*[img.shape for img in images])) + batch_shape = (len(images),) + max_size + batched_imgs = images[0].new(*batch_shape).fill_(fill_value) + for img, pad_img in zip(images, batched_imgs): + pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img) + return batched_imgs + + +def collate_fn(batch): + images, targets = list(zip(*batch)) + batched_imgs = cat_list(images, fill_value=0) + batched_targets = cat_list(targets, fill_value=255) + return batched_imgs, batched_targets + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.gpu = int(os.environ["LOCAL_RANK"]) + # elif "SLURM_PROCID" in os.environ: + # args.rank = int(os.environ["SLURM_PROCID"]) + # args.gpu = args.rank % torch.cuda.device_count() + elif hasattr(args, "rank"): + pass + else: + print("Not using distributed mode") + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = "nccl" + print( + f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank + ) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +def reduce_across_processes(val): + if not is_dist_avail_and_initialized(): + # nothing to sync, but we still convert to tensor for consistency with the distributed case. + return torch.tensor(val) + + t = torch.tensor(val, device="cuda") + dist.barrier() + dist.all_reduce(t) + return t + + +def create_lr_scheduler(optimizer, + num_step: int, + epochs: int, + warmup=True, + warmup_epochs=1, + warmup_factor=1e-3): + assert num_step > 0 and epochs > 0 + if warmup is False: + warmup_epochs = 0 + + def f(x): + """ + 根据step数返回一个学习率倍率因子, + 注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法 + """ + if warmup is True and x <= (warmup_epochs * num_step): + alpha = float(x) / (warmup_epochs * num_step) + # warmup过程中lr倍率因子从warmup_factor -> 1 + return warmup_factor * (1 - alpha) + alpha + else: + # warmup后lr倍率因子从1 -> 0 + # 参考deeplab_v2: Learning rate policy + return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9 + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f) \ No newline at end of file