Add Homework2.
This commit is contained in:
356
hw2/code/train.py
Normal file
356
hw2/code/train.py
Normal file
@@ -0,0 +1,356 @@
|
||||
# ========================================================
|
||||
# Media and Cognition
|
||||
# Homework 2 Convolutional Neural Network
|
||||
# train.py - Train traffic sign classification model
|
||||
# Student ID:
|
||||
# Name:
|
||||
# Tsinghua University
|
||||
# (C) Copyright 2024
|
||||
# ========================================================
|
||||
|
||||
# ==== Part 1: import libs
|
||||
import argparse # argparse is used to conveniently set our configurations
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Import our data loader function and Classifier model defined in other files
|
||||
from datasets import get_data_loader
|
||||
from networks import Classifier
|
||||
|
||||
|
||||
# ==== Part 2: training and validation
|
||||
def train(
|
||||
data_root,
|
||||
augment,
|
||||
in_channels,
|
||||
num_classes,
|
||||
batch_norm,
|
||||
dropout,
|
||||
stn,
|
||||
n_epochs,
|
||||
batch_size,
|
||||
lr,
|
||||
momentum,
|
||||
weight_decay,
|
||||
optim_type,
|
||||
ckpt_path,
|
||||
max_ckpt_save_num,
|
||||
ckpt_save_interval,
|
||||
val_interval,
|
||||
resume,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
The main training procedure
|
||||
----------------------------
|
||||
:param data_root: path to the root directory of dataset
|
||||
:param augment: whether to use data augmentation
|
||||
:param in_channels: channel number of image
|
||||
:param num_classes: number of classes, in this task it is 26 English letters
|
||||
:param batch_norm: whether to use batch normalization in convolutional layers and linear layers
|
||||
:param dropout: dropout ratio of dropout layer which ranges from 0 to 1
|
||||
:param stn: whether to use spatial transformer network
|
||||
:param n_epochs: number of training epochs
|
||||
:param batch_size: batch size of training
|
||||
:param lr: learning rate
|
||||
:param momentum: only used if optim_type == 'sgd'
|
||||
:param weight_decay: the factor of L2 penalty on network weights
|
||||
:param optim_type: optimizer, which can be set as 'sgd', 'adagrad', 'rmsprop', 'adam', or 'adadelta'
|
||||
:param ckpt_path: path to save checkpoint models
|
||||
:param max_ckpt_save_num: maximum number of saving checkpoint models
|
||||
:param ckpt_save_interval: intervals of saving checkpoint models, e.g., if ckpt_save_interval = 2, then save checkpoint models every 2 epochs
|
||||
:param val_interval: intervals of validation, e.g., if val_interval = 5, then do validation after each 5 training epochs
|
||||
:param resume: path to resume model
|
||||
:param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available
|
||||
"""
|
||||
|
||||
# construct training and validation data loader
|
||||
train_loader = get_data_loader(
|
||||
data_root,
|
||||
"train",
|
||||
image_size=(32, 32),
|
||||
batch_size=batch_size,
|
||||
num_workers=2,
|
||||
augment=augment,
|
||||
)
|
||||
val_loader = get_data_loader(
|
||||
data_root, "val", image_size=(32, 32), batch_size=batch_size, num_workers=2
|
||||
)
|
||||
|
||||
model = Classifier(in_channels, num_classes, batch_norm, stn, dropout_prob=dropout)
|
||||
|
||||
# put the model on CPU or GPU
|
||||
model = model.to(device)
|
||||
|
||||
# define loss function and optimizer
|
||||
loss_func = nn.CrossEntropyLoss()
|
||||
|
||||
if optim_type == "sgd":
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif optim_type == "adagrad":
|
||||
optimizer = optim.Adagrad(model.parameters(), lr, weight_decay=weight_decay)
|
||||
elif optim_type == "rmsprop":
|
||||
optimizer = optim.RMSprop(model.parameters(), lr, weight_decay=weight_decay)
|
||||
elif optim_type == "adam":
|
||||
optimizer = optim.Adam(model.parameters(), lr, weight_decay=weight_decay)
|
||||
elif optim_type == "adadelta":
|
||||
optimizer = optim.Adadelta(model.parameters(), lr, weight_decay=weight_decay)
|
||||
else:
|
||||
print(
|
||||
"[Error] optim_type should be one of sgd, adagrad, rmsprop, adam, or adadelta"
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
if resume is not None:
|
||||
print(f"[Info] resuming model from {resume} ...")
|
||||
checkpoint = torch.load(resume)
|
||||
model.load_state_dict(checkpoint["model_state"])
|
||||
optimizer.load_state_dict(checkpoint["optimizer_state"])
|
||||
|
||||
# training
|
||||
# to save loss of each training epoch in a python "list" data structure
|
||||
losses = []
|
||||
# to save accuracy on validation set of each training epoch in a python "list" data structure
|
||||
accuracy_list = []
|
||||
val_epochs = []
|
||||
|
||||
print("training...")
|
||||
for epoch in range(n_epochs):
|
||||
# set the model in training mode
|
||||
model.train()
|
||||
|
||||
# to save total loss in one epoch
|
||||
total_loss = 0.0
|
||||
|
||||
for step, (input, label) in enumerate(train_loader): # get a batch of data
|
||||
# set data type and device
|
||||
input, label = (
|
||||
input.type(torch.float).to(device),
|
||||
label.type(torch.long).to(device),
|
||||
)
|
||||
|
||||
# clear gradients in the optimizer
|
||||
optimizer.zero_grad()
|
||||
|
||||
# run the model which is the forward process
|
||||
out = model(input)
|
||||
|
||||
# compute the CrossEntropy loss, and call backward propagation function
|
||||
loss = loss_func(out, label)
|
||||
loss.backward()
|
||||
|
||||
# update parameters of the model
|
||||
optimizer.step()
|
||||
|
||||
# sum up of total loss, loss.item() return the value of the tensor as a standard python number
|
||||
# this operation is not differentiable
|
||||
total_loss += loss.item()
|
||||
|
||||
# average of the total loss for iterations
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
losses.append(avg_loss)
|
||||
|
||||
# evaluate model on validation set
|
||||
if (epoch + 1) % val_interval == 0:
|
||||
val_accuracy = eval_one_epoch(model, val_loader, device)
|
||||
accuracy_list.append(val_accuracy)
|
||||
val_epochs.append(epoch)
|
||||
print(
|
||||
"Epoch {:02d}: loss = {:.3f}, accuracy on validation set = {:.3f}".format(
|
||||
epoch + 1, avg_loss, val_accuracy
|
||||
)
|
||||
)
|
||||
|
||||
if (epoch + 1) % ckpt_save_interval == 0:
|
||||
# get info of all saved checkpoints
|
||||
ckpt_list = glob.glob(os.path.join(ckpt_path, "ckpt_epoch_*.pth"))
|
||||
# sort checkpoints by saving time
|
||||
ckpt_list.sort(key=os.path.getmtime)
|
||||
# remove surplus ckpt file if the number is larger than max_ckpt_save_num
|
||||
if len(ckpt_list) >= max_ckpt_save_num:
|
||||
for cur_file_idx in range(0, len(ckpt_list) - max_ckpt_save_num + 1):
|
||||
os.remove(ckpt_list[cur_file_idx])
|
||||
|
||||
# save model parameters in a file
|
||||
ckpt_name = os.path.join(ckpt_path, "ckpt_epoch_%d.pth" % (epoch + 1))
|
||||
save_dict = {
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
"configs": {
|
||||
"in_channels": in_channels,
|
||||
"num_classes": num_classes,
|
||||
"use_batch_norm": batch_norm,
|
||||
"use_stn": stn,
|
||||
"dropout_prob": dropout,
|
||||
},
|
||||
}
|
||||
|
||||
torch.save(save_dict, ckpt_name)
|
||||
print("Model saved in {}\n".format(ckpt_name))
|
||||
|
||||
plot(losses, accuracy_list, val_epochs, ckpt_path)
|
||||
|
||||
|
||||
def eval_one_epoch(model, val_loader, device):
|
||||
"""
|
||||
Evaluate model performance.
|
||||
--------------------------
|
||||
:param model: model
|
||||
:param val_loader: validation dataloader
|
||||
:param device: 'cpu' or 'cuda'
|
||||
:return accuracy: performance of model
|
||||
"""
|
||||
|
||||
# enter the evaluation mode
|
||||
model.eval()
|
||||
correct = 0 # number of images that are correctly classified
|
||||
n_samples = 0
|
||||
with torch.no_grad(): # we do not need to compute gradients during validation
|
||||
for input, label in val_loader:
|
||||
# set data type and device
|
||||
input, label = (
|
||||
input.type(torch.float).to(device),
|
||||
label.type(torch.long).to(device),
|
||||
)
|
||||
# get the prediction result
|
||||
pred = model(input)
|
||||
pred = torch.argmax(pred, dim=-1)
|
||||
correct += torch.sum(pred == label).item()
|
||||
n_samples += len(label)
|
||||
|
||||
# calculate accuracy
|
||||
accuracy = correct / n_samples
|
||||
|
||||
return accuracy
|
||||
|
||||
|
||||
def plot(losses, accuracy_list, val_epochs, ckpt_path):
|
||||
"""
|
||||
Draw loss and accuracy curve
|
||||
------------------
|
||||
:param losses: a list with loss of each training epoch
|
||||
:param accuracy_list: a list with accuracy on validation set of each training epoch
|
||||
"""
|
||||
|
||||
# create a plot
|
||||
f, ax1 = plt.subplots()
|
||||
|
||||
# draw loss
|
||||
ax1.plot(val_epochs, losses)
|
||||
ax2 = ax1.twinx()
|
||||
ax2.plot(val_epochs, accuracy_list, "r")
|
||||
|
||||
# set labels
|
||||
ax1.set_xlabel("training epoch")
|
||||
ax1.set_ylabel("loss")
|
||||
# ax2.set_ylim([0, 1])
|
||||
ax2.set_ylabel("accuracy")
|
||||
|
||||
# show the image
|
||||
plt.savefig(os.path.join(ckpt_path, "loss_and_acc.jpg"), dpi=300)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# set random seed for reproducibility
|
||||
seed = 2024
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# set configurations of the model and training process
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default="data",
|
||||
help="file list of training image paths and labels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--augment", action="store_true", help="whether to use data augmentation"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epoch", type=int, default=15, help="number of training epochs"
|
||||
)
|
||||
parser.add_argument("--batchsize", type=int, default=32, help="training batch size")
|
||||
parser.add_argument("--lr", type=float, default=1e-3, help="learning rate")
|
||||
parser.add_argument(
|
||||
"--momentum", type=float, default=0.9, help="momentum of SGD optimizer"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--weight_decay",
|
||||
type=float,
|
||||
default=0,
|
||||
help="the factor of L2 penalty on network weights",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optim_type",
|
||||
type=str,
|
||||
default="adam",
|
||||
help="type of optimizer, can be sgd, adagrad, rmsprop, adam, or adadelta",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bn", action="store_true", help="whether to use batch normalization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stn", action="store_true", help="whether to use spatial transformer network"
|
||||
)
|
||||
parser.add_argument("--dropout", type=float, default=0.0, help="dropout ratio")
|
||||
parser.add_argument(
|
||||
"--ckpt_path",
|
||||
type=str,
|
||||
default="checkpoints/default",
|
||||
help="path to save checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_ckpt_save_num",
|
||||
type=int,
|
||||
default=10,
|
||||
help="maximum number of saving checkpoints",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_interval", type=int, default=1, help="intervals of validation"
|
||||
)
|
||||
parser.add_argument("--resume", type=str, default=None, help="path to resume model")
|
||||
parser.add_argument("--device", type=str, help="cpu or cuda")
|
||||
|
||||
opt = parser.parse_args()
|
||||
if opt.device is None:
|
||||
opt.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
os.makedirs(opt.ckpt_path, exist_ok=True)
|
||||
|
||||
# run the training procedure
|
||||
train(
|
||||
data_root=opt.data_root,
|
||||
augment=opt.augment,
|
||||
in_channels=3,
|
||||
num_classes=26,
|
||||
batch_norm=opt.bn,
|
||||
dropout=opt.dropout,
|
||||
stn=opt.stn,
|
||||
n_epochs=opt.epoch,
|
||||
batch_size=opt.batchsize,
|
||||
lr=opt.lr,
|
||||
momentum=opt.momentum,
|
||||
weight_decay=opt.weight_decay,
|
||||
optim_type=opt.optim_type,
|
||||
ckpt_path=opt.ckpt_path,
|
||||
max_ckpt_save_num=opt.max_ckpt_save_num,
|
||||
ckpt_save_interval=1,
|
||||
val_interval=opt.val_interval,
|
||||
resume=opt.resume,
|
||||
device=opt.device,
|
||||
)
|
||||
Reference in New Issue
Block a user