154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
import os
|
||
import time
|
||
from datetime import datetime
|
||
from pathlib import Path
|
||
|
||
import torch
|
||
from torchvision.models.segmentation import deeplabv3_resnet50, fcn_resnet50, lraspp_mobilenet_v3_large
|
||
from torchvision import transforms
|
||
import torch.utils.data as data
|
||
from torch import nn
|
||
import numpy as np
|
||
|
||
from geotiff_utils import VOCYJSSegmentation
|
||
import utils as utils
|
||
import warnings
|
||
warnings.filterwarnings("ignore")
|
||
|
||
|
||
def parse_args():
|
||
import argparse
|
||
parser = argparse.ArgumentParser(description="pytorch deeplabv3 training")
|
||
|
||
parser.add_argument(
|
||
"--data-path", default=r"E:\RSdata\wlk_right_448", help="VOCdevkit root")
|
||
parser.add_argument("--num-classes", default=7, type=int)
|
||
parser.add_argument("--device", default="cuda", help="training device")
|
||
parser.add_argument("--batch-size", default=4, type=int)
|
||
parser.add_argument("--epochs", default=50, type=int, metavar="N",
|
||
help="number of total epochs to train")
|
||
parser.add_argument('--lr', default=0.005, type=float,
|
||
help='initial learning rate')
|
||
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
|
||
help='momentum')
|
||
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
|
||
metavar='W', help='weight decay (default: 1e-4)',
|
||
dest='weight_decay')
|
||
parser.add_argument('-out-dir', type=str,
|
||
default='DeeplabV3_JL')
|
||
|
||
args = parser.parse_args()
|
||
|
||
return args
|
||
|
||
|
||
class DeeplabV3_JL_3(nn.Module):
|
||
def __init__(self, n_class):
|
||
super(DeeplabV3_JL_3, self).__init__()
|
||
self.n_class = n_class
|
||
|
||
self.conv_fc = nn.Conv2d(
|
||
21, self.n_class, kernel_size=(1, 1), stride=(1, 1))
|
||
|
||
self.seg = deeplabv3_resnet50(weights='DEFAULT')
|
||
|
||
def forward(self, x):
|
||
x = self.seg(x)["out"]
|
||
x = self.conv_fc(x)
|
||
return x
|
||
|
||
|
||
def main(args):
|
||
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
||
|
||
input_transform = transforms.Compose([
|
||
transforms.ToTensor(),
|
||
transforms.Normalize([.485, .456, .406],
|
||
[.229, .224, .225]),
|
||
])
|
||
data_kwargs = {'transform': input_transform,
|
||
'base_size': 448, 'crop_size': 448}
|
||
|
||
# 读取geotiff数据,构建训练集、验证集
|
||
train_dataset = VOCYJSSegmentation(root=args.data_path, split='train', mode='train',
|
||
**data_kwargs)
|
||
val_dataset = VOCYJSSegmentation(root=args.data_path, split='val', mode='val',
|
||
**data_kwargs)
|
||
num_workers = min(
|
||
[os.cpu_count(), args.batch_size if args.batch_size > 1 else 0])
|
||
train_loader = data.DataLoader(
|
||
train_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||
val_loader = data.DataLoader(
|
||
val_dataset, batch_size=args.batch_size, num_workers=num_workers, pin_memory=True, shuffle=True)
|
||
|
||
model = DeeplabV3_JL_3(n_class=args.num_classes)
|
||
model.to(device)
|
||
|
||
criterion = torch.nn.CrossEntropyLoss(ignore_index=255)
|
||
optimizer = torch.optim.SGD(
|
||
model.parameters(),
|
||
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay
|
||
)
|
||
|
||
lr_scheduler = utils.create_lr_scheduler(
|
||
optimizer, len(train_loader), args.epochs, warmup=True)
|
||
|
||
now = datetime.now()
|
||
date_time = now.strftime("%Y-%m-%d__%H-%M__")
|
||
out_dir = Path(os.path.join("./train_output", date_time + args.out_dir))
|
||
if not out_dir.exists():
|
||
out_dir.mkdir()
|
||
f = open(os.path.join(out_dir, "log.txt"), 'w')
|
||
start_time = time.time()
|
||
best_acc = 0
|
||
for epoch in range(args.epochs):
|
||
print(f"Epoch {epoch+1}\n-------------------------------")
|
||
model.train()
|
||
for idx, (image, target) in enumerate(train_loader):
|
||
image, target = image.to(
|
||
device), target.to(device)
|
||
output = model(image)
|
||
|
||
loss = criterion(output, target)
|
||
|
||
optimizer.zero_grad()
|
||
loss.backward()
|
||
optimizer.step()
|
||
lr_scheduler.step()
|
||
|
||
if idx % 100 == 0:
|
||
print("[ {} / {} ] loss: {:.4f}, lr: {}".format(idx,
|
||
len(train_loader), loss.item(), optimizer.param_groups[0]["lr"]))
|
||
|
||
model.eval()
|
||
confmat = utils.ConfusionMatrix(args.num_classes)
|
||
with torch.no_grad():
|
||
for image, target in val_loader:
|
||
image, target = image.to(device), target.to(device)
|
||
output = model(image)
|
||
|
||
confmat.update(target.flatten(), output.argmax(1).flatten())
|
||
|
||
info, mIoU = confmat.get_info()
|
||
print(info)
|
||
|
||
f.write(f"Epoch {epoch+1}\n-------------------------------\n")
|
||
f.write(info+"\n\n")
|
||
f.flush()
|
||
|
||
# # 保存准确率最好的模型
|
||
# if mIoU > best_acc:
|
||
# print("[Save model]")
|
||
# torch.save(model, os.path.join(out_dir, "best_mIoU.pth"))
|
||
# best_acc = mIoU
|
||
torch.save(model, os.path.join(out_dir, f"{epoch+1}.pth"))
|
||
total_time = time.time() - start_time
|
||
print("total time:", total_time)
|
||
torch.save(model, os.path.join(out_dir, "last.pth"))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
args = parse_args()
|
||
|
||
main(args)
|