Files
MediaNCognition/hw2/code/train.py
2024-04-05 14:04:40 +08:00

357 lines
12 KiB
Python

# ========================================================
# Media and Cognition
# Homework 2 Convolutional Neural Network
# train.py - Train traffic sign classification model
# Student ID:
# Name:
# Tsinghua University
# (C) Copyright 2024
# ========================================================
# ==== Part 1: import libs
import argparse # argparse is used to conveniently set our configurations
import glob
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
# Import our data loader function and Classifier model defined in other files
from datasets import get_data_loader
from networks import Classifier
# ==== Part 2: training and validation
def train(
data_root,
augment,
in_channels,
num_classes,
batch_norm,
dropout,
stn,
n_epochs,
batch_size,
lr,
momentum,
weight_decay,
optim_type,
ckpt_path,
max_ckpt_save_num,
ckpt_save_interval,
val_interval,
resume,
device="cpu",
):
"""
The main training procedure
----------------------------
:param data_root: path to the root directory of dataset
:param augment: whether to use data augmentation
:param in_channels: channel number of image
:param num_classes: number of classes, in this task it is 26 English letters
:param batch_norm: whether to use batch normalization in convolutional layers and linear layers
:param dropout: dropout ratio of dropout layer which ranges from 0 to 1
:param stn: whether to use spatial transformer network
:param n_epochs: number of training epochs
:param batch_size: batch size of training
:param lr: learning rate
:param momentum: only used if optim_type == 'sgd'
:param weight_decay: the factor of L2 penalty on network weights
:param optim_type: optimizer, which can be set as 'sgd', 'adagrad', 'rmsprop', 'adam', or 'adadelta'
:param ckpt_path: path to save checkpoint models
:param max_ckpt_save_num: maximum number of saving checkpoint models
:param ckpt_save_interval: intervals of saving checkpoint models, e.g., if ckpt_save_interval = 2, then save checkpoint models every 2 epochs
:param val_interval: intervals of validation, e.g., if val_interval = 5, then do validation after each 5 training epochs
:param resume: path to resume model
:param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available
"""
# construct training and validation data loader
train_loader = get_data_loader(
data_root,
"train",
image_size=(32, 32),
batch_size=batch_size,
num_workers=2,
augment=augment,
)
val_loader = get_data_loader(
data_root, "val", image_size=(32, 32), batch_size=batch_size, num_workers=2
)
model = Classifier(in_channels, num_classes, batch_norm, stn, dropout_prob=dropout)
# put the model on CPU or GPU
model = model.to(device)
# define loss function and optimizer
loss_func = nn.CrossEntropyLoss()
if optim_type == "sgd":
optimizer = optim.SGD(
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
)
elif optim_type == "adagrad":
optimizer = optim.Adagrad(model.parameters(), lr, weight_decay=weight_decay)
elif optim_type == "rmsprop":
optimizer = optim.RMSprop(model.parameters(), lr, weight_decay=weight_decay)
elif optim_type == "adam":
optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
elif optim_type == "adadelta":
optimizer = optim.Adadelta(model.parameters(), lr, weight_decay=weight_decay)
else:
print(
"[Error] optim_type should be one of sgd, adagrad, rmsprop, adam, or adadelta"
)
raise NotImplementedError
if resume is not None:
print(f"[Info] resuming model from {resume} ...")
checkpoint = torch.load(resume)
model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optimizer_state"])
# training
# to save loss of each training epoch in a python "list" data structure
losses = []
# to save accuracy on validation set of each training epoch in a python "list" data structure
accuracy_list = []
val_epochs = []
print("training...")
for epoch in range(n_epochs):
# set the model in training mode
model.train()
# to save total loss in one epoch
total_loss = 0.0
for step, (input, label) in enumerate(train_loader): # get a batch of data
# set data type and device
input, label = (
input.type(torch.float).to(device),
label.type(torch.long).to(device),
)
# clear gradients in the optimizer
optimizer.zero_grad()
# run the model which is the forward process
out = model(input)
# compute the CrossEntropy loss, and call backward propagation function
loss = loss_func(out, label)
loss.backward()
# update parameters of the model
optimizer.step()
# sum up of total loss, loss.item() return the value of the tensor as a standard python number
# this operation is not differentiable
total_loss += loss.item()
# average of the total loss for iterations
avg_loss = total_loss / len(train_loader)
losses.append(avg_loss)
# evaluate model on validation set
if (epoch + 1) % val_interval == 0:
val_accuracy = eval_one_epoch(model, val_loader, device)
accuracy_list.append(val_accuracy)
val_epochs.append(epoch)
print(
"Epoch {:02d}: loss = {:.3f}, accuracy on validation set = {:.3f}".format(
epoch + 1, avg_loss, val_accuracy
)
)
if (epoch + 1) % ckpt_save_interval == 0:
# get info of all saved checkpoints
ckpt_list = glob.glob(os.path.join(ckpt_path, "ckpt_epoch_*.pth"))
# sort checkpoints by saving time
ckpt_list.sort(key=os.path.getmtime)
# remove surplus ckpt file if the number is larger than max_ckpt_save_num
if len(ckpt_list) >= max_ckpt_save_num:
for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
os.remove(ckpt_list[cur_file_idx])
# save model parameters in a file
ckpt_name = os.path.join(ckpt_path, "ckpt_epoch_%d.pth" % (epoch + 1))
save_dict = {
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
"configs": {
"in_channels": in_channels,
"num_classes": num_classes,
"use_batch_norm": batch_norm,
"use_stn": stn,
"dropout_prob": dropout,
},
}
torch.save(save_dict, ckpt_name)
print("Model saved in {}\n".format(ckpt_name))
plot(losses, accuracy_list, val_epochs, ckpt_path)
def eval_one_epoch(model, val_loader, device):
"""
Evaluate model performance.
--------------------------
:param model: model
:param val_loader: validation dataloader
:param device: 'cpu' or 'cuda'
:return accuracy: performance of model
"""
# enter the evaluation mode
model.eval()
correct = 0 # number of images that are correctly classified
n_samples = 0
with torch.no_grad(): # we do not need to compute gradients during validation
for input, label in val_loader:
# set data type and device
input, label = (
input.type(torch.float).to(device),
label.type(torch.long).to(device),
)
# get the prediction result
pred = model(input)
pred = torch.argmax(pred, dim=-1)
correct += torch.sum(pred == label).item()
n_samples += len(label)
# calculate accuracy
accuracy = correct / n_samples
return accuracy
def plot(losses, accuracy_list, val_epochs, ckpt_path):
"""
Draw loss and accuracy curve
------------------
:param losses: a list with loss of each training epoch
:param accuracy_list: a list with accuracy on validation set of each training epoch
"""
# create a plot
f, ax1 = plt.subplots()
# draw loss
ax1.plot(val_epochs, losses)
ax2 = ax1.twinx()
ax2.plot(val_epochs, accuracy_list, "r")
# set labels
ax1.set_xlabel("training epoch")
ax1.set_ylabel("loss")
# ax2.set_ylim([0, 1])
ax2.set_ylabel("accuracy")
# show the image
plt.savefig(os.path.join(ckpt_path, "loss_and_acc.jpg"), dpi=300)
plt.show()
if __name__ == "__main__":
# set random seed for reproducibility
seed = 2024
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
# set configurations of the model and training process
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="data",
help="file list of training image paths and labels",
)
parser.add_argument(
"--augment", action="store_true", help="whether to use data augmentation"
)
parser.add_argument(
"--epoch", type=int, default=15, help="number of training epochs"
)
parser.add_argument("--batchsize", type=int, default=32, help="training batch size")
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
parser.add_argument(
"--momentum", type=float, default=0.9, help="momentum of SGD optimizer"
)
parser.add_argument(
"--weight_decay",
type=float,
default=0,
help="the factor of L2 penalty on network weights",
)
parser.add_argument(
"--optim_type",
type=str,
default="adam",
help="type of optimizer, can be sgd, adagrad, rmsprop, adam, or adadelta",
)
parser.add_argument(
"--bn", action="store_true", help="whether to use batch normalization"
)
parser.add_argument(
"--stn", action="store_true", help="whether to use spatial transformer network"
)
parser.add_argument("--dropout", type=float, default=0.0, help="dropout ratio")
parser.add_argument(
"--ckpt_path",
type=str,
default="checkpoints/default",
help="path to save checkpoints",
)
parser.add_argument(
"--max_ckpt_save_num",
type=int,
default=10,
help="maximum number of saving checkpoints",
)
parser.add_argument(
"--val_interval", type=int, default=1, help="intervals of validation"
)
parser.add_argument("--resume", type=str, default=None, help="path to resume model")
parser.add_argument("--device", type=str, help="cpu or cuda")
opt = parser.parse_args()
if opt.device is None:
opt.device = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs(opt.ckpt_path, exist_ok=True)
# run the training procedure
train(
data_root=opt.data_root,
augment=opt.augment,
in_channels=3,
num_classes=26,
batch_norm=opt.bn,
dropout=opt.dropout,
stn=opt.stn,
n_epochs=opt.epoch,
batch_size=opt.batchsize,
lr=opt.lr,
momentum=opt.momentum,
weight_decay=opt.weight_decay,
optim_type=opt.optim_type,
ckpt_path=opt.ckpt_path,
max_ckpt_save_num=opt.max_ckpt_save_num,
ckpt_save_interval=1,
val_interval=opt.val_interval,
resume=opt.resume,
device=opt.device,
)