first commit
This commit is contained in:
parent
f8436f7769
commit
5e0d438280
138
data_preprocessing/cut_smalltif.py
Normal file
138
data_preprocessing/cut_smalltif.py
Normal file
@ -0,0 +1,138 @@
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
|
||||
|
||||
def read_tif(fileName):
|
||||
dataset = gdal.Open(fileName)
|
||||
|
||||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||
im_bands = dataset.RasterCount # 波段数
|
||||
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||
if len(im_data.shape) == 2:
|
||||
im_data = im_data[np.newaxis, :]
|
||||
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||
im_proj = dataset.GetProjection() # 获取投影信息
|
||||
|
||||
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||
|
||||
|
||||
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||
if 'int8' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_Byte
|
||||
elif 'int16' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_UInt16
|
||||
else:
|
||||
datatype = gdal.GDT_Float32
|
||||
|
||||
if len(im_data.shape) == 3:
|
||||
im_bands, im_height, im_width = im_data.shape
|
||||
else:
|
||||
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||
# 创建文件
|
||||
driver = gdal.GetDriverByName("GTiff")
|
||||
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||
if dataset != None and im_geotrans != None and im_proj != None:
|
||||
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||
dataset.SetProjection(im_proj) # 写入投影
|
||||
for i in range(im_bands):
|
||||
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||
del dataset
|
||||
|
||||
|
||||
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
||||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||
fileName)
|
||||
|
||||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||
fileName2)
|
||||
|
||||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||
mask_fileName)
|
||||
mask_im_data = np.int8(mask_im_data)
|
||||
|
||||
|
||||
# geotiff归一化
|
||||
for i in range(im_bands):
|
||||
arr = im_data[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data[i] = normalized_arr
|
||||
|
||||
for i in range(im_bands2):
|
||||
arr = im_data2[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data2[i] = normalized_arr
|
||||
|
||||
|
||||
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||
im_data = im_data/255
|
||||
for i in range(im_bands):
|
||||
pixels = im_data[i, :, :].ravel()
|
||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data = im_data*255
|
||||
|
||||
im_data2 = im_data2/255
|
||||
for i in range(im_bands2):
|
||||
pixels = im_data2[i, :, :].ravel()
|
||||
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
i, np.mean(pixels), np.std(pixels)))
|
||||
im_data2 = im_data2*255
|
||||
|
||||
|
||||
# 切成小图
|
||||
a = 0
|
||||
size = 224
|
||||
for i in range(0, int(mask_im_height / size)):
|
||||
for j in range(0, int(mask_im_width / size)):
|
||||
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
||||
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
|
||||
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
|
||||
|
||||
# 以mask为判断基准,同时处理geotiff和mask
|
||||
labelfla = np.array(mask_cut).flatten()
|
||||
if np.all(labelfla == 15): # 15为NoData
|
||||
print("Skip!!!")
|
||||
else:
|
||||
# 5m
|
||||
left_h = i * size * im_geotrans[5] + im_geotrans[3]
|
||||
left_w = j * size * im_geotrans[1] + im_geotrans[0]
|
||||
new_geotrans = np.array(im_geotrans)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
|
||||
|
||||
# dan
|
||||
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
|
||||
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
|
||||
new_geotrans = np.array(im_geotrans2)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
|
||||
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
|
||||
|
||||
# 存mask
|
||||
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||
mask_new_geotrans[0] = mask_left_w
|
||||
mask_new_geotrans[3] = mask_left_h
|
||||
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
|
||||
write_tif(mask_cut, size, size, mask_out,
|
||||
mask_out_geotrans, mask_im_proj)
|
||||
|
||||
print(mask_out + 'Cut to complete')
|
||||
|
||||
a = a+1
|
169
data_preprocessing/cut_smalltif_multi.py
Normal file
169
data_preprocessing/cut_smalltif_multi.py
Normal file
@ -0,0 +1,169 @@
|
||||
import numpy as np
|
||||
from osgeo import gdal
|
||||
|
||||
|
||||
def read_tif(fileName):
|
||||
dataset = gdal.Open(fileName)
|
||||
|
||||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||||
im_bands = dataset.RasterCount # 波段数
|
||||
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||||
if len(im_data.shape) == 2:
|
||||
im_data = im_data[np.newaxis, :]
|
||||
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||||
im_proj = dataset.GetProjection() # 获取投影信息
|
||||
|
||||
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||||
|
||||
|
||||
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||||
if 'int8' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_Byte
|
||||
elif 'int16' in im_data.dtype.name:
|
||||
datatype = gdal.GDT_UInt16
|
||||
else:
|
||||
datatype = gdal.GDT_Float32
|
||||
|
||||
if len(im_data.shape) == 3:
|
||||
im_bands, im_height, im_width = im_data.shape
|
||||
else:
|
||||
im_bands, (im_height, im_width) = 1, im_data.shape
|
||||
# 创建文件
|
||||
driver = gdal.GetDriverByName("GTiff")
|
||||
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||||
if dataset != None and im_geotrans != None and im_proj != None:
|
||||
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||||
dataset.SetProjection(im_proj) # 写入投影
|
||||
for i in range(im_bands):
|
||||
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||||
del dataset
|
||||
|
||||
|
||||
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right3.tif'
|
||||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||||
fileName)
|
||||
|
||||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||||
fileName2)
|
||||
|
||||
# fileName3 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj10m.tif'
|
||||
# im_data3, im_width3, im_height3, im_bands3, im_geotrans3, im_proj3 = read_tif(
|
||||
# fileName3)
|
||||
|
||||
# fileName4 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj20m.tif'
|
||||
# im_data4, im_width4, im_height4, im_bands4, im_geotrans4, im_proj4 = read_tif(
|
||||
# fileName4)
|
||||
|
||||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||||
mask_fileName)
|
||||
mask_im_data = np.int8(mask_im_data)
|
||||
|
||||
# geotiff归一化
|
||||
for i in range(im_bands):
|
||||
arr = im_data[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data[i] = normalized_arr
|
||||
|
||||
for i in range(im_bands2):
|
||||
arr = im_data2[i, :, :]
|
||||
Min = arr.min()
|
||||
Max = arr.max()
|
||||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||||
im_data2[i] = normalized_arr
|
||||
|
||||
# # 计算大图每个波段的均值和方差,train.py里transform会用到
|
||||
# im_data2 = im_data2/255
|
||||
# for i in range(im_bands2):
|
||||
# pixels = im_data2[i, :, :].ravel()
|
||||
# print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||||
# i, np.mean(pixels), np.std(pixels)))
|
||||
|
||||
|
||||
# 切成小图
|
||||
a = 0
|
||||
size = 224
|
||||
for i in range(0, int(im_height / size)):
|
||||
for j in range(0, int(im_width / size)):
|
||||
im_cut = im_data[:, i * size*4:i * size*4 +
|
||||
size*4, j * size*4:j * size*4 + size*4]
|
||||
im_cut2 = im_data2[:, i * size*4:i * size*4 +
|
||||
size*4, j * size*4:j * size*4 + size*4]
|
||||
# im_cut3 = im_data3[:, i * size*2:i * size*2 +
|
||||
# size*2, j * size*2:j * size*2 + size*2]
|
||||
# im_cut4 = im_data4[:, i * size:i * size +
|
||||
# size, j * size:j * size + size]
|
||||
mask_cut = mask_im_data[:, i * size*4:i *
|
||||
size*4 + size*4, j * size*4:j * size*4 + size*4]
|
||||
|
||||
# 以20m为判断基准,同时处理geotiff和mask
|
||||
labelfla_bool = np.all(np.array(mask_cut).flatten() == 15)
|
||||
|
||||
if labelfla_bool:
|
||||
print("False")
|
||||
else:
|
||||
left_h = i * size*4 * im_geotrans[5] + im_geotrans[3]
|
||||
left_w = j * size*4 * im_geotrans[1] + im_geotrans[0]
|
||||
new_geotrans = np.array(im_geotrans)
|
||||
new_geotrans[0] = left_w
|
||||
new_geotrans[3] = left_h
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
im_out = 'E:/RSdata/mask_5m/dataset_5m/geotiff' + \
|
||||
str(a) + '.tif'
|
||||
write_tif(im_cut, size*4, size*4, im_out, out_geotrans, im_proj)
|
||||
print(im_out + 'Cut to complete')
|
||||
|
||||
left_h2 = i * size*4 * im_geotrans2[5] + im_geotrans2[3]
|
||||
left_w2 = j * size*4 * im_geotrans2[1] + im_geotrans2[0]
|
||||
new_geotrans = np.array(im_geotrans2)
|
||||
new_geotrans[0] = left_w2
|
||||
new_geotrans[3] = left_h2
|
||||
out_geotrans = tuple(new_geotrans)
|
||||
|
||||
im_out = 'E:/RSdata/mask_5m/dataset_dan/geotiff' + \
|
||||
str(a) + '.tif'
|
||||
write_tif(im_cut2, size*4, size*4, im_out, out_geotrans, im_proj2)
|
||||
print(im_out + 'Cut to complete')
|
||||
|
||||
# left_h3 = i * size*2 * im_geotrans3[5] + im_geotrans3[3]
|
||||
# left_w3 = j * size*2 * im_geotrans3[1] + im_geotrans3[0]
|
||||
# new_geotrans = np.array(im_geotrans3)
|
||||
# new_geotrans[0] = left_w3
|
||||
# new_geotrans[3] = left_h3
|
||||
# out_geotrans = tuple(new_geotrans)
|
||||
|
||||
# im_out = 'E:/mask_20m/all/dataset_10m/geotiff' + \
|
||||
# str(a) + '.tif'
|
||||
# write_tif(im_cut3, size*2, size*2, im_out, out_geotrans, im_proj3)
|
||||
# print(im_out + 'Cut to complete')
|
||||
|
||||
# left_h4 = i * size * im_geotrans4[5] + im_geotrans4[3]
|
||||
# left_w4 = j * size * im_geotrans4[1] + im_geotrans4[0]
|
||||
# new_geotrans = np.array(im_geotrans4)
|
||||
# new_geotrans[0] = left_w4
|
||||
# new_geotrans[3] = left_h4
|
||||
# out_geotrans = tuple(new_geotrans)
|
||||
|
||||
# im_out = 'E:/mask_20m/all/dataset_20m/geotiff' + \
|
||||
# str(a) + '.tif'
|
||||
# write_tif(im_cut4, size, size, im_out, out_geotrans, im_proj4)
|
||||
# print(im_out + 'Cut to complete')
|
||||
|
||||
mask_left_h = i * size*4 * \
|
||||
mask_im_geotrans[5] + mask_im_geotrans[3]
|
||||
mask_left_w = j * size*4 * \
|
||||
mask_im_geotrans[1] + mask_im_geotrans[0]
|
||||
mask_new_geotrans = np.array(mask_im_geotrans)
|
||||
mask_new_geotrans[0] = mask_left_w
|
||||
mask_new_geotrans[3] = mask_left_h
|
||||
mask_out_geotrans = tuple(mask_new_geotrans)
|
||||
mask_out = 'E:/RSdata/mask_5m/mask/geotiff' + str(a) + '.tif'
|
||||
write_tif(mask_cut, size*4, size*4, mask_out,
|
||||
mask_out_geotrans, mask_im_proj)
|
||||
print(mask_out + 'Cut to complete')
|
||||
|
||||
a = a+1
|
54
data_preprocessing/sift_tif.py
Normal file
54
data_preprocessing/sift_tif.py
Normal file
@ -0,0 +1,54 @@
|
||||
from osgeo import gdal
|
||||
import zipfile
|
||||
import os
|
||||
from xml.dom.minidom import parseString
|
||||
import shutil
|
||||
|
||||
|
||||
# 搜索所有硬盘里所有.zip文件
|
||||
def get_all_zipfiles(root):
|
||||
zip_lt = []
|
||||
for root2, dirs, files in os.walk(root):
|
||||
for file in files:
|
||||
if file[-4:] == '.zip':
|
||||
zip_lt.append(os.path.join(root2, file))
|
||||
return zip_lt
|
||||
|
||||
# 定位并读取.zip文件
|
||||
# zip_dir = 'F:/2101-2400/' # 只处理某一个文件夹
|
||||
# zip_files_list = os.listdir(zip_dir)
|
||||
|
||||
|
||||
zip_dir = 'F:\\'
|
||||
zip_files_list = get_all_zipfiles(zip_dir)
|
||||
|
||||
# 读取.xml文件
|
||||
for file_name in zip_files_list:
|
||||
# print(file_name)
|
||||
try:
|
||||
file = zipfile.ZipFile(file_name, "r")
|
||||
except:
|
||||
continue
|
||||
info_list = file.infolist()
|
||||
for info in info_list:
|
||||
if info.filename[-4:] == ".xml":
|
||||
xml_file = file.read(info.filename)
|
||||
|
||||
# 解析.xml文件
|
||||
domTree = parseString(xml_file)
|
||||
rootNode = domTree.documentElement
|
||||
infos = rootNode.getElementsByTagName("ProductInfo")[0]
|
||||
|
||||
# 筛选西北地区影像
|
||||
CenterLatitude = eval(infos.getElementsByTagName(
|
||||
"CenterLatitude")[0].childNodes[0].data)
|
||||
CenterLongitude = eval(infos.getElementsByTagName(
|
||||
"CenterLongitude")[0].childNodes[0].data)
|
||||
|
||||
# 西北地区四个边界信息: 最高纬度:50 ;最低维度:37 ;最高经度:123 ;最低经度:73 ;
|
||||
if 37 < CenterLatitude < 44:
|
||||
if 76 < CenterLongitude < 78:
|
||||
print(file)
|
||||
|
||||
# 将文件拷贝至新的文件夹中
|
||||
shutil.copy(file_name,'E:/wlk_test/')
|
21
data_preprocessing/split_data.py
Normal file
21
data_preprocessing/split_data.py
Normal file
@ -0,0 +1,21 @@
|
||||
import os
|
||||
import random
|
||||
|
||||
|
||||
random.seed(42)
|
||||
|
||||
geotiffs = os.listdir('E:\RSdata\wlk_right_224_2\dataset_5m')
|
||||
num = len(geotiffs)
|
||||
split_rate = 0.2
|
||||
|
||||
eval_index = random.sample(geotiffs, k=int(num*split_rate))
|
||||
|
||||
f_train = open('E:\RSdata\wlk_right_224_2/train.txt', 'w')
|
||||
f_val = open('E:\RSdata\wlk_right_224_2/val.txt', 'w')
|
||||
|
||||
# 写入文件
|
||||
for geotiff in geotiffs:
|
||||
if geotiff in eval_index:
|
||||
f_val.write(str(geotiff)+'\n')
|
||||
else:
|
||||
f_train.write(str(geotiff)+'\n')
|
296
geotiff_utils.py
Normal file
296
geotiff_utils.py
Normal file
@ -0,0 +1,296 @@
|
||||
|
||||
"""Pascal VOC Semantic Segmentation Dataset."""
|
||||
from PIL import Image, ImageOps, ImageFilter
|
||||
import torchvision.transforms as transforms
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
from PIL import Image
|
||||
import cv2
|
||||
# import gdal
|
||||
from osgeo import gdal
|
||||
import random
|
||||
import torch.utils.data as data
|
||||
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
||||
|
||||
|
||||
class SegmentationDataset(object):
|
||||
"""Segmentation Base Dataset"""
|
||||
|
||||
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
||||
super(SegmentationDataset, self).__init__()
|
||||
self.root = root
|
||||
self.transform = transform
|
||||
self.split = split
|
||||
self.mode = mode if mode is not None else split
|
||||
self.base_size = base_size
|
||||
self.crop_size = crop_size
|
||||
|
||||
def _val_sync_transform(self, img, mask):
|
||||
outsize = self.crop_size
|
||||
short_size = outsize
|
||||
w, h = img.size
|
||||
if w > h:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
else:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# center crop
|
||||
w, h = img.size
|
||||
x1 = int(round((w - outsize) / 2.))
|
||||
y1 = int(round((h - outsize) / 2.))
|
||||
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform(self, img, mask):
|
||||
# random mirror
|
||||
if random.random() < 0.5:
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
crop_size = self.crop_size
|
||||
# random scale (short edge)
|
||||
short_size = random.randint(
|
||||
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
||||
w, h = img.size
|
||||
if h > w:
|
||||
ow = short_size
|
||||
oh = int(1.0 * h * ow / w)
|
||||
else:
|
||||
oh = short_size
|
||||
ow = int(1.0 * w * oh / h)
|
||||
img = img.resize((ow, oh), Image.BILINEAR)
|
||||
mask = mask.resize((ow, oh), Image.NEAREST)
|
||||
# pad crop
|
||||
if short_size < crop_size:
|
||||
padh = crop_size - oh if oh < crop_size else 0
|
||||
padw = crop_size - ow if ow < crop_size else 0
|
||||
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
||||
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
||||
# random crop crop_size
|
||||
w, h = img.size
|
||||
x1 = random.randint(0, w - crop_size)
|
||||
y1 = random.randint(0, h - crop_size)
|
||||
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
||||
# gaussian blur as in PSP
|
||||
if random.random() < 0.5:
|
||||
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif(self, img, mask):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _sync_transform_tif_geofeat(self, img, mask, img_feat):
|
||||
# random mirror
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
img_feat = self._img_transform(img_feat)
|
||||
return img, mask, img_feat
|
||||
|
||||
def _val_sync_transform_tif(self, img, mask):
|
||||
# final transform
|
||||
img, mask = self._img_transform(img), self._mask_transform(mask)
|
||||
return img, mask
|
||||
|
||||
def _img_transform(self, img):
|
||||
return np.array(img)
|
||||
|
||||
# def _mask_transform(self, mask):
|
||||
# return np.array(mask).astype('int32')
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def num_class(self):
|
||||
"""Number of categories."""
|
||||
return self.NUM_CLASS
|
||||
|
||||
@property
|
||||
def pred_offset(self):
|
||||
return 0
|
||||
|
||||
|
||||
class VOCYJSSegmentation(SegmentationDataset):
|
||||
"""Pascal VOC Semantic Segmentation Dataset.
|
||||
Parameters
|
||||
----------
|
||||
root : string
|
||||
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
||||
split: string
|
||||
'train', 'val' or 'test'
|
||||
transform : callable, optional
|
||||
A function that transforms the image
|
||||
Examples
|
||||
--------
|
||||
>>> from torchvision import transforms
|
||||
>>> import torch.utils.data as data
|
||||
>>> # Transforms for Normalization
|
||||
>>> input_transform = transforms.Compose([
|
||||
>>> transforms.ToTensor(),
|
||||
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||
>>> ])
|
||||
>>> # Create Dataset
|
||||
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
||||
>>> # Create Training Loader
|
||||
>>> train_data = data.DataLoader(
|
||||
>>> trainset, 4, shuffle=True,
|
||||
>>> num_workers=4)
|
||||
"""
|
||||
NUM_CLASS = 13
|
||||
|
||||
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
||||
super(VOCYJSSegmentation, self).__init__(
|
||||
root, split, mode, transform, **kwargs)
|
||||
_voc_root = root
|
||||
txt_path = os.path.join(root, split+'.txt')
|
||||
self._mask_dir = os.path.join(_voc_root, 'masks_GE')
|
||||
self._image_dir = os.path.join(_voc_root, 'images_SE')
|
||||
self._mask_LS_dir = os.path.join(_voc_root, 'masks_LS')
|
||||
self._image_LS_dir = os.path.join(_voc_root, "images_LS")
|
||||
self.image_list = read_text(txt_path)
|
||||
self.transform_SE = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
|
||||
[.485, .456, .406, .485, .456, .406, .406], [.229, .224, .225, .229, .224, .225, .229]),])
|
||||
random.shuffle(self.image_list)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# print( "image file path is %s "% self.images[index])
|
||||
|
||||
# 读取两种类型的图片
|
||||
img_HR = gdal.Open(os.path.join(self._image_dir, self.image_list[index])).ReadAsArray(
|
||||
).transpose(1, 2, 0).astype(np.float32)
|
||||
img_LS = gdal.Open(os.path.join(self._image_LS_dir, self.image_list[index])).ReadAsArray(
|
||||
).transpose(1, 2, 0).astype(np.float32)
|
||||
# img_LS = cv2.resize(img_LS,(672,672),interpolation=cv2.INTER_CUBIC)
|
||||
# 读取两种类型的标注
|
||||
mask_HR = gdal.Open(os.path.join(
|
||||
self._mask_dir, self.image_list[index])).ReadAsArray()
|
||||
mask = gdal.Open(os.path.join(self._mask_LS_dir,
|
||||
self.image_list[index])).ReadAsArray()
|
||||
# synchronized transform
|
||||
# 只包含两种模式: train 和 val
|
||||
if self.mode == 'train':
|
||||
# img, mask = self._sync_transform_tif(img, mask)
|
||||
img_LS, mask, img_HR = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask, img_HR)
|
||||
elif self.mode == 'val':
|
||||
# img, mask = self._val_sync_transform_tif(img, mask)
|
||||
img_LS, mask, img_HR = self._sync_transform_tif_geofeat(
|
||||
img_LS, mask, img_HR)
|
||||
# general resize, normalize and toTensor
|
||||
if self.transform is not None:
|
||||
img_HR = cv2.resize(img_HR, (448, 448),
|
||||
interpolation=cv2.INTER_CUBIC)
|
||||
img_HR = self.transform_SE(img_HR)
|
||||
img_LS = self.transform(img_LS)
|
||||
# img_feat = torch.from_numpy(img_feat)
|
||||
# 多返回了一个img_feat
|
||||
# ,transforms.ToTensor()(img_feat), os.path.basename(self.images[index])
|
||||
return img_LS, img_HR, mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_list)
|
||||
|
||||
def _mask_transform(self, mask):
|
||||
target = np.array(mask).astype('int32')
|
||||
# target = target[np.newaxis, :]
|
||||
target[target > 12] = 255
|
||||
return torch.from_numpy(target).long()
|
||||
|
||||
@property
|
||||
def classes(self):
|
||||
"""Category names."""
|
||||
return ('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12')
|
||||
|
||||
|
||||
def generator_list_of_imagepath(path):
|
||||
image_list = []
|
||||
for image in os.listdir(path):
|
||||
# print(path)
|
||||
# print(image)
|
||||
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
||||
image_list.append(image)
|
||||
return image_list
|
||||
|
||||
|
||||
def read_text(textfile):
|
||||
list = []
|
||||
with open(textfile, "r") as lines:
|
||||
for line in lines:
|
||||
list.append(line.rstrip('\n'))
|
||||
return list
|
||||
|
||||
|
||||
def dataset_segmentation(textpath, imagepath, train_percent):
|
||||
image_list = generator_list_of_imagepath(imagepath)
|
||||
num = len(image_list)
|
||||
list = range(num)
|
||||
train_num = int(num * train_percent) # training set num
|
||||
train_list = random.sample(list, train_num)
|
||||
print("train set size", train_num)
|
||||
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
||||
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
||||
for i in list:
|
||||
name = image_list[i] + '\n'
|
||||
if i in train_list:
|
||||
ftrain.write(name)
|
||||
else:
|
||||
fval.write(name)
|
||||
ftrain.close()
|
||||
fval.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
||||
# list=generator_list_of_imagepath(path)
|
||||
# print(list)
|
||||
# 切割数据集
|
||||
|
||||
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
||||
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
||||
train_percent = 0.8
|
||||
dataset_segmentation(textpath, imagepath, train_percent)
|
||||
# 显示各种图片
|
||||
|
||||
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
||||
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
||||
# cv2.imshow('img', img)
|
||||
# img = Image.fromarray (img,'RGB')
|
||||
# img.show()
|
||||
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
||||
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
||||
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
||||
# img2 = Image.fromarray (img2,'RGB')
|
||||
# img2.show()
|
||||
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
||||
# img3 = gdal.Open(img3).ReadAsArray()
|
||||
# img3 = Image.fromarray (img3)
|
||||
# img3.show()
|
||||
|
||||
# dataset和dataloader的测试
|
||||
|
||||
# 测试dataloader能不能用
|
||||
'''
|
||||
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
||||
input_transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
||||
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
||||
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
||||
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
||||
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
||||
'''
|
156
train_LS.py
Normal file
156
train_LS.py
Normal file
@ -0,0 +1,156 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||||
from torchvision import transforms
|
||||
import torch.utils.data as data
|
||||
from torch import nn
|
||||
import numpy as np
|
||||
|
||||
from geotiff_utils import VOCYJSSegmentation
|
||||
import utils
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def parse_args():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||||
|
||||
parser.add_argument(
|
||||
"--data-path", default="E:/repository/DeepLearning23/datasets/WLKdata_1111/WLKdataset", help="VOCdevkit root")
|
||||
parser.add_argument("--num-classes", default=13, type=int)
|
||||
parser.add_argument("--device", default="cuda", help="training device")
|
||||
parser.add_argument("--batch-size", default=8, type=int)
|
||||
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||||
help="number of total epochs to train")
|
||||
parser.add_argument('--lr', default=0.005, type=float,
|
||||
help='initial learning rate')
|
||||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||||
help='momentum')
|
||||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)',
|
||||
dest='weight_decay')
|
||||
parser.add_argument('-out-dir', type=str,
|
||||
default='DeeplabV3_LS')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DeeplabV3_LS(nn.Module):
|
||||
def __init__(self, n_class):
|
||||
super(DeeplabV3_LS, self).__init__()
|
||||
self.n_class = n_class
|
||||
self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1)
|
||||
|
||||
self.conv_fc = nn.Conv2d(
|
||||
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||||
|
||||
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||||
|
||||
def forward(self, x_LS, x_SE):
|
||||
# x = torch.cat([x, x_dan], dim=1)
|
||||
# x = self.conv7_3(x)
|
||||
x = self.seg(x_LS)["out"]
|
||||
x = self.conv_fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
input_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
||||
])
|
||||
data_kwargs = {'transform': input_transform,
|
||||
'base_size': 224, 'crop_size': 224}
|
||||
|
||||
# 读取geotiff数据,构建训练集、验证集
|
||||
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||||
**data_kwargs)
|
||||
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||||
**data_kwargs)
|
||||
num_workers = min(
|
||||
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||||
train_loader = data.DataLoader(
|
||||
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
val_loader = data.DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||||
|
||||
model = DeeplabV3_LS(n_class=args.num_classes)
|
||||
model.to(device)
|
||||
|
||||
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||||
optimizer = torch.optim.SGD(
|
||||
model.parameters(),
|
||||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||||
)
|
||||
|
||||
lr_scheduler = utils.create_lr_scheduler(
|
||||
optimizer, len(train_loader), args.epochs, warmup=True)
|
||||
|
||||
now = datetime.now()
|
||||
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||||
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||||
if not out_dir.exists():
|
||||
out_dir.mkdir()
|
||||
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||||
start_time = time.time()
|
||||
best_acc = 0
|
||||
for epoch in range(args.epochs):
|
||||
print(f"Epoch {epoch+1}\n-------------------------------")
|
||||
model.train()
|
||||
for idx, (image, image_dan, target) in enumerate(train_loader):
|
||||
image, image_dan, target = image.to(
|
||||
device), image_dan.to(device), target.to(device)
|
||||
output = model(image, image_dan)
|
||||
|
||||
loss = criterion(output, target)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
if idx % 100 == 0:
|
||||
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||||
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||||
|
||||
model.eval()
|
||||
confmat = utils.ConfusionMatrix(args.num_classes)
|
||||
with torch.no_grad():
|
||||
for image, image_dan, target in val_loader:
|
||||
image, image_dan, target = image.to(device), image_dan.to(
|
||||
device), target.to(device)
|
||||
output = model(image, image_dan)
|
||||
|
||||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||||
|
||||
info, mIoU = confmat.get_info()
|
||||
print(info)
|
||||
|
||||
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||||
f.write(info+"\n\n")
|
||||
f.flush()
|
||||
|
||||
# # 保存准确率最好的模型
|
||||
# if mIoU > best_acc:
|
||||
# print("[Save model]")
|
||||
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||||
# best_acc = mIoU
|
||||
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||||
total_time = time.time() - start_time
|
||||
print("total time:", total_time)
|
||||
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
main(args)
|
330
utils.py
Normal file
330
utils.py
Normal file
@ -0,0 +1,330 @@
|
||||
import datetime
|
||||
import errno
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
class SmoothedValue:
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({global_avg:.4f})"
|
||||
self.deque = deque(maxlen=window_size)
|
||||
self.total = 0.0
|
||||
self.count = 0
|
||||
self.fmt = fmt
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
"""
|
||||
Warning: does not synchronize the deque!
|
||||
"""
|
||||
t = reduce_across_processes([self.count, self.total])
|
||||
t = t.tolist()
|
||||
self.count = int(t[0])
|
||||
self.total = t[1]
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
d = torch.tensor(list(self.deque))
|
||||
return d.median().item()
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
||||
return d.mean().item()
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrix:
|
||||
def __init__(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
self.mat = None
|
||||
|
||||
def update(self, a, b):
|
||||
n = self.num_classes
|
||||
if self.mat is None:
|
||||
self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
|
||||
with torch.inference_mode():
|
||||
k = (a >= 0) & (a < n)
|
||||
inds = n * a[k].to(torch.int64) + b[k]
|
||||
self.mat += torch.bincount(inds, minlength=n**2).reshape(n, n)
|
||||
|
||||
def reset(self):
|
||||
self.mat.zero_()
|
||||
|
||||
def compute(self):
|
||||
h = self.mat.float()
|
||||
acc_global = torch.diag(h).sum() / h.sum()
|
||||
acc = torch.diag(h) / h.sum(1)
|
||||
iu = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
|
||||
return acc_global, acc, iu
|
||||
|
||||
def reduce_from_all_processes(self):
|
||||
reduce_across_processes(self.mat)
|
||||
|
||||
def get_info(self):
|
||||
acc_global, acc, iu = self.compute()
|
||||
return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
|
||||
acc_global.item() * 100,
|
||||
[f"{i:.1f}" for i in (acc * 100).tolist()],
|
||||
[f"{i:.1f}" for i in (iu * 100).tolist()],
|
||||
iu.mean().item() * 100,
|
||||
), iu.mean().item() * 100
|
||||
|
||||
|
||||
class MetricLogger:
|
||||
def __init__(self, delimiter="\t"):
|
||||
self.meters = defaultdict(SmoothedValue)
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
v = v.item()
|
||||
if not isinstance(v, (float, int)):
|
||||
raise TypeError(
|
||||
f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}"
|
||||
)
|
||||
self.meters[k].update(v)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.meters:
|
||||
return self.meters[attr]
|
||||
if attr in self.__dict__:
|
||||
return self.__dict__[attr]
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{attr}'")
|
||||
|
||||
def __str__(self):
|
||||
loss_str = []
|
||||
for name, meter in self.meters.items():
|
||||
loss_str.append(f"{name}: {str(meter)}")
|
||||
return self.delimiter.join(loss_str)
|
||||
|
||||
def synchronize_between_processes(self):
|
||||
for meter in self.meters.values():
|
||||
meter.synchronize_between_processes()
|
||||
|
||||
def add_meter(self, name, meter):
|
||||
self.meters[name] = meter
|
||||
|
||||
def log_every(self, iterable, print_freq, header=None):
|
||||
i = 0
|
||||
if not header:
|
||||
header = ""
|
||||
start_time = time.time()
|
||||
end = time.time()
|
||||
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
data_time = SmoothedValue(fmt="{avg:.4f}")
|
||||
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
||||
if torch.cuda.is_available():
|
||||
log_msg = self.delimiter.join(
|
||||
[
|
||||
header,
|
||||
"[{0" + space_fmt + "}/{1}]",
|
||||
"eta: {eta}",
|
||||
"{meters}",
|
||||
"time: {time}",
|
||||
"data: {data}",
|
||||
"max mem: {memory:.0f}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
log_msg = self.delimiter.join(
|
||||
[header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}",
|
||||
"{meters}", "time: {time}", "data: {data}"]
|
||||
)
|
||||
MB = 1024.0 * 1024.0
|
||||
for obj in iterable:
|
||||
data_time.update(time.time() - end)
|
||||
yield obj
|
||||
iter_time.update(time.time() - end)
|
||||
if i % print_freq == 0:
|
||||
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
||||
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
||||
if torch.cuda.is_available():
|
||||
print(
|
||||
log_msg.format(
|
||||
i,
|
||||
len(iterable),
|
||||
eta=eta_string,
|
||||
meters=str(self),
|
||||
time=str(iter_time),
|
||||
data=str(data_time),
|
||||
memory=torch.cuda.max_memory_allocated() / MB,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
log_msg.format(
|
||||
i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
|
||||
)
|
||||
)
|
||||
i += 1
|
||||
end = time.time()
|
||||
total_time = time.time() - start_time
|
||||
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
||||
print(f"{header} Total time: {total_time_str}")
|
||||
|
||||
|
||||
def cat_list(images, fill_value=0):
|
||||
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
|
||||
batch_shape = (len(images),) + max_size
|
||||
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
|
||||
for img, pad_img in zip(images, batched_imgs):
|
||||
pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
|
||||
return batched_imgs
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
images, targets = list(zip(*batch))
|
||||
batched_imgs = cat_list(images, fill_value=0)
|
||||
batched_targets = cat_list(targets, fill_value=255)
|
||||
return batched_imgs, batched_targets
|
||||
|
||||
|
||||
def mkdir(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EEXIST:
|
||||
raise
|
||||
|
||||
|
||||
def setup_for_distributed(is_master):
|
||||
"""
|
||||
This function disables printing when not in master process
|
||||
"""
|
||||
import builtins as __builtin__
|
||||
|
||||
builtin_print = __builtin__.print
|
||||
|
||||
def print(*args, **kwargs):
|
||||
force = kwargs.pop("force", False)
|
||||
if is_master or force:
|
||||
builtin_print(*args, **kwargs)
|
||||
|
||||
__builtin__.print = print
|
||||
|
||||
|
||||
def is_dist_avail_and_initialized():
|
||||
if not dist.is_available():
|
||||
return False
|
||||
if not dist.is_initialized():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_world_size():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
|
||||
def get_rank():
|
||||
if not is_dist_avail_and_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
|
||||
def save_on_master(*args, **kwargs):
|
||||
if is_main_process():
|
||||
torch.save(*args, **kwargs)
|
||||
|
||||
|
||||
def init_distributed_mode(args):
|
||||
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
||||
args.rank = int(os.environ["RANK"])
|
||||
args.world_size = int(os.environ["WORLD_SIZE"])
|
||||
args.gpu = int(os.environ["LOCAL_RANK"])
|
||||
# elif "SLURM_PROCID" in os.environ:
|
||||
# args.rank = int(os.environ["SLURM_PROCID"])
|
||||
# args.gpu = args.rank % torch.cuda.device_count()
|
||||
elif hasattr(args, "rank"):
|
||||
pass
|
||||
else:
|
||||
print("Not using distributed mode")
|
||||
args.distributed = False
|
||||
return
|
||||
|
||||
args.distributed = True
|
||||
|
||||
torch.cuda.set_device(args.gpu)
|
||||
args.dist_backend = "nccl"
|
||||
print(
|
||||
f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
|
||||
torch.distributed.init_process_group(
|
||||
backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
|
||||
)
|
||||
torch.distributed.barrier()
|
||||
setup_for_distributed(args.rank == 0)
|
||||
|
||||
|
||||
def reduce_across_processes(val):
|
||||
if not is_dist_avail_and_initialized():
|
||||
# nothing to sync, but we still convert to tensor for consistency with the distributed case.
|
||||
return torch.tensor(val)
|
||||
|
||||
t = torch.tensor(val, device="cuda")
|
||||
dist.barrier()
|
||||
dist.all_reduce(t)
|
||||
return t
|
||||
|
||||
|
||||
def create_lr_scheduler(optimizer,
|
||||
num_step: int,
|
||||
epochs: int,
|
||||
warmup=True,
|
||||
warmup_epochs=1,
|
||||
warmup_factor=1e-3):
|
||||
assert num_step > 0 and epochs > 0
|
||||
if warmup is False:
|
||||
warmup_epochs = 0
|
||||
|
||||
def f(x):
|
||||
"""
|
||||
根据step数返回一个学习率倍率因子,
|
||||
注意在训练开始之前,pytorch会提前调用一次lr_scheduler.step()方法
|
||||
"""
|
||||
if warmup is True and x <= (warmup_epochs * num_step):
|
||||
alpha = float(x) / (warmup_epochs * num_step)
|
||||
# warmup过程中lr倍率因子从warmup_factor -> 1
|
||||
return warmup_factor * (1 - alpha) + alpha
|
||||
else:
|
||||
# warmup后lr倍率因子从1 -> 0
|
||||
# 参考deeplab_v2: Learning rate policy
|
||||
return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 0.9
|
||||
|
||||
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)
|
Loading…
Reference in New Issue
Block a user