加入jpg图像训练代码

This commit is contained in:
weixin_46229132 2025-05-26 09:33:01 +08:00
parent 2a40d11a2a
commit 704c24d79d
21 changed files with 2604 additions and 65 deletions

1
.gitignore vendored
View File

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
/train_output

View File

@ -0,0 +1,138 @@
import numpy as np
from osgeo import gdal
def read_tif(fileName):
dataset = gdal.Open(fileName)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_bands = dataset.RasterCount # 波段数
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
if len(im_data.shape) == 2:
im_data = im_data[np.newaxis, :]
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
im_proj = dataset.GetProjection() # 获取投影信息
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if dataset != None and im_geotrans != None and im_proj != None:
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
fileName)
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
fileName2)
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
mask_fileName)
mask_im_data = np.int8(mask_im_data)
# geotiff归一化
for i in range(im_bands):
arr = im_data[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data[i] = normalized_arr
for i in range(im_bands2):
arr = im_data2[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data2[i] = normalized_arr
# 计算大图每个波段的均值和方差train.py里transform会用到
im_data = im_data/255
for i in range(im_bands):
pixels = im_data[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data = im_data*255
im_data2 = im_data2/255
for i in range(im_bands2):
pixels = im_data2[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data2 = im_data2*255
# 切成小图
a = 0
size = 224
for i in range(0, int(mask_im_height / size)):
for j in range(0, int(mask_im_width / size)):
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
# 以mask为判断基准同时处理geotiff和mask
labelfla = np.array(mask_cut).flatten()
if np.all(labelfla == 15): # 15为NoData
print("Skip!!!")
else:
# 5m
left_h = i * size * im_geotrans[5] + im_geotrans[3]
left_w = j * size * im_geotrans[1] + im_geotrans[0]
new_geotrans = np.array(im_geotrans)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
# dan
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
new_geotrans = np.array(im_geotrans2)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
# 存mask
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
mask_new_geotrans = np.array(mask_im_geotrans)
mask_new_geotrans[0] = mask_left_w
mask_new_geotrans[3] = mask_left_h
mask_out_geotrans = tuple(mask_new_geotrans)
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
write_tif(mask_cut, size, size, mask_out,
mask_out_geotrans, mask_im_proj)
print(mask_out + 'Cut to complete')
a = a+1

View File

@ -44,9 +44,6 @@ fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
fileName)
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
fileName2)
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
@ -55,19 +52,14 @@ mask_im_data = np.int8(mask_im_data)
# geotiff归一化
lower_percent = 2
upper_percent = 98
for i in range(im_bands):
arr = im_data[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data[i] = normalized_arr
for i in range(im_bands2):
arr = im_data2[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data2[i] = normalized_arr
lower = np.percentile(arr, lower_percent)
upper = np.percentile(arr, upper_percent)
stretched = np.clip((arr - lower) / (upper - lower), 0, 1)
im_data[i] = (stretched * 255).astype(np.uint8)
# 计算大图每个波段的均值和方差train.py里transform会用到
@ -78,61 +70,42 @@ for i in range(im_bands):
i, np.mean(pixels), np.std(pixels)))
im_data = im_data*255
im_data2 = im_data2/255
for i in range(im_bands2):
pixels = im_data2[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data2 = im_data2*255
# 切成小图
a = 0
size = 224
size = 448
for i in range(0, int(mask_im_height / size)):
for j in range(0, int(mask_im_width / size)):
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size:i *
size + size, j * size:j * size + size]
# 以mask为判断基准同时处理geotiff和mask
labelfla = np.array(mask_cut).flatten()
if np.all(labelfla == 15): # 15为NoData
print("Skip!!!")
else:
# 5m
left_h = i * size * im_geotrans[5] + im_geotrans[3]
left_w = j * size * im_geotrans[1] + im_geotrans[0]
new_geotrans = np.array(im_geotrans)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
# 取5、4、3波段注意顺序
rgb_cut = np.stack([
im_cut[4], # 第5波段
im_cut[3], # 第4波段
im_cut[2], # 第3波段
], axis=0) # shape: (3, H, W)
# 转为 (H, W, 3)
rgb_cut = np.transpose(rgb_cut, (1, 2, 0))
# 归一化到0-255并转uint8
rgb_cut = np.clip(rgb_cut, 0, 255)
rgb_cut = rgb_cut.astype(np.uint8)
# 保存为jpg
from PIL import Image
rgb_img = Image.fromarray(rgb_cut)
rgb_img.save(
f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg')
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
# dan
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
new_geotrans = np.array(im_geotrans2)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
# 存mask
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
mask_new_geotrans = np.array(mask_im_geotrans)
mask_new_geotrans[0] = mask_left_w
mask_new_geotrans[3] = mask_left_h
mask_out_geotrans = tuple(mask_new_geotrans)
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
write_tif(mask_cut, size, size, mask_out,
mask_out_geotrans, mask_im_proj)
print(mask_out + 'Cut to complete')
# mask只取第一个波段如果是单通道转uint8保存为png
mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut
mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8)
mask_img = Image.fromarray(mask_arr)
mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png')
print(f'img_{a}.jpg and mask_{a}.png saved')
a = a+1

View File

@ -4,18 +4,18 @@ import random
random.seed(42)
geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m')
geotiffs = os.listdir(r'E:\RSdata\wlk_right_448\dataset_5m_jpg')
num = len(geotiffs)
split_rate = 0.2
eval_index = random.sample(geotiffs, k=int(num*split_rate))
f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w')
f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w')
f_train = open(r'E:\RSdata\wlk_right_448/train.txt', 'w')
f_val = open(r'E:\RSdata\wlk_right_448/val.txt', 'w')
# 写入文件
for geotiff in geotiffs:
if geotiff in eval_index:
f_val.write(str(geotiff)+'\n')
f_train.write(str(geotiff)+'\n')
else:
f_train.write(str(geotiff)+'\n')
f_val.write(str(geotiff)+'\n')

View File

@ -0,0 +1,62 @@
import os
import shutil
# 定义文件夹路径
base_dir = r"e:\datasets\wlk_right_448"
jpeg_images_dir = os.path.join(base_dir, "JPEGImages")
segmentation_class_dir = os.path.join(base_dir, "SegmentationClass")
annotations_dir = os.path.join(base_dir, "stare_stuct", "annotations")
images_dir = os.path.join(base_dir, "stare_stuct", "images")
# 定义目标文件夹
annotations_training_dir = os.path.join(annotations_dir, "training")
annotations_validation_dir = os.path.join(annotations_dir, "validation")
images_training_dir = os.path.join(images_dir, "training")
images_validation_dir = os.path.join(images_dir, "validation")
# 创建目标文件夹
os.makedirs(annotations_training_dir, exist_ok=True)
os.makedirs(annotations_validation_dir, exist_ok=True)
os.makedirs(images_training_dir, exist_ok=True)
os.makedirs(images_validation_dir, exist_ok=True)
# 读取 train.txt 和 val.txt
train_file = os.path.join(base_dir, "train.txt")
val_file = os.path.join(base_dir, "val.txt")
def read_file_list(file_path):
with open(file_path, "r") as f:
return [line.strip() for line in f.readlines()]
train_list = read_file_list(train_file)
val_list = read_file_list(val_file)
# 移动文件函数
def move_files(file_list, src_images_dir, src_labels_dir, dst_images_dir, dst_labels_dir):
for file_name in file_list:
# 图片文件
image_src = os.path.join(src_images_dir, file_name)
image_dst = os.path.join(dst_images_dir, file_name)
if os.path.exists(image_src):
shutil.copy(image_src, image_dst)
# 标签文件
label_src = os.path.join(src_labels_dir, file_name)
label_dst = os.path.join(dst_labels_dir, file_name)
if os.path.exists(label_src):
shutil.copy(label_src, label_dst)
# 移动训练集文件
move_files(train_list, jpeg_images_dir, segmentation_class_dir,
images_training_dir, annotations_training_dir)
# 移动验证集文件
move_files(val_list, jpeg_images_dir, segmentation_class_dir,
images_validation_dir, annotations_validation_dir)
print("文件组织完成!")

View File

@ -0,0 +1,22 @@
from PIL import Image
import numpy as np
import os
from osgeo import gdal
mask_dir = r"E:\datasets\wlk_right_448\mask" # 修改为你的mask文件夹路径
all_labels = set()
for file in os.listdir(mask_dir):
if file.lower().endswith('.tif'):
tif_path = os.path.join(mask_dir, file)
dataset = gdal.Open(tif_path)
if dataset is None:
print(f"无法打开: {tif_path}")
continue
band = dataset.ReadAsArray()
unique = np.unique(band)
all_labels.update(unique)
print("所有mask中出现过的标签数字", sorted(all_labels))

View File

@ -0,0 +1,6 @@
input_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val.txt"
output_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val_no.txt"
with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
for line in fin:
fout.write(line.strip().replace(".tif", "") + "\n")

View File

@ -0,0 +1,30 @@
import os
import numpy as np
from PIL import Image
from osgeo import gdal
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS"
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS_png"
os.makedirs(dst_dir, exist_ok=True)
for file in os.listdir(src_dir):
if file.lower().endswith('.tif'):
tif_path = os.path.join(src_dir, file)
dataset = gdal.Open(tif_path)
if dataset is None:
print(f"无法打开: {tif_path}")
continue
mask = dataset.ReadAsArray()
if mask.ndim != 2:
print(f"{file} 不是单波段,跳过")
continue
# 替换像素值
mask = mask.copy()
mask[mask == 15] = 255
png_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".png")
Image.fromarray(mask.astype(np.uint8)).save(png_path)
print(f"已保存: {png_path}")
print("全部转换完成!")

View File

@ -0,0 +1,33 @@
import os
import numpy as np
from PIL import Image
from osgeo import gdal
# 输入和输出文件夹
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS"
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS_jpg"
os.makedirs(dst_dir, exist_ok=True)
for file in os.listdir(src_dir):
if file.lower().endswith('.tif'):
tif_path = os.path.join(src_dir, file)
dataset = gdal.Open(tif_path)
if dataset is None:
print(f"无法打开: {tif_path}")
continue
# 读取所有波段
bands = []
for i in range(1, dataset.RasterCount + 1):
band = dataset.GetRasterBand(i).ReadAsArray()
bands.append(band)
img = np.stack(bands, axis=-1) if len(bands) > 1 else bands[0]
# 转换为uint8
img = img.astype(np.uint8)
jpg_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".jpg")
Image.fromarray(img).save(jpg_path, quality=95)
print(f"已保存: {jpg_path}")
print("全部转换完成!")

