271 lines
9.4 KiB
Python
271 lines
9.4 KiB
Python
|
|
"""Pascal VOC Semantic Segmentation Dataset."""
|
|
from PIL import Image, ImageOps, ImageFilter
|
|
import torchvision.transforms as transforms
|
|
import os
|
|
import torch
|
|
import numpy as np
|
|
from matplotlib import pyplot as plt
|
|
from PIL import Image
|
|
import cv2
|
|
# import gdal
|
|
from osgeo import gdal
|
|
import random
|
|
import torch.utils.data as data
|
|
os.environ.setdefault('OPENCV_IO_MAX_IMAGE_PIXELS', '2000000000')
|
|
|
|
|
|
class SegmentationDataset(object):
|
|
"""Segmentation Base Dataset"""
|
|
|
|
def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
|
|
super(SegmentationDataset, self).__init__()
|
|
self.root = root
|
|
self.transform = transform
|
|
self.split = split
|
|
self.mode = mode if mode is not None else split
|
|
self.base_size = base_size
|
|
self.crop_size = crop_size
|
|
|
|
def _val_sync_transform(self, img, mask):
|
|
outsize = self.crop_size
|
|
short_size = outsize
|
|
w, h = img.size
|
|
if w > h:
|
|
oh = short_size
|
|
ow = int(1.0 * w * oh / h)
|
|
else:
|
|
ow = short_size
|
|
oh = int(1.0 * h * ow / w)
|
|
img = img.resize((ow, oh), Image.BILINEAR)
|
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
|
# center crop
|
|
w, h = img.size
|
|
x1 = int(round((w - outsize) / 2.))
|
|
y1 = int(round((h - outsize) / 2.))
|
|
img = img.crop((x1, y1, x1 + outsize, y1 + outsize))
|
|
mask = mask.crop((x1, y1, x1 + outsize, y1 + outsize))
|
|
# final transform
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
return img, mask
|
|
|
|
def _sync_transform(self, img, mask):
|
|
# random mirror
|
|
if random.random() < 0.5:
|
|
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
|
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
|
crop_size = self.crop_size
|
|
# random scale (short edge)
|
|
short_size = random.randint(
|
|
int(self.base_size * 0.5), int(self.base_size * 2.0))
|
|
w, h = img.size
|
|
if h > w:
|
|
ow = short_size
|
|
oh = int(1.0 * h * ow / w)
|
|
else:
|
|
oh = short_size
|
|
ow = int(1.0 * w * oh / h)
|
|
img = img.resize((ow, oh), Image.BILINEAR)
|
|
mask = mask.resize((ow, oh), Image.NEAREST)
|
|
# pad crop
|
|
if short_size < crop_size:
|
|
padh = crop_size - oh if oh < crop_size else 0
|
|
padw = crop_size - ow if ow < crop_size else 0
|
|
img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
|
|
mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
|
|
# random crop crop_size
|
|
w, h = img.size
|
|
x1 = random.randint(0, w - crop_size)
|
|
y1 = random.randint(0, h - crop_size)
|
|
img = img.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
|
mask = mask.crop((x1, y1, x1 + crop_size, y1 + crop_size))
|
|
# gaussian blur as in PSP
|
|
if random.random() < 0.5:
|
|
img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
|
|
# final transform
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
return img, mask
|
|
|
|
def _sync_transform_tif(self, img, mask):
|
|
# random mirror
|
|
# final transform
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
return img, mask
|
|
|
|
def _sync_transform_tif_geofeat(self, img, mask):
|
|
# random mirror
|
|
# final transform
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
return img, mask
|
|
|
|
def _val_sync_transform_tif(self, img, mask):
|
|
# final transform
|
|
img, mask = self._img_transform(img), self._mask_transform(mask)
|
|
return img, mask
|
|
|
|
def _img_transform(self, img):
|
|
return np.array(img)
|
|
|
|
# def _mask_transform(self, mask):
|
|
# return np.array(mask).astype('int32')
|
|
|
|
def _mask_transform(self, mask):
|
|
target = np.array(mask).astype('int32')
|
|
# target = target[np.newaxis, :]
|
|
target[target > 12] = 255
|
|
return torch.from_numpy(target).long()
|
|
|
|
@property
|
|
def num_class(self):
|
|
"""Number of categories."""
|
|
return self.NUM_CLASS
|
|
|
|
@property
|
|
def pred_offset(self):
|
|
return 0
|
|
|
|
|
|
class VOCYJSSegmentation(SegmentationDataset):
|
|
"""Pascal VOC Semantic Segmentation Dataset.
|
|
Parameters
|
|
----------
|
|
root : string
|
|
Path to VOCdevkit folder. Default is './datasets/VOCdevkit'
|
|
split: string
|
|
'train', 'val' or 'test'
|
|
transform : callable, optional
|
|
A function that transforms the image
|
|
Examples
|
|
--------
|
|
>>> from torchvision import transforms
|
|
>>> import torch.utils.data as data
|
|
>>> # Transforms for Normalization
|
|
>>> input_transform = transforms.Compose([
|
|
>>> transforms.ToTensor(),
|
|
>>> transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
|
>>> ])
|
|
>>> # Create Dataset
|
|
>>> trainset = VOCSegmentation(split='train', transform=input_transform)
|
|
>>> # Create Training Loader
|
|
>>> train_data = data.DataLoader(
|
|
>>> trainset, 4, shuffle=True,
|
|
>>> num_workers=4)
|
|
"""
|
|
NUM_CLASS = 13
|
|
|
|
def __init__(self, root='../VOC/', split='train', mode=None, transform=None, **kwargs):
|
|
super(VOCYJSSegmentation, self).__init__(
|
|
root, split, mode, transform, **kwargs)
|
|
_voc_root = root
|
|
txt_path = os.path.join(root, split+'.txt')
|
|
self._mask_LS_dir = os.path.join(_voc_root, "masks_LS_png")
|
|
self._image_LS_dir = os.path.join(_voc_root, "images_LS_jpg")
|
|
self.image_list = read_text(txt_path)
|
|
random.shuffle(self.image_list)
|
|
|
|
def __getitem__(self, index):
|
|
img_name = self.image_list[index].split('.')[0]+'.jpg'
|
|
mask_name = self.image_list[index].split('.')[0]+'.png'
|
|
img_LS = np.array(Image.open(os.path.join(
|
|
self._image_LS_dir, img_name))).astype(np.float32)
|
|
mask = np.array(Image.open(os.path.join(
|
|
self._mask_LS_dir, mask_name))).astype(np.int32)
|
|
mask = torch.from_numpy(mask).long()
|
|
|
|
# general resize, normalize and toTensor
|
|
if self.transform is not None:
|
|
img_LS = self.transform(img_LS)
|
|
return img_LS, mask
|
|
|
|
def __len__(self):
|
|
return len(self.image_list)
|
|
|
|
def _mask_transform(self, mask):
|
|
target = np.array(mask).astype('int32')
|
|
# target = target[np.newaxis, :]
|
|
target[target > 12] = 255
|
|
return torch.from_numpy(target).long()
|
|
|
|
@property
|
|
def classes(self):
|
|
"""Category names."""
|
|
return ('0', '1', '2', '3', '4', '5', '6', '7' '8', '9', '10', '11', '12')
|
|
|
|
|
|
def generator_list_of_imagepath(path):
|
|
image_list = []
|
|
for image in os.listdir(path):
|
|
# print(path)
|
|
# print(image)
|
|
if not image == '.DS_Store' and 'tif' == image.split('.')[-1]:
|
|
image_list.append(image)
|
|
return image_list
|
|
|
|
|
|
def read_text(textfile):
|
|
list = []
|
|
with open(textfile, "r") as lines:
|
|
for line in lines:
|
|
list.append(line.rstrip('\n'))
|
|
return list
|
|
|
|
|
|
def dataset_segmentation(textpath, imagepath, train_percent):
|
|
image_list = generator_list_of_imagepath(imagepath)
|
|
num = len(image_list)
|
|
list = range(num)
|
|
train_num = int(num * train_percent) # training set num
|
|
train_list = random.sample(list, train_num)
|
|
print("train set size", train_num)
|
|
ftrain = open(os.path.join(textpath, 'train.txt'), 'w')
|
|
fval = open(os.path.join(textpath, 'val.txt'), 'w')
|
|
for i in list:
|
|
name = image_list[i] + '\n'
|
|
if i in train_list:
|
|
ftrain.write(name)
|
|
else:
|
|
fval.write(name)
|
|
ftrain.close()
|
|
fval.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# path = r'C:\Users\51440\Desktop\WLKdata\googleEarth\train\images'
|
|
# list=generator_list_of_imagepath(path)
|
|
# print(list)
|
|
# 切割数据集
|
|
|
|
textpath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset'
|
|
imagepath = r'C:\Users\51440\Desktop\WLKdata\WLKdata_1111\WLKdataset\images_GE'
|
|
train_percent = 0.8
|
|
dataset_segmentation(textpath, imagepath, train_percent)
|
|
# 显示各种图片
|
|
|
|
# img=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_GE\\322.tif'
|
|
# img = gdal.Open(img).ReadAsArray().transpose(1,2,0)
|
|
# cv2.imshow('img', img)
|
|
# img = Image.fromarray (img,'RGB')
|
|
# img.show()
|
|
# img2=r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\images_LS\\322.tif'
|
|
# img2 = gdal.Open(img2).ReadAsArray().transpose(1,2,0).astype(np.uint8)
|
|
# img2 = cv2.resize(img2, (672, 672), interpolation=cv2.INTER_CUBIC)
|
|
# img2 = Image.fromarray (img2,'RGB')
|
|
# img2.show()
|
|
# img3 = r'C:\\Users\\51440\\Desktop\\WLKdata\\WLKdata_1111\\train\\masks_LS\\322.tif'
|
|
# img3 = gdal.Open(img3).ReadAsArray()
|
|
# img3 = Image.fromarray (img3)
|
|
# img3.show()
|
|
|
|
# dataset和dataloader的测试
|
|
|
|
# 测试dataloader能不能用
|
|
'''
|
|
data_dir = r'C:/Users/51440/Desktop/WLKdata/WLKdata_1111/WLKdataset'
|
|
input_transform = transforms.Compose(
|
|
[transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225])])
|
|
dataset_train = VOCYJSSegmentation(data_dir, 'train',mode='train',transform=input_transform, base_size=224, crop_size=224)
|
|
dataset_val = VOCYJSSegmentation(data_dir, 'val', mode='val', transform=input_transform, base_size=224, crop_size=224)
|
|
train_data = data.DataLoader(dataset_train, 4, shuffle=True, num_workers=4)
|
|
test_data = data.DataLoader(dataset_val, 4, shuffle=True, num_workers=4)
|
|
'''
|