semantic-segmentation/train_LS/train_LS.py
2025-05-26 09:33:01 +08:00

157 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import time
from datetime import datetime
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
from torchvision import transforms
import torch.utils.data as data
from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
import utils as utils
import warnings
warnings.filterwarnings("ignore")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
parser.add_argument("--num-classes", default=13, type=int)
parser.add_argument("--device", default="cuda", help="training device")
parser.add_argument("--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=50, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-out-dir', type=str,
default='DeeplabV3_LS')
args = parser.parse_args()
return args
class DeeplabV3_LS(nn.Module):
def __init__(self, n_class):
super(DeeplabV3_LS, self).__init__()
self.n_class = n_class
self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1)
self.conv_fc = nn.Conv2d(
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
self.seg = deeplabv3_resnet50(weights='DEFAULT')
def forward(self, x_LS, x_SE):
# x = torch.cat([x, x_dan], dim=1)
# x = self.conv7_3(x)
x = self.seg(x_LS)["out"]
x = self.conv_fc(x)
return x
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
input_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
])
data_kwargs = {'transform': input_transform,
'base_size': 224, 'crop_size': 224}
# 读取geotiff数据构建训练集、验证集
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
**data_kwargs)
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
**data_kwargs)
num_workers = min(
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
model = DeeplabV3_LS(n_class=args.num_classes)
model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = utils.create_lr_scheduler(
optimizer, len(train_loader), args.epochs, warmup=True)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d__%H-%M__")
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
if not out_dir.exists():
out_dir.mkdir()
f = open(os.path.join(out_dir, "log.txt"), 'w')
start_time = time.time()
best_acc = 0
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
model.train()
for idx, (image, image_dan, target) in enumerate(train_loader):
image, image_dan, target = image.to(
device), image_dan.to(device), target.to(device)
output = model(image, image_dan)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if idx % 100 == 0:
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
model.eval()
confmat = utils.ConfusionMatrix(args.num_classes)
with torch.no_grad():
for image, image_dan, target in val_loader:
image, image_dan, target = image.to(device), image_dan.to(
device), target.to(device)
output = model(image, image_dan)
confmat.update(target.flatten(), output.argmax(1).flatten())
info, mIoU = confmat.get_info()
print(info)
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
f.write(info+"\n\n")
f.flush()
# # 保存准确率最好的模型
# if mIoU > best_acc:
# print("[Save model]")
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
# best_acc = mIoU
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
total_time = time.time() - start_time
print("total time:", total_time)
torch.save(model, os.path.join(out_dir, "last.pth"))
if __name__ == '__main__':
args = parse_args()
main(args)