274
train_JL/geotiff_utils.py Normal file
View File

@ -0,0 +1,274 @@
"""Pascal VOC Semantic Segmentation Dataset."""
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import os
import torch
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# import gdal
from osgeo import gdal
import random
import torch.utils.data as data
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
class SegmentationDataset(object):
"""Segmentation Base Dataset"""
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
super(SegmentationDataset, self).__init__()
self.root = root
self.transform = transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
# random scale (short edge)
short_size = random.randint(
int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
# gaussian blur as in PSP
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif_geofeat(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _val_sync_transform_tif(self, img, mask):
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _img_transform(self, img):
return np.array(img)
# def _mask_transform(self, mask):
# return np.array(mask).astype('int32')
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def num_class(self):
"""Number of categories."""
return self.NUM_CLASS
@property
def pred_offset(self):
return 0
class VOCYJSSegmentation(SegmentationDataset):
"""Pascal VOC Semantic Segmentation Dataset.
Parameters
----------
root : string
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
>>> ])
>>> # Create Dataset
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
NUM_CLASS = 13
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
super(VOCYJSSegmentation, self).__init__(
root, split, mode, transform, **kwargs)
_voc_root = root
txt_path = os.path.join(root, split+'.txt')
self._mask_LS_dir = os.path.join(_voc_root, 'mask')
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m")
self.image_list = read_text(txt_path)
random.shuffle(self.image_list)
def __getitem__(self, index):
img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray(
).transpose(1, 2, 0).astype(np.float32)
mask = gdal.Open(os.path.join(self._mask_LS_dir,
self.image_list[index])).ReadAsArray()
# synchronized transform
# 只包含两种模式: train 和 val
if self.mode == 'train':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
elif self.mode == 'val':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
# general resize, normalize and toTensor
if self.transform is not None:
img_LS = self.transform(img_LS)
return img_LS, mask
def __len__(self):
return len(self.image_list)
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def classes(self):
"""Category names."""
return ('0', '1', '2', '3', '4', '5', '6')
def generator_list_of_imagepath(path):
image_list = []
for image in os.listdir(path):
# print(path)
# print(image)
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
image_list.append(image)
return image_list
def read_text(textfile):
list = []
with open(textfile, "r") as lines:
for line in lines:
list.append(line.rstrip('\n'))
return list
def dataset_segmentation(textpath, imagepath, train_percent):
image_list = generator_list_of_imagepath(imagepath)
num = len(image_list)
list = range(num)
train_num = int(num * train_percent) # training set num
train_list = random.sample(list, train_num)
print("train set size", train_num)
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
fval = open(os.path.join(textpath, 'val.txt'), 'w')
for i in list:
name = image_list[i] + '\n'
if i in train_list:
ftrain.write(name)
else:
fval.write(name)
ftrain.close()
fval.close()
if __name__ == '__main__':
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
# list=generator_list_of_imagepath(path)
# print(list)
# 切割数据集
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
train_percent = 0.8
dataset_segmentation(textpath, imagepath, train_percent)
# 显示各种图片
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
# cv2.imshow('img', img)
# img = Image.fromarray (img,'RGB')
# img.show()
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
# img2 = Image.fromarray (img2,'RGB')
# img2.show()
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
# img3 = gdal.Open(img3).ReadAsArray()
# img3 = Image.fromarray (img3)
# img3.show()
# dataset和dataloader的测试
# 测试dataloader能不能用
'''
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
input_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
'''

156
train_JL/train_JL.py Normal file
View File

@ -0,0 +1,156 @@
import os
import time
from datetime import datetime
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
from torchvision import transforms
import torch.utils.data as data
from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
import utils as utils
import warnings
warnings.filterwarnings("ignore")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
"--data-path", default=r"E:\datasets\wlk_right_448", help="VOCdevkit root")
parser.add_argument("--num-classes", default=7, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=50, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-out-dir', type=str,
default='DeeplabV3_JL')
args = parser.parse_args()
return args
class DeeplabV3_JL(nn.Module):
def __init__(self, n_class):
super(DeeplabV3_JL, self).__init__()
self.n_class = n_class
self.conv6_3 = nn.Conv2d(6, 3, kernel_size=1, stride=1)
self.conv_fc = nn.Conv2d(
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
self.seg = deeplabv3_resnet50(weights='DEFAULT')
def forward(self, x):
# x = torch.cat([x, x_dan], dim=1)
x = self.conv6_3(x)
x = self.seg(x)["out"]
x = self.conv_fc(x)
return x
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.535, .767, .732, .561, .494, .564],
[.0132, .0188, .0181, .0173, .0183, .0259]),
])
data_kwargs = {'transform': input_transform,
'base_size': 448, 'crop_size': 448}
# 读取geotiff数据构建训练集、验证集
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
**data_kwargs)
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
**data_kwargs)
num_workers = min(
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
model = DeeplabV3_JL(n_class=args.num_classes)
model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = utils.create_lr_scheduler(
optimizer, len(train_loader), args.epochs, warmup=True)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d__%H-%M__")
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
if not out_dir.exists():
out_dir.mkdir()
f = open(os.path.join(out_dir, "log.txt"), 'w')
start_time = time.time()
best_acc = 0
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
model.train()
for idx, (image, target) in enumerate(train_loader):
image, target = image.to(
device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if idx % 100 == 0:
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
model.eval()
confmat = utils.ConfusionMatrix(args.num_classes)
with torch.no_grad():
for image, target in val_loader:
image, target = image.to(device), target.to(device)
output = model(image)
confmat.update(target.flatten(), output.argmax(1).flatten())
info, mIoU = confmat.get_info()
print(info)
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
f.write(info+"\n\n")
f.flush()
# # 保存准确率最好的模型
# if mIoU > best_acc:
# print("[Save model]")
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
# best_acc = mIoU
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
total_time = time.time() - start_time
print("total time:", total_time)
torch.save(model, os.path.join(out_dir, "last.pth"))
if __name__ == '__main__':
args = parse_args()
main(args)

View File

@ -0,0 +1,279 @@
"""Pascal VOC Semantic Segmentation Dataset."""
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import os
import torch
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# import gdal
from osgeo import gdal
import random
import torch.utils.data as data
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
class SegmentationDataset(object):
"""Segmentation Base Dataset"""
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
super(SegmentationDataset, self).__init__()
self.root = root
self.transform = transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
# random scale (short edge)
short_size = random.randint(
int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
# gaussian blur as in PSP
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif_geofeat(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _val_sync_transform_tif(self, img, mask):
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _img_transform(self, img):
return np.array(img)
# def _mask_transform(self, mask):
# return np.array(mask).astype('int32')
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def num_class(self):
"""Number of categories."""
return self.NUM_CLASS
@property
def pred_offset(self):
return 0
class VOCYJSSegmentation(SegmentationDataset):
"""Pascal VOC Semantic Segmentation Dataset.
Parameters
----------
root : string
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
>>> ])
>>> # Create Dataset
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
NUM_CLASS = 13
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
super(VOCYJSSegmentation, self).__init__(
root, split, mode, transform, **kwargs)
_voc_root = root
txt_path = os.path.join(root, split+'.txt')
self._mask_LS_dir = os.path.join(_voc_root, 'mask_png')
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m_jpg")
self.image_list = read_text(txt_path)
random.shuffle(self.image_list)
def __getitem__(self, index):
img_name = self.image_list[index].split('.')[0]+'.jpg'
mask_name = self.image_list[index].split('.')[0]+'.png'
mask_name = mask_name.replace('img', 'mask')
img_LS = np.array(Image.open(os.path.join(
self._image_LS_dir, img_name))).astype(np.float32)
mask = np.array(Image.open(os.path.join(
self._mask_LS_dir, mask_name))).astype(np.int32)
mask = torch.from_numpy(mask).long()
# synchronized transform
# 只包含两种模式: train 和 val
if self.mode == 'train':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
elif self.mode == 'val':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
# general resize, normalize and toTensor
if self.transform is not None:
img_LS = self.transform(img_LS)
return img_LS, mask
def __len__(self):
return len(self.image_list)
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def classes(self):
"""Category names."""
return ('0', '1', '2', '3', '4', '5', '6')
def generator_list_of_imagepath(path):
image_list = []
for image in os.listdir(path):
# print(path)
# print(image)
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
image_list.append(image)
return image_list
def read_text(textfile):
list = []
with open(textfile, "r") as lines:
for line in lines:
list.append(line.rstrip('\n'))
return list
def dataset_segmentation(textpath, imagepath, train_percent):
image_list = generator_list_of_imagepath(imagepath)
num = len(image_list)
list = range(num)
train_num = int(num * train_percent) # training set num
train_list = random.sample(list, train_num)
print("train set size", train_num)
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
fval = open(os.path.join(textpath, 'val.txt'), 'w')
for i in list:
name = image_list[i] + '\n'
if i in train_list:
ftrain.write(name)
else:
fval.write(name)
ftrain.close()
fval.close()
if __name__ == '__main__':
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
# list=generator_list_of_imagepath(path)
# print(list)
# 切割数据集
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
train_percent = 0.8
dataset_segmentation(textpath, imagepath, train_percent)
# 显示各种图片
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
# cv2.imshow('img', img)
# img = Image.fromarray (img,'RGB')
# img.show()
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
# img2 = Image.fromarray (img2,'RGB')
# img2.show()
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
# img3 = gdal.Open(img3).ReadAsArray()
# img3 = Image.fromarray (img3)
# img3.show()
# dataset和dataloader的测试
# 测试dataloader能不能用
'''
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
input_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
'''

