semantic-segmentation/train_JL/train_JL.py

157 lines
5.5 KiB
Python
Raw Permalink Normal View History

2025-05-14 20:45:42 +08:00
import os
import time
from datetime import datetime
from pathlib import Path
import torch
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
from torchvision import transforms
import torch.utils.data as data
from torch import nn
import numpy as np
from geotiff_utils import VOCYJSSegmentation
2025-05-26 09:33:01 +08:00
import utils as utils
2025-05-14 20:45:42 +08:00
import warnings
warnings.filterwarnings("ignore")
def parse_args():
import argparse
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
parser.add_argument(
2025-05-26 09:33:01 +08:00
"--data-path", default=r"E:\datasets\wlk_right_448", help="VOCdevkit root")
parser.add_argument("--num-classes", default=7, type=int)
2025-05-14 20:45:42 +08:00
parser.add_argument("--device", default="cuda", help="training device")
2025-05-26 09:33:01 +08:00
parser.add_argument("--batch-size", default=4, type=int)
2025-05-14 20:45:42 +08:00
parser.add_argument("--epochs", default=50, type=int, metavar="N",
help="number of total epochs to train")
parser.add_argument('--lr', default=0.005, type=float,
help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
help='momentum')
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('-out-dir', type=str,
2025-05-26 09:33:01 +08:00
default='DeeplabV3_JL')
2025-05-14 20:45:42 +08:00
args = parser.parse_args()
return args
2025-05-26 09:33:01 +08:00
class DeeplabV3_JL(nn.Module):
2025-05-14 20:45:42 +08:00
def __init__(self, n_class):
2025-05-26 09:33:01 +08:00
super(DeeplabV3_JL, self).__init__()
2025-05-14 20:45:42 +08:00
self.n_class = n_class
2025-05-26 09:33:01 +08:00
self.conv6_3 = nn.Conv2d(6, 3, kernel_size=1, stride=1)
2025-05-14 20:45:42 +08:00
self.conv_fc = nn.Conv2d(
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
self.seg = deeplabv3_resnet50(weights='DEFAULT')
2025-05-26 09:33:01 +08:00
def forward(self, x):
2025-05-14 20:45:42 +08:00
# x = torch.cat([x, x_dan], dim=1)
2025-05-26 09:33:01 +08:00
x = self.conv6_3(x)
x = self.seg(x)["out"]
2025-05-14 20:45:42 +08:00
x = self.conv_fc(x)
return x
def main(args):
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
input_transform = transforms.Compose([
transforms.ToTensor(),
2025-05-26 09:33:01 +08:00
transforms.Normalize([.535, .767, .732, .561, .494, .564],
[.0132, .0188, .0181, .0173, .0183, .0259]),
2025-05-14 20:45:42 +08:00
])
data_kwargs = {'transform': input_transform,
2025-05-26 09:33:01 +08:00
'base_size': 448, 'crop_size': 448}
2025-05-14 20:45:42 +08:00
# 读取geotiff数据构建训练集、验证集
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
**data_kwargs)
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
**data_kwargs)
num_workers = min(
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
2025-05-26 09:33:01 +08:00
model = DeeplabV3_JL(n_class=args.num_classes)
2025-05-14 20:45:42 +08:00
model.to(device)
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
)
lr_scheduler = utils.create_lr_scheduler(
optimizer, len(train_loader), args.epochs, warmup=True)
now = datetime.now()
date_time = now.strftime("%Y-%m-%d__%H-%M__")
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
if not out_dir.exists():
out_dir.mkdir()
f = open(os.path.join(out_dir, "log.txt"), 'w')
start_time = time.time()
best_acc = 0
for epoch in range(args.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
model.train()
2025-05-26 09:33:01 +08:00
for idx, (image, target) in enumerate(train_loader):
image, target = image.to(
device), target.to(device)
output = model(image)
2025-05-14 20:45:42 +08:00
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
if idx % 100 == 0:
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
model.eval()
confmat = utils.ConfusionMatrix(args.num_classes)
with torch.no_grad():
2025-05-26 09:33:01 +08:00
for image, target in val_loader:
image, target = image.to(device), target.to(device)
output = model(image)
2025-05-14 20:45:42 +08:00
confmat.update(target.flatten(), output.argmax(1).flatten())
info, mIoU = confmat.get_info()
print(info)
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
f.write(info+"\n\n")
f.flush()
# # 保存准确率最好的模型
# if mIoU > best_acc:
# print("[Save model]")
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
# best_acc = mIoU
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
total_time = time.time() - start_time
print("total time:", total_time)
torch.save(model, os.path.join(out_dir, "last.pth"))
if __name__ == '__main__':
args = parse_args()
main(args)