semantic-segmentation/data_preprocessing/cut_smalltif.bak

139 lines
5.0 KiB
Plaintext
Raw Normal View History

2025-05-26 09:33:01 +08:00
import numpy as np
from osgeo import gdal
def read_tif(fileName):
dataset = gdal.Open(fileName)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_bands = dataset.RasterCount # 波段数
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
if len(im_data.shape) == 2:
im_data = im_data[np.newaxis, :]
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
im_proj = dataset.GetProjection() # 获取投影信息
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if dataset != None and im_geotrans != None and im_proj != None:
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
fileName)
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
fileName2)
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
mask_fileName)
mask_im_data = np.int8(mask_im_data)
# geotiff归一化
for i in range(im_bands):
arr = im_data[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data[i] = normalized_arr
for i in range(im_bands2):
arr = im_data2[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data2[i] = normalized_arr
# 计算大图每个波段的均值和方差train.py里transform会用到
im_data = im_data/255
for i in range(im_bands):
pixels = im_data[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data = im_data*255
im_data2 = im_data2/255
for i in range(im_bands2):
pixels = im_data2[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data2 = im_data2*255
# 切成小图
a = 0
size = 224
for i in range(0, int(mask_im_height / size)):
for j in range(0, int(mask_im_width / size)):
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
im_cut2 = im_data2[:, i * size:i *size + size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size:i *size + size, j * size:j * size + size]
# 以mask为判断基准同时处理geotiff和mask
labelfla = np.array(mask_cut).flatten()
if np.all(labelfla == 15): # 15为NoData
print("Skip!!!")
else:
# 5m
left_h = i * size * im_geotrans[5] + im_geotrans[3]
left_w = j * size * im_geotrans[1] + im_geotrans[0]
new_geotrans = np.array(im_geotrans)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/wlk_right_224_2/dataset_5m/geotiff' + str(a) + '.tif'
write_tif(im_cut, size, size, im_out, out_geotrans, im_proj)
# dan
left_h = i * size * im_geotrans2[5] + im_geotrans2[3]
left_w = j * size * im_geotrans2[1] + im_geotrans2[0]
new_geotrans = np.array(im_geotrans2)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/wlk_right_224_2/dataset_dan/geotiff' + str(a) + '.tif'
write_tif(im_cut2, size, size, im_out, out_geotrans, im_proj2)
# 存mask
mask_left_h = i * size * mask_im_geotrans[5] + mask_im_geotrans[3]
mask_left_w = j * size * mask_im_geotrans[1] + mask_im_geotrans[0]
mask_new_geotrans = np.array(mask_im_geotrans)
mask_new_geotrans[0] = mask_left_w
mask_new_geotrans[3] = mask_left_h
mask_out_geotrans = tuple(mask_new_geotrans)
mask_out = 'E:/RSdata/wlk_right_224_2/mask/geotiff' + str(a) + '.tif'
write_tif(mask_cut, size, size, mask_out,
mask_out_geotrans, mask_im_proj)
print(mask_out + 'Cut to complete')
a = a+1