153
train_JL_jpg/train_JL.py Normal file
View File

@ -0,0 +1,153 @@
import os
import time
from datetime import datetime
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
from torchvision import transforms
import torch.utils.data as data
from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
import utils as utils
import warnings
warnings.filterwarnings("ignore")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
"--data-path", default=r"E:\RSdata\wlk_right_448", help="VOCdevkit root")
parser.add_argument("--num-classes", default=7, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("--batch-size", default=4, type=int)
parser.add_argument("--epochs", default=50, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-out-dir', type=str,
default='DeeplabV3_JL')
args = parser.parse_args()
return args
class DeeplabV3_JL_3(nn.Module):
def __init__(self, n_class):
super(DeeplabV3_JL_3, self).__init__()
self.n_class = n_class
self.conv_fc = nn.Conv2d(
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
self.seg = deeplabv3_resnet50(weights='DEFAULT')
def forward(self, x):
x = self.seg(x)["out"]
x = self.conv_fc(x)
return x
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406],
[.229, .224, .225]),
])
data_kwargs = {'transform': input_transform,
'base_size': 448, 'crop_size': 448}
# 读取geotiff数据构建训练集、验证集
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
**data_kwargs)
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
**data_kwargs)
num_workers = min(
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
model = DeeplabV3_JL_3(n_class=args.num_classes)
model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = utils.create_lr_scheduler(
optimizer, len(train_loader), args.epochs, warmup=True)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d__%H-%M__")
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
if not out_dir.exists():
out_dir.mkdir()
f = open(os.path.join(out_dir, "log.txt"), 'w')
start_time = time.time()
best_acc = 0
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
model.train()
for idx, (image, target) in enumerate(train_loader):
image, target = image.to(
device), target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if idx % 100 == 0:
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
model.eval()
confmat = utils.ConfusionMatrix(args.num_classes)
with torch.no_grad():
for image, target in val_loader:
image, target = image.to(device), target.to(device)
output = model(image)
confmat.update(target.flatten(), output.argmax(1).flatten())
info, mIoU = confmat.get_info()
print(info)
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
f.write(info+"\n\n")
f.flush()
# # 保存准确率最好的模型
# if mIoU > best_acc:
# print("[Save model]")
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
# best_acc = mIoU
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
total_time = time.time() - start_time
print("total time:", total_time)
torch.save(model, os.path.join(out_dir, "last.pth"))
if __name__ == '__main__':
args = parse_args()
main(args)

330
train_JL_jpg/utils.py Normal file
View File

@ -0,0 +1,330 @@
import datetime
import errno
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class ConfusionMatrix:
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.inference_mode():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
def reduce_from_all_processes(self):
reduce_across_processes(self.mat)
def get_info(self):
acc_global, acc, iu = self.compute()
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
acc_global.item() * 100,
[f"{i:.1f}" for i in (acc * 100).tolist()],
[f"{i:.1f}" for i in (iu * 100).tolist()],
iu.mean().item() * 100,
), iu.mean().item() * 100
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
"{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
return batched_imgs
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def create_lr_scheduler(optimizer,
num_step: int,
epochs: int,
warmup=True,
warmup_epochs=1,
warmup_factor=1e-3):
assert num_step > 0 and epochs > 0
if warmup is False:
warmup_epochs = 0
def f(x):
"""
根据step数返回一个学习率倍率因子
注意在训练开始之前pytorch会提前调用一次lr_scheduler.step()方法
"""
if warmup is True and x <= (warmup_epochs * num_step):
alpha = float(x) / (warmup_epochs * num_step)
# warmup过程中lr倍率因子从warmup_factor -> 1
return warmup_factor * (1 - alpha) + alpha
else:
# warmup后lr倍率因子从1 -> 0
# 参考deeplab_v2: Learning rate policy
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)

