semantic-segmentation/data_preprocessing/cut_smalltif_multi.py
weixin_46229132 5e0d438280 first commit
2025-05-14 20:45:42 +08:00

170 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from osgeo import gdal
def read_tif(fileName):
dataset = gdal.Open(fileName)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_bands = dataset.RasterCount # 波段数
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
if len(im_data.shape) == 2:
im_data = im_data[np.newaxis, :]
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
im_proj = dataset.GetProjection() # 获取投影信息
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if dataset != None and im_geotrans != None and im_proj != None:
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right3.tif'
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
fileName)
fileName2 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cjdan.tif'
im_data2, im_width2, im_height2, im_bands2, im_geotrans2, im_proj2 = read_tif(
fileName2)
# fileName3 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj10m.tif'
# im_data3, im_width3, im_height3, im_bands3, im_geotrans3, im_proj3 = read_tif(
# fileName3)
# fileName4 = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj20m.tif'
# im_data4, im_width4, im_height4, im_bands4, im_geotrans4, im_proj4 = read_tif(
# fileName4)
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
mask_fileName)
mask_im_data = np.int8(mask_im_data)
# geotiff归一化
for i in range(im_bands):
arr = im_data[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data[i] = normalized_arr
for i in range(im_bands2):
arr = im_data2[i, :, :]
Min = arr.min()
Max = arr.max()
normalized_arr = (arr-Min)/(Max-Min)*255
im_data2[i] = normalized_arr
# # 计算大图每个波段的均值和方差train.py里transform会用到
# im_data2 = im_data2/255
# for i in range(im_bands2):
# pixels = im_data2[i, :, :].ravel()
# print("波段{} mean: {:.4f}, std: {:.4f}".format(
# i, np.mean(pixels), np.std(pixels)))
# 切成小图
a = 0
size = 224
for i in range(0, int(im_height / size)):
for j in range(0, int(im_width / size)):
im_cut = im_data[:, i * size*4:i * size*4 +
size*4, j * size*4:j * size*4 + size*4]
im_cut2 = im_data2[:, i * size*4:i * size*4 +
size*4, j * size*4:j * size*4 + size*4]
# im_cut3 = im_data3[:, i * size*2:i * size*2 +
# size*2, j * size*2:j * size*2 + size*2]
# im_cut4 = im_data4[:, i * size:i * size +
# size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size*4:i *
size*4 + size*4, j * size*4:j * size*4 + size*4]
# 以20m为判断基准同时处理geotiff和mask
labelfla_bool = np.all(np.array(mask_cut).flatten() == 15)
if labelfla_bool:
print("False")
else:
left_h = i * size*4 * im_geotrans[5] + im_geotrans[3]
left_w = j * size*4 * im_geotrans[1] + im_geotrans[0]
new_geotrans = np.array(im_geotrans)
new_geotrans[0] = left_w
new_geotrans[3] = left_h
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/mask_5m/dataset_5m/geotiff' + \
str(a) + '.tif'
write_tif(im_cut, size*4, size*4, im_out, out_geotrans, im_proj)
print(im_out + 'Cut to complete')
left_h2 = i * size*4 * im_geotrans2[5] + im_geotrans2[3]
left_w2 = j * size*4 * im_geotrans2[1] + im_geotrans2[0]
new_geotrans = np.array(im_geotrans2)
new_geotrans[0] = left_w2
new_geotrans[3] = left_h2
out_geotrans = tuple(new_geotrans)
im_out = 'E:/RSdata/mask_5m/dataset_dan/geotiff' + \
str(a) + '.tif'
write_tif(im_cut2, size*4, size*4, im_out, out_geotrans, im_proj2)
print(im_out + 'Cut to complete')
# left_h3 = i * size*2 * im_geotrans3[5] + im_geotrans3[3]
# left_w3 = j * size*2 * im_geotrans3[1] + im_geotrans3[0]
# new_geotrans = np.array(im_geotrans3)
# new_geotrans[0] = left_w3
# new_geotrans[3] = left_h3
# out_geotrans = tuple(new_geotrans)
# im_out = 'E:/mask_20m/all/dataset_10m/geotiff' + \
# str(a) + '.tif'
# write_tif(im_cut3, size*2, size*2, im_out, out_geotrans, im_proj3)
# print(im_out + 'Cut to complete')
# left_h4 = i * size * im_geotrans4[5] + im_geotrans4[3]
# left_w4 = j * size * im_geotrans4[1] + im_geotrans4[0]
# new_geotrans = np.array(im_geotrans4)
# new_geotrans[0] = left_w4
# new_geotrans[3] = left_h4
# out_geotrans = tuple(new_geotrans)
# im_out = 'E:/mask_20m/all/dataset_20m/geotiff' + \
# str(a) + '.tif'
# write_tif(im_cut4, size, size, im_out, out_geotrans, im_proj4)
# print(im_out + 'Cut to complete')
mask_left_h = i * size*4 * \
mask_im_geotrans[5] + mask_im_geotrans[3]
mask_left_w = j * size*4 * \
mask_im_geotrans[1] + mask_im_geotrans[0]
mask_new_geotrans = np.array(mask_im_geotrans)
mask_new_geotrans[0] = mask_left_w
mask_new_geotrans[3] = mask_left_h
mask_out_geotrans = tuple(mask_new_geotrans)
mask_out = 'E:/RSdata/mask_5m/mask/geotiff' + str(a) + '.tif'
write_tif(mask_cut, size*4, size*4, mask_out,
mask_out_geotrans, mask_im_proj)
print(mask_out + 'Cut to complete')
a = a+1