170 lines
6.6 KiB
Python
170 lines
6.6 KiB
Python
import numpy as np
|
||
from osgeo import gdal
|
||
|
||
|
||
def read_tif(fileName):
|
||
dataset = gdal.Open(fileName)
|
||
|
||
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
||
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
||
im_bands = dataset.RasterCount # 波段数
|
||
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
||
if len(im_data.shape) == 2:
|
||
im_data = im_data[np.newaxis, :]
|
||
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
||
im_proj = dataset.GetProjection() # 获取投影信息
|
||
|
||
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
||
|
||
|
||
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
||
if 'int8' in im_data.dtype.name:
|
||
datatype = gdal.GDT_Byte
|
||
elif 'int16' in im_data.dtype.name:
|
||
datatype = gdal.GDT_UInt16
|
||
else:
|
||
datatype = gdal.GDT_Float32
|
||
|
||
if len(im_data.shape) == 3:
|
||
im_bands, im_height, im_width = im_data.shape
|
||
else:
|
||
im_bands, (im_height, im_width) = 1, im_data.shape
|
||
# 创建文件
|
||
driver = gdal.GetDriverByName("GTiff")
|
||
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
||
if dataset != None and im_geotrans != None and im_proj != None:
|
||
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
||
dataset.SetProjection(im_proj) # 写入投影
|
||
for i in range(im_bands):
|
||
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
||
del dataset
|
||
|
||
|
||
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right3.tif'
|
||
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
||
fileName)
|
||
|
||
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
|
||
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
|
||
fileName2)
|
||
|
||
# fileName3 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj10m.tif'
|
||
# im_data3, im_width3, im_height3, im_bands3, im_geotrans3, im_proj3 = read_tif(
|
||
# fileName3)
|
||
|
||
# fileName4 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj20m.tif'
|
||
# im_data4, im_width4, im_height4, im_bands4, im_geotrans4, im_proj4 = read_tif(
|
||
# fileName4)
|
||
|
||
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
||
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
||
mask_fileName)
|
||
mask_im_data = np.int8(mask_im_data)
|
||
|
||
# geotiff归一化
|
||
for i in range(im_bands):
|
||
arr = im_data[i, :, :]
|
||
Min = arr.min()
|
||
Max = arr.max()
|
||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||
im_data[i] = normalized_arr
|
||
|
||
for i in range(im_bands2):
|
||
arr = im_data2[i, :, :]
|
||
Min = arr.min()
|
||
Max = arr.max()
|
||
normalized_arr = (arr-Min)/(Max-Min)*255
|
||
im_data2[i] = normalized_arr
|
||
|
||
# # 计算大图每个波段的均值和方差,train.py里transform会用到
|
||
# im_data2 = im_data2/255
|
||
# for i in range(im_bands2):
|
||
# pixels = im_data2[i, :, :].ravel()
|
||
# print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
||
# i, np.mean(pixels), np.std(pixels)))
|
||
|
||
|
||
# 切成小图
|
||
a = 0
|
||
size = 224
|
||
for i in range(0, int(im_height / size)):
|
||
for j in range(0, int(im_width / size)):
|
||
im_cut = im_data[:, i * size*4:i * size*4 +
|
||
size*4, j * size*4:j * size*4 + size*4]
|
||
im_cut2 = im_data2[:, i * size*4:i * size*4 +
|
||
size*4, j * size*4:j * size*4 + size*4]
|
||
# im_cut3 = im_data3[:, i * size*2:i * size*2 +
|
||
# size*2, j * size*2:j * size*2 + size*2]
|
||
# im_cut4 = im_data4[:, i * size:i * size +
|
||
# size, j * size:j * size + size]
|
||
mask_cut = mask_im_data[:, i * size*4:i *
|
||
size*4 + size*4, j * size*4:j * size*4 + size*4]
|
||
|
||
# 以20m为判断基准,同时处理geotiff和mask
|
||
labelfla_bool = np.all(np.array(mask_cut).flatten() == 15)
|
||
|
||
if labelfla_bool:
|
||
print("False")
|
||
else:
|
||
left_h = i * size*4 * im_geotrans[5] + im_geotrans[3]
|
||
left_w = j * size*4 * im_geotrans[1] + im_geotrans[0]
|
||
new_geotrans = np.array(im_geotrans)
|
||
new_geotrans[0] = left_w
|
||
new_geotrans[3] = left_h
|
||
out_geotrans = tuple(new_geotrans)
|
||
im_out = 'E:/RSdata/mask_5m/dataset_5m/geotiff' + \
|
||
str(a) + '.tif'
|
||
write_tif(im_cut, size*4, size*4, im_out, out_geotrans, im_proj)
|
||
print(im_out + 'Cut to complete')
|
||
|
||
left_h2 = i * size*4 * im_geotrans2[5] + im_geotrans2[3]
|
||
left_w2 = j * size*4 * im_geotrans2[1] + im_geotrans2[0]
|
||
new_geotrans = np.array(im_geotrans2)
|
||
new_geotrans[0] = left_w2
|
||
new_geotrans[3] = left_h2
|
||
out_geotrans = tuple(new_geotrans)
|
||
|
||
im_out = 'E:/RSdata/mask_5m/dataset_dan/geotiff' + \
|
||
str(a) + '.tif'
|
||
write_tif(im_cut2, size*4, size*4, im_out, out_geotrans, im_proj2)
|
||
print(im_out + 'Cut to complete')
|
||
|
||
# left_h3 = i * size*2 * im_geotrans3[5] + im_geotrans3[3]
|
||
# left_w3 = j * size*2 * im_geotrans3[1] + im_geotrans3[0]
|
||
# new_geotrans = np.array(im_geotrans3)
|
||
# new_geotrans[0] = left_w3
|
||
# new_geotrans[3] = left_h3
|
||
# out_geotrans = tuple(new_geotrans)
|
||
|
||
# im_out = 'E:/mask_20m/all/dataset_10m/geotiff' + \
|
||
# str(a) + '.tif'
|
||
# write_tif(im_cut3, size*2, size*2, im_out, out_geotrans, im_proj3)
|
||
# print(im_out + 'Cut to complete')
|
||
|
||
# left_h4 = i * size * im_geotrans4[5] + im_geotrans4[3]
|
||
# left_w4 = j * size * im_geotrans4[1] + im_geotrans4[0]
|
||
# new_geotrans = np.array(im_geotrans4)
|
||
# new_geotrans[0] = left_w4
|
||
# new_geotrans[3] = left_h4
|
||
# out_geotrans = tuple(new_geotrans)
|
||
|
||
# im_out = 'E:/mask_20m/all/dataset_20m/geotiff' + \
|
||
# str(a) + '.tif'
|
||
# write_tif(im_cut4, size, size, im_out, out_geotrans, im_proj4)
|
||
# print(im_out + 'Cut to complete')
|
||
|
||
mask_left_h = i * size*4 * \
|
||
mask_im_geotrans[5] + mask_im_geotrans[3]
|
||
mask_left_w = j * size*4 * \
|
||
mask_im_geotrans[1] + mask_im_geotrans[0]
|
||
mask_new_geotrans = np.array(mask_im_geotrans)
|
||
mask_new_geotrans[0] = mask_left_w
|
||
mask_new_geotrans[3] = mask_left_h
|
||
mask_out_geotrans = tuple(mask_new_geotrans)
|
||
mask_out = 'E:/RSdata/mask_5m/mask/geotiff' + str(a) + '.tif'
|
||
write_tif(mask_cut, size*4, size*4, mask_out,
|
||
mask_out_geotrans, mask_im_proj)
|
||
print(mask_out + 'Cut to complete')
|
||
|
||
a = a+1
|