加入jpg图像训练代码
This commit is contained in:
parent
2a40d11a2a
commit
704c24d79d
1
.gitignore
vendored
1
.gitignore
vendored
@ -160,3 +160,4 @@ cython_debug/
|
|||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
/train_output
|
138
data_preprocessing/cut_smalltif.bak
Normal file
138
data_preprocessing/cut_smalltif.bak
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import numpy as np
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
|
||||||
|
def read_tif(fileName):
|
||||||
|
dataset = gdal.Open(fileName)
|
||||||
|
|
||||||
|
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
im_bands = dataset.RasterCount # 波段数
|
||||||
|
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||||
|
if len(im_data.shape) == 2:
|
||||||
|
im_data = im_data[np.newaxis, :]
|
||||||
|
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||||
|
im_proj = dataset.GetProjection() # 获取投影信息
|
||||||
|
|
||||||
|
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||||
|
|
||||||
|
|
||||||
|
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||||
|
if 'int8' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_Byte
|
||||||
|
elif 'int16' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_UInt16
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
# 创建文件
|
||||||
|
driver = gdal.GetDriverByName("GTiff")
|
||||||
|
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||||
|
if dataset != None and im_geotrans != None and im_proj != None:
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
for i in range(im_bands):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
|
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
||||||
|
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||||
|
fileName)
|
||||||
|
|
||||||
|
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||||
|
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||||
|
fileName2)
|
||||||
|
|
||||||
|
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||||
|
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||||
|
mask_fileName)
|
||||||
|
mask_im_data = np.int8(mask_im_data)
|
||||||
|
|
||||||
|
|
||||||
|
# geotiff归一化
|
||||||
|
for i in range(im_bands):
|
||||||
|
arr = im_data[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data[i] = normalized_arr
|
||||||
|
|
||||||
|
for i in range(im_bands2):
|
||||||
|
arr = im_data2[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data2[i] = normalized_arr
|
||||||
|
|
||||||
|
|
||||||
|
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||||
|
im_data = im_data/255
|
||||||
|
for i in range(im_bands):
|
||||||
|
pixels = im_data[i, :, :].ravel()
|
||||||
|
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||||
|
i, np.mean(pixels), np.std(pixels)))
|
||||||
|
im_data = im_data*255
|
||||||
|
|
||||||
|
im_data2 = im_data2/255
|
||||||
|
for i in range(im_bands2):
|
||||||
|
pixels = im_data2[i, :, :].ravel()
|
||||||
|
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||||
|
i, np.mean(pixels), np.std(pixels)))
|
||||||
|
im_data2 = im_data2*255
|
||||||
|
|
||||||
|
|
||||||
|
# 切成小图
|
||||||
|
a = 0
|
||||||
|
size = 224
|
||||||
|
for i in range(0, int(mask_im_height / size)):
|
||||||
|
for j in range(0, int(mask_im_width / size)):
|
||||||
|
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||||
|
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
||||||
|
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
||||||
|
|
||||||
|
# 以mask为判断基准,同时处理geotiff和mask
|
||||||
|
labelfla = np.array(mask_cut).flatten()
|
||||||
|
if np.all(labelfla == 15): # 15为NoData
|
||||||
|
print("Skip!!!")
|
||||||
|
else:
|
||||||
|
# 5m
|
||||||
|
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
||||||
|
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
||||||
|
new_geotrans = np.array(im_geotrans)
|
||||||
|
new_geotrans[0] = left_w
|
||||||
|
new_geotrans[3] = left_h
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
||||||
|
|
||||||
|
# dan
|
||||||
|
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
||||||
|
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
||||||
|
new_geotrans = np.array(im_geotrans2)
|
||||||
|
new_geotrans[0] = left_w
|
||||||
|
new_geotrans[3] = left_h
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
||||||
|
|
||||||
|
# 存mask
|
||||||
|
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||||
|
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||||
|
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||||
|
mask_new_geotrans[0] = mask_left_w
|
||||||
|
mask_new_geotrans[3] = mask_left_h
|
||||||
|
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||||
|
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(mask_cut, size, size, mask_out,
|
||||||
|
mask_out_geotrans, mask_im_proj)
|
||||||
|
|
||||||
|
print(mask_out + 'Cut to complete')
|
||||||
|
|
||||||
|
a = a+1
|
@ -44,9 +44,6 @@ fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
|||||||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||||
fileName)
|
fileName)
|
||||||
|
|
||||||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
|
||||||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
|
||||||
fileName2)
|
|
||||||
|
|
||||||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||||
@ -55,19 +52,14 @@ mask_im_data = np.int8(mask_im_data)
|
|||||||
|
|
||||||
|
|
||||||
# geotiff归一化
|
# geotiff归一化
|
||||||
|
lower_percent = 2
|
||||||
|
upper_percent = 98
|
||||||
for i in range(im_bands):
|
for i in range(im_bands):
|
||||||
arr = im_data[i, :, :]
|
arr = im_data[i, :, :]
|
||||||
Min = arr.min()
|
lower = np.percentile(arr, lower_percent)
|
||||||
Max = arr.max()
|
upper = np.percentile(arr, upper_percent)
|
||||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
stretched = np.clip((arr - lower) / (upper - lower), 0, 1)
|
||||||
im_data[i] = normalized_arr
|
im_data[i] = (stretched * 255).astype(np.uint8)
|
||||||
|
|
||||||
for i in range(im_bands2):
|
|
||||||
arr = im_data2[i, :, :]
|
|
||||||
Min = arr.min()
|
|
||||||
Max = arr.max()
|
|
||||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
|
||||||
im_data2[i] = normalized_arr
|
|
||||||
|
|
||||||
|
|
||||||
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||||
@ -78,61 +70,42 @@ for i in range(im_bands):
|
|||||||
i, np.mean(pixels), np.std(pixels)))
|
i, np.mean(pixels), np.std(pixels)))
|
||||||
im_data = im_data*255
|
im_data = im_data*255
|
||||||
|
|
||||||
im_data2 = im_data2/255
|
|
||||||
for i in range(im_bands2):
|
|
||||||
pixels = im_data2[i, :, :].ravel()
|
|
||||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
|
||||||
i, np.mean(pixels), np.std(pixels)))
|
|
||||||
im_data2 = im_data2*255
|
|
||||||
|
|
||||||
|
|
||||||
# 切成小图
|
# 切成小图
|
||||||
a = 0
|
a = 0
|
||||||
size = 224
|
size = 448
|
||||||
for i in range(0, int(mask_im_height / size)):
|
for i in range(0, int(mask_im_height / size)):
|
||||||
for j in range(0, int(mask_im_width / size)):
|
for j in range(0, int(mask_im_width / size)):
|
||||||
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||||
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
mask_cut = mask_im_data[:, i * size:i *
|
||||||
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
size + size, j * size:j * size + size]
|
||||||
|
|
||||||
# 以mask为判断基准,同时处理geotiff和mask
|
|
||||||
labelfla = np.array(mask_cut).flatten()
|
labelfla = np.array(mask_cut).flatten()
|
||||||
if np.all(labelfla == 15): # 15为NoData
|
if np.all(labelfla == 15): # 15为NoData
|
||||||
print("Skip!!!")
|
print("Skip!!!")
|
||||||
else:
|
else:
|
||||||
# 5m
|
# 取5、4、3波段,注意顺序
|
||||||
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
rgb_cut = np.stack([
|
||||||
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
im_cut[4], # 第5波段
|
||||||
new_geotrans = np.array(im_geotrans)
|
im_cut[3], # 第4波段
|
||||||
new_geotrans[0] = left_w
|
im_cut[2], # 第3波段
|
||||||
new_geotrans[3] = left_h
|
], axis=0) # shape: (3, H, W)
|
||||||
out_geotrans = tuple(new_geotrans)
|
# 转为 (H, W, 3)
|
||||||
|
rgb_cut = np.transpose(rgb_cut, (1, 2, 0))
|
||||||
|
# 归一化到0-255并转uint8
|
||||||
|
rgb_cut = np.clip(rgb_cut, 0, 255)
|
||||||
|
rgb_cut = rgb_cut.astype(np.uint8)
|
||||||
|
# 保存为jpg
|
||||||
|
from PIL import Image
|
||||||
|
rgb_img = Image.fromarray(rgb_cut)
|
||||||
|
rgb_img.save(
|
||||||
|
f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg')
|
||||||
|
|
||||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
# mask只取第一个波段(如果是单通道),转uint8保存为png
|
||||||
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut
|
||||||
|
mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8)
|
||||||
# dan
|
mask_img = Image.fromarray(mask_arr)
|
||||||
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png')
|
||||||
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
|
||||||
new_geotrans = np.array(im_geotrans2)
|
|
||||||
new_geotrans[0] = left_w
|
|
||||||
new_geotrans[3] = left_h
|
|
||||||
out_geotrans = tuple(new_geotrans)
|
|
||||||
|
|
||||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
|
||||||
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
|
||||||
|
|
||||||
# 存mask
|
|
||||||
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
|
||||||
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
|
||||||
mask_new_geotrans = np.array(mask_im_geotrans)
|
|
||||||
mask_new_geotrans[0] = mask_left_w
|
|
||||||
mask_new_geotrans[3] = mask_left_h
|
|
||||||
mask_out_geotrans = tuple(mask_new_geotrans)
|
|
||||||
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
|
||||||
write_tif(mask_cut, size, size, mask_out,
|
|
||||||
mask_out_geotrans, mask_im_proj)
|
|
||||||
|
|
||||||
print(mask_out + 'Cut to complete')
|
|
||||||
|
|
||||||
|
print(f'img_{a}.jpg and mask_{a}.png saved')
|
||||||
a = a+1
|
a = a+1
|
||||||
|
@ -4,18 +4,18 @@ import random
|
|||||||
|
|
||||||
random.seed(42)
|
random.seed(42)
|
||||||
|
|
||||||
geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m')
|
geotiffs = os.listdir(r'E:\RSdata\wlk_right_448\dataset_5m_jpg')
|
||||||
num = len(geotiffs)
|
num = len(geotiffs)
|
||||||
split_rate = 0.2
|
split_rate = 0.2
|
||||||
|
|
||||||
eval_index = random.sample(geotiffs, k=int(num*split_rate))
|
eval_index = random.sample(geotiffs, k=int(num*split_rate))
|
||||||
|
|
||||||
f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w')
|
f_train = open(r'E:\RSdata\wlk_right_448/train.txt', 'w')
|
||||||
f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w')
|
f_val = open(r'E:\RSdata\wlk_right_448/val.txt', 'w')
|
||||||
|
|
||||||
# 写入文件
|
# 写入文件
|
||||||
for geotiff in geotiffs:
|
for geotiff in geotiffs:
|
||||||
if geotiff in eval_index:
|
if geotiff in eval_index:
|
||||||
f_val.write(str(geotiff)+'\n')
|
f_train.write(str(geotiff)+'\n')
|
||||||
else:
|
else:
|
||||||
f_train.write(str(geotiff)+'\n')
|
f_val.write(str(geotiff)+'\n')
|
||||||
|
62
datasets_pro_code/datasets_pro.py
Normal file
62
datasets_pro_code/datasets_pro.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# 定义文件夹路径
|
||||||
|
base_dir = r"e:\datasets\wlk_right_448"
|
||||||
|
jpeg_images_dir = os.path.join(base_dir, "JPEGImages")
|
||||||
|
segmentation_class_dir = os.path.join(base_dir, "SegmentationClass")
|
||||||
|
annotations_dir = os.path.join(base_dir, "stare_stuct", "annotations")
|
||||||
|
images_dir = os.path.join(base_dir, "stare_stuct", "images")
|
||||||
|
|
||||||
|
# 定义目标文件夹
|
||||||
|
annotations_training_dir = os.path.join(annotations_dir, "training")
|
||||||
|
annotations_validation_dir = os.path.join(annotations_dir, "validation")
|
||||||
|
images_training_dir = os.path.join(images_dir, "training")
|
||||||
|
images_validation_dir = os.path.join(images_dir, "validation")
|
||||||
|
|
||||||
|
# 创建目标文件夹
|
||||||
|
os.makedirs(annotations_training_dir, exist_ok=True)
|
||||||
|
os.makedirs(annotations_validation_dir, exist_ok=True)
|
||||||
|
os.makedirs(images_training_dir, exist_ok=True)
|
||||||
|
os.makedirs(images_validation_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 读取 train.txt 和 val.txt
|
||||||
|
train_file = os.path.join(base_dir, "train.txt")
|
||||||
|
val_file = os.path.join(base_dir, "val.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_list(file_path):
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
return [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
|
|
||||||
|
train_list = read_file_list(train_file)
|
||||||
|
val_list = read_file_list(val_file)
|
||||||
|
|
||||||
|
# 移动文件函数
|
||||||
|
|
||||||
|
|
||||||
|
def move_files(file_list, src_images_dir, src_labels_dir, dst_images_dir, dst_labels_dir):
|
||||||
|
for file_name in file_list:
|
||||||
|
# 图片文件
|
||||||
|
image_src = os.path.join(src_images_dir, file_name)
|
||||||
|
image_dst = os.path.join(dst_images_dir, file_name)
|
||||||
|
if os.path.exists(image_src):
|
||||||
|
shutil.copy(image_src, image_dst)
|
||||||
|
|
||||||
|
# 标签文件
|
||||||
|
label_src = os.path.join(src_labels_dir, file_name)
|
||||||
|
label_dst = os.path.join(dst_labels_dir, file_name)
|
||||||
|
if os.path.exists(label_src):
|
||||||
|
shutil.copy(label_src, label_dst)
|
||||||
|
|
||||||
|
|
||||||
|
# 移动训练集文件
|
||||||
|
move_files(train_list, jpeg_images_dir, segmentation_class_dir,
|
||||||
|
images_training_dir, annotations_training_dir)
|
||||||
|
|
||||||
|
# 移动验证集文件
|
||||||
|
move_files(val_list, jpeg_images_dir, segmentation_class_dir,
|
||||||
|
images_validation_dir, annotations_validation_dir)
|
||||||
|
|
||||||
|
print("文件组织完成!")
|
22
datasets_pro_code/read_mask.py
Normal file
22
datasets_pro_code/read_mask.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from PIL import Image
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
|
||||||
|
mask_dir = r"E:\datasets\wlk_right_448\mask" # 修改为你的mask文件夹路径
|
||||||
|
|
||||||
|
all_labels = set()
|
||||||
|
|
||||||
|
for file in os.listdir(mask_dir):
|
||||||
|
if file.lower().endswith('.tif'):
|
||||||
|
tif_path = os.path.join(mask_dir, file)
|
||||||
|
dataset = gdal.Open(tif_path)
|
||||||
|
if dataset is None:
|
||||||
|
print(f"无法打开: {tif_path}")
|
||||||
|
continue
|
||||||
|
band = dataset.ReadAsArray()
|
||||||
|
unique = np.unique(band)
|
||||||
|
all_labels.update(unique)
|
||||||
|
|
||||||
|
print("所有mask中出现过的标签数字:", sorted(all_labels))
|
6
datasets_pro_code/remove_tif.py
Normal file
6
datasets_pro_code/remove_tif.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
input_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val.txt"
|
||||||
|
output_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val_no.txt"
|
||||||
|
|
||||||
|
with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
|
||||||
|
for line in fin:
|
||||||
|
fout.write(line.strip().replace(".tif", "") + "\n")
|
30
datasets_pro_code/tif_mask_to_png.py
Normal file
30
datasets_pro_code/tif_mask_to_png.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS"
|
||||||
|
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS_png"
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for file in os.listdir(src_dir):
|
||||||
|
if file.lower().endswith('.tif'):
|
||||||
|
tif_path = os.path.join(src_dir, file)
|
||||||
|
dataset = gdal.Open(tif_path)
|
||||||
|
if dataset is None:
|
||||||
|
print(f"无法打开: {tif_path}")
|
||||||
|
continue
|
||||||
|
mask = dataset.ReadAsArray()
|
||||||
|
if mask.ndim != 2:
|
||||||
|
print(f"{file} 不是单波段,跳过")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 替换像素值
|
||||||
|
mask = mask.copy()
|
||||||
|
mask[mask == 15] = 255
|
||||||
|
|
||||||
|
png_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".png")
|
||||||
|
Image.fromarray(mask.astype(np.uint8)).save(png_path)
|
||||||
|
print(f"已保存: {png_path}")
|
||||||
|
|
||||||
|
print("全部转换完成!")
|
33
datasets_pro_code/tif_to_jpg.py
Normal file
33
datasets_pro_code/tif_to_jpg.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
# 输入和输出文件夹
|
||||||
|
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS"
|
||||||
|
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS_jpg"
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
|
||||||
|
for file in os.listdir(src_dir):
|
||||||
|
if file.lower().endswith('.tif'):
|
||||||
|
tif_path = os.path.join(src_dir, file)
|
||||||
|
dataset = gdal.Open(tif_path)
|
||||||
|
if dataset is None:
|
||||||
|
print(f"无法打开: {tif_path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 读取所有波段
|
||||||
|
bands = []
|
||||||
|
for i in range(1, dataset.RasterCount + 1):
|
||||||
|
band = dataset.GetRasterBand(i).ReadAsArray()
|
||||||
|
bands.append(band)
|
||||||
|
img = np.stack(bands, axis=-1) if len(bands) > 1 else bands[0]
|
||||||
|
|
||||||
|
# 转换为uint8
|
||||||
|
img = img.astype(np.uint8)
|
||||||
|
|
||||||
|
jpg_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".jpg")
|
||||||
|
Image.fromarray(img).save(jpg_path, quality=95)
|
||||||
|
print(f"已保存: {jpg_path}")
|
||||||
|
|
||||||
|
print("全部转换完成!")
|
274
train_JL/geotiff_utils.py
Normal file
274
train_JL/geotiff_utils.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# import gdal
|
||||||
|
from osgeo import gdal
|
||||||
|
import random
|
||||||
|
import torch.utils.data as data
|
||||||
|
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationDataset(object):
|
||||||
|
"""Segmentation Base Dataset"""
|
||||||
|
|
||||||
|
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||||
|
super(SegmentationDataset, self).__init__()
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.split = split
|
||||||
|
self.mode = mode if mode is not None else split
|
||||||
|
self.base_size = base_size
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
def _val_sync_transform(self, img, mask):
|
||||||
|
outsize = self.crop_size
|
||||||
|
short_size = outsize
|
||||||
|
w, h = img.size
|
||||||
|
if w > h:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
else:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# center crop
|
||||||
|
w, h = img.size
|
||||||
|
x1 = int(round((w - outsize) / 2.))
|
||||||
|
y1 = int(round((h - outsize) / 2.))
|
||||||
|
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
crop_size = self.crop_size
|
||||||
|
# random scale (short edge)
|
||||||
|
short_size = random.randint(
|
||||||
|
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||||
|
w, h = img.size
|
||||||
|
if h > w:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
else:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# pad crop
|
||||||
|
if short_size < crop_size:
|
||||||
|
padh = crop_size - oh if oh < crop_size else 0
|
||||||
|
padw = crop_size - ow if ow < crop_size else 0
|
||||||
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||||
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||||
|
# random crop crop_size
|
||||||
|
w, h = img.size
|
||||||
|
x1 = random.randint(0, w - crop_size)
|
||||||
|
y1 = random.randint(0, h - crop_size)
|
||||||
|
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
# gaussian blur as in PSP
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif_geofeat(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _val_sync_transform_tif(self, img, mask):
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _img_transform(self, img):
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
# def _mask_transform(self, mask):
|
||||||
|
# return np.array(mask).astype('int32')
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_class(self):
|
||||||
|
"""Number of categories."""
|
||||||
|
return self.NUM_CLASS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_offset(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class VOCYJSSegmentation(SegmentationDataset):
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
root : string
|
||||||
|
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||||
|
split: string
|
||||||
|
'train', 'val' or 'test'
|
||||||
|
transform : callable, optional
|
||||||
|
A function that transforms the image
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from torchvision import transforms
|
||||||
|
>>> import torch.utils.data as data
|
||||||
|
>>> # Transforms for Normalization
|
||||||
|
>>> input_transform = transforms.Compose([
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||||
|
>>> ])
|
||||||
|
>>> # Create Dataset
|
||||||
|
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||||
|
>>> # Create Training Loader
|
||||||
|
>>> train_data = data.DataLoader(
|
||||||
|
>>> trainset, 4, shuffle=True,
|
||||||
|
>>> num_workers=4)
|
||||||
|
"""
|
||||||
|
NUM_CLASS = 13
|
||||||
|
|
||||||
|
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||||
|
super(VOCYJSSegmentation, self).__init__(
|
||||||
|
root, split, mode, transform, **kwargs)
|
||||||
|
_voc_root = root
|
||||||
|
txt_path = os.path.join(root, split+'.txt')
|
||||||
|
self._mask_LS_dir = os.path.join(_voc_root, 'mask')
|
||||||
|
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m")
|
||||||
|
self.image_list = read_text(txt_path)
|
||||||
|
random.shuffle(self.image_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray(
|
||||||
|
).transpose(1, 2, 0).astype(np.float32)
|
||||||
|
mask = gdal.Open(os.path.join(self._mask_LS_dir,
|
||||||
|
self.image_list[index])).ReadAsArray()
|
||||||
|
# synchronized transform
|
||||||
|
# 只包含两种模式: train 和 val
|
||||||
|
if self.mode == 'train':
|
||||||
|
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask)
|
||||||
|
elif self.mode == 'val':
|
||||||
|
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask)
|
||||||
|
# general resize, normalize and toTensor
|
||||||
|
if self.transform is not None:
|
||||||
|
img_LS = self.transform(img_LS)
|
||||||
|
return img_LS, mask
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_list)
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes(self):
|
||||||
|
"""Category names."""
|
||||||
|
return ('0', '1', '2', '3', '4', '5', '6')
|
||||||
|
|
||||||
|
|
||||||
|
def generator_list_of_imagepath(path):
|
||||||
|
image_list = []
|
||||||
|
for image in os.listdir(path):
|
||||||
|
# print(path)
|
||||||
|
# print(image)
|
||||||
|
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||||
|
image_list.append(image)
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def read_text(textfile):
|
||||||
|
list = []
|
||||||
|
with open(textfile, "r") as lines:
|
||||||
|
for line in lines:
|
||||||
|
list.append(line.rstrip('\n'))
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||||
|
image_list = generator_list_of_imagepath(imagepath)
|
||||||
|
num = len(image_list)
|
||||||
|
list = range(num)
|
||||||
|
train_num = int(num * train_percent) # training set num
|
||||||
|
train_list = random.sample(list, train_num)
|
||||||
|
print("train set size", train_num)
|
||||||
|
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||||
|
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||||
|
for i in list:
|
||||||
|
name = image_list[i] + '\n'
|
||||||
|
if i in train_list:
|
||||||
|
ftrain.write(name)
|
||||||
|
else:
|
||||||
|
fval.write(name)
|
||||||
|
ftrain.close()
|
||||||
|
fval.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||||
|
# list=generator_list_of_imagepath(path)
|
||||||
|
# print(list)
|
||||||
|
# 切割数据集
|
||||||
|
|
||||||
|
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||||
|
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||||
|
train_percent = 0.8
|
||||||
|
dataset_segmentation(textpath, imagepath, train_percent)
|
||||||
|
# 显示各种图片
|
||||||
|
|
||||||
|
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||||
|
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||||
|
# cv2.imshow('img', img)
|
||||||
|
# img = Image.fromarray (img,'RGB')
|
||||||
|
# img.show()
|
||||||
|
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||||
|
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||||
|
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||||
|
# img2 = Image.fromarray (img2,'RGB')
|
||||||
|
# img2.show()
|
||||||
|
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||||
|
# img3 = gdal.Open(img3).ReadAsArray()
|
||||||
|
# img3 = Image.fromarray (img3)
|
||||||
|
# img3.show()
|
||||||
|
|
||||||
|
# dataset和dataloader的测试
|
||||||
|
|
||||||
|
# 测试dataloader能不能用
|
||||||
|
'''
|
||||||
|
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||||
|
input_transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||||
|
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||||
|
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||||
|
'''
|
156
train_JL/train_JL.py
Normal file
156
train_JL/train_JL.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch.utils.data as data
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from geotiff_utils import VOCYJSSegmentation
|
||||||
|
import utils as utils
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-path", default=r"E:\datasets\wlk_right_448", help="VOCdevkit root")
|
||||||
|
parser.add_argument("--num-classes", default=7, type=int)
|
||||||
|
parser.add_argument("--device", default="cuda", help="training device")
|
||||||
|
parser.add_argument("--batch-size", default=4, type=int)
|
||||||
|
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||||
|
help="number of total epochs to train")
|
||||||
|
parser.add_argument('--lr', default=0.005, type=float,
|
||||||
|
help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
parser.add_argument('-out-dir', type=str,
|
||||||
|
default='DeeplabV3_JL')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class DeeplabV3_JL(nn.Module):
|
||||||
|
def __init__(self, n_class):
|
||||||
|
super(DeeplabV3_JL, self).__init__()
|
||||||
|
self.n_class = n_class
|
||||||
|
self.conv6_3 = nn.Conv2d(6, 3, kernel_size=1, stride=1)
|
||||||
|
|
||||||
|
self.conv_fc = nn.Conv2d(
|
||||||
|
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||||
|
|
||||||
|
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# x = torch.cat([x, x_dan], dim=1)
|
||||||
|
x = self.conv6_3(x)
|
||||||
|
x = self.seg(x)["out"]
|
||||||
|
x = self.conv_fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
input_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([.535, .767, .732, .561, .494, .564],
|
||||||
|
[.0132, .0188, .0181, .0173, .0183, .0259]),
|
||||||
|
])
|
||||||
|
data_kwargs = {'transform': input_transform,
|
||||||
|
'base_size': 448, 'crop_size': 448}
|
||||||
|
|
||||||
|
# 读取geotiff数据,构建训练集、验证集
|
||||||
|
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||||
|
**data_kwargs)
|
||||||
|
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||||
|
**data_kwargs)
|
||||||
|
num_workers = min(
|
||||||
|
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||||
|
train_loader = data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
val_loader = data.DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
|
||||||
|
model = DeeplabV3_JL(n_class=args.num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = utils.create_lr_scheduler(
|
||||||
|
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||||
|
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||||
|
if not out_dir.exists():
|
||||||
|
out_dir.mkdir()
|
||||||
|
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||||
|
start_time = time.time()
|
||||||
|
best_acc = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||||
|
model.train()
|
||||||
|
for idx, (image, target) in enumerate(train_loader):
|
||||||
|
image, target = image.to(
|
||||||
|
device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if idx % 100 == 0:
|
||||||
|
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||||
|
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, target in val_loader:
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
info, mIoU = confmat.get_info()
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||||
|
f.write(info+"\n\n")
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
# # 保存准确率最好的模型
|
||||||
|
# if mIoU > best_acc:
|
||||||
|
# print("[Save model]")
|
||||||
|
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||||
|
# best_acc = mIoU
|
||||||
|
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print("total time:", total_time)
|
||||||
|
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
main(args)
|
279
train_JL_jpg/geotiff_utils.py
Normal file
279
train_JL_jpg/geotiff_utils.py
Normal file
@ -0,0 +1,279 @@
|
|||||||
|
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# import gdal
|
||||||
|
from osgeo import gdal
|
||||||
|
import random
|
||||||
|
import torch.utils.data as data
|
||||||
|
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationDataset(object):
|
||||||
|
"""Segmentation Base Dataset"""
|
||||||
|
|
||||||
|
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||||
|
super(SegmentationDataset, self).__init__()
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.split = split
|
||||||
|
self.mode = mode if mode is not None else split
|
||||||
|
self.base_size = base_size
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
def _val_sync_transform(self, img, mask):
|
||||||
|
outsize = self.crop_size
|
||||||
|
short_size = outsize
|
||||||
|
w, h = img.size
|
||||||
|
if w > h:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
else:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# center crop
|
||||||
|
w, h = img.size
|
||||||
|
x1 = int(round((w - outsize) / 2.))
|
||||||
|
y1 = int(round((h - outsize) / 2.))
|
||||||
|
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
crop_size = self.crop_size
|
||||||
|
# random scale (short edge)
|
||||||
|
short_size = random.randint(
|
||||||
|
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||||
|
w, h = img.size
|
||||||
|
if h > w:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
else:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# pad crop
|
||||||
|
if short_size < crop_size:
|
||||||
|
padh = crop_size - oh if oh < crop_size else 0
|
||||||
|
padw = crop_size - ow if ow < crop_size else 0
|
||||||
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||||
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||||
|
# random crop crop_size
|
||||||
|
w, h = img.size
|
||||||
|
x1 = random.randint(0, w - crop_size)
|
||||||
|
y1 = random.randint(0, h - crop_size)
|
||||||
|
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
# gaussian blur as in PSP
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif_geofeat(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _val_sync_transform_tif(self, img, mask):
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _img_transform(self, img):
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
# def _mask_transform(self, mask):
|
||||||
|
# return np.array(mask).astype('int32')
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_class(self):
|
||||||
|
"""Number of categories."""
|
||||||
|
return self.NUM_CLASS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_offset(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class VOCYJSSegmentation(SegmentationDataset):
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
root : string
|
||||||
|
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||||
|
split: string
|
||||||
|
'train', 'val' or 'test'
|
||||||
|
transform : callable, optional
|
||||||
|
A function that transforms the image
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from torchvision import transforms
|
||||||
|
>>> import torch.utils.data as data
|
||||||
|
>>> # Transforms for Normalization
|
||||||
|
>>> input_transform = transforms.Compose([
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||||
|
>>> ])
|
||||||
|
>>> # Create Dataset
|
||||||
|
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||||
|
>>> # Create Training Loader
|
||||||
|
>>> train_data = data.DataLoader(
|
||||||
|
>>> trainset, 4, shuffle=True,
|
||||||
|
>>> num_workers=4)
|
||||||
|
"""
|
||||||
|
NUM_CLASS = 13
|
||||||
|
|
||||||
|
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||||
|
super(VOCYJSSegmentation, self).__init__(
|
||||||
|
root, split, mode, transform, **kwargs)
|
||||||
|
_voc_root = root
|
||||||
|
txt_path = os.path.join(root, split+'.txt')
|
||||||
|
self._mask_LS_dir = os.path.join(_voc_root, 'mask_png')
|
||||||
|
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m_jpg")
|
||||||
|
self.image_list = read_text(txt_path)
|
||||||
|
random.shuffle(self.image_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img_name = self.image_list[index].split('.')[0]+'.jpg'
|
||||||
|
mask_name = self.image_list[index].split('.')[0]+'.png'
|
||||||
|
mask_name = mask_name.replace('img', 'mask')
|
||||||
|
img_LS = np.array(Image.open(os.path.join(
|
||||||
|
self._image_LS_dir, img_name))).astype(np.float32)
|
||||||
|
mask = np.array(Image.open(os.path.join(
|
||||||
|
self._mask_LS_dir, mask_name))).astype(np.int32)
|
||||||
|
mask = torch.from_numpy(mask).long()
|
||||||
|
|
||||||
|
# synchronized transform
|
||||||
|
# 只包含两种模式: train 和 val
|
||||||
|
if self.mode == 'train':
|
||||||
|
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask)
|
||||||
|
elif self.mode == 'val':
|
||||||
|
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask)
|
||||||
|
# general resize, normalize and toTensor
|
||||||
|
if self.transform is not None:
|
||||||
|
img_LS = self.transform(img_LS)
|
||||||
|
return img_LS, mask
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_list)
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes(self):
|
||||||
|
"""Category names."""
|
||||||
|
return ('0', '1', '2', '3', '4', '5', '6')
|
||||||
|
|
||||||
|
|
||||||
|
def generator_list_of_imagepath(path):
|
||||||
|
image_list = []
|
||||||
|
for image in os.listdir(path):
|
||||||
|
# print(path)
|
||||||
|
# print(image)
|
||||||
|
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||||
|
image_list.append(image)
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def read_text(textfile):
|
||||||
|
list = []
|
||||||
|
with open(textfile, "r") as lines:
|
||||||
|
for line in lines:
|
||||||
|
list.append(line.rstrip('\n'))
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||||
|
image_list = generator_list_of_imagepath(imagepath)
|
||||||
|
num = len(image_list)
|
||||||
|
list = range(num)
|
||||||
|
train_num = int(num * train_percent) # training set num
|
||||||
|
train_list = random.sample(list, train_num)
|
||||||
|
print("train set size", train_num)
|
||||||
|
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||||
|
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||||
|
for i in list:
|
||||||
|
name = image_list[i] + '\n'
|
||||||
|
if i in train_list:
|
||||||
|
ftrain.write(name)
|
||||||
|
else:
|
||||||
|
fval.write(name)
|
||||||
|
ftrain.close()
|
||||||
|
fval.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||||
|
# list=generator_list_of_imagepath(path)
|
||||||
|
# print(list)
|
||||||
|
# 切割数据集
|
||||||
|
|
||||||
|
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||||
|
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||||
|
train_percent = 0.8
|
||||||
|
dataset_segmentation(textpath, imagepath, train_percent)
|
||||||
|
# 显示各种图片
|
||||||
|
|
||||||
|
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||||
|
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||||
|
# cv2.imshow('img', img)
|
||||||
|
# img = Image.fromarray (img,'RGB')
|
||||||
|
# img.show()
|
||||||
|
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||||
|
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||||
|
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||||
|
# img2 = Image.fromarray (img2,'RGB')
|
||||||
|
# img2.show()
|
||||||
|
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||||
|
# img3 = gdal.Open(img3).ReadAsArray()
|
||||||
|
# img3 = Image.fromarray (img3)
|
||||||
|
# img3.show()
|
||||||
|
|
||||||
|
# dataset和dataloader的测试
|
||||||
|
|
||||||
|
# 测试dataloader能不能用
|
||||||
|
'''
|
||||||
|
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||||
|
input_transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||||
|
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||||
|
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||||
|
'''
|
153
train_JL_jpg/train_JL.py
Normal file
153
train_JL_jpg/train_JL.py
Normal file
@ -0,0 +1,153 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch.utils.data as data
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from geotiff_utils import VOCYJSSegmentation
|
||||||
|
import utils as utils
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-path", default=r"E:\RSdata\wlk_right_448", help="VOCdevkit root")
|
||||||
|
parser.add_argument("--num-classes", default=7, type=int)
|
||||||
|
parser.add_argument("--device", default="cuda", help="training device")
|
||||||
|
parser.add_argument("--batch-size", default=4, type=int)
|
||||||
|
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||||
|
help="number of total epochs to train")
|
||||||
|
parser.add_argument('--lr', default=0.005, type=float,
|
||||||
|
help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
parser.add_argument('-out-dir', type=str,
|
||||||
|
default='DeeplabV3_JL')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class DeeplabV3_JL_3(nn.Module):
|
||||||
|
def __init__(self, n_class):
|
||||||
|
super(DeeplabV3_JL_3, self).__init__()
|
||||||
|
self.n_class = n_class
|
||||||
|
|
||||||
|
self.conv_fc = nn.Conv2d(
|
||||||
|
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||||
|
|
||||||
|
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.seg(x)["out"]
|
||||||
|
x = self.conv_fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
input_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([.485, .456, .406],
|
||||||
|
[.229, .224, .225]),
|
||||||
|
])
|
||||||
|
data_kwargs = {'transform': input_transform,
|
||||||
|
'base_size': 448, 'crop_size': 448}
|
||||||
|
|
||||||
|
# 读取geotiff数据,构建训练集、验证集
|
||||||
|
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||||
|
**data_kwargs)
|
||||||
|
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||||
|
**data_kwargs)
|
||||||
|
num_workers = min(
|
||||||
|
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||||
|
train_loader = data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
val_loader = data.DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
|
||||||
|
model = DeeplabV3_JL_3(n_class=args.num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = utils.create_lr_scheduler(
|
||||||
|
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||||
|
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||||
|
if not out_dir.exists():
|
||||||
|
out_dir.mkdir()
|
||||||
|
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||||
|
start_time = time.time()
|
||||||
|
best_acc = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||||
|
model.train()
|
||||||
|
for idx, (image, target) in enumerate(train_loader):
|
||||||
|
image, target = image.to(
|
||||||
|
device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if idx % 100 == 0:
|
||||||
|
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||||
|
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, target in val_loader:
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
info, mIoU = confmat.get_info()
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||||
|
f.write(info+"\n\n")
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
# # 保存准确率最好的模型
|
||||||
|
# if mIoU > best_acc:
|
||||||
|
# print("[Save model]")
|
||||||
|
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||||
|
# best_acc = mIoU
|
||||||
|
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print("total time:", total_time)
|
||||||
|
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
main(args)
|
330
train_JL_jpg/utils.py
Normal file
330
train_JL_jpg/utils.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue:
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
t = reduce_across_processes([self.count, self.total])
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix:
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mat = None
|
||||||
|
|
||||||
|
def update(self, a, b):
|
||||||
|
n = self.num_classes
|
||||||
|
if self.mat is None:
|
||||||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||||
|
with torch.inference_mode():
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||||||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mat.zero_()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
h = self.mat.float()
|
||||||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||||||
|
acc = torch.diag(h) / h.sum(1)
|
||||||
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||||
|
return acc_global, acc, iu
|
||||||
|
|
||||||
|
def reduce_from_all_processes(self):
|
||||||
|
reduce_across_processes(self.mat)
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
acc_global, acc, iu = self.compute()
|
||||||
|
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||||
|
acc_global.item() * 100,
|
||||||
|
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||||
|
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||||
|
iu.mean().item() * 100,
|
||||||
|
), iu.mean().item() * 100
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger:
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
if not isinstance(v, (float, int)):
|
||||||
|
raise TypeError(
|
||||||
|
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||||
|
)
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError(
|
||||||
|
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(f"{name}: {str(meter)}")
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||||
|
"{meters}", "time: {time}", "data: {data}"]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f"{header} Total time: {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def cat_list(images, fill_value=0):
|
||||||
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||||
|
batch_shape = (len(images),) + max_size
|
||||||
|
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||||
|
for img, pad_img in zip(images, batched_imgs):
|
||||||
|
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||||
|
return batched_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
images, targets = list(zip(*batch))
|
||||||
|
batched_imgs = cat_list(images, fill_value=0)
|
||||||
|
batched_targets = cat_list(targets, fill_value=255)
|
||||||
|
return batched_imgs, batched_targets
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
# elif "SLURM_PROCID" in os.environ:
|
||||||
|
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
# args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
elif hasattr(args, "rank"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print(
|
||||||
|
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_across_processes(val):
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||||
|
return torch.tensor(val)
|
||||||
|
|
||||||
|
t = torch.tensor(val, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def create_lr_scheduler(optimizer,
|
||||||
|
num_step: int,
|
||||||
|
epochs: int,
|
||||||
|
warmup=True,
|
||||||
|
warmup_epochs=1,
|
||||||
|
warmup_factor=1e-3):
|
||||||
|
assert num_step > 0 and epochs > 0
|
||||||
|
if warmup is False:
|
||||||
|
warmup_epochs = 0
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
"""
|
||||||
|
根据step数返回一个学习率倍率因子,
|
||||||
|
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||||
|
"""
|
||||||
|
if warmup is True and x <= (warmup_epochs * num_step):
|
||||||
|
alpha = float(x) / (warmup_epochs * num_step)
|
||||||
|
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||||
|
return warmup_factor * (1 - alpha) + alpha
|
||||||
|
else:
|
||||||
|
# warmup后lr倍率因子从1 -> 0
|
||||||
|
# 参考deeplab_v2: Learning rate policy
|
||||||
|
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
@ -11,7 +11,7 @@ from torch import nn
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from geotiff_utils import VOCYJSSegmentation
|
from geotiff_utils import VOCYJSSegmentation
|
||||||
import utils
|
import utils as utils
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
@ -21,7 +21,7 @@ def parse_args():
|
|||||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root")
|
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
|
||||||
parser.add_argument("--num-classes", default=13, type=int)
|
parser.add_argument("--num-classes", default=13, type=int)
|
||||||
parser.add_argument("--device", default="cuda", help="training device")
|
parser.add_argument("--device", default="cuda", help="training device")
|
||||||
parser.add_argument("--batch-size", default=8, type=int)
|
parser.add_argument("--batch-size", default=8, type=int)
|
330
train_LS/utils.py
Normal file
330
train_LS/utils.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue:
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
t = reduce_across_processes([self.count, self.total])
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix:
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mat = None
|
||||||
|
|
||||||
|
def update(self, a, b):
|
||||||
|
n = self.num_classes
|
||||||
|
if self.mat is None:
|
||||||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||||
|
with torch.inference_mode():
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||||||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mat.zero_()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
h = self.mat.float()
|
||||||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||||||
|
acc = torch.diag(h) / h.sum(1)
|
||||||
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||||
|
return acc_global, acc, iu
|
||||||
|
|
||||||
|
def reduce_from_all_processes(self):
|
||||||
|
reduce_across_processes(self.mat)
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
acc_global, acc, iu = self.compute()
|
||||||
|
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||||
|
acc_global.item() * 100,
|
||||||
|
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||||
|
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||||
|
iu.mean().item() * 100,
|
||||||
|
), iu.mean().item() * 100
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger:
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
if not isinstance(v, (float, int)):
|
||||||
|
raise TypeError(
|
||||||
|
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||||
|
)
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError(
|
||||||
|
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(f"{name}: {str(meter)}")
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||||
|
"{meters}", "time: {time}", "data: {data}"]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f"{header} Total time: {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def cat_list(images, fill_value=0):
|
||||||
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||||
|
batch_shape = (len(images),) + max_size
|
||||||
|
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||||
|
for img, pad_img in zip(images, batched_imgs):
|
||||||
|
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||||
|
return batched_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
images, targets = list(zip(*batch))
|
||||||
|
batched_imgs = cat_list(images, fill_value=0)
|
||||||
|
batched_targets = cat_list(targets, fill_value=255)
|
||||||
|
return batched_imgs, batched_targets
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
# elif "SLURM_PROCID" in os.environ:
|
||||||
|
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
# args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
elif hasattr(args, "rank"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print(
|
||||||
|
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_across_processes(val):
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||||
|
return torch.tensor(val)
|
||||||
|
|
||||||
|
t = torch.tensor(val, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def create_lr_scheduler(optimizer,
|
||||||
|
num_step: int,
|
||||||
|
epochs: int,
|
||||||
|
warmup=True,
|
||||||
|
warmup_epochs=1,
|
||||||
|
warmup_factor=1e-3):
|
||||||
|
assert num_step > 0 and epochs > 0
|
||||||
|
if warmup is False:
|
||||||
|
warmup_epochs = 0
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
"""
|
||||||
|
根据step数返回一个学习率倍率因子,
|
||||||
|
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||||
|
"""
|
||||||
|
if warmup is True and x <= (warmup_epochs * num_step):
|
||||||
|
alpha = float(x) / (warmup_epochs * num_step)
|
||||||
|
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||||
|
return warmup_factor * (1 - alpha) + alpha
|
||||||
|
else:
|
||||||
|
# warmup后lr倍率因子从1 -> 0
|
||||||
|
# 参考deeplab_v2: Learning rate policy
|
||||||
|
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
270
train_LS_jpg/geotiff_utils.py
Normal file
270
train_LS_jpg/geotiff_utils.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# import gdal
|
||||||
|
from osgeo import gdal
|
||||||
|
import random
|
||||||
|
import torch.utils.data as data
|
||||||
|
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationDataset(object):
|
||||||
|
"""Segmentation Base Dataset"""
|
||||||
|
|
||||||
|
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||||
|
super(SegmentationDataset, self).__init__()
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.split = split
|
||||||
|
self.mode = mode if mode is not None else split
|
||||||
|
self.base_size = base_size
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
def _val_sync_transform(self, img, mask):
|
||||||
|
outsize = self.crop_size
|
||||||
|
short_size = outsize
|
||||||
|
w, h = img.size
|
||||||
|
if w > h:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
else:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# center crop
|
||||||
|
w, h = img.size
|
||||||
|
x1 = int(round((w - outsize) / 2.))
|
||||||
|
y1 = int(round((h - outsize) / 2.))
|
||||||
|
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
crop_size = self.crop_size
|
||||||
|
# random scale (short edge)
|
||||||
|
short_size = random.randint(
|
||||||
|
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||||
|
w, h = img.size
|
||||||
|
if h > w:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
else:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# pad crop
|
||||||
|
if short_size < crop_size:
|
||||||
|
padh = crop_size - oh if oh < crop_size else 0
|
||||||
|
padw = crop_size - ow if ow < crop_size else 0
|
||||||
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||||
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||||
|
# random crop crop_size
|
||||||
|
w, h = img.size
|
||||||
|
x1 = random.randint(0, w - crop_size)
|
||||||
|
y1 = random.randint(0, h - crop_size)
|
||||||
|
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
# gaussian blur as in PSP
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif_geofeat(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _val_sync_transform_tif(self, img, mask):
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _img_transform(self, img):
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
# def _mask_transform(self, mask):
|
||||||
|
# return np.array(mask).astype('int32')
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_class(self):
|
||||||
|
"""Number of categories."""
|
||||||
|
return self.NUM_CLASS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_offset(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class VOCYJSSegmentation(SegmentationDataset):
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
root : string
|
||||||
|
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||||
|
split: string
|
||||||
|
'train', 'val' or 'test'
|
||||||
|
transform : callable, optional
|
||||||
|
A function that transforms the image
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from torchvision import transforms
|
||||||
|
>>> import torch.utils.data as data
|
||||||
|
>>> # Transforms for Normalization
|
||||||
|
>>> input_transform = transforms.Compose([
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||||
|
>>> ])
|
||||||
|
>>> # Create Dataset
|
||||||
|
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||||
|
>>> # Create Training Loader
|
||||||
|
>>> train_data = data.DataLoader(
|
||||||
|
>>> trainset, 4, shuffle=True,
|
||||||
|
>>> num_workers=4)
|
||||||
|
"""
|
||||||
|
NUM_CLASS = 13
|
||||||
|
|
||||||
|
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||||
|
super(VOCYJSSegmentation, self).__init__(
|
||||||
|
root, split, mode, transform, **kwargs)
|
||||||
|
_voc_root = root
|
||||||
|
txt_path = os.path.join(root, split+'.txt')
|
||||||
|
self._mask_LS_dir = os.path.join(_voc_root, "masks_LS_png")
|
||||||
|
self._image_LS_dir = os.path.join(_voc_root, "images_LS_jpg")
|
||||||
|
self.image_list = read_text(txt_path)
|
||||||
|
random.shuffle(self.image_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
img_name = self.image_list[index].split('.')[0]+'.jpg'
|
||||||
|
mask_name = self.image_list[index].split('.')[0]+'.png'
|
||||||
|
img_LS = np.array(Image.open(os.path.join(
|
||||||
|
self._image_LS_dir, img_name))).astype(np.float32)
|
||||||
|
mask = np.array(Image.open(os.path.join(
|
||||||
|
self._mask_LS_dir, mask_name))).astype(np.int32)
|
||||||
|
mask = torch.from_numpy(mask).long()
|
||||||
|
|
||||||
|
# general resize, normalize and toTensor
|
||||||
|
if self.transform is not None:
|
||||||
|
img_LS = self.transform(img_LS)
|
||||||
|
return img_LS, mask
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_list)
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes(self):
|
||||||
|
"""Category names."""
|
||||||
|
return ('0', '1', '2', '3', '4', '5', '6', '7' '8', '9', '10', '11', '12')
|
||||||
|
|
||||||
|
|
||||||
|
def generator_list_of_imagepath(path):
|
||||||
|
image_list = []
|
||||||
|
for image in os.listdir(path):
|
||||||
|
# print(path)
|
||||||
|
# print(image)
|
||||||
|
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||||
|
image_list.append(image)
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def read_text(textfile):
|
||||||
|
list = []
|
||||||
|
with open(textfile, "r") as lines:
|
||||||
|
for line in lines:
|
||||||
|
list.append(line.rstrip('\n'))
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||||
|
image_list = generator_list_of_imagepath(imagepath)
|
||||||
|
num = len(image_list)
|
||||||
|
list = range(num)
|
||||||
|
train_num = int(num * train_percent) # training set num
|
||||||
|
train_list = random.sample(list, train_num)
|
||||||
|
print("train set size", train_num)
|
||||||
|
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||||
|
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||||
|
for i in list:
|
||||||
|
name = image_list[i] + '\n'
|
||||||
|
if i in train_list:
|
||||||
|
ftrain.write(name)
|
||||||
|
else:
|
||||||
|
fval.write(name)
|
||||||
|
ftrain.close()
|
||||||
|
fval.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||||
|
# list=generator_list_of_imagepath(path)
|
||||||
|
# print(list)
|
||||||
|
# 切割数据集
|
||||||
|
|
||||||
|
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||||
|
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||||
|
train_percent = 0.8
|
||||||
|
dataset_segmentation(textpath, imagepath, train_percent)
|
||||||
|
# 显示各种图片
|
||||||
|
|
||||||
|
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||||
|
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||||
|
# cv2.imshow('img', img)
|
||||||
|
# img = Image.fromarray (img,'RGB')
|
||||||
|
# img.show()
|
||||||
|
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||||
|
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||||
|
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||||
|
# img2 = Image.fromarray (img2,'RGB')
|
||||||
|
# img2.show()
|
||||||
|
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||||
|
# img3 = gdal.Open(img3).ReadAsArray()
|
||||||
|
# img3 = Image.fromarray (img3)
|
||||||
|
# img3.show()
|
||||||
|
|
||||||
|
# dataset和dataloader的测试
|
||||||
|
|
||||||
|
# 测试dataloader能不能用
|
||||||
|
'''
|
||||||
|
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||||
|
input_transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||||
|
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||||
|
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||||
|
'''
|
152
train_LS_jpg/train.py
Normal file
152
train_LS_jpg/train.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch.utils.data as data
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from geotiff_utils import VOCYJSSegmentation
|
||||||
|
import utils
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
|
||||||
|
parser.add_argument("--num-classes", default=13, type=int)
|
||||||
|
parser.add_argument("--device", default="cuda", help="training device")
|
||||||
|
parser.add_argument("--batch-size", default=8, type=int)
|
||||||
|
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||||
|
help="number of total epochs to train")
|
||||||
|
parser.add_argument('--lr', default=0.005, type=float,
|
||||||
|
help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
parser.add_argument('-out-dir', type=str,
|
||||||
|
default='DeeplabV3_LS_3')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class DeeplabV3_LS_3(nn.Module):
|
||||||
|
def __init__(self, n_class):
|
||||||
|
super(DeeplabV3_LS_3, self).__init__()
|
||||||
|
self.n_class = n_class
|
||||||
|
self.conv_fc = nn.Conv2d(
|
||||||
|
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||||
|
|
||||||
|
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.seg(x)["out"]
|
||||||
|
x = self.conv_fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
input_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([.485, .456, .406],
|
||||||
|
[.229, .224, .225]),
|
||||||
|
])
|
||||||
|
data_kwargs = {'transform': input_transform,
|
||||||
|
'base_size': 224, 'crop_size': 224}
|
||||||
|
|
||||||
|
# 读取geotiff数据,构建训练集、验证集
|
||||||
|
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||||
|
**data_kwargs)
|
||||||
|
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||||
|
**data_kwargs)
|
||||||
|
num_workers = min(
|
||||||
|
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||||
|
train_loader = data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
val_loader = data.DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
|
||||||
|
model = DeeplabV3_LS_3(n_class=args.num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = utils.create_lr_scheduler(
|
||||||
|
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||||
|
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||||
|
if not out_dir.exists():
|
||||||
|
out_dir.mkdir()
|
||||||
|
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||||
|
start_time = time.time()
|
||||||
|
best_acc = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||||
|
model.train()
|
||||||
|
for idx, (image, target) in enumerate(train_loader):
|
||||||
|
image = image.to(device)
|
||||||
|
target = target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if idx % 100 == 0:
|
||||||
|
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||||
|
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, target in val_loader:
|
||||||
|
image, target = image.to(device), target.to(device)
|
||||||
|
output = model(image)
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
info, mIoU = confmat.get_info()
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||||
|
f.write(info+"\n\n")
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
# # 保存准确率最好的模型
|
||||||
|
# if mIoU > best_acc:
|
||||||
|
# print("[Save model]")
|
||||||
|
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||||
|
# best_acc = mIoU
|
||||||
|
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print("total time:", total_time)
|
||||||
|
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
main(args)
|
330
train_LS_jpg/utils.py
Normal file
330
train_LS_jpg/utils.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue:
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
t = reduce_across_processes([self.count, self.total])
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix:
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mat = None
|
||||||
|
|
||||||
|
def update(self, a, b):
|
||||||
|
n = self.num_classes
|
||||||
|
if self.mat is None:
|
||||||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||||
|
with torch.inference_mode():
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||||||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mat.zero_()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
h = self.mat.float()
|
||||||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||||||
|
acc = torch.diag(h) / h.sum(1)
|
||||||
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||||
|
return acc_global, acc, iu
|
||||||
|
|
||||||
|
def reduce_from_all_processes(self):
|
||||||
|
reduce_across_processes(self.mat)
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
acc_global, acc, iu = self.compute()
|
||||||
|
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||||
|
acc_global.item() * 100,
|
||||||
|
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||||
|
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||||
|
iu.mean().item() * 100,
|
||||||
|
), iu.mean().item() * 100
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger:
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
if not isinstance(v, (float, int)):
|
||||||
|
raise TypeError(
|
||||||
|
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||||
|
)
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError(
|
||||||
|
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(f"{name}: {str(meter)}")
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||||
|
"{meters}", "time: {time}", "data: {data}"]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f"{header} Total time: {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def cat_list(images, fill_value=0):
|
||||||
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||||
|
batch_shape = (len(images),) + max_size
|
||||||
|
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||||
|
for img, pad_img in zip(images, batched_imgs):
|
||||||
|
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||||
|
return batched_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
images, targets = list(zip(*batch))
|
||||||
|
batched_imgs = cat_list(images, fill_value=0)
|
||||||
|
batched_targets = cat_list(targets, fill_value=255)
|
||||||
|
return batched_imgs, batched_targets
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
# elif "SLURM_PROCID" in os.environ:
|
||||||
|
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
# args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
elif hasattr(args, "rank"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print(
|
||||||
|
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_across_processes(val):
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||||
|
return torch.tensor(val)
|
||||||
|
|
||||||
|
t = torch.tensor(val, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def create_lr_scheduler(optimizer,
|
||||||
|
num_step: int,
|
||||||
|
epochs: int,
|
||||||
|
warmup=True,
|
||||||
|
warmup_epochs=1,
|
||||||
|
warmup_factor=1e-3):
|
||||||
|
assert num_step > 0 and epochs > 0
|
||||||
|
if warmup is False:
|
||||||
|
warmup_epochs = 0
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
"""
|
||||||
|
根据step数返回一个学习率倍率因子,
|
||||||
|
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||||
|
"""
|
||||||
|
if warmup is True and x <= (warmup_epochs * num_step):
|
||||||
|
alpha = float(x) / (warmup_epochs * num_step)
|
||||||
|
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||||
|
return warmup_factor * (1 - alpha) + alpha
|
||||||
|
else:
|
||||||
|
# warmup后lr倍率因子从1 -> 0
|
||||||
|
# 参考deeplab_v2: Learning rate policy
|
||||||
|
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
Loading…
Reference in New Issue
Block a user