import os import time from datetime import datetime from pathlib import Path import torch from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large from torchvision import transforms import torch.utils.data as data from torch import nn import numpy as np from geotiff_utils import VOCYJSSegmentation import utils as utils import warnings warnings.filterwarnings("ignore") def parse_args(): import argparse parser = argparse.ArgumentParser(description="pytorch deeplabv3 training") parser.add_argument( "--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root") parser.add_argument("--num-classes", default=13, type=int) parser.add_argument("--device", default="cuda", help="training device") parser.add_argument("--batch-size", default=8, type=int) parser.add_argument("--epochs", default=50, type=int, metavar="N", help="number of total epochs to train") parser.add_argument('--lr', default=0.005, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float, metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('-out-dir', type=str, default='DeeplabV3_LS') args = parser.parse_args() return args class DeeplabV3_LS(nn.Module): def __init__(self, n_class): super(DeeplabV3_LS, self).__init__() self.n_class = n_class self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1) self.conv_fc = nn.Conv2d( 21, self.n_class, kernel_size=(1, 1), stride=(1, 1)) self.seg = deeplabv3_resnet50(weights='DEFAULT') def forward(self, x_LS, x_SE): # x = torch.cat([x, x_dan], dim=1) # x = self.conv7_3(x) x = self.seg(x_LS)["out"] x = self.conv_fc(x) return x def main(args): device = torch.device(args.device if torch.cuda.is_available() else "cpu") input_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([.485, .456, .406], [.229, .224, .225]), ]) data_kwargs = {'transform': input_transform, 'base_size': 224, 'crop_size': 224} # 读取geotiff数据,构建训练集、验证集 train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train', **data_kwargs) val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val', **data_kwargs) num_workers = min( [os.cpu_count(), args.batch_size if args.batch_size > 1 else 0]) train_loader = data.DataLoader( train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) val_loader = data.DataLoader( val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True) model = DeeplabV3_LS(n_class=args.num_classes) model.to(device) criterion = torch.nn.CrossEntropyLoss(ignore_index=255) optimizer = torch.optim.SGD( model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay ) lr_scheduler = utils.create_lr_scheduler( optimizer, len(train_loader), args.epochs, warmup=True) now = datetime.now() date_time = now.strftime("%Y-%m-%d__%H-%M__") out_dir = Path(os.path.join("./train_output", date_time + args.out_dir)) if not out_dir.exists(): out_dir.mkdir() f = open(os.path.join(out_dir, "log.txt"), 'w') start_time = time.time() best_acc = 0 for epoch in range(args.epochs): print(f"Epoch {epoch+1}\n-------------------------------") model.train() for idx, (image, image_dan, target) in enumerate(train_loader): image, image_dan, target = image.to( device), image_dan.to(device), target.to(device) output = model(image, image_dan) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() lr_scheduler.step() if idx % 100 == 0: print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx, len(train_loader), loss.item(), optimizer.param_groups[0]["lr"])) model.eval() confmat = utils.ConfusionMatrix(args.num_classes) with torch.no_grad(): for image, image_dan, target in val_loader: image, image_dan, target = image.to(device), image_dan.to( device), target.to(device) output = model(image, image_dan) confmat.update(target.flatten(), output.argmax(1).flatten()) info, mIoU = confmat.get_info() print(info) f.write(f"Epoch {epoch+1}\n-------------------------------\n") f.write(info+"\n\n") f.flush() # # 保存准确率最好的模型 # if mIoU > best_acc: # print("[Save model]") # torch.save(model, os.path.join(out_dir, "best_mIoU.pth")) # best_acc = mIoU torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth")) total_time = time.time() - start_time print("total time:", total_time) torch.save(model, os.path.join(out_dir, "last.pth")) if __name__ == '__main__': args = parse_args() main(args)