加入jpg图像训练代码
This commit is contained in:
parent
2a40d11a2a
commit
704c24d79d
1
.gitignore
vendored
1
.gitignore
vendored
@ -160,3 +160,4 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
/train_output
|
138
data_preprocessing/cut_smalltif.bak
Normal file
138
data_preprocessing/cut_smalltif.bak
Normal file
@ -0,0 +1,138 @@
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
|
||||
|
||||
def read_tif(fileName):
|
||||
dataset = gdal.Open(fileName)
|
||||
|
||||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||
im_bands = dataset.RasterCount # 波段数
|
||||
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||
if len(im_data.shape) == 2:
|
||||
im_data = im_data[np.newaxis, :]
|
||||
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||
im_proj = dataset.GetProjection() # 获取投影信息
|
||||
|
||||
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||
|
||||
|
||||
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||
if 'int8' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_Byte
|
||||
elif 'int16' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_UInt16
|
||||
else:
|
||||
datatype = gdal.GDT_Float32
|
||||
|
||||
if len(im_data.shape) == 3:
|
||||
im_bands, im_height, im_width = im_data.shape
|
||||
else:
|
||||
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||
# 创建文件
|
||||
driver = gdal.GetDriverByName("GTiff")
|
||||
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||
if dataset != None and im_geotrans != None and im_proj != None:
|
||||
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||
dataset.SetProjection(im_proj) # 写入投影
|
||||
for i in range(im_bands):
|
||||
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||
del dataset
|
||||
|
||||
|
||||
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
||||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||
fileName)
|
||||
|
||||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||
fileName2)
|
||||
|
||||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||
mask_fileName)
|
||||
mask_im_data = np.int8(mask_im_data)
|
||||
|
||||
|
||||
# geotiff归一化
|
||||
for i in range(im_bands):
|
||||
arr = im_data[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data[i] = normalized_arr
|
||||
|
||||
for i in range(im_bands2):
|
||||
arr = im_data2[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data2[i] = normalized_arr
|
||||
|
||||
|
||||
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||
im_data = im_data/255
|
||||
for i in range(im_bands):
|
||||
pixels = im_data[i, :, :].ravel()
|
||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data = im_data*255
|
||||
|
||||
im_data2 = im_data2/255
|
||||
for i in range(im_bands2):
|
||||
pixels = im_data2[i, :, :].ravel()
|
||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data2 = im_data2*255
|
||||
|
||||
|
||||
# 切成小图
|
||||
a = 0
|
||||
size = 224
|
||||
for i in range(0, int(mask_im_height / size)):
|
||||
for j in range(0, int(mask_im_width / size)):
|
||||
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
||||
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
||||
|
||||
# 以mask为判断基准,同时处理geotiff和mask
|
||||
labelfla = np.array(mask_cut).flatten()
|
||||
if np.all(labelfla == 15): # 15为NoData
|
||||
print("Skip!!!")
|
||||
else:
|
||||
# 5m
|
||||
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
||||
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
||||
new_geotrans = np.array(im_geotrans)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
||||
|
||||
# dan
|
||||
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
||||
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
||||
new_geotrans = np.array(im_geotrans2)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
||||
|
||||
# 存mask
|
||||
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||
mask_new_geotrans[0] = mask_left_w
|
||||
mask_new_geotrans[3] = mask_left_h
|
||||
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
||||
write_tif(mask_cut, size, size, mask_out,
|
||||
mask_out_geotrans, mask_im_proj)
|
||||
|
||||
print(mask_out + 'Cut to complete')
|
||||
|
||||
a = a+1
|
@ -44,9 +44,6 @@ fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
||||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||
fileName)
|
||||
|
||||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||
fileName2)
|
||||
|
||||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||
@ -55,19 +52,14 @@ mask_im_data = np.int8(mask_im_data)
|
||||
|
||||
|
||||
# geotiff归一化
|
||||
lower_percent = 2
|
||||
upper_percent = 98
|
||||
for i in range(im_bands):
|
||||
arr = im_data[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data[i] = normalized_arr
|
||||
|
||||
for i in range(im_bands2):
|
||||
arr = im_data2[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data2[i] = normalized_arr
|
||||
lower = np.percentile(arr, lower_percent)
|
||||
upper = np.percentile(arr, upper_percent)
|
||||
stretched = np.clip((arr - lower) / (upper - lower), 0, 1)
|
||||
im_data[i] = (stretched * 255).astype(np.uint8)
|
||||
|
||||
|
||||
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||
@ -78,61 +70,42 @@ for i in range(im_bands):
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data = im_data*255
|
||||
|
||||
im_data2 = im_data2/255
|
||||
for i in range(im_bands2):
|
||||
pixels = im_data2[i, :, :].ravel()
|
||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data2 = im_data2*255
|
||||
|
||||
|
||||
# 切成小图
|
||||
a = 0
|
||||
size = 224
|
||||
size = 448
|
||||
for i in range(0, int(mask_im_height / size)):
|
||||
for j in range(0, int(mask_im_width / size)):
|
||||
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
||||
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
||||
mask_cut = mask_im_data[:, i * size:i *
|
||||
size + size, j * size:j * size + size]
|
||||
|
||||
# 以mask为判断基准,同时处理geotiff和mask
|
||||
labelfla = np.array(mask_cut).flatten()
|
||||
if np.all(labelfla == 15): # 15为NoData
|
||||
print("Skip!!!")
|
||||
else:
|
||||
# 5m
|
||||
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
||||
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
||||
new_geotrans = np.array(im_geotrans)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
# 取5、4、3波段,注意顺序
|
||||
rgb_cut = np.stack([
|
||||
im_cut[4], # 第5波段
|
||||
im_cut[3], # 第4波段
|
||||
im_cut[2], # 第3波段
|
||||
], axis=0) # shape: (3, H, W)
|
||||
# 转为 (H, W, 3)
|
||||
rgb_cut = np.transpose(rgb_cut, (1, 2, 0))
|
||||
# 归一化到0-255并转uint8
|
||||
rgb_cut = np.clip(rgb_cut, 0, 255)
|
||||
rgb_cut = rgb_cut.astype(np.uint8)
|
||||
# 保存为jpg
|
||||
from PIL import Image
|
||||
rgb_img = Image.fromarray(rgb_cut)
|
||||
rgb_img.save(
|
||||
f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg')
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
||||
|
||||
# dan
|
||||
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
||||
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
||||
new_geotrans = np.array(im_geotrans2)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
||||
|
||||
# 存mask
|
||||
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||
mask_new_geotrans[0] = mask_left_w
|
||||
mask_new_geotrans[3] = mask_left_h
|
||||
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
||||
write_tif(mask_cut, size, size, mask_out,
|
||||
mask_out_geotrans, mask_im_proj)
|
||||
|
||||
print(mask_out + 'Cut to complete')
|
||||
# mask只取第一个波段(如果是单通道),转uint8保存为png
|
||||
mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut
|
||||
mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8)
|
||||
mask_img = Image.fromarray(mask_arr)
|
||||
mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png')
|
||||
|
||||
print(f'img_{a}.jpg and mask_{a}.png saved')
|
||||
a = a+1
|
||||
|
@ -4,18 +4,18 @@ import random
|
||||
|
||||
random.seed(42)
|
||||
|
||||
geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m')
|
||||
geotiffs = os.listdir(r'E:\RSdata\wlk_right_448\dataset_5m_jpg')
|
||||
num = len(geotiffs)
|
||||
split_rate = 0.2
|
||||
|
||||
eval_index = random.sample(geotiffs, k=int(num*split_rate))
|
||||
|
||||
f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w')
|
||||
f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w')
|
||||
f_train = open(r'E:\RSdata\wlk_right_448/train.txt', 'w')
|
||||
f_val = open(r'E:\RSdata\wlk_right_448/val.txt', 'w')
|
||||
|
||||
# 写入文件
|
||||
for geotiff in geotiffs:
|
||||
if geotiff in eval_index:
|
||||
f_val.write(str(geotiff)+'\n')
|
||||
f_train.write(str(geotiff)+'\n')
|
||||
else:
|
||||
f_train.write(str(geotiff)+'\n')
|
||||
f_val.write(str(geotiff)+'\n')
|
||||
|
62
datasets_pro_code/datasets_pro.py
Normal file
62
datasets_pro_code/datasets_pro.py
Normal file
@ -0,0 +1,62 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
# 定义文件夹路径
|
||||
base_dir = r"e:\datasets\wlk_right_448"
|
||||
jpeg_images_dir = os.path.join(base_dir, "JPEGImages")
|
||||
segmentation_class_dir = os.path.join(base_dir, "SegmentationClass")
|
||||
annotations_dir = os.path.join(base_dir, "stare_stuct", "annotations")
|
||||
images_dir = os.path.join(base_dir, "stare_stuct", "images")
|
||||
|
||||
# 定义目标文件夹
|
||||
annotations_training_dir = os.path.join(annotations_dir, "training")
|
||||
annotations_validation_dir = os.path.join(annotations_dir, "validation")
|
||||
images_training_dir = os.path.join(images_dir, "training")
|
||||
images_validation_dir = os.path.join(images_dir, "validation")
|
||||
|
||||
# 创建目标文件夹
|
||||
os.makedirs(annotations_training_dir, exist_ok=True)
|
||||
os.makedirs(annotations_validation_dir, exist_ok=True)
|
||||
os.makedirs(images_training_dir, exist_ok=True)
|
||||
os.makedirs(images_validation_dir, exist_ok=True)
|
||||
|
||||
# 读取 train.txt 和 val.txt
|
||||
train_file = os.path.join(base_dir, "train.txt")
|
||||
val_file = os.path.join(base_dir, "val.txt")
|
||||
|
||||
|
||||
def read_file_list(file_path):
|
||||
with open(file_path, "r") as f:
|
||||
return [line.strip() for line in f.readlines()]
|
||||
|
||||
|
||||
train_list = read_file_list(train_file)
|
||||
val_list = read_file_list(val_file)
|
||||
|
||||
# 移动文件函数
|
||||
|
||||
|
||||
def move_files(file_list, src_images_dir, src_labels_dir, dst_images_dir, dst_labels_dir):
|
||||
for file_name in file_list:
|
||||
# 图片文件
|
||||
image_src = os.path.join(src_images_dir, file_name)
|
||||
image_dst = os.path.join(dst_images_dir, file_name)
|
||||
if os.path.exists(image_src):
|
||||
shutil.copy(image_src, image_dst)
|
||||
|
||||
# 标签文件
|
||||
label_src = os.path.join(src_labels_dir, file_name)
|
||||
label_dst = os.path.join(dst_labels_dir, file_name)
|
||||
if os.path.exists(label_src):
|
||||
shutil.copy(label_src, label_dst)
|
||||
|
||||
|
||||
# 移动训练集文件
|
||||
move_files(train_list, jpeg_images_dir, segmentation_class_dir,
|
||||
images_training_dir, annotations_training_dir)
|
||||
|
||||
# 移动验证集文件
|
||||
move_files(val_list, jpeg_images_dir, segmentation_class_dir,
|
||||
images_validation_dir, annotations_validation_dir)
|
||||
|
||||
print("文件组织完成!")
|
22
datasets_pro_code/read_mask.py
Normal file
22
datasets_pro_code/read_mask.py
Normal file
@ -0,0 +1,22 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import os
|
||||
from osgeo import gdal
|
||||
|
||||
|
||||
mask_dir = r"E:\datasets\wlk_right_448\mask" # 修改为你的mask文件夹路径
|
||||
|
||||
all_labels = set()
|
||||
|
||||
for file in os.listdir(mask_dir):
|
||||
if file.lower().endswith('.tif'):
|
||||
tif_path = os.path.join(mask_dir, file)
|
||||
dataset = gdal.Open(tif_path)
|
||||
if dataset is None:
|
||||
print(f"无法打开: {tif_path}")
|
||||
continue
|
||||
band = dataset.ReadAsArray()
|
||||
unique = np.unique(band)
|
||||
all_labels.update(unique)
|
||||
|
||||
print("所有mask中出现过的标签数字:", sorted(all_labels))
|
6
datasets_pro_code/remove_tif.py
Normal file
6
datasets_pro_code/remove_tif.py
Normal file
@ -0,0 +1,6 @@
|
||||
input_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val.txt"
|
||||
output_path = r"E:\datasets\WLKdata_1111\WLK_voc\ImageSets\Segmentation\val_no.txt"
|
||||
|
||||
with open(input_path, "r", encoding="utf-8") as fin, open(output_path, "w", encoding="utf-8") as fout:
|
||||
for line in fin:
|
||||
fout.write(line.strip().replace(".tif", "") + "\n")
|
30
datasets_pro_code/tif_mask_to_png.py
Normal file
30
datasets_pro_code/tif_mask_to_png.py
Normal file
@ -0,0 +1,30 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from osgeo import gdal
|
||||
|
||||
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS"
|
||||
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\masks_LS_png"
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
for file in os.listdir(src_dir):
|
||||
if file.lower().endswith('.tif'):
|
||||
tif_path = os.path.join(src_dir, file)
|
||||
dataset = gdal.Open(tif_path)
|
||||
if dataset is None:
|
||||
print(f"无法打开: {tif_path}")
|
||||
continue
|
||||
mask = dataset.ReadAsArray()
|
||||
if mask.ndim != 2:
|
||||
print(f"{file} 不是单波段,跳过")
|
||||
continue
|
||||
|
||||
# 替换像素值
|
||||
mask = mask.copy()
|
||||
mask[mask == 15] = 255
|
||||
|
||||
png_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".png")
|
||||
Image.fromarray(mask.astype(np.uint8)).save(png_path)
|
||||
print(f"已保存: {png_path}")
|
||||
|
||||
print("全部转换完成!")
|
33
datasets_pro_code/tif_to_jpg.py
Normal file
33
datasets_pro_code/tif_to_jpg.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from osgeo import gdal
|
||||
|
||||
# 输入和输出文件夹
|
||||
src_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS"
|
||||
dst_dir = r"E:\datasets\WLKdata_1111\WLKdataset\images_LS_jpg"
|
||||
os.makedirs(dst_dir, exist_ok=True)
|
||||
|
||||
for file in os.listdir(src_dir):
|
||||
if file.lower().endswith('.tif'):
|
||||
tif_path = os.path.join(src_dir, file)
|
||||
dataset = gdal.Open(tif_path)
|
||||
if dataset is None:
|
||||
print(f"无法打开: {tif_path}")
|
||||
continue
|
||||
|
||||
# 读取所有波段
|
||||
bands = []
|
||||
for i in range(1, dataset.RasterCount + 1):
|
||||
band = dataset.GetRasterBand(i).ReadAsArray()
|
||||
bands.append(band)
|
||||
img = np.stack(bands, axis=-1) if len(bands) > 1 else bands[0]
|
||||
|
||||
# 转换为uint8
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
jpg_path = os.path.join(dst_dir, os.path.splitext(file)[0] + ".jpg")
|
||||
Image.fromarray(img).save(jpg_path, quality=95)
|
||||
print(f"已保存: {jpg_path}")
|
||||
|
||||
print("全部转换完成!")
|
274
train_JL/geotiff_utils.py
Normal file
274
train_JL/geotiff_utils.py
Normal file
@ -0,0 +1,274 @@
|
||||
|
||||
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# import gdal
|
||||
from osgeo import gdal
|
||||
import random
|
||||
import torch.utils.data as data
|
||||
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||
|
||||
|
||||
class SegmentationDataset(object):
|
||||
"""Segmentation Base Dataset"""
|
||||
|
||||
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||
super(SegmentationDataset, self).__init__()
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
self.mode = mode if mode is not None else split
|
||||
self.base_size = base_size
|
||||
self.crop_size = crop_size
|
||||
|
||||
def _val_sync_transform(self, img, mask):
|
||||
outsize = self.crop_size
|
||||
short_size = outsize
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
else:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# center crop
|
||||
w, h = img.size
|
||||
x1 = int(round((w - outsize) / 2.))
|
||||
y1 = int(round((h - outsize) / 2.))
|
||||
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform(self, img, mask):
|
||||
# random mirror
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
crop_size = self.crop_size
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(
|
||||
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < crop_size:
|
||||
padh = crop_size - oh if oh < crop_size else 0
|
||||
padw = crop_size - ow if ow < crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - crop_size)
|
||||
y1 = random.randint(0, h - crop_size)
|
||||
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
# gaussian blur as in PSP
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif_geofeat(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _val_sync_transform_tif(self, img, mask):
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _img_transform(self, img):
|
||||
return np.array(img)
|
||||
|
||||
# def _mask_transform(self, mask):
|
||||
# return np.array(mask).astype('int32')
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def num_class(self):
|
||||
"""Number of categories."""
|
||||
return self.NUM_CLASS
|
||||
|
||||
@property
|
||||
def pred_offset(self):
|
||||
return 0
|
||||
|
||||
|
||||
class VOCYJSSegmentation(SegmentationDataset):
|
||||
"""Pascal VOC Semantic Segmentation Dataset.
|
||||
Parameters
|
||||
----------
|
||||
root : string
|
||||
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||
split: string
|
||||
'train', 'val' or 'test'
|
||||
transform : callable, optional
|
||||
A function that transforms the image
|
||||
Examples
|
||||
--------
|
||||
>>> from torchvision import transforms
|
||||
>>> import torch.utils.data as data
|
||||
>>> # Transforms for Normalization
|
||||
>>> input_transform = transforms.Compose([
|
||||
>>> transforms.ToTensor(),
|
||||
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||
>>> ])
|
||||
>>> # Create Dataset
|
||||
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||
>>> # Create Training Loader
|
||||
>>> train_data = data.DataLoader(
|
||||
>>> trainset, 4, shuffle=True,
|
||||
>>> num_workers=4)
|
||||
"""
|
||||
NUM_CLASS = 13
|
||||
|
||||
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||
super(VOCYJSSegmentation, self).__init__(
|
||||
root, split, mode, transform, **kwargs)
|
||||
_voc_root = root
|
||||
txt_path = os.path.join(root, split+'.txt')
|
||||
self._mask_LS_dir = os.path.join(_voc_root, 'mask')
|
||||
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m")
|
||||
self.image_list = read_text(txt_path)
|
||||
random.shuffle(self.image_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray(
|
||||
).transpose(1, 2, 0).astype(np.float32)
|
||||
mask = gdal.Open(os.path.join(self._mask_LS_dir,
|
||||
self.image_list[index])).ReadAsArray()
|
||||
# synchronized transform
|
||||
# 只包含两种模式: train 和 val
|
||||
if self.mode == 'train':
|
||||
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask)
|
||||
elif self.mode == 'val':
|
||||
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask)
|
||||
# general resize, normalize and toTensor
|
||||
if self.transform is not None:
|
||||
img_LS = self.transform(img_LS)
|
||||
return img_LS, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
"""Category names."""
|
||||
return ('0', '1', '2', '3', '4', '5', '6')
|
||||
|
||||
|
||||
def generator_list_of_imagepath(path):
|
||||
image_list = []
|
||||
for image in os.listdir(path):
|
||||
# print(path)
|
||||
# print(image)
|
||||
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||
image_list.append(image)
|
||||
return image_list
|
||||
|
||||
|
||||
def read_text(textfile):
|
||||
list = []
|
||||
with open(textfile, "r") as lines:
|
||||
for line in lines:
|
||||
list.append(line.rstrip('\n'))
|
||||
return list
|
||||
|
||||
|
||||
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||
image_list = generator_list_of_imagepath(imagepath)
|
||||
num = len(image_list)
|
||||
list = range(num)
|
||||
train_num = int(num * train_percent) # training set num
|
||||
train_list = random.sample(list, train_num)
|
||||
print("train set size", train_num)
|
||||
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||
for i in list:
|
||||
name = image_list[i] + '\n'
|
||||
if i in train_list:
|
||||
ftrain.write(name)
|
||||
else:
|
||||
fval.write(name)
|
||||
ftrain.close()
|
||||
fval.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||
# list=generator_list_of_imagepath(path)
|
||||
# print(list)
|
||||
# 切割数据集
|
||||
|
||||
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||
train_percent = 0.8
|
||||
dataset_segmentation(textpath, imagepath, train_percent)
|
||||
# 显示各种图片
|
||||
|
||||
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||
# cv2.imshow('img', img)
|
||||
# img = Image.fromarray (img,'RGB')
|
||||
# img.show()
|
||||
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||
# img2 = Image.fromarray (img2,'RGB')
|
||||
# img2.show()
|
||||
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||
# img3 = gdal.Open(img3).ReadAsArray()
|
||||
# img3 = Image.fromarray (img3)
|
||||
# img3.show()
|
||||
|
||||
# dataset和dataloader的测试
|
||||
|
||||
# 测试dataloader能不能用
|
||||
'''
|
||||
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||
input_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||
'''
|
156
train_JL/train_JL.py
Normal file
156
train_JL/train_JL.py
Normal file
@ -0,0 +1,156 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||
from torchvision import transforms
|
||||
import torch.utils.data as data
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from geotiff_utils import VOCYJSSegmentation
|
||||
import utils as utils
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||
|
||||
parser.add_argument(
|
||||
"--data-path", default=r"E:\datasets\wlk_right_448", help="VOCdevkit root")
|
||||
parser.add_argument("--num-classes", default=7, type=int)
|
||||
parser.add_argument("--device", default="cuda", help="training device")
|
||||
parser.add_argument("--batch-size", default=4, type=int)
|
||||
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||
help="number of total epochs to train")
|
||||
parser.add_argument('--lr', default=0.005, type=float,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-out-dir', type=str,
|
||||
default='DeeplabV3_JL')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DeeplabV3_JL(nn.Module):
|
||||
def __init__(self, n_class):
|
||||
super(DeeplabV3_JL, self).__init__()
|
||||
self.n_class = n_class
|
||||
self.conv6_3 = nn.Conv2d(6, 3, kernel_size=1, stride=1)
|
||||
|
||||
self.conv_fc = nn.Conv2d(
|
||||
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||
|
||||
def forward(self, x):
|
||||
# x = torch.cat([x, x_dan], dim=1)
|
||||
x = self.conv6_3(x)
|
||||
x = self.seg(x)["out"]
|
||||
x = self.conv_fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([.535, .767, .732, .561, .494, .564],
|
||||
[.0132, .0188, .0181, .0173, .0183, .0259]),
|
||||
])
|
||||
data_kwargs = {'transform': input_transform,
|
||||
'base_size': 448, 'crop_size': 448}
|
||||
|
||||
# 读取geotiff数据,构建训练集、验证集
|
||||
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||
**data_kwargs)
|
||||
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||
**data_kwargs)
|
||||
num_workers = min(
|
||||
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||
train_loader = data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
val_loader = data.DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
|
||||
model = DeeplabV3_JL(n_class=args.num_classes)
|
||||
model.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_scheduler = utils.create_lr_scheduler(
|
||||
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||
if not out_dir.exists():
|
||||
out_dir.mkdir()
|
||||
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||
start_time = time.time()
|
||||
best_acc = 0
|
||||
for epoch in range(args.epochs):
|
||||
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||
model.train()
|
||||
for idx, (image, target) in enumerate(train_loader):
|
||||
image, target = image.to(
|
||||
device), target.to(device)
|
||||
output = model(image)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 100 == 0:
|
||||
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||
|
||||
model.eval()
|
||||
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||
with torch.no_grad():
|
||||
for image, target in val_loader:
|
||||
image, target = image.to(device), target.to(device)
|
||||
output = model(image)
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
info, mIoU = confmat.get_info()
|
||||
print(info)
|
||||
|
||||
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||
f.write(info+"\n\n")
|
||||
f.flush()
|
||||
|
||||
# # 保存准确率最好的模型
|
||||
# if mIoU > best_acc:
|
||||
# print("[Save model]")
|
||||
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||
# best_acc = mIoU
|
||||
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||
total_time = time.time() - start_time
|
||||
print("total time:", total_time)
|
||||
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
279
train_JL_jpg/geotiff_utils.py
Normal file
279
train_JL_jpg/geotiff_utils.py
Normal file
@ -0,0 +1,279 @@
|
||||
|
||||
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# import gdal
|
||||
from osgeo import gdal
|
||||
import random
|
||||
import torch.utils.data as data
|
||||
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||
|
||||
|
||||
class SegmentationDataset(object):
|
||||
"""Segmentation Base Dataset"""
|
||||
|
||||
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||
super(SegmentationDataset, self).__init__()
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
self.mode = mode if mode is not None else split
|
||||
self.base_size = base_size
|
||||
self.crop_size = crop_size
|
||||
|
||||
def _val_sync_transform(self, img, mask):
|
||||
outsize = self.crop_size
|
||||
short_size = outsize
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
else:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# center crop
|
||||
w, h = img.size
|
||||
x1 = int(round((w - outsize) / 2.))
|
||||
y1 = int(round((h - outsize) / 2.))
|
||||
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform(self, img, mask):
|
||||
# random mirror
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
crop_size = self.crop_size
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(
|
||||
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < crop_size:
|
||||
padh = crop_size - oh if oh < crop_size else 0
|
||||
padw = crop_size - ow if ow < crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - crop_size)
|
||||
y1 = random.randint(0, h - crop_size)
|
||||
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
# gaussian blur as in PSP
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif_geofeat(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _val_sync_transform_tif(self, img, mask):
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _img_transform(self, img):
|
||||
return np.array(img)
|
||||
|
||||
# def _mask_transform(self, mask):
|
||||
# return np.array(mask).astype('int32')
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def num_class(self):
|
||||
"""Number of categories."""
|
||||
return self.NUM_CLASS
|
||||
|
||||
@property
|
||||
def pred_offset(self):
|
||||
return 0
|
||||
|
||||
|
||||
class VOCYJSSegmentation(SegmentationDataset):
|
||||
"""Pascal VOC Semantic Segmentation Dataset.
|
||||
Parameters
|
||||
----------
|
||||
root : string
|
||||
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||
split: string
|
||||
'train', 'val' or 'test'
|
||||
transform : callable, optional
|
||||
A function that transforms the image
|
||||
Examples
|
||||
--------
|
||||
>>> from torchvision import transforms
|
||||
>>> import torch.utils.data as data
|
||||
>>> # Transforms for Normalization
|
||||
>>> input_transform = transforms.Compose([
|
||||
>>> transforms.ToTensor(),
|
||||
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||
>>> ])
|
||||
>>> # Create Dataset
|
||||
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||
>>> # Create Training Loader
|
||||
>>> train_data = data.DataLoader(
|
||||
>>> trainset, 4, shuffle=True,
|
||||
>>> num_workers=4)
|
||||
"""
|
||||
NUM_CLASS = 13
|
||||
|
||||
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||
super(VOCYJSSegmentation, self).__init__(
|
||||
root, split, mode, transform, **kwargs)
|
||||
_voc_root = root
|
||||
txt_path = os.path.join(root, split+'.txt')
|
||||
self._mask_LS_dir = os.path.join(_voc_root, 'mask_png')
|
||||
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m_jpg")
|
||||
self.image_list = read_text(txt_path)
|
||||
random.shuffle(self.image_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_name = self.image_list[index].split('.')[0]+'.jpg'
|
||||
mask_name = self.image_list[index].split('.')[0]+'.png'
|
||||
mask_name = mask_name.replace('img', 'mask')
|
||||
img_LS = np.array(Image.open(os.path.join(
|
||||
self._image_LS_dir, img_name))).astype(np.float32)
|
||||
mask = np.array(Image.open(os.path.join(
|
||||
self._mask_LS_dir, mask_name))).astype(np.int32)
|
||||
mask = torch.from_numpy(mask).long()
|
||||
|
||||
# synchronized transform
|
||||
# 只包含两种模式: train 和 val
|
||||
if self.mode == 'train':
|
||||
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask)
|
||||
elif self.mode == 'val':
|
||||
img_LS, mask = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask)
|
||||
# general resize, normalize and toTensor
|
||||
if self.transform is not None:
|
||||
img_LS = self.transform(img_LS)
|
||||
return img_LS, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
"""Category names."""
|
||||
return ('0', '1', '2', '3', '4', '5', '6')
|
||||
|
||||
|
||||
def generator_list_of_imagepath(path):
|
||||
image_list = []
|
||||
for image in os.listdir(path):
|
||||
# print(path)
|
||||
# print(image)
|
||||
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||
image_list.append(image)
|
||||
return image_list
|
||||
|
||||
|
||||
def read_text(textfile):
|
||||
list = []
|
||||
with open(textfile, "r") as lines:
|
||||
for line in lines:
|
||||
list.append(line.rstrip('\n'))
|
||||
return list
|
||||
|
||||
|
||||
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||
image_list = generator_list_of_imagepath(imagepath)
|
||||
num = len(image_list)
|
||||
list = range(num)
|
||||
train_num = int(num * train_percent) # training set num
|
||||
train_list = random.sample(list, train_num)
|
||||
print("train set size", train_num)
|
||||
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||
for i in list:
|
||||
name = image_list[i] + '\n'
|
||||
if i in train_list:
|
||||
ftrain.write(name)
|
||||
else:
|
||||
fval.write(name)
|
||||
ftrain.close()
|
||||
fval.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||
# list=generator_list_of_imagepath(path)
|
||||
# print(list)
|
||||
# 切割数据集
|
||||
|
||||
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||
train_percent = 0.8
|
||||
dataset_segmentation(textpath, imagepath, train_percent)
|
||||
# 显示各种图片
|
||||
|
||||
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||
# cv2.imshow('img', img)
|
||||
# img = Image.fromarray (img,'RGB')
|
||||
# img.show()
|
||||
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||
# img2 = Image.fromarray (img2,'RGB')
|
||||
# img2.show()
|
||||
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||
# img3 = gdal.Open(img3).ReadAsArray()
|
||||
# img3 = Image.fromarray (img3)
|
||||
# img3.show()
|
||||
|
||||
# dataset和dataloader的测试
|
||||
|
||||
# 测试dataloader能不能用
|
||||
'''
|
||||
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||
input_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||
'''
|
153
train_JL_jpg/train_JL.py
Normal file
153
train_JL_jpg/train_JL.py
Normal file
@ -0,0 +1,153 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||
from torchvision import transforms
|
||||
import torch.utils.data as data
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from geotiff_utils import VOCYJSSegmentation
|
||||
import utils as utils
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||
|
||||
parser.add_argument(
|
||||
"--data-path", default=r"E:\RSdata\wlk_right_448", help="VOCdevkit root")
|
||||
parser.add_argument("--num-classes", default=7, type=int)
|
||||
parser.add_argument("--device", default="cuda", help="training device")
|
||||
parser.add_argument("--batch-size", default=4, type=int)
|
||||
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||
help="number of total epochs to train")
|
||||
parser.add_argument('--lr', default=0.005, type=float,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-out-dir', type=str,
|
||||
default='DeeplabV3_JL')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DeeplabV3_JL_3(nn.Module):
|
||||
def __init__(self, n_class):
|
||||
super(DeeplabV3_JL_3, self).__init__()
|
||||
self.n_class = n_class
|
||||
|
||||
self.conv_fc = nn.Conv2d(
|
||||
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.seg(x)["out"]
|
||||
x = self.conv_fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([.485, .456, .406],
|
||||
[.229, .224, .225]),
|
||||
])
|
||||
data_kwargs = {'transform': input_transform,
|
||||
'base_size': 448, 'crop_size': 448}
|
||||
|
||||
# 读取geotiff数据,构建训练集、验证集
|
||||
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||
**data_kwargs)
|
||||
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||
**data_kwargs)
|
||||
num_workers = min(
|
||||
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||
train_loader = data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
val_loader = data.DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
|
||||
model = DeeplabV3_JL_3(n_class=args.num_classes)
|
||||
model.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_scheduler = utils.create_lr_scheduler(
|
||||
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||
if not out_dir.exists():
|
||||
out_dir.mkdir()
|
||||
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||
start_time = time.time()
|
||||
best_acc = 0
|
||||
for epoch in range(args.epochs):
|
||||
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||
model.train()
|
||||
for idx, (image, target) in enumerate(train_loader):
|
||||
image, target = image.to(
|
||||
device), target.to(device)
|
||||
output = model(image)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 100 == 0:
|
||||
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||
|
||||
model.eval()
|
||||
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||
with torch.no_grad():
|
||||
for image, target in val_loader:
|
||||
image, target = image.to(device), target.to(device)
|
||||
output = model(image)
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
info, mIoU = confmat.get_info()
|
||||
print(info)
|
||||
|
||||
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||
f.write(info+"\n\n")
|
||||
f.flush()
|
||||
|
||||
# # 保存准确率最好的模型
|
||||
# if mIoU > best_acc:
|
||||
# print("[Save model]")
|
||||
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||
# best_acc = mIoU
|
||||
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||
total_time = time.time() - start_time
|
||||
print("total time:", total_time)
|
||||
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
330
train_JL_jpg/utils.py
Normal file
330
train_JL_jpg/utils.py
Normal file
@ -0,0 +1,330 @@
|
||||
import datetime
|
||||
import errno
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
t = reduce_across_processes([self.count, self.total])
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.inference_mode():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||
return acc_global, acc, iu
|
||||
|
||||
def reduce_from_all_processes(self):
|
||||
reduce_across_processes(self.mat)
|
||||
|
||||
def get_info(self):
|
||||
acc_global, acc, iu = self.compute()
|
||||
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||
acc_global.item() * 100,
|
||||
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100,
|
||||
), iu.mean().item() * 100
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
if not isinstance(v, (float, int)):
|
||||
raise TypeError(
|
||||
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||
)
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(f"{name}: {str(meter)}")
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||
"{meters}", "time: {time}", "data: {data}"]
|
||||
)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f"{header} Total time: {total_time_str}")
|
||||
|
||||
|
||||
def cat_list(images, fill_value=0):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
batch_shape = (len(images),) + max_size
|
||||
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||
for img, pad_img in zip(images, batched_imgs):
|
||||
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||
return batched_imgs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
images, targets = list(zip(*batch))
|
||||
batched_imgs = cat_list(images, fill_value=0)
|
||||
batched_targets = cat_list(targets, fill_value=255)
|
||||
return batched_imgs, batched_targets
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
# elif "SLURM_PROCID" in os.environ:
|
||||
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||
# args.gpu = args.rank % torch.cuda.device_count()
|
||||
elif hasattr(args, "rank"):
|
||||
pass
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def reduce_across_processes(val):
|
||||
if not is_dist_avail_and_initialized():
|
||||
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||
return torch.tensor(val)
|
||||
|
||||
t = torch.tensor(val, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
return t
|
||||
|
||||
|
||||
def create_lr_scheduler(optimizer,
|
||||
num_step: int,
|
||||
epochs: int,
|
||||
warmup=True,
|
||||
warmup_epochs=1,
|
||||
warmup_factor=1e-3):
|
||||
assert num_step > 0 and epochs > 0
|
||||
if warmup is False:
|
||||
warmup_epochs = 0
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
根据step数返回一个学习率倍率因子,
|
||||
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||
"""
|
||||
if warmup is True and x <= (warmup_epochs * num_step):
|
||||
alpha = float(x) / (warmup_epochs * num_step)
|
||||
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
# warmup后lr倍率因子从1 -> 0
|
||||
# 参考deeplab_v2: Learning rate policy
|
||||
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
@ -11,7 +11,7 @@ from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from geotiff_utils import VOCYJSSegmentation
|
||||
import utils
|
||||
import utils as utils
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
@ -21,7 +21,7 @@ def parse_args():
|
||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||
|
||||
parser.add_argument(
|
||||
"--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root")
|
||||
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
|
||||
parser.add_argument("--num-classes", default=13, type=int)
|
||||
parser.add_argument("--device", default="cuda", help="training device")
|
||||
parser.add_argument("--batch-size", default=8, type=int)
|
330
train_LS/utils.py
Normal file
330
train_LS/utils.py
Normal file
@ -0,0 +1,330 @@
|
||||
import datetime
|
||||
import errno
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
t = reduce_across_processes([self.count, self.total])
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.inference_mode():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||
return acc_global, acc, iu
|
||||
|
||||
def reduce_from_all_processes(self):
|
||||
reduce_across_processes(self.mat)
|
||||
|
||||
def get_info(self):
|
||||
acc_global, acc, iu = self.compute()
|
||||
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||
acc_global.item() * 100,
|
||||
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100,
|
||||
), iu.mean().item() * 100
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
if not isinstance(v, (float, int)):
|
||||
raise TypeError(
|
||||
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||
)
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(f"{name}: {str(meter)}")
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||
"{meters}", "time: {time}", "data: {data}"]
|
||||
)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f"{header} Total time: {total_time_str}")
|
||||
|
||||
|
||||
def cat_list(images, fill_value=0):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
batch_shape = (len(images),) + max_size
|
||||
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||
for img, pad_img in zip(images, batched_imgs):
|
||||
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||
return batched_imgs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
images, targets = list(zip(*batch))
|
||||
batched_imgs = cat_list(images, fill_value=0)
|
||||
batched_targets = cat_list(targets, fill_value=255)
|
||||
return batched_imgs, batched_targets
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
# elif "SLURM_PROCID" in os.environ:
|
||||
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||
# args.gpu = args.rank % torch.cuda.device_count()
|
||||
elif hasattr(args, "rank"):
|
||||
pass
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def reduce_across_processes(val):
|
||||
if not is_dist_avail_and_initialized():
|
||||
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||
return torch.tensor(val)
|
||||
|
||||
t = torch.tensor(val, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
return t
|
||||
|
||||
|
||||
def create_lr_scheduler(optimizer,
|
||||
num_step: int,
|
||||
epochs: int,
|
||||
warmup=True,
|
||||
warmup_epochs=1,
|
||||
warmup_factor=1e-3):
|
||||
assert num_step > 0 and epochs > 0
|
||||
if warmup is False:
|
||||
warmup_epochs = 0
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
根据step数返回一个学习率倍率因子,
|
||||
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||
"""
|
||||
if warmup is True and x <= (warmup_epochs * num_step):
|
||||
alpha = float(x) / (warmup_epochs * num_step)
|
||||
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
# warmup后lr倍率因子从1 -> 0
|
||||
# 参考deeplab_v2: Learning rate policy
|
||||
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
270
train_LS_jpg/geotiff_utils.py
Normal file
270
train_LS_jpg/geotiff_utils.py
Normal file
@ -0,0 +1,270 @@
|
||||
|
||||
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# import gdal
|
||||
from osgeo import gdal
|
||||
import random
|
||||
import torch.utils.data as data
|
||||
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||
|
||||
|
||||
class SegmentationDataset(object):
|
||||
"""Segmentation Base Dataset"""
|
||||
|
||||
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||
super(SegmentationDataset, self).__init__()
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
self.mode = mode if mode is not None else split
|
||||
self.base_size = base_size
|
||||
self.crop_size = crop_size
|
||||
|
||||
def _val_sync_transform(self, img, mask):
|
||||
outsize = self.crop_size
|
||||
short_size = outsize
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
else:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# center crop
|
||||
w, h = img.size
|
||||
x1 = int(round((w - outsize) / 2.))
|
||||
y1 = int(round((h - outsize) / 2.))
|
||||
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform(self, img, mask):
|
||||
# random mirror
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
crop_size = self.crop_size
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(
|
||||
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < crop_size:
|
||||
padh = crop_size - oh if oh < crop_size else 0
|
||||
padw = crop_size - ow if ow < crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - crop_size)
|
||||
y1 = random.randint(0, h - crop_size)
|
||||
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
# gaussian blur as in PSP
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif_geofeat(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _val_sync_transform_tif(self, img, mask):
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _img_transform(self, img):
|
||||
return np.array(img)
|
||||
|
||||
# def _mask_transform(self, mask):
|
||||
# return np.array(mask).astype('int32')
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def num_class(self):
|
||||
"""Number of categories."""
|
||||
return self.NUM_CLASS
|
||||
|
||||
@property
|
||||
def pred_offset(self):
|
||||
return 0
|
||||
|
||||
|
||||
class VOCYJSSegmentation(SegmentationDataset):
|
||||
"""Pascal VOC Semantic Segmentation Dataset.
|
||||
Parameters
|
||||
----------
|
||||
root : string
|
||||
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||
split: string
|
||||
'train', 'val' or 'test'
|
||||
transform : callable, optional
|
||||
A function that transforms the image
|
||||
Examples
|
||||
--------
|
||||
>>> from torchvision import transforms
|
||||
>>> import torch.utils.data as data
|
||||
>>> # Transforms for Normalization
|
||||
>>> input_transform = transforms.Compose([
|
||||
>>> transforms.ToTensor(),
|
||||
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||
>>> ])
|
||||
>>> # Create Dataset
|
||||
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||
>>> # Create Training Loader
|
||||
>>> train_data = data.DataLoader(
|
||||
>>> trainset, 4, shuffle=True,
|
||||
>>> num_workers=4)
|
||||
"""
|
||||
NUM_CLASS = 13
|
||||
|
||||
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||
super(VOCYJSSegmentation, self).__init__(
|
||||
root, split, mode, transform, **kwargs)
|
||||
_voc_root = root
|
||||
txt_path = os.path.join(root, split+'.txt')
|
||||
self._mask_LS_dir = os.path.join(_voc_root, "masks_LS_png")
|
||||
self._image_LS_dir = os.path.join(_voc_root, "images_LS_jpg")
|
||||
self.image_list = read_text(txt_path)
|
||||
random.shuffle(self.image_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img_name = self.image_list[index].split('.')[0]+'.jpg'
|
||||
mask_name = self.image_list[index].split('.')[0]+'.png'
|
||||
img_LS = np.array(Image.open(os.path.join(
|
||||
self._image_LS_dir, img_name))).astype(np.float32)
|
||||
mask = np.array(Image.open(os.path.join(
|
||||
self._mask_LS_dir, mask_name))).astype(np.int32)
|
||||
mask = torch.from_numpy(mask).long()
|
||||
|
||||
# general resize, normalize and toTensor
|
||||
if self.transform is not None:
|
||||
img_LS = self.transform(img_LS)
|
||||
return img_LS, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
"""Category names."""
|
||||
return ('0', '1', '2', '3', '4', '5', '6', '7' '8', '9', '10', '11', '12')
|
||||
|
||||
|
||||
def generator_list_of_imagepath(path):
|
||||
image_list = []
|
||||
for image in os.listdir(path):
|
||||
# print(path)
|
||||
# print(image)
|
||||
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||
image_list.append(image)
|
||||
return image_list
|
||||
|
||||
|
||||
def read_text(textfile):
|
||||
list = []
|
||||
with open(textfile, "r") as lines:
|
||||
for line in lines:
|
||||
list.append(line.rstrip('\n'))
|
||||
return list
|
||||
|
||||
|
||||
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||
image_list = generator_list_of_imagepath(imagepath)
|
||||
num = len(image_list)
|
||||
list = range(num)
|
||||
train_num = int(num * train_percent) # training set num
|
||||
train_list = random.sample(list, train_num)
|
||||
print("train set size", train_num)
|
||||
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||
for i in list:
|
||||
name = image_list[i] + '\n'
|
||||
if i in train_list:
|
||||
ftrain.write(name)
|
||||
else:
|
||||
fval.write(name)
|
||||
ftrain.close()
|
||||
fval.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||
# list=generator_list_of_imagepath(path)
|
||||
# print(list)
|
||||
# 切割数据集
|
||||
|
||||
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||
train_percent = 0.8
|
||||
dataset_segmentation(textpath, imagepath, train_percent)
|
||||
# 显示各种图片
|
||||
|
||||
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||
# cv2.imshow('img', img)
|
||||
# img = Image.fromarray (img,'RGB')
|
||||
# img.show()
|
||||
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||
# img2 = Image.fromarray (img2,'RGB')
|
||||
# img2.show()
|
||||
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||
# img3 = gdal.Open(img3).ReadAsArray()
|
||||
# img3 = Image.fromarray (img3)
|
||||
# img3.show()
|
||||
|
||||
# dataset和dataloader的测试
|
||||
|
||||
# 测试dataloader能不能用
|
||||
'''
|
||||
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||
input_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||
'''
|
152
train_LS_jpg/train.py
Normal file
152
train_LS_jpg/train.py
Normal file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||
from torchvision import transforms
|
||||
import torch.utils.data as data
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from geotiff_utils import VOCYJSSegmentation
|
||||
import utils
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||
|
||||
parser.add_argument(
|
||||
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
|
||||
parser.add_argument("--num-classes", default=13, type=int)
|
||||
parser.add_argument("--device", default="cuda", help="training device")
|
||||
parser.add_argument("--batch-size", default=8, type=int)
|
||||
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||
help="number of total epochs to train")
|
||||
parser.add_argument('--lr', default=0.005, type=float,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-out-dir', type=str,
|
||||
default='DeeplabV3_LS_3')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DeeplabV3_LS_3(nn.Module):
|
||||
def __init__(self, n_class):
|
||||
super(DeeplabV3_LS_3, self).__init__()
|
||||
self.n_class = n_class
|
||||
self.conv_fc = nn.Conv2d(
|
||||
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.seg(x)["out"]
|
||||
x = self.conv_fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([.485, .456, .406],
|
||||
[.229, .224, .225]),
|
||||
])
|
||||
data_kwargs = {'transform': input_transform,
|
||||
'base_size': 224, 'crop_size': 224}
|
||||
|
||||
# 读取geotiff数据,构建训练集、验证集
|
||||
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||
**data_kwargs)
|
||||
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||
**data_kwargs)
|
||||
num_workers = min(
|
||||
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||
train_loader = data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
val_loader = data.DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
|
||||
model = DeeplabV3_LS_3(n_class=args.num_classes)
|
||||
model.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_scheduler = utils.create_lr_scheduler(
|
||||
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||
if not out_dir.exists():
|
||||
out_dir.mkdir()
|
||||
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||
start_time = time.time()
|
||||
best_acc = 0
|
||||
for epoch in range(args.epochs):
|
||||
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||
model.train()
|
||||
for idx, (image, target) in enumerate(train_loader):
|
||||
image = image.to(device)
|
||||
target = target.to(device)
|
||||
output = model(image)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 100 == 0:
|
||||
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||
|
||||
model.eval()
|
||||
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||
with torch.no_grad():
|
||||
for image, target in val_loader:
|
||||
image, target = image.to(device), target.to(device)
|
||||
output = model(image)
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
info, mIoU = confmat.get_info()
|
||||
print(info)
|
||||
|
||||
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||
f.write(info+"\n\n")
|
||||
f.flush()
|
||||
|
||||
# # 保存准确率最好的模型
|
||||
# if mIoU > best_acc:
|
||||
# print("[Save model]")
|
||||
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||
# best_acc = mIoU
|
||||
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||
total_time = time.time() - start_time
|
||||
print("total time:", total_time)
|
||||
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
330
train_LS_jpg/utils.py
Normal file
330
train_LS_jpg/utils.py
Normal file
@ -0,0 +1,330 @@
|
||||
import datetime
|
||||
import errno
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
t = reduce_across_processes([self.count, self.total])
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.inference_mode():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||
return acc_global, acc, iu
|
||||
|
||||
def reduce_from_all_processes(self):
|
||||
reduce_across_processes(self.mat)
|
||||
|
||||
def get_info(self):
|
||||
acc_global, acc, iu = self.compute()
|
||||
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||
acc_global.item() * 100,
|
||||
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100,
|
||||
), iu.mean().item() * 100
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
if not isinstance(v, (float, int)):
|
||||
raise TypeError(
|
||||
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||
)
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(f"{name}: {str(meter)}")
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||
"{meters}", "time: {time}", "data: {data}"]
|
||||
)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f"{header} Total time: {total_time_str}")
|
||||
|
||||
|
||||
def cat_list(images, fill_value=0):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
batch_shape = (len(images),) + max_size
|
||||
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||
for img, pad_img in zip(images, batched_imgs):
|
||||
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||
return batched_imgs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
images, targets = list(zip(*batch))
|
||||
batched_imgs = cat_list(images, fill_value=0)
|
||||
batched_targets = cat_list(targets, fill_value=255)
|
||||
return batched_imgs, batched_targets
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
# elif "SLURM_PROCID" in os.environ:
|
||||
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||
# args.gpu = args.rank % torch.cuda.device_count()
|
||||
elif hasattr(args, "rank"):
|
||||
pass
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def reduce_across_processes(val):
|
||||
if not is_dist_avail_and_initialized():
|
||||
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||
return torch.tensor(val)
|
||||
|
||||
t = torch.tensor(val, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
return t
|
||||
|
||||
|
||||
def create_lr_scheduler(optimizer,
|
||||
num_step: int,
|
||||
epochs: int,
|
||||
warmup=True,
|
||||
warmup_epochs=1,
|
||||
warmup_factor=1e-3):
|
||||
assert num_step > 0 and epochs > 0
|
||||
if warmup is False:
|
||||
warmup_epochs = 0
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
根据step数返回一个学习率倍率因子,
|
||||
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||
"""
|
||||
if warmup is True and x <= (warmup_epochs * num_step):
|
||||
alpha = float(x) / (warmup_epochs * num_step)
|
||||
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
# warmup后lr倍率因子从1 -> 0
|
||||
# 参考deeplab_v2: Learning rate policy
|
||||
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
Loading…
Reference in New Issue
Block a user