View File

@ -11,7 +11,7 @@ from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
import utils
import utils as utils
import warnings
warnings.filterwarnings("ignore")
@ -21,7 +21,7 @@ def parse_args():
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
"--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root")
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
parser.add_argument("--num-classes", default=13, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("--batch-size", default=8, type=int)

330
train_LS/utils.py Normal file
View File

@ -0,0 +1,330 @@
import datetime
import errno
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class ConfusionMatrix:
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.inference_mode():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
def reduce_from_all_processes(self):
reduce_across_processes(self.mat)
def get_info(self):
acc_global, acc, iu = self.compute()
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
acc_global.item() * 100,
[f"{i:.1f}" for i in (acc * 100).tolist()],
[f"{i:.1f}" for i in (iu * 100).tolist()],
iu.mean().item() * 100,
), iu.mean().item() * 100
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
"{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
return batched_imgs
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def create_lr_scheduler(optimizer,
num_step: int,
epochs: int,
warmup=True,
warmup_epochs=1,
warmup_factor=1e-3):
assert num_step > 0 and epochs > 0
if warmup is False:
warmup_epochs = 0
def f(x):
"""
根据step数返回一个学习率倍率因子
注意在训练开始之前pytorch会提前调用一次lr_scheduler.step()方法
"""
if warmup is True and x <= (warmup_epochs * num_step):
alpha = float(x) / (warmup_epochs * num_step)
# warmup过程中lr倍率因子从warmup_factor -> 1
return warmup_factor * (1 - alpha) + alpha
else:
# warmup后lr倍率因子从1 -> 0
# 参考deeplab_v2: Learning rate policy
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)

