357 lines
12 KiB
Python
357 lines
12 KiB
Python
# ========================================================
|
|
# Media and Cognition
|
|
# Homework 2 Convolutional Neural Network
|
|
# train.py - Train traffic sign classification model
|
|
# Student ID:
|
|
# Name:
|
|
# Tsinghua University
|
|
# (C) Copyright 2024
|
|
# ========================================================
|
|
|
|
# ==== Part 1: import libs
|
|
import argparse # argparse is used to conveniently set our configurations
|
|
import glob
|
|
import os
|
|
import random
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
|
|
# Import our data loader function and Classifier model defined in other files
|
|
from datasets import get_data_loader
|
|
from networks import Classifier
|
|
|
|
|
|
# ==== Part 2: training and validation
|
|
def train(
|
|
data_root,
|
|
augment,
|
|
in_channels,
|
|
num_classes,
|
|
batch_norm,
|
|
dropout,
|
|
stn,
|
|
n_epochs,
|
|
batch_size,
|
|
lr,
|
|
momentum,
|
|
weight_decay,
|
|
optim_type,
|
|
ckpt_path,
|
|
max_ckpt_save_num,
|
|
ckpt_save_interval,
|
|
val_interval,
|
|
resume,
|
|
device="cpu",
|
|
):
|
|
"""
|
|
The main training procedure
|
|
----------------------------
|
|
:param data_root: path to the root directory of dataset
|
|
:param augment: whether to use data augmentation
|
|
:param in_channels: channel number of image
|
|
:param num_classes: number of classes, in this task it is 26 English letters
|
|
:param batch_norm: whether to use batch normalization in convolutional layers and linear layers
|
|
:param dropout: dropout ratio of dropout layer which ranges from 0 to 1
|
|
:param stn: whether to use spatial transformer network
|
|
:param n_epochs: number of training epochs
|
|
:param batch_size: batch size of training
|
|
:param lr: learning rate
|
|
:param momentum: only used if optim_type == 'sgd'
|
|
:param weight_decay: the factor of L2 penalty on network weights
|
|
:param optim_type: optimizer, which can be set as 'sgd', 'adagrad', 'rmsprop', 'adam', or 'adadelta'
|
|
:param ckpt_path: path to save checkpoint models
|
|
:param max_ckpt_save_num: maximum number of saving checkpoint models
|
|
:param ckpt_save_interval: intervals of saving checkpoint models, e.g., if ckpt_save_interval = 2, then save checkpoint models every 2 epochs
|
|
:param val_interval: intervals of validation, e.g., if val_interval = 5, then do validation after each 5 training epochs
|
|
:param resume: path to resume model
|
|
:param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available
|
|
"""
|
|
|
|
# construct training and validation data loader
|
|
train_loader = get_data_loader(
|
|
data_root,
|
|
"train",
|
|
image_size=(32, 32),
|
|
batch_size=batch_size,
|
|
num_workers=2,
|
|
augment=augment,
|
|
)
|
|
val_loader = get_data_loader(
|
|
data_root, "val", image_size=(32, 32), batch_size=batch_size, num_workers=2
|
|
)
|
|
|
|
model = Classifier(in_channels, num_classes, batch_norm, stn, dropout_prob=dropout)
|
|
|
|
# put the model on CPU or GPU
|
|
model = model.to(device)
|
|
|
|
# define loss function and optimizer
|
|
loss_func = nn.CrossEntropyLoss()
|
|
|
|
if optim_type == "sgd":
|
|
optimizer = optim.SGD(
|
|
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
|
|
)
|
|
elif optim_type == "adagrad":
|
|
optimizer = optim.Adagrad(model.parameters(), lr, weight_decay=weight_decay)
|
|
elif optim_type == "rmsprop":
|
|
optimizer = optim.RMSprop(model.parameters(), lr, weight_decay=weight_decay)
|
|
elif optim_type == "adam":
|
|
optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
|
|
elif optim_type == "adadelta":
|
|
optimizer = optim.Adadelta(model.parameters(), lr, weight_decay=weight_decay)
|
|
else:
|
|
print(
|
|
"[Error] optim_type should be one of sgd, adagrad, rmsprop, adam, or adadelta"
|
|
)
|
|
raise NotImplementedError
|
|
|
|
if resume is not None:
|
|
print(f"[Info] resuming model from {resume} ...")
|
|
checkpoint = torch.load(resume)
|
|
model.load_state_dict(checkpoint["model_state"])
|
|
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
|
|
|
# training
|
|
# to save loss of each training epoch in a python "list" data structure
|
|
losses = []
|
|
# to save accuracy on validation set of each training epoch in a python "list" data structure
|
|
accuracy_list = []
|
|
val_epochs = []
|
|
|
|
print("training...")
|
|
for epoch in range(n_epochs):
|
|
# set the model in training mode
|
|
model.train()
|
|
|
|
# to save total loss in one epoch
|
|
total_loss = 0.0
|
|
|
|
for step, (input, label) in enumerate(train_loader): # get a batch of data
|
|
# set data type and device
|
|
input, label = (
|
|
input.type(torch.float).to(device),
|
|
label.type(torch.long).to(device),
|
|
)
|
|
|
|
# clear gradients in the optimizer
|
|
optimizer.zero_grad()
|
|
|
|
# run the model which is the forward process
|
|
out = model(input)
|
|
|
|
# compute the CrossEntropy loss, and call backward propagation function
|
|
loss = loss_func(out, label)
|
|
loss.backward()
|
|
|
|
# update parameters of the model
|
|
optimizer.step()
|
|
|
|
# sum up of total loss, loss.item() return the value of the tensor as a standard python number
|
|
# this operation is not differentiable
|
|
total_loss += loss.item()
|
|
|
|
# average of the total loss for iterations
|
|
avg_loss = total_loss / len(train_loader)
|
|
losses.append(avg_loss)
|
|
|
|
# evaluate model on validation set
|
|
if (epoch + 1) % val_interval == 0:
|
|
val_accuracy = eval_one_epoch(model, val_loader, device)
|
|
accuracy_list.append(val_accuracy)
|
|
val_epochs.append(epoch)
|
|
print(
|
|
"Epoch {:02d}: loss = {:.3f}, accuracy on validation set = {:.3f}".format(
|
|
epoch + 1, avg_loss, val_accuracy
|
|
)
|
|
)
|
|
|
|
if (epoch + 1) % ckpt_save_interval == 0:
|
|
# get info of all saved checkpoints
|
|
ckpt_list = glob.glob(os.path.join(ckpt_path, "ckpt_epoch_*.pth"))
|
|
# sort checkpoints by saving time
|
|
ckpt_list.sort(key=os.path.getmtime)
|
|
# remove surplus ckpt file if the number is larger than max_ckpt_save_num
|
|
if len(ckpt_list) >= max_ckpt_save_num:
|
|
for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
|
|
os.remove(ckpt_list[cur_file_idx])
|
|
|
|
# save model parameters in a file
|
|
ckpt_name = os.path.join(ckpt_path, "ckpt_epoch_%d.pth" % (epoch + 1))
|
|
save_dict = {
|
|
"model_state": model.state_dict(),
|
|
"optimizer_state": optimizer.state_dict(),
|
|
"configs": {
|
|
"in_channels": in_channels,
|
|
"num_classes": num_classes,
|
|
"use_batch_norm": batch_norm,
|
|
"use_stn": stn,
|
|
"dropout_prob": dropout,
|
|
},
|
|
}
|
|
|
|
torch.save(save_dict, ckpt_name)
|
|
print("Model saved in {}\n".format(ckpt_name))
|
|
|
|
plot(losses, accuracy_list, val_epochs, ckpt_path)
|
|
|
|
|
|
def eval_one_epoch(model, val_loader, device):
|
|
"""
|
|
Evaluate model performance.
|
|
--------------------------
|
|
:param model: model
|
|
:param val_loader: validation dataloader
|
|
:param device: 'cpu' or 'cuda'
|
|
:return accuracy: performance of model
|
|
"""
|
|
|
|
# enter the evaluation mode
|
|
model.eval()
|
|
correct = 0 # number of images that are correctly classified
|
|
n_samples = 0
|
|
with torch.no_grad(): # we do not need to compute gradients during validation
|
|
for input, label in val_loader:
|
|
# set data type and device
|
|
input, label = (
|
|
input.type(torch.float).to(device),
|
|
label.type(torch.long).to(device),
|
|
)
|
|
# get the prediction result
|
|
pred = model(input)
|
|
pred = torch.argmax(pred, dim=-1)
|
|
correct += torch.sum(pred == label).item()
|
|
n_samples += len(label)
|
|
|
|
# calculate accuracy
|
|
accuracy = correct / n_samples
|
|
|
|
return accuracy
|
|
|
|
|
|
def plot(losses, accuracy_list, val_epochs, ckpt_path):
|
|
"""
|
|
Draw loss and accuracy curve
|
|
------------------
|
|
:param losses: a list with loss of each training epoch
|
|
:param accuracy_list: a list with accuracy on validation set of each training epoch
|
|
"""
|
|
|
|
# create a plot
|
|
f, ax1 = plt.subplots()
|
|
|
|
# draw loss
|
|
ax1.plot(val_epochs, losses)
|
|
ax2 = ax1.twinx()
|
|
ax2.plot(val_epochs, accuracy_list, "r")
|
|
|
|
# set labels
|
|
ax1.set_xlabel("training epoch")
|
|
ax1.set_ylabel("loss")
|
|
# ax2.set_ylim([0, 1])
|
|
ax2.set_ylabel("accuracy")
|
|
|
|
# show the image
|
|
plt.savefig(os.path.join(ckpt_path, "loss_and_acc.jpg"), dpi=300)
|
|
plt.show()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# set random seed for reproducibility
|
|
seed = 2024
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
# set configurations of the model and training process
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--data_root",
|
|
type=str,
|
|
default="data",
|
|
help="file list of training image paths and labels",
|
|
)
|
|
parser.add_argument(
|
|
"--augment", action="store_true", help="whether to use data augmentation"
|
|
)
|
|
parser.add_argument(
|
|
"--epoch", type=int, default=15, help="number of training epochs"
|
|
)
|
|
parser.add_argument("--batchsize", type=int, default=32, help="training batch size")
|
|
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
|
parser.add_argument(
|
|
"--momentum", type=float, default=0.9, help="momentum of SGD optimizer"
|
|
)
|
|
parser.add_argument(
|
|
"--weight_decay",
|
|
type=float,
|
|
default=0,
|
|
help="the factor of L2 penalty on network weights",
|
|
)
|
|
parser.add_argument(
|
|
"--optim_type",
|
|
type=str,
|
|
default="adam",
|
|
help="type of optimizer, can be sgd, adagrad, rmsprop, adam, or adadelta",
|
|
)
|
|
parser.add_argument(
|
|
"--bn", action="store_true", help="whether to use batch normalization"
|
|
)
|
|
parser.add_argument(
|
|
"--stn", action="store_true", help="whether to use spatial transformer network"
|
|
)
|
|
parser.add_argument("--dropout", type=float, default=0.0, help="dropout ratio")
|
|
parser.add_argument(
|
|
"--ckpt_path",
|
|
type=str,
|
|
default="checkpoints/default",
|
|
help="path to save checkpoints",
|
|
)
|
|
parser.add_argument(
|
|
"--max_ckpt_save_num",
|
|
type=int,
|
|
default=10,
|
|
help="maximum number of saving checkpoints",
|
|
)
|
|
parser.add_argument(
|
|
"--val_interval", type=int, default=1, help="intervals of validation"
|
|
)
|
|
parser.add_argument("--resume", type=str, default=None, help="path to resume model")
|
|
parser.add_argument("--device", type=str, help="cpu or cuda")
|
|
|
|
opt = parser.parse_args()
|
|
if opt.device is None:
|
|
opt.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
os.makedirs(opt.ckpt_path, exist_ok=True)
|
|
|
|
# run the training procedure
|
|
train(
|
|
data_root=opt.data_root,
|
|
augment=opt.augment,
|
|
in_channels=3,
|
|
num_classes=26,
|
|
batch_norm=opt.bn,
|
|
dropout=opt.dropout,
|
|
stn=opt.stn,
|
|
n_epochs=opt.epoch,
|
|
batch_size=opt.batchsize,
|
|
lr=opt.lr,
|
|
momentum=opt.momentum,
|
|
weight_decay=opt.weight_decay,
|
|
optim_type=opt.optim_type,
|
|
ckpt_path=opt.ckpt_path,
|
|
max_ckpt_save_num=opt.max_ckpt_save_num,
|
|
ckpt_save_interval=1,
|
|
val_interval=opt.val_interval,
|
|
resume=opt.resume,
|
|
device=opt.device,
|
|
)
|