first commit
This commit is contained in:
parent
f8436f7769
commit
5e0d438280
138
data_preprocessing/cut_smalltif.py
Normal file
138
data_preprocessing/cut_smalltif.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
import numpy as np
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
|
||||||
|
def read_tif(fileName):
|
||||||
|
dataset = gdal.Open(fileName)
|
||||||
|
|
||||||
|
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
im_bands = dataset.RasterCount # 波段数
|
||||||
|
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||||
|
if len(im_data.shape) == 2:
|
||||||
|
im_data = im_data[np.newaxis, :]
|
||||||
|
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||||
|
im_proj = dataset.GetProjection() # 获取投影信息
|
||||||
|
|
||||||
|
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||||
|
|
||||||
|
|
||||||
|
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||||
|
if 'int8' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_Byte
|
||||||
|
elif 'int16' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_UInt16
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
# 创建文件
|
||||||
|
driver = gdal.GetDriverByName("GTiff")
|
||||||
|
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||||
|
if dataset != None and im_geotrans != None and im_proj != None:
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
for i in range(im_bands):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
|
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
||||||
|
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||||
|
fileName)
|
||||||
|
|
||||||
|
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||||
|
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||||
|
fileName2)
|
||||||
|
|
||||||
|
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||||
|
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||||
|
mask_fileName)
|
||||||
|
mask_im_data = np.int8(mask_im_data)
|
||||||
|
|
||||||
|
|
||||||
|
# geotiff归一化
|
||||||
|
for i in range(im_bands):
|
||||||
|
arr = im_data[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data[i] = normalized_arr
|
||||||
|
|
||||||
|
for i in range(im_bands2):
|
||||||
|
arr = im_data2[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data2[i] = normalized_arr
|
||||||
|
|
||||||
|
|
||||||
|
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||||
|
im_data = im_data/255
|
||||||
|
for i in range(im_bands):
|
||||||
|
pixels = im_data[i, :, :].ravel()
|
||||||
|
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||||
|
i, np.mean(pixels), np.std(pixels)))
|
||||||
|
im_data = im_data*255
|
||||||
|
|
||||||
|
im_data2 = im_data2/255
|
||||||
|
for i in range(im_bands2):
|
||||||
|
pixels = im_data2[i, :, :].ravel()
|
||||||
|
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||||
|
i, np.mean(pixels), np.std(pixels)))
|
||||||
|
im_data2 = im_data2*255
|
||||||
|
|
||||||
|
|
||||||
|
# 切成小图
|
||||||
|
a = 0
|
||||||
|
size = 224
|
||||||
|
for i in range(0, int(mask_im_height / size)):
|
||||||
|
for j in range(0, int(mask_im_width / size)):
|
||||||
|
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||||
|
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
||||||
|
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
||||||
|
|
||||||
|
# 以mask为判断基准,同时处理geotiff和mask
|
||||||
|
labelfla = np.array(mask_cut).flatten()
|
||||||
|
if np.all(labelfla == 15): # 15为NoData
|
||||||
|
print("Skip!!!")
|
||||||
|
else:
|
||||||
|
# 5m
|
||||||
|
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
||||||
|
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
||||||
|
new_geotrans = np.array(im_geotrans)
|
||||||
|
new_geotrans[0] = left_w
|
||||||
|
new_geotrans[3] = left_h
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
||||||
|
|
||||||
|
# dan
|
||||||
|
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
||||||
|
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
||||||
|
new_geotrans = np.array(im_geotrans2)
|
||||||
|
new_geotrans[0] = left_w
|
||||||
|
new_geotrans[3] = left_h
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
||||||
|
|
||||||
|
# 存mask
|
||||||
|
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||||
|
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||||
|
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||||
|
mask_new_geotrans[0] = mask_left_w
|
||||||
|
mask_new_geotrans[3] = mask_left_h
|
||||||
|
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||||
|
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(mask_cut, size, size, mask_out,
|
||||||
|
mask_out_geotrans, mask_im_proj)
|
||||||
|
|
||||||
|
print(mask_out + 'Cut to complete')
|
||||||
|
|
||||||
|
a = a+1
|
169
data_preprocessing/cut_smalltif_multi.py
Normal file
169
data_preprocessing/cut_smalltif_multi.py
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
import numpy as np
|
||||||
|
from osgeo import gdal
|
||||||
|
|
||||||
|
|
||||||
|
def read_tif(fileName):
|
||||||
|
dataset = gdal.Open(fileName)
|
||||||
|
|
||||||
|
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||||
|
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||||
|
im_bands = dataset.RasterCount # 波段数
|
||||||
|
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||||
|
if len(im_data.shape) == 2:
|
||||||
|
im_data = im_data[np.newaxis, :]
|
||||||
|
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||||
|
im_proj = dataset.GetProjection() # 获取投影信息
|
||||||
|
|
||||||
|
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||||
|
|
||||||
|
|
||||||
|
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||||
|
if 'int8' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_Byte
|
||||||
|
elif 'int16' in im_data.dtype.name:
|
||||||
|
datatype = gdal.GDT_UInt16
|
||||||
|
else:
|
||||||
|
datatype = gdal.GDT_Float32
|
||||||
|
|
||||||
|
if len(im_data.shape) == 3:
|
||||||
|
im_bands, im_height, im_width = im_data.shape
|
||||||
|
else:
|
||||||
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||||
|
# 创建文件
|
||||||
|
driver = gdal.GetDriverByName("GTiff")
|
||||||
|
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||||
|
if dataset != None and im_geotrans != None and im_proj != None:
|
||||||
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||||
|
dataset.SetProjection(im_proj) # 写入投影
|
||||||
|
for i in range(im_bands):
|
||||||
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||||
|
del dataset
|
||||||
|
|
||||||
|
|
||||||
|
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right3.tif'
|
||||||
|
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||||
|
fileName)
|
||||||
|
|
||||||
|
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||||
|
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||||
|
fileName2)
|
||||||
|
|
||||||
|
# fileName3 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj10m.tif'
|
||||||
|
# im_data3, im_width3, im_height3, im_bands3, im_geotrans3, im_proj3 = read_tif(
|
||||||
|
# fileName3)
|
||||||
|
|
||||||
|
# fileName4 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj20m.tif'
|
||||||
|
# im_data4, im_width4, im_height4, im_bands4, im_geotrans4, im_proj4 = read_tif(
|
||||||
|
# fileName4)
|
||||||
|
|
||||||
|
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||||
|
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||||
|
mask_fileName)
|
||||||
|
mask_im_data = np.int8(mask_im_data)
|
||||||
|
|
||||||
|
# geotiff归一化
|
||||||
|
for i in range(im_bands):
|
||||||
|
arr = im_data[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data[i] = normalized_arr
|
||||||
|
|
||||||
|
for i in range(im_bands2):
|
||||||
|
arr = im_data2[i, :, :]
|
||||||
|
Min = arr.min()
|
||||||
|
Max = arr.max()
|
||||||
|
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||||
|
im_data2[i] = normalized_arr
|
||||||
|
|
||||||
|
# # 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||||
|
# im_data2 = im_data2/255
|
||||||
|
# for i in range(im_bands2):
|
||||||
|
# pixels = im_data2[i, :, :].ravel()
|
||||||
|
# print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||||
|
# i, np.mean(pixels), np.std(pixels)))
|
||||||
|
|
||||||
|
|
||||||
|
# 切成小图
|
||||||
|
a = 0
|
||||||
|
size = 224
|
||||||
|
for i in range(0, int(im_height / size)):
|
||||||
|
for j in range(0, int(im_width / size)):
|
||||||
|
im_cut = im_data[:, i * size*4:i * size*4 +
|
||||||
|
size*4, j * size*4:j * size*4 + size*4]
|
||||||
|
im_cut2 = im_data2[:, i * size*4:i * size*4 +
|
||||||
|
size*4, j * size*4:j * size*4 + size*4]
|
||||||
|
# im_cut3 = im_data3[:, i * size*2:i * size*2 +
|
||||||
|
# size*2, j * size*2:j * size*2 + size*2]
|
||||||
|
# im_cut4 = im_data4[:, i * size:i * size +
|
||||||
|
# size, j * size:j * size + size]
|
||||||
|
mask_cut = mask_im_data[:, i * size*4:i *
|
||||||
|
size*4 + size*4, j * size*4:j * size*4 + size*4]
|
||||||
|
|
||||||
|
# 以20m为判断基准,同时处理geotiff和mask
|
||||||
|
labelfla_bool = np.all(np.array(mask_cut).flatten() == 15)
|
||||||
|
|
||||||
|
if labelfla_bool:
|
||||||
|
print("False")
|
||||||
|
else:
|
||||||
|
left_h = i * size*4 * im_geotrans[5] + im_geotrans[3]
|
||||||
|
left_w = j * size*4 * im_geotrans[1] + im_geotrans[0]
|
||||||
|
new_geotrans = np.array(im_geotrans)
|
||||||
|
new_geotrans[0] = left_w
|
||||||
|
new_geotrans[3] = left_h
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
im_out = 'E:/RSdata/mask_5m/dataset_5m/geotiff' + \
|
||||||
|
str(a) + '.tif'
|
||||||
|
write_tif(im_cut, size*4, size*4, im_out, out_geotrans, im_proj)
|
||||||
|
print(im_out + 'Cut to complete')
|
||||||
|
|
||||||
|
left_h2 = i * size*4 * im_geotrans2[5] + im_geotrans2[3]
|
||||||
|
left_w2 = j * size*4 * im_geotrans2[1] + im_geotrans2[0]
|
||||||
|
new_geotrans = np.array(im_geotrans2)
|
||||||
|
new_geotrans[0] = left_w2
|
||||||
|
new_geotrans[3] = left_h2
|
||||||
|
out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
im_out = 'E:/RSdata/mask_5m/dataset_dan/geotiff' + \
|
||||||
|
str(a) + '.tif'
|
||||||
|
write_tif(im_cut2, size*4, size*4, im_out, out_geotrans, im_proj2)
|
||||||
|
print(im_out + 'Cut to complete')
|
||||||
|
|
||||||
|
# left_h3 = i * size*2 * im_geotrans3[5] + im_geotrans3[3]
|
||||||
|
# left_w3 = j * size*2 * im_geotrans3[1] + im_geotrans3[0]
|
||||||
|
# new_geotrans = np.array(im_geotrans3)
|
||||||
|
# new_geotrans[0] = left_w3
|
||||||
|
# new_geotrans[3] = left_h3
|
||||||
|
# out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
# im_out = 'E:/mask_20m/all/dataset_10m/geotiff' + \
|
||||||
|
# str(a) + '.tif'
|
||||||
|
# write_tif(im_cut3, size*2, size*2, im_out, out_geotrans, im_proj3)
|
||||||
|
# print(im_out + 'Cut to complete')
|
||||||
|
|
||||||
|
# left_h4 = i * size * im_geotrans4[5] + im_geotrans4[3]
|
||||||
|
# left_w4 = j * size * im_geotrans4[1] + im_geotrans4[0]
|
||||||
|
# new_geotrans = np.array(im_geotrans4)
|
||||||
|
# new_geotrans[0] = left_w4
|
||||||
|
# new_geotrans[3] = left_h4
|
||||||
|
# out_geotrans = tuple(new_geotrans)
|
||||||
|
|
||||||
|
# im_out = 'E:/mask_20m/all/dataset_20m/geotiff' + \
|
||||||
|
# str(a) + '.tif'
|
||||||
|
# write_tif(im_cut4, size, size, im_out, out_geotrans, im_proj4)
|
||||||
|
# print(im_out + 'Cut to complete')
|
||||||
|
|
||||||
|
mask_left_h = i * size*4 * \
|
||||||
|
mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||||
|
mask_left_w = j * size*4 * \
|
||||||
|
mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||||
|
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||||
|
mask_new_geotrans[0] = mask_left_w
|
||||||
|
mask_new_geotrans[3] = mask_left_h
|
||||||
|
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||||
|
mask_out = 'E:/RSdata/mask_5m/mask/geotiff' + str(a) + '.tif'
|
||||||
|
write_tif(mask_cut, size*4, size*4, mask_out,
|
||||||
|
mask_out_geotrans, mask_im_proj)
|
||||||
|
print(mask_out + 'Cut to complete')
|
||||||
|
|
||||||
|
a = a+1
|
54
data_preprocessing/sift_tif.py
Normal file
54
data_preprocessing/sift_tif.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from osgeo import gdal
|
||||||
|
import zipfile
|
||||||
|
import os
|
||||||
|
from xml.dom.minidom import parseString
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
# 搜索所有硬盘里所有.zip文件
|
||||||
|
def get_all_zipfiles(root):
|
||||||
|
zip_lt = []
|
||||||
|
for root2, dirs, files in os.walk(root):
|
||||||
|
for file in files:
|
||||||
|
if file[-4:] == '.zip':
|
||||||
|
zip_lt.append(os.path.join(root2, file))
|
||||||
|
return zip_lt
|
||||||
|
|
||||||
|
# 定位并读取.zip文件
|
||||||
|
# zip_dir = 'F:/2101-2400/' # 只处理某一个文件夹
|
||||||
|
# zip_files_list = os.listdir(zip_dir)
|
||||||
|
|
||||||
|
|
||||||
|
zip_dir = 'F:\\'
|
||||||
|
zip_files_list = get_all_zipfiles(zip_dir)
|
||||||
|
|
||||||
|
# 读取.xml文件
|
||||||
|
for file_name in zip_files_list:
|
||||||
|
# print(file_name)
|
||||||
|
try:
|
||||||
|
file = zipfile.ZipFile(file_name, "r")
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
info_list = file.infolist()
|
||||||
|
for info in info_list:
|
||||||
|
if info.filename[-4:] == ".xml":
|
||||||
|
xml_file = file.read(info.filename)
|
||||||
|
|
||||||
|
# 解析.xml文件
|
||||||
|
domTree = parseString(xml_file)
|
||||||
|
rootNode = domTree.documentElement
|
||||||
|
infos = rootNode.getElementsByTagName("ProductInfo")[0]
|
||||||
|
|
||||||
|
# 筛选西北地区影像
|
||||||
|
CenterLatitude = eval(infos.getElementsByTagName(
|
||||||
|
"CenterLatitude")[0].childNodes[0].data)
|
||||||
|
CenterLongitude = eval(infos.getElementsByTagName(
|
||||||
|
"CenterLongitude")[0].childNodes[0].data)
|
||||||
|
|
||||||
|
# 西北地区四个边界信息: 最高纬度:50 ;最低维度:37 ;最高经度:123 ;最低经度:73 ;
|
||||||
|
if 37 < CenterLatitude < 44:
|
||||||
|
if 76 < CenterLongitude < 78:
|
||||||
|
print(file)
|
||||||
|
|
||||||
|
# 将文件拷贝至新的文件夹中
|
||||||
|
shutil.copy(file_name,'E:/wlk_test/')
|
21
data_preprocessing/split_data.py
Normal file
21
data_preprocessing/split_data.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
|
geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m')
|
||||||
|
num = len(geotiffs)
|
||||||
|
split_rate = 0.2
|
||||||
|
|
||||||
|
eval_index = random.sample(geotiffs, k=int(num*split_rate))
|
||||||
|
|
||||||
|
f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w')
|
||||||
|
f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w')
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
for geotiff in geotiffs:
|
||||||
|
if geotiff in eval_index:
|
||||||
|
f_val.write(str(geotiff)+'\n')
|
||||||
|
else:
|
||||||
|
f_train.write(str(geotiff)+'\n')
|
296
geotiff_utils.py
Normal file
296
geotiff_utils.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||||
|
from PIL import Image, ImageOps, ImageFilter
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from PIL import Image
|
||||||
|
import cv2
|
||||||
|
# import gdal
|
||||||
|
from osgeo import gdal
|
||||||
|
import random
|
||||||
|
import torch.utils.data as data
|
||||||
|
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||||
|
|
||||||
|
|
||||||
|
class SegmentationDataset(object):
|
||||||
|
"""Segmentation Base Dataset"""
|
||||||
|
|
||||||
|
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||||
|
super(SegmentationDataset, self).__init__()
|
||||||
|
self.root = root
|
||||||
|
self.transform = transform
|
||||||
|
self.split = split
|
||||||
|
self.mode = mode if mode is not None else split
|
||||||
|
self.base_size = base_size
|
||||||
|
self.crop_size = crop_size
|
||||||
|
|
||||||
|
def _val_sync_transform(self, img, mask):
|
||||||
|
outsize = self.crop_size
|
||||||
|
short_size = outsize
|
||||||
|
w, h = img.size
|
||||||
|
if w > h:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
else:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# center crop
|
||||||
|
w, h = img.size
|
||||||
|
x1 = int(round((w - outsize) / 2.))
|
||||||
|
y1 = int(round((h - outsize) / 2.))
|
||||||
|
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||||
|
crop_size = self.crop_size
|
||||||
|
# random scale (short edge)
|
||||||
|
short_size = random.randint(
|
||||||
|
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||||
|
w, h = img.size
|
||||||
|
if h > w:
|
||||||
|
ow = short_size
|
||||||
|
oh = int(1.0 * h * ow / w)
|
||||||
|
else:
|
||||||
|
oh = short_size
|
||||||
|
ow = int(1.0 * w * oh / h)
|
||||||
|
img = img.resize((ow, oh), Image.BILINEAR)
|
||||||
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||||
|
# pad crop
|
||||||
|
if short_size < crop_size:
|
||||||
|
padh = crop_size - oh if oh < crop_size else 0
|
||||||
|
padw = crop_size - ow if ow < crop_size else 0
|
||||||
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||||
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||||
|
# random crop crop_size
|
||||||
|
w, h = img.size
|
||||||
|
x1 = random.randint(0, w - crop_size)
|
||||||
|
y1 = random.randint(0, h - crop_size)
|
||||||
|
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||||
|
# gaussian blur as in PSP
|
||||||
|
if random.random() < 0.5:
|
||||||
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif(self, img, mask):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _sync_transform_tif_geofeat(self, img, mask, img_feat):
|
||||||
|
# random mirror
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
img_feat = self._img_transform(img_feat)
|
||||||
|
return img, mask, img_feat
|
||||||
|
|
||||||
|
def _val_sync_transform_tif(self, img, mask):
|
||||||
|
# final transform
|
||||||
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||||
|
return img, mask
|
||||||
|
|
||||||
|
def _img_transform(self, img):
|
||||||
|
return np.array(img)
|
||||||
|
|
||||||
|
# def _mask_transform(self, mask):
|
||||||
|
# return np.array(mask).astype('int32')
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_class(self):
|
||||||
|
"""Number of categories."""
|
||||||
|
return self.NUM_CLASS
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_offset(self):
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
class VOCYJSSegmentation(SegmentationDataset):
|
||||||
|
"""Pascal VOC Semantic Segmentation Dataset.
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
root : string
|
||||||
|
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||||
|
split: string
|
||||||
|
'train', 'val' or 'test'
|
||||||
|
transform : callable, optional
|
||||||
|
A function that transforms the image
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> from torchvision import transforms
|
||||||
|
>>> import torch.utils.data as data
|
||||||
|
>>> # Transforms for Normalization
|
||||||
|
>>> input_transform = transforms.Compose([
|
||||||
|
>>> transforms.ToTensor(),
|
||||||
|
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||||
|
>>> ])
|
||||||
|
>>> # Create Dataset
|
||||||
|
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||||
|
>>> # Create Training Loader
|
||||||
|
>>> train_data = data.DataLoader(
|
||||||
|
>>> trainset, 4, shuffle=True,
|
||||||
|
>>> num_workers=4)
|
||||||
|
"""
|
||||||
|
NUM_CLASS = 13
|
||||||
|
|
||||||
|
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||||
|
super(VOCYJSSegmentation, self).__init__(
|
||||||
|
root, split, mode, transform, **kwargs)
|
||||||
|
_voc_root = root
|
||||||
|
txt_path = os.path.join(root, split+'.txt')
|
||||||
|
self._mask_dir = os.path.join(_voc_root, 'masks_GE')
|
||||||
|
self._image_dir = os.path.join(_voc_root, 'images_SE')
|
||||||
|
self._mask_LS_dir = os.path.join(_voc_root, 'masks_LS')
|
||||||
|
self._image_LS_dir = os.path.join(_voc_root, "images_LS")
|
||||||
|
self.image_list = read_text(txt_path)
|
||||||
|
self.transform_SE = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
|
||||||
|
[.485, .456, .406, .485, .456, .406, .406], [.229, .224, .225, .229, .224, .225, .229]),])
|
||||||
|
random.shuffle(self.image_list)
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
# print( "image file path is %s "% self.images[index])
|
||||||
|
|
||||||
|
# 读取两种类型的图片
|
||||||
|
img_HR = gdal.Open(os.path.join(self._image_dir, self.image_list[index])).ReadAsArray(
|
||||||
|
).transpose(1, 2, 0).astype(np.float32)
|
||||||
|
img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray(
|
||||||
|
).transpose(1, 2, 0).astype(np.float32)
|
||||||
|
# img_LS = cv2.resize(img_LS,(672,672),interpolation=cv2.INTER_CUBIC)
|
||||||
|
# 读取两种类型的标注
|
||||||
|
mask_HR = gdal.Open(os.path.join(
|
||||||
|
self._mask_dir, self.image_list[index])).ReadAsArray()
|
||||||
|
mask = gdal.Open(os.path.join(self._mask_LS_dir,
|
||||||
|
self.image_list[index])).ReadAsArray()
|
||||||
|
# synchronized transform
|
||||||
|
# 只包含两种模式: train 和 val
|
||||||
|
if self.mode == 'train':
|
||||||
|
# img, mask = self._sync_transform_tif(img, mask)
|
||||||
|
img_LS, mask, img_HR = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask, img_HR)
|
||||||
|
elif self.mode == 'val':
|
||||||
|
# img, mask = self._val_sync_transform_tif(img, mask)
|
||||||
|
img_LS, mask, img_HR = self._sync_transform_tif_geofeat(
|
||||||
|
img_LS, mask, img_HR)
|
||||||
|
# general resize, normalize and toTensor
|
||||||
|
if self.transform is not None:
|
||||||
|
img_HR = cv2.resize(img_HR, (448, 448),
|
||||||
|
interpolation=cv2.INTER_CUBIC)
|
||||||
|
img_HR = self.transform_SE(img_HR)
|
||||||
|
img_LS = self.transform(img_LS)
|
||||||
|
# img_feat = torch.from_numpy(img_feat)
|
||||||
|
# 多返回了一个img_feat
|
||||||
|
# ,transforms.ToTensor()(img_feat), os.path.basename(self.images[index])
|
||||||
|
return img_LS, img_HR, mask
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_list)
|
||||||
|
|
||||||
|
def _mask_transform(self, mask):
|
||||||
|
target = np.array(mask).astype('int32')
|
||||||
|
# target = target[np.newaxis, :]
|
||||||
|
target[target > 12] = 255
|
||||||
|
return torch.from_numpy(target).long()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def classes(self):
|
||||||
|
"""Category names."""
|
||||||
|
return ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12')
|
||||||
|
|
||||||
|
|
||||||
|
def generator_list_of_imagepath(path):
|
||||||
|
image_list = []
|
||||||
|
for image in os.listdir(path):
|
||||||
|
# print(path)
|
||||||
|
# print(image)
|
||||||
|
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||||
|
image_list.append(image)
|
||||||
|
return image_list
|
||||||
|
|
||||||
|
|
||||||
|
def read_text(textfile):
|
||||||
|
list = []
|
||||||
|
with open(textfile, "r") as lines:
|
||||||
|
for line in lines:
|
||||||
|
list.append(line.rstrip('\n'))
|
||||||
|
return list
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||||
|
image_list = generator_list_of_imagepath(imagepath)
|
||||||
|
num = len(image_list)
|
||||||
|
list = range(num)
|
||||||
|
train_num = int(num * train_percent) # training set num
|
||||||
|
train_list = random.sample(list, train_num)
|
||||||
|
print("train set size", train_num)
|
||||||
|
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||||
|
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||||
|
for i in list:
|
||||||
|
name = image_list[i] + '\n'
|
||||||
|
if i in train_list:
|
||||||
|
ftrain.write(name)
|
||||||
|
else:
|
||||||
|
fval.write(name)
|
||||||
|
ftrain.close()
|
||||||
|
fval.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||||
|
# list=generator_list_of_imagepath(path)
|
||||||
|
# print(list)
|
||||||
|
# 切割数据集
|
||||||
|
|
||||||
|
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||||
|
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||||
|
train_percent = 0.8
|
||||||
|
dataset_segmentation(textpath, imagepath, train_percent)
|
||||||
|
# 显示各种图片
|
||||||
|
|
||||||
|
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||||
|
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||||
|
# cv2.imshow('img', img)
|
||||||
|
# img = Image.fromarray (img,'RGB')
|
||||||
|
# img.show()
|
||||||
|
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||||
|
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||||
|
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||||
|
# img2 = Image.fromarray (img2,'RGB')
|
||||||
|
# img2.show()
|
||||||
|
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||||
|
# img3 = gdal.Open(img3).ReadAsArray()
|
||||||
|
# img3 = Image.fromarray (img3)
|
||||||
|
# img3.show()
|
||||||
|
|
||||||
|
# dataset和dataloader的测试
|
||||||
|
|
||||||
|
# 测试dataloader能不能用
|
||||||
|
'''
|
||||||
|
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||||
|
input_transform = transforms.Compose(
|
||||||
|
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||||
|
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||||
|
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||||
|
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||||
|
'''
|
156
train_LS.py
Normal file
156
train_LS.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||||
|
from torchvision import transforms
|
||||||
|
import torch.utils.data as data
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from geotiff_utils import VOCYJSSegmentation
|
||||||
|
import utils
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings("ignore")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root")
|
||||||
|
parser.add_argument("--num-classes", default=13, type=int)
|
||||||
|
parser.add_argument("--device", default="cuda", help="training device")
|
||||||
|
parser.add_argument("--batch-size", default=8, type=int)
|
||||||
|
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||||
|
help="number of total epochs to train")
|
||||||
|
parser.add_argument('--lr', default=0.005, type=float,
|
||||||
|
help='initial learning rate')
|
||||||
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||||
|
help='momentum')
|
||||||
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||||
|
metavar='W', help='weight decay (default: 1e-4)',
|
||||||
|
dest='weight_decay')
|
||||||
|
parser.add_argument('-out-dir', type=str,
|
||||||
|
default='DeeplabV3_LS')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
class DeeplabV3_LS(nn.Module):
|
||||||
|
def __init__(self, n_class):
|
||||||
|
super(DeeplabV3_LS, self).__init__()
|
||||||
|
self.n_class = n_class
|
||||||
|
self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1)
|
||||||
|
|
||||||
|
self.conv_fc = nn.Conv2d(
|
||||||
|
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||||
|
|
||||||
|
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||||
|
|
||||||
|
def forward(self, x_LS, x_SE):
|
||||||
|
# x = torch.cat([x, x_dan], dim=1)
|
||||||
|
# x = self.conv7_3(x)
|
||||||
|
x = self.seg(x_LS)["out"]
|
||||||
|
x = self.conv_fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
input_transform = transforms.Compose([
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||||
|
])
|
||||||
|
data_kwargs = {'transform': input_transform,
|
||||||
|
'base_size': 224, 'crop_size': 224}
|
||||||
|
|
||||||
|
# 读取geotiff数据,构建训练集、验证集
|
||||||
|
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||||
|
**data_kwargs)
|
||||||
|
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||||
|
**data_kwargs)
|
||||||
|
num_workers = min(
|
||||||
|
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||||
|
train_loader = data.DataLoader(
|
||||||
|
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
val_loader = data.DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||||
|
|
||||||
|
model = DeeplabV3_LS(n_class=args.num_classes)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
model.parameters(),
|
||||||
|
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||||
|
)
|
||||||
|
|
||||||
|
lr_scheduler = utils.create_lr_scheduler(
|
||||||
|
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||||
|
|
||||||
|
now = datetime.now()
|
||||||
|
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||||
|
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||||
|
if not out_dir.exists():
|
||||||
|
out_dir.mkdir()
|
||||||
|
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||||
|
start_time = time.time()
|
||||||
|
best_acc = 0
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||||
|
model.train()
|
||||||
|
for idx, (image, image_dan, target) in enumerate(train_loader):
|
||||||
|
image, image_dan, target = image.to(
|
||||||
|
device), image_dan.to(device), target.to(device)
|
||||||
|
output = model(image, image_dan)
|
||||||
|
|
||||||
|
loss = criterion(output, target)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
if idx % 100 == 0:
|
||||||
|
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||||
|
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||||
|
with torch.no_grad():
|
||||||
|
for image, image_dan, target in val_loader:
|
||||||
|
image, image_dan, target = image.to(device), image_dan.to(
|
||||||
|
device), target.to(device)
|
||||||
|
output = model(image, image_dan)
|
||||||
|
|
||||||
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||||
|
|
||||||
|
info, mIoU = confmat.get_info()
|
||||||
|
print(info)
|
||||||
|
|
||||||
|
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||||
|
f.write(info+"\n\n")
|
||||||
|
f.flush()
|
||||||
|
|
||||||
|
# # 保存准确率最好的模型
|
||||||
|
# if mIoU > best_acc:
|
||||||
|
# print("[Save model]")
|
||||||
|
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||||
|
# best_acc = mIoU
|
||||||
|
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
print("total time:", total_time)
|
||||||
|
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
main(args)
|
330
utils.py
Normal file
330
utils.py
Normal file
@ -0,0 +1,330 @@
|
|||||||
|
import datetime
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
class SmoothedValue:
|
||||||
|
"""Track a series of values and provide access to smoothed values over a
|
||||||
|
window or the global series average.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, window_size=20, fmt=None):
|
||||||
|
if fmt is None:
|
||||||
|
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||||
|
self.deque = deque(maxlen=window_size)
|
||||||
|
self.total = 0.0
|
||||||
|
self.count = 0
|
||||||
|
self.fmt = fmt
|
||||||
|
|
||||||
|
def update(self, value, n=1):
|
||||||
|
self.deque.append(value)
|
||||||
|
self.count += n
|
||||||
|
self.total += value * n
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
"""
|
||||||
|
Warning: does not synchronize the deque!
|
||||||
|
"""
|
||||||
|
t = reduce_across_processes([self.count, self.total])
|
||||||
|
t = t.tolist()
|
||||||
|
self.count = int(t[0])
|
||||||
|
self.total = t[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def median(self):
|
||||||
|
d = torch.tensor(list(self.deque))
|
||||||
|
return d.median().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def avg(self):
|
||||||
|
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||||
|
return d.mean().item()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def global_avg(self):
|
||||||
|
return self.total / self.count
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max(self):
|
||||||
|
return max(self.deque)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def value(self):
|
||||||
|
return self.deque[-1]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.fmt.format(
|
||||||
|
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix:
|
||||||
|
def __init__(self, num_classes):
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.mat = None
|
||||||
|
|
||||||
|
def update(self, a, b):
|
||||||
|
n = self.num_classes
|
||||||
|
if self.mat is None:
|
||||||
|
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||||
|
with torch.inference_mode():
|
||||||
|
k = (a >= 0) & (a < n)
|
||||||
|
inds = n * a[k].to(torch.int64) + b[k]
|
||||||
|
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.mat.zero_()
|
||||||
|
|
||||||
|
def compute(self):
|
||||||
|
h = self.mat.float()
|
||||||
|
acc_global = torch.diag(h).sum() / h.sum()
|
||||||
|
acc = torch.diag(h) / h.sum(1)
|
||||||
|
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||||
|
return acc_global, acc, iu
|
||||||
|
|
||||||
|
def reduce_from_all_processes(self):
|
||||||
|
reduce_across_processes(self.mat)
|
||||||
|
|
||||||
|
def get_info(self):
|
||||||
|
acc_global, acc, iu = self.compute()
|
||||||
|
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||||
|
acc_global.item() * 100,
|
||||||
|
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||||
|
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||||
|
iu.mean().item() * 100,
|
||||||
|
), iu.mean().item() * 100
|
||||||
|
|
||||||
|
|
||||||
|
class MetricLogger:
|
||||||
|
def __init__(self, delimiter="\t"):
|
||||||
|
self.meters = defaultdict(SmoothedValue)
|
||||||
|
self.delimiter = delimiter
|
||||||
|
|
||||||
|
def update(self, **kwargs):
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
v = v.item()
|
||||||
|
if not isinstance(v, (float, int)):
|
||||||
|
raise TypeError(
|
||||||
|
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||||
|
)
|
||||||
|
self.meters[k].update(v)
|
||||||
|
|
||||||
|
def __getattr__(self, attr):
|
||||||
|
if attr in self.meters:
|
||||||
|
return self.meters[attr]
|
||||||
|
if attr in self.__dict__:
|
||||||
|
return self.__dict__[attr]
|
||||||
|
raise AttributeError(
|
||||||
|
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
loss_str = []
|
||||||
|
for name, meter in self.meters.items():
|
||||||
|
loss_str.append(f"{name}: {str(meter)}")
|
||||||
|
return self.delimiter.join(loss_str)
|
||||||
|
|
||||||
|
def synchronize_between_processes(self):
|
||||||
|
for meter in self.meters.values():
|
||||||
|
meter.synchronize_between_processes()
|
||||||
|
|
||||||
|
def add_meter(self, name, meter):
|
||||||
|
self.meters[name] = meter
|
||||||
|
|
||||||
|
def log_every(self, iterable, print_freq, header=None):
|
||||||
|
i = 0
|
||||||
|
if not header:
|
||||||
|
header = ""
|
||||||
|
start_time = time.time()
|
||||||
|
end = time.time()
|
||||||
|
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||||
|
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[
|
||||||
|
header,
|
||||||
|
"[{0" + space_fmt + "}/{1}]",
|
||||||
|
"eta: {eta}",
|
||||||
|
"{meters}",
|
||||||
|
"time: {time}",
|
||||||
|
"data: {data}",
|
||||||
|
"max mem: {memory:.0f}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
log_msg = self.delimiter.join(
|
||||||
|
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||||
|
"{meters}", "time: {time}", "data: {data}"]
|
||||||
|
)
|
||||||
|
MB = 1024.0 * 1024.0
|
||||||
|
for obj in iterable:
|
||||||
|
data_time.update(time.time() - end)
|
||||||
|
yield obj
|
||||||
|
iter_time.update(time.time() - end)
|
||||||
|
if i % print_freq == 0:
|
||||||
|
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||||
|
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i,
|
||||||
|
len(iterable),
|
||||||
|
eta=eta_string,
|
||||||
|
meters=str(self),
|
||||||
|
time=str(iter_time),
|
||||||
|
data=str(data_time),
|
||||||
|
memory=torch.cuda.max_memory_allocated() / MB,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
log_msg.format(
|
||||||
|
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
i += 1
|
||||||
|
end = time.time()
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||||
|
print(f"{header} Total time: {total_time_str}")
|
||||||
|
|
||||||
|
|
||||||
|
def cat_list(images, fill_value=0):
|
||||||
|
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||||
|
batch_shape = (len(images),) + max_size
|
||||||
|
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||||
|
for img, pad_img in zip(images, batched_imgs):
|
||||||
|
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||||
|
return batched_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def collate_fn(batch):
|
||||||
|
images, targets = list(zip(*batch))
|
||||||
|
batched_imgs = cat_list(images, fill_value=0)
|
||||||
|
batched_targets = cat_list(targets, fill_value=255)
|
||||||
|
return batched_imgs, batched_targets
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir(path):
|
||||||
|
try:
|
||||||
|
os.makedirs(path)
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno != errno.EEXIST:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def setup_for_distributed(is_master):
|
||||||
|
"""
|
||||||
|
This function disables printing when not in master process
|
||||||
|
"""
|
||||||
|
import builtins as __builtin__
|
||||||
|
|
||||||
|
builtin_print = __builtin__.print
|
||||||
|
|
||||||
|
def print(*args, **kwargs):
|
||||||
|
force = kwargs.pop("force", False)
|
||||||
|
if is_master or force:
|
||||||
|
builtin_print(*args, **kwargs)
|
||||||
|
|
||||||
|
__builtin__.print = print
|
||||||
|
|
||||||
|
|
||||||
|
def is_dist_avail_and_initialized():
|
||||||
|
if not dist.is_available():
|
||||||
|
return False
|
||||||
|
if not dist.is_initialized():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_world_size():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 1
|
||||||
|
return dist.get_world_size()
|
||||||
|
|
||||||
|
|
||||||
|
def get_rank():
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
return 0
|
||||||
|
return dist.get_rank()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
return get_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_on_master(*args, **kwargs):
|
||||||
|
if is_main_process():
|
||||||
|
torch.save(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def init_distributed_mode(args):
|
||||||
|
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||||
|
args.rank = int(os.environ["RANK"])
|
||||||
|
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||||
|
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||||
|
# elif "SLURM_PROCID" in os.environ:
|
||||||
|
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||||
|
# args.gpu = args.rank % torch.cuda.device_count()
|
||||||
|
elif hasattr(args, "rank"):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
print("Not using distributed mode")
|
||||||
|
args.distributed = False
|
||||||
|
return
|
||||||
|
|
||||||
|
args.distributed = True
|
||||||
|
|
||||||
|
torch.cuda.set_device(args.gpu)
|
||||||
|
args.dist_backend = "nccl"
|
||||||
|
print(
|
||||||
|
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||||
|
)
|
||||||
|
torch.distributed.barrier()
|
||||||
|
setup_for_distributed(args.rank == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce_across_processes(val):
|
||||||
|
if not is_dist_avail_and_initialized():
|
||||||
|
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||||
|
return torch.tensor(val)
|
||||||
|
|
||||||
|
t = torch.tensor(val, device="cuda")
|
||||||
|
dist.barrier()
|
||||||
|
dist.all_reduce(t)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def create_lr_scheduler(optimizer,
|
||||||
|
num_step: int,
|
||||||
|
epochs: int,
|
||||||
|
warmup=True,
|
||||||
|
warmup_epochs=1,
|
||||||
|
warmup_factor=1e-3):
|
||||||
|
assert num_step > 0 and epochs > 0
|
||||||
|
if warmup is False:
|
||||||
|
warmup_epochs = 0
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
"""
|
||||||
|
根据step数返回一个学习率倍率因子,
|
||||||
|
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||||
|
"""
|
||||||
|
if warmup is True and x <= (warmup_epochs * num_step):
|
||||||
|
alpha = float(x) / (warmup_epochs * num_step)
|
||||||
|
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||||
|
return warmup_factor * (1 - alpha) + alpha
|
||||||
|
else:
|
||||||
|
# warmup后lr倍率因子从1 -> 0
|
||||||
|
# 参考deeplab_v2: Learning rate policy
|
||||||
|
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||||
|
|
||||||
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
Loading…
Reference in New Issue
Block a user