View File

@ -0,0 +1,270 @@
"""Pascal VOC Semantic Segmentation Dataset."""
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import os
import torch
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# import gdal
from osgeo import gdal
import random
import torch.utils.data as data
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
class SegmentationDataset(object):
"""Segmentation Base Dataset"""
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
super(SegmentationDataset, self).__init__()
self.root = root
self.transform = transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
# random scale (short edge)
short_size = random.randint(
int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
# gaussian blur as in PSP
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif_geofeat(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _val_sync_transform_tif(self, img, mask):
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _img_transform(self, img):
return np.array(img)
# def _mask_transform(self, mask):
# return np.array(mask).astype('int32')
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def num_class(self):
"""Number of categories."""
return self.NUM_CLASS
@property
def pred_offset(self):
return 0
class VOCYJSSegmentation(SegmentationDataset):
"""Pascal VOC Semantic Segmentation Dataset.
Parameters
----------
root : string
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
>>> ])
>>> # Create Dataset
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
NUM_CLASS = 13
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
super(VOCYJSSegmentation, self).__init__(
root, split, mode, transform, **kwargs)
_voc_root = root
txt_path = os.path.join(root, split+'.txt')
self._mask_LS_dir = os.path.join(_voc_root, "masks_LS_png")
self._image_LS_dir = os.path.join(_voc_root, "images_LS_jpg")
self.image_list = read_text(txt_path)
random.shuffle(self.image_list)
def __getitem__(self, index):
img_name = self.image_list[index].split('.')[0]+'.jpg'
mask_name = self.image_list[index].split('.')[0]+'.png'
img_LS = np.array(Image.open(os.path.join(
self._image_LS_dir, img_name))).astype(np.float32)
mask = np.array(Image.open(os.path.join(
self._mask_LS_dir, mask_name))).astype(np.int32)
mask = torch.from_numpy(mask).long()
# general resize, normalize and toTensor
if self.transform is not None:
img_LS = self.transform(img_LS)
return img_LS, mask
def __len__(self):
return len(self.image_list)
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def classes(self):
"""Category names."""
return ('0', '1', '2', '3', '4', '5', '6', '7' '8', '9', '10', '11', '12')
def generator_list_of_imagepath(path):
image_list = []
for image in os.listdir(path):
# print(path)
# print(image)
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
image_list.append(image)
return image_list
def read_text(textfile):
list = []
with open(textfile, "r") as lines:
for line in lines:
list.append(line.rstrip('\n'))
return list
def dataset_segmentation(textpath, imagepath, train_percent):
image_list = generator_list_of_imagepath(imagepath)
num = len(image_list)
list = range(num)
train_num = int(num * train_percent) # training set num
train_list = random.sample(list, train_num)
print("train set size", train_num)
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
fval = open(os.path.join(textpath, 'val.txt'), 'w')
for i in list:
name = image_list[i] + '\n'
if i in train_list:
ftrain.write(name)
else:
fval.write(name)
ftrain.close()
fval.close()
if __name__ == '__main__':
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
# list=generator_list_of_imagepath(path)
# print(list)
# 切割数据集
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
train_percent = 0.8
dataset_segmentation(textpath, imagepath, train_percent)
# 显示各种图片
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
# cv2.imshow('img', img)
# img = Image.fromarray (img,'RGB')
# img.show()
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
# img2 = Image.fromarray (img2,'RGB')
# img2.show()
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
# img3 = gdal.Open(img3).ReadAsArray()
# img3 = Image.fromarray (img3)
# img3.show()
# dataset和dataloader的测试
# 测试dataloader能不能用
'''
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
input_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
'''

152
train_LS_jpg/train.py Normal file
View File

@ -0,0 +1,152 @@
import os
import time
from datetime import datetime
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
from torchvision import transforms
import torch.utils.data as data
from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
import utils
import warnings
warnings.filterwarnings("ignore")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
parser.add_argument("--num-classes", default=13, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=50, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-out-dir', type=str,
default='DeeplabV3_LS_3')
args = parser.parse_args()
return args
class DeeplabV3_LS_3(nn.Module):
def __init__(self, n_class):
super(DeeplabV3_LS_3, self).__init__()
self.n_class = n_class
self.conv_fc = nn.Conv2d(
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
self.seg = deeplabv3_resnet50(weights='DEFAULT')
def forward(self, x):
x = self.seg(x)["out"]
x = self.conv_fc(x)
return x
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406],
[.229, .224, .225]),
])
data_kwargs = {'transform': input_transform,
'base_size': 224, 'crop_size': 224}
# 读取geotiff数据构建训练集、验证集
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
**data_kwargs)
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
**data_kwargs)
num_workers = min(
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
model = DeeplabV3_LS_3(n_class=args.num_classes)
model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = utils.create_lr_scheduler(
optimizer, len(train_loader), args.epochs, warmup=True)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d__%H-%M__")
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
if not out_dir.exists():
out_dir.mkdir()
f = open(os.path.join(out_dir, "log.txt"), 'w')
start_time = time.time()
best_acc = 0
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
model.train()
for idx, (image, target) in enumerate(train_loader):
image = image.to(device)
target = target.to(device)
output = model(image)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if idx % 100 == 0:
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
model.eval()
confmat = utils.ConfusionMatrix(args.num_classes)
with torch.no_grad():
for image, target in val_loader:
image, target = image.to(device), target.to(device)
output = model(image)
confmat.update(target.flatten(), output.argmax(1).flatten())
info, mIoU = confmat.get_info()
print(info)
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
f.write(info+"\n\n")
f.flush()
# # 保存准确率最好的模型
# if mIoU > best_acc:
# print("[Save model]")
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
# best_acc = mIoU
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
total_time = time.time() - start_time
print("total time:", total_time)
torch.save(model, os.path.join(out_dir, "last.pth"))
if __name__ == '__main__':
args = parse_args()
main(args)

330
train_LS_jpg/utils.py Normal file
View File

@ -0,0 +1,330 @@
import datetime
import errno
import os
import time
from collections import defaultdict, deque
import torch
import torch.distributed as dist
class SmoothedValue:
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = deque(maxlen=window_size)
self.total = 0.0
self.count = 0
self.fmt = fmt
def update(self, value, n=1):
self.deque.append(value)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""
Warning: does not synchronize the deque!
"""
t = reduce_across_processes([self.count, self.total])
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = torch.tensor(list(self.deque))
return d.median().item()
@property
def avg(self):
d = torch.tensor(list(self.deque), dtype=torch.float32)
return d.mean().item()
@property
def global_avg(self):
return self.total / self.count
@property
def max(self):
return max(self.deque)
@property
def value(self):
return self.deque[-1]
def __str__(self):
return self.fmt.format(
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
)
class ConfusionMatrix:
def __init__(self, num_classes):
self.num_classes = num_classes
self.mat = None
def update(self, a, b):
n = self.num_classes
if self.mat is None:
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
with torch.inference_mode():
k = (a >= 0) & (a < n)
inds = n * a[k].to(torch.int64) + b[k]
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
def reset(self):
self.mat.zero_()
def compute(self):
h = self.mat.float()
acc_global = torch.diag(h).sum() / h.sum()
acc = torch.diag(h) / h.sum(1)
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
return acc_global, acc, iu
def reduce_from_all_processes(self):
reduce_across_processes(self.mat)
def get_info(self):
acc_global, acc, iu = self.compute()
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
acc_global.item() * 100,
[f"{i:.1f}" for i in (acc * 100).tolist()],
[f"{i:.1f}" for i in (iu * 100).tolist()],
iu.mean().item() * 100,
), iu.mean().item() * 100
class MetricLogger:
def __init__(self, delimiter="\t"):
self.meters = defaultdict(SmoothedValue)
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if not isinstance(v, (float, int)):
raise TypeError(
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
)
self.meters[k].update(v)
def __getattr__(self, attr):
if attr in self.meters:
return self.meters[attr]
if attr in self.__dict__:
return self.__dict__[attr]
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{attr}'")
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {str(meter)}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def add_meter(self, name, meter):
self.meters[name] = meter
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ""
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt="{avg:.4f}")
data_time = SmoothedValue(fmt="{avg:.4f}")
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
if torch.cuda.is_available():
log_msg = self.delimiter.join(
[
header,
"[{0" + space_fmt + "}/{1}]",
"eta: {eta}",
"{meters}",
"time: {time}",
"data: {data}",
"max mem: {memory:.0f}",
]
)
else:
log_msg = self.delimiter.join(
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
"{meters}", "time: {time}", "data: {data}"]
)
MB = 1024.0 * 1024.0
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available():
print(
log_msg.format(
i,
len(iterable),
eta=eta_string,
meters=str(self),
time=str(iter_time),
data=str(data_time),
memory=torch.cuda.max_memory_allocated() / MB,
)
)
else:
print(
log_msg.format(
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
)
)
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"{header} Total time: {total_time_str}")
def cat_list(images, fill_value=0):
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
batch_shape = (len(images),) + max_size
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
for img, pad_img in zip(images, batched_imgs):
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
return batched_imgs
def collate_fn(batch):
images, targets = list(zip(*batch))
batched_imgs = cat_list(images, fill_value=0)
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def mkdir(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop("force", False)
if is_master or force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_master(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ["WORLD_SIZE"])
args.gpu = int(os.environ["LOCAL_RANK"])
# elif "SLURM_PROCID" in os.environ:
# args.rank = int(os.environ["SLURM_PROCID"])
# args.gpu = args.rank % torch.cuda.device_count()
elif hasattr(args, "rank"):
pass
else:
print("Not using distributed mode")
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = "nccl"
print(
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
torch.distributed.init_process_group(
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
)
torch.distributed.barrier()
setup_for_distributed(args.rank == 0)
def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
return torch.tensor(val)
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
def create_lr_scheduler(optimizer,
num_step: int,
epochs: int,
warmup=True,
warmup_epochs=1,
warmup_factor=1e-3):
assert num_step > 0 and epochs > 0
if warmup is False:
warmup_epochs = 0
def f(x):
"""
根据step数返回一个学习率倍率因子
注意在训练开始之前pytorch会提前调用一次lr_scheduler.step()方法
"""
if warmup is True and x <= (warmup_epochs * num_step):
alpha = float(x) / (warmup_epochs * num_step)
# warmup过程中lr倍率因子从warmup_factor -> 1
return warmup_factor * (1 - alpha) + alpha
else:
# warmup后lr倍率因子从1 -> 0
# 参考deeplab_v2: Learning rate policy
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)