2025-05-14 20:45:42 +08:00
|
|
|
|
import numpy as np
|
|
|
|
|
from osgeo import gdal
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def read_tif(fileName):
|
|
|
|
|
dataset = gdal.Open(fileName)
|
|
|
|
|
|
|
|
|
|
im_width = dataset.RasterXSize # 栅格矩阵的列数
|
|
|
|
|
im_height = dataset.RasterYSize # 栅格矩阵的行数
|
|
|
|
|
im_bands = dataset.RasterCount # 波段数
|
|
|
|
|
im_data = dataset.ReadAsArray().astype(np.float32) # 获取数据
|
|
|
|
|
if len(im_data.shape) == 2:
|
|
|
|
|
im_data = im_data[np.newaxis, :]
|
|
|
|
|
im_geotrans = dataset.GetGeoTransform() # 获取仿射矩阵信息
|
|
|
|
|
im_proj = dataset.GetProjection() # 获取投影信息
|
|
|
|
|
|
|
|
|
|
return im_data, im_width, im_height, im_bands, im_geotrans, im_proj
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def write_tif(im_data, im_width, im_height, path, im_geotrans, im_proj):
|
|
|
|
|
if 'int8' in im_data.dtype.name:
|
|
|
|
|
datatype = gdal.GDT_Byte
|
|
|
|
|
elif 'int16' in im_data.dtype.name:
|
|
|
|
|
datatype = gdal.GDT_UInt16
|
|
|
|
|
else:
|
|
|
|
|
datatype = gdal.GDT_Float32
|
|
|
|
|
|
|
|
|
|
if len(im_data.shape) == 3:
|
|
|
|
|
im_bands, im_height, im_width = im_data.shape
|
|
|
|
|
else:
|
|
|
|
|
im_bands, (im_height, im_width) = 1, im_data.shape
|
|
|
|
|
# 创建文件
|
|
|
|
|
driver = gdal.GetDriverByName("GTiff")
|
|
|
|
|
dataset = driver.Create(path, im_width, im_height, im_bands, datatype)
|
|
|
|
|
if dataset != None and im_geotrans != None and im_proj != None:
|
|
|
|
|
dataset.SetGeoTransform(im_geotrans) # 写入仿射变换参数
|
|
|
|
|
dataset.SetProjection(im_proj) # 写入投影
|
|
|
|
|
for i in range(im_bands):
|
|
|
|
|
dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
|
|
|
|
|
del dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fileName = 'E:/RSdata/wlk_tif/wlk_right/wlk_right_cj.tif'
|
|
|
|
|
im_data, im_width, im_height, im_bands, im_geotrans, im_proj = read_tif(
|
|
|
|
|
fileName)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_fileName = 'E:/RSdata/wlk_tif/wlk_right/label_right_cj.tif'
|
|
|
|
|
mask_im_data, mask_im_width, mask_im_height, mask_im_bands, mask_im_geotrans, mask_im_proj = read_tif(
|
|
|
|
|
mask_fileName)
|
|
|
|
|
mask_im_data = np.int8(mask_im_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# geotiff归一化
|
2025-05-26 09:33:01 +08:00
|
|
|
|
lower_percent = 2
|
|
|
|
|
upper_percent = 98
|
2025-05-14 20:45:42 +08:00
|
|
|
|
for i in range(im_bands):
|
|
|
|
|
arr = im_data[i, :, :]
|
2025-05-26 09:33:01 +08:00
|
|
|
|
lower = np.percentile(arr, lower_percent)
|
|
|
|
|
upper = np.percentile(arr, upper_percent)
|
|
|
|
|
stretched = np.clip((arr - lower) / (upper - lower), 0, 1)
|
|
|
|
|
im_data[i] = (stretched * 255).astype(np.uint8)
|
2025-05-14 20:45:42 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 计算大图每个波段的均值和方差,train.py里transform会用到
|
|
|
|
|
im_data = im_data/255
|
|
|
|
|
for i in range(im_bands):
|
|
|
|
|
pixels = im_data[i, :, :].ravel()
|
|
|
|
|
print("波段{} mean: {:.4f}, std: {:.4f}".format(
|
|
|
|
|
i, np.mean(pixels), np.std(pixels)))
|
|
|
|
|
im_data = im_data*255
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 切成小图
|
|
|
|
|
a = 0
|
2025-05-26 09:33:01 +08:00
|
|
|
|
size = 448
|
2025-05-14 20:45:42 +08:00
|
|
|
|
for i in range(0, int(mask_im_height / size)):
|
|
|
|
|
for j in range(0, int(mask_im_width / size)):
|
|
|
|
|
im_cut = im_data[:, i * size:i * size + size, j * size:j * size + size]
|
2025-05-26 09:33:01 +08:00
|
|
|
|
mask_cut = mask_im_data[:, i * size:i *
|
|
|
|
|
size + size, j * size:j * size + size]
|
2025-05-14 20:45:42 +08:00
|
|
|
|
|
|
|
|
|
labelfla = np.array(mask_cut).flatten()
|
|
|
|
|
if np.all(labelfla == 15): # 15为NoData
|
|
|
|
|
print("Skip!!!")
|
|
|
|
|
else:
|
2025-05-26 09:33:01 +08:00
|
|
|
|
# 取5、4、3波段,注意顺序
|
|
|
|
|
rgb_cut = np.stack([
|
|
|
|
|
im_cut[4], # 第5波段
|
|
|
|
|
im_cut[3], # 第4波段
|
|
|
|
|
im_cut[2], # 第3波段
|
|
|
|
|
], axis=0) # shape: (3, H, W)
|
|
|
|
|
# 转为 (H, W, 3)
|
|
|
|
|
rgb_cut = np.transpose(rgb_cut, (1, 2, 0))
|
|
|
|
|
# 归一化到0-255并转uint8
|
|
|
|
|
rgb_cut = np.clip(rgb_cut, 0, 255)
|
|
|
|
|
rgb_cut = rgb_cut.astype(np.uint8)
|
|
|
|
|
# 保存为jpg
|
|
|
|
|
from PIL import Image
|
|
|
|
|
rgb_img = Image.fromarray(rgb_cut)
|
|
|
|
|
rgb_img.save(
|
|
|
|
|
f'E:/RSdata/wlk_right_448/dataset_5m_jpg/img_{a}.jpg')
|
|
|
|
|
|
|
|
|
|
# mask只取第一个波段(如果是单通道),转uint8保存为png
|
|
|
|
|
mask_arr = mask_cut[0] if mask_cut.shape[0] == 1 else mask_cut
|
|
|
|
|
mask_arr = np.clip(mask_arr, 0, 255).astype(np.uint8)
|
|
|
|
|
mask_img = Image.fromarray(mask_arr)
|
|
|
|
|
mask_img.save(f'E:/RSdata/wlk_right_448/mask_png/mask_{a}.png')
|
|
|
|
|
|
|
|
|
|
print(f'img_{a}.jpg and mask_{a}.png saved')
|
2025-05-14 20:45:42 +08:00
|
|
|
|
a = a+1
|