semantic-segmentation/data_preprocessing/cut_smalltif.py
2025-05-26 09:33:01 +08:00

112 lines
3.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
from osgeo import gdal
def read_tif(fileName):
dataset = gdal.Open(fileName)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_bands = dataset.RasterCount # 波段数
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
if len(im_data.shape) == 2:
im_data = im_data[np.newaxis, :]
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
im_proj = dataset.GetProjection() # 获取投影信息
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
if 'int8' in im_data.dtype.name:
datatype = gdal.GDT_Byte
elif 'int16' in im_data.dtype.name:
datatype = gdal.GDT_UInt16
else:
datatype = gdal.GDT_Float32
if len(im_data.shape) == 3:
im_bands, im_height, im_width = im_data.shape
else:
im_bands, (im_height, im_width) = 1, im_data.shape
# 创建文件
driver = gdal.GetDriverByName("GTiff")
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
if dataset != None and im_geotrans != None and im_proj != None:
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
dataset.SetProjection(im_proj) # 写入投影
for i in range(im_bands):
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
del dataset
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
fileName)
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
mask_fileName)
mask_im_data = np.int8(mask_im_data)
# geotiff归一化
lower_percent = 2
upper_percent = 98
for i in range(im_bands):
arr = im_data[i, :, :]
lower = np.percentile(arr, lower_percent)
upper = np.percentile(arr, upper_percent)
stretched = np.clip((arr - lower) / (upper - lower), 0, 1)
im_data[i] = (stretched * 255).astype(np.uint8)
# 计算大图每个波段的均值和方差train.py里transform会用到
im_data = im_data/255
for i in range(im_bands):
pixels = im_data[i, :, :].ravel()
print("波段{} mean: {:.4f}, std: {:.4f}".format(
i, np.mean(pixels), np.std(pixels)))
im_data = im_data*255
# 切成小图
a = 0
size = 448
for i in range(0, int(mask_im_height / size)):
for j in range(0, int(mask_im_width / size)):
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
mask_cut = mask_im_data[:, i * size:i *
size + size, j * size:j * size + size]
labelfla = np.array(mask_cut).flatten()
if np.all(labelfla == 15): # 15为NoData
print("Skip!!!")
else:
# 取5、4、3波段注意顺序
rgb_cut = np.stack([
im_cut[4], # 第5波段
im_cut[3], # 第4波段
im_cut[2], # 第3波段
], axis=0) # shape: (3, H, W)
# 转为 (H, W, 3)
rgb_cut = np.transpose(rgb_cut, (1, 2, 0))
# 归一化到0-255并转uint8
rgb_cut = np.clip(rgb_cut, 0, 255)
rgb_cut = rgb_cut.astype(np.uint8)
# 保存为jpg
from PIL import Image
rgb_img = Image.fromarray(rgb_cut)
rgb_img.save(
f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg')
# mask只取第一个波段如果是单通道转uint8保存为png
mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut
mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8)
mask_img = Image.fromarray(mask_arr)
mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png')
print(f'img_{a}.jpg and mask_{a}.png saved')
a = a+1