semantic-segmentation/train_JL_jpg/geotiff_utils.py
2025-05-26 09:33:01 +08:00

280 lines
9.7 KiB
Python

"""Pascal VOC Semantic Segmentation Dataset."""
from PIL import Image, ImageOps, ImageFilter
import torchvision.transforms as transforms
import os
import torch
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import cv2
# import gdal
from osgeo import gdal
import random
import torch.utils.data as data
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
class SegmentationDataset(object):
"""Segmentation Base Dataset"""
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
super(SegmentationDataset, self).__init__()
self.root = root
self.transform = transform
self.split = split
self.mode = mode if mode is not None else split
self.base_size = base_size
self.crop_size = crop_size
def _val_sync_transform(self, img, mask):
outsize = self.crop_size
short_size = outsize
w, h = img.size
if w > h:
oh = short_size
ow = int(1.0 * w * oh / h)
else:
ow = short_size
oh = int(1.0 * h * ow / w)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# center crop
w, h = img.size
x1 = int(round((w - outsize) / 2.))
y1 = int(round((h - outsize) / 2.))
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform(self, img, mask):
# random mirror
if random.random() < 0.5:
img = img.transpose(Image.FLIP_LEFT_RIGHT)
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
crop_size = self.crop_size
# random scale (short edge)
short_size = random.randint(
int(self.base_size * 0.5), int(self.base_size * 2.0))
w, h = img.size
if h > w:
ow = short_size
oh = int(1.0 * h * ow / w)
else:
oh = short_size
ow = int(1.0 * w * oh / h)
img = img.resize((ow, oh), Image.BILINEAR)
mask = mask.resize((ow, oh), Image.NEAREST)
# pad crop
if short_size < crop_size:
padh = crop_size - oh if oh < crop_size else 0
padw = crop_size - ow if ow < crop_size else 0
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
# random crop crop_size
w, h = img.size
x1 = random.randint(0, w - crop_size)
y1 = random.randint(0, h - crop_size)
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
# gaussian blur as in PSP
if random.random() < 0.5:
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _sync_transform_tif_geofeat(self, img, mask):
# random mirror
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _val_sync_transform_tif(self, img, mask):
# final transform
img, mask = self._img_transform(img), self._mask_transform(mask)
return img, mask
def _img_transform(self, img):
return np.array(img)
# def _mask_transform(self, mask):
# return np.array(mask).astype('int32')
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def num_class(self):
"""Number of categories."""
return self.NUM_CLASS
@property
def pred_offset(self):
return 0
class VOCYJSSegmentation(SegmentationDataset):
"""Pascal VOC Semantic Segmentation Dataset.
Parameters
----------
root : string
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
split: string
'train', 'val' or 'test'
transform : callable, optional
A function that transforms the image
Examples
--------
>>> from torchvision import transforms
>>> import torch.utils.data as data
>>> # Transforms for Normalization
>>> input_transform = transforms.Compose([
>>> transforms.ToTensor(),
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
>>> ])
>>> # Create Dataset
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
>>> # Create Training Loader
>>> train_data = data.DataLoader(
>>> trainset, 4, shuffle=True,
>>> num_workers=4)
"""
NUM_CLASS = 13
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
super(VOCYJSSegmentation, self).__init__(
root, split, mode, transform, **kwargs)
_voc_root = root
txt_path = os.path.join(root, split+'.txt')
self._mask_LS_dir = os.path.join(_voc_root, 'mask_png')
self._image_LS_dir = os.path.join(_voc_root, "dataset_5m_jpg")
self.image_list = read_text(txt_path)
random.shuffle(self.image_list)
def __getitem__(self, index):
img_name = self.image_list[index].split('.')[0]+'.jpg'
mask_name = self.image_list[index].split('.')[0]+'.png'
mask_name = mask_name.replace('img', 'mask')
img_LS = np.array(Image.open(os.path.join(
self._image_LS_dir, img_name))).astype(np.float32)
mask = np.array(Image.open(os.path.join(
self._mask_LS_dir, mask_name))).astype(np.int32)
mask = torch.from_numpy(mask).long()
# synchronized transform
# 只包含两种模式: train 和 val
if self.mode == 'train':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
elif self.mode == 'val':
img_LS, mask = self._sync_transform_tif_geofeat(
img_LS, mask)
# general resize, normalize and toTensor
if self.transform is not None:
img_LS = self.transform(img_LS)
return img_LS, mask
def __len__(self):
return len(self.image_list)
def _mask_transform(self, mask):
target = np.array(mask).astype('int32')
# target = target[np.newaxis, :]
target[target > 12] = 255
return torch.from_numpy(target).long()
@property
def classes(self):
"""Category names."""
return ('0', '1', '2', '3', '4', '5', '6')
def generator_list_of_imagepath(path):
image_list = []
for image in os.listdir(path):
# print(path)
# print(image)
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
image_list.append(image)
return image_list
def read_text(textfile):
list = []
with open(textfile, "r") as lines:
for line in lines:
list.append(line.rstrip('\n'))
return list
def dataset_segmentation(textpath, imagepath, train_percent):
image_list = generator_list_of_imagepath(imagepath)
num = len(image_list)
list = range(num)
train_num = int(num * train_percent) # training set num
train_list = random.sample(list, train_num)
print("train set size", train_num)
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
fval = open(os.path.join(textpath, 'val.txt'), 'w')
for i in list:
name = image_list[i] + '\n'
if i in train_list:
ftrain.write(name)
else:
fval.write(name)
ftrain.close()
fval.close()
if __name__ == '__main__':
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
# list=generator_list_of_imagepath(path)
# print(list)
# 切割数据集
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
train_percent = 0.8
dataset_segmentation(textpath, imagepath, train_percent)
# 显示各种图片
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
# cv2.imshow('img', img)
# img = Image.fromarray (img,'RGB')
# img.show()
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
# img2 = Image.fromarray (img2,'RGB')
# img2.show()
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
# img3 = gdal.Open(img3).ReadAsArray()
# img3 = Image.fromarray (img3)
# img3.show()
# dataset和dataloader的测试
# 测试dataloader能不能用
'''
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
input_transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
'''