2025-05-14 20:45:42 +08:00
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
|
|
|
|
from torchvision import transforms
|
|
|
|
|
import torch.utils.data as data
|
|
|
|
|
from torch import nn
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
from geotiff_utils import VOCYJSSegmentation
|
2025-05-26 09:33:01 +08:00
|
|
|
|
import utils as utils
|
2025-05-14 20:45:42 +08:00
|
|
|
|
import warnings
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
|
import argparse
|
|
|
|
|
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
2025-05-26 09:33:01 +08:00
|
|
|
|
"--data-path", default=r"E:\datasets\WLKdata_1111\WLKdataset", help="VOCdevkit root")
|
2025-05-14 20:45:42 +08:00
|
|
|
|
parser.add_argument("--num-classes", default=13, type=int)
|
|
|
|
|
parser.add_argument("--device", default="cuda", help="training device")
|
|
|
|
|
parser.add_argument("--batch-size", default=8, type=int)
|
|
|
|
|
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
|
|
|
|
help="number of total epochs to train")
|
|
|
|
|
parser.add_argument('--lr', default=0.005, type=float,
|
|
|
|
|
help='initial learning rate')
|
|
|
|
|
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
|
|
|
|
help='momentum')
|
|
|
|
|
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
|
|
|
|
metavar='W', help='weight decay (default: 1e-4)',
|
|
|
|
|
dest='weight_decay')
|
|
|
|
|
parser.add_argument('-out-dir', type=str,
|
|
|
|
|
default='DeeplabV3_LS')
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeeplabV3_LS(nn.Module):
|
|
|
|
|
def __init__(self, n_class):
|
|
|
|
|
super(DeeplabV3_LS, self).__init__()
|
|
|
|
|
self.n_class = n_class
|
|
|
|
|
self.conv7_3 = nn.Conv2d(7, 3, kernel_size=1, stride=1)
|
|
|
|
|
|
|
|
|
|
self.conv_fc = nn.Conv2d(
|
|
|
|
|
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
|
|
|
|
|
|
|
|
|
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
|
|
|
|
|
|
|
|
|
def forward(self, x_LS, x_SE):
|
|
|
|
|
# x = torch.cat([x, x_dan], dim=1)
|
|
|
|
|
# x = self.conv7_3(x)
|
|
|
|
|
x = self.seg(x_LS)["out"]
|
|
|
|
|
x = self.conv_fc(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(args):
|
|
|
|
|
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
input_transform = transforms.Compose([
|
|
|
|
|
transforms.ToTensor(),
|
|
|
|
|
transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
|
|
|
|
|
])
|
|
|
|
|
data_kwargs = {'transform': input_transform,
|
|
|
|
|
'base_size': 224, 'crop_size': 224}
|
|
|
|
|
|
|
|
|
|
# 读取geotiff数据,构建训练集、验证集
|
|
|
|
|
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
|
|
|
|
**data_kwargs)
|
|
|
|
|
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
|
|
|
|
**data_kwargs)
|
|
|
|
|
num_workers = min(
|
|
|
|
|
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
|
|
|
|
train_loader = data.DataLoader(
|
|
|
|
|
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
|
|
|
|
val_loader = data.DataLoader(
|
|
|
|
|
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
|
|
|
|
|
|
|
|
|
model = DeeplabV3_LS(n_class=args.num_classes)
|
|
|
|
|
model.to(device)
|
|
|
|
|
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
|
|
|
|
optimizer = torch.optim.SGD(
|
|
|
|
|
model.parameters(),
|
|
|
|
|
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
lr_scheduler = utils.create_lr_scheduler(
|
|
|
|
|
optimizer, len(train_loader), args.epochs, warmup=True)
|
|
|
|
|
|
|
|
|
|
now = datetime.now()
|
|
|
|
|
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
|
|
|
|
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
|
|
|
|
if not out_dir.exists():
|
|
|
|
|
out_dir.mkdir()
|
|
|
|
|
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
best_acc = 0
|
|
|
|
|
for epoch in range(args.epochs):
|
|
|
|
|
print(f"Epoch {epoch+1}\n-------------------------------")
|
|
|
|
|
model.train()
|
|
|
|
|
for idx, (image, image_dan, target) in enumerate(train_loader):
|
|
|
|
|
image, image_dan, target = image.to(
|
|
|
|
|
device), image_dan.to(device), target.to(device)
|
|
|
|
|
output = model(image, image_dan)
|
|
|
|
|
|
|
|
|
|
loss = criterion(output, target)
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
loss.backward()
|
|
|
|
|
optimizer.step()
|
|
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
|
|
|
|
if idx % 100 == 0:
|
|
|
|
|
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
|
|
|
|
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
confmat = utils.ConfusionMatrix(args.num_classes)
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for image, image_dan, target in val_loader:
|
|
|
|
|
image, image_dan, target = image.to(device), image_dan.to(
|
|
|
|
|
device), target.to(device)
|
|
|
|
|
output = model(image, image_dan)
|
|
|
|
|
|
|
|
|
|
confmat.update(target.flatten(), output.argmax(1).flatten())
|
|
|
|
|
|
|
|
|
|
info, mIoU = confmat.get_info()
|
|
|
|
|
print(info)
|
|
|
|
|
|
|
|
|
|
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
|
|
|
|
f.write(info+"\n\n")
|
|
|
|
|
f.flush()
|
|
|
|
|
|
|
|
|
|
# # 保存准确率最好的模型
|
|
|
|
|
# if mIoU > best_acc:
|
|
|
|
|
# print("[Save model]")
|
|
|
|
|
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
|
|
|
|
# best_acc = mIoU
|
|
|
|
|
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
|
|
|
|
total_time = time.time() - start_time
|
|
|
|
|
print("total time:", total_time)
|
|
|
|
|
torch.save(model, os.path.join(out_dir, "last.pth"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
|
|
|
|
main(args)
|