# ======================================================== # Media and Cognition # Homework 2 Convolutional Neural Network # test.py - Test our model for character classification # Student ID: # Name: # Tsinghua University # (C) Copyright 2024 # ======================================================== import argparse import os import string import cv2 import torch from datasets import get_data_loader from networks import Classifier def test(data_root, ckpt_path, epoch, save_results, device="cpu"): """ The main testing procedure ---------------------------- :param data_root: path to the root directory of dataset :param ckpt_path: path to load checkpoints :param epoch: epoch of checkpoint you want to load :param save_results: whether to save results :param device: 'cpu' or 'cuda', we can use 'cpu' for our homework if GPU with cuda support is not available """ if save_results: save_dir = os.path.join(ckpt_path, "results") if not os.path.exists(save_dir): os.mkdir(save_dir) # construct testing data loader test_loader = get_data_loader(data_root, "test", image_size=(32, 32), batch_size=1) print( "[Info] loading checkpoint from %s ..." % os.path.join(ckpt_path, "ckpt_epoch_%d.pth" % epoch) ) checkpoint = torch.load(os.path.join(ckpt_path, "ckpt_epoch_%d.pth" % epoch)) configs = checkpoint["configs"] model = Classifier( configs["in_channels"], configs["num_classes"], configs["use_batch_norm"], configs["use_stn"], configs["dropout_prob"], ) # load model parameters (checkpoint['model_state']) we saved in model_path using model.load_state_dict() model.load_state_dict(checkpoint["model_state"]) # put the model on CPU or GPU model = model.to(device) # enter the evaluation mode model.eval() correct = 0 n = 0 letters = string.ascii_letters[-26:] for input, label in test_loader: # set data type and device input, label = ( input.type(torch.float).to(device), label.type(torch.long).to(device), ) # get the prediction result pred = model(input) pred = torch.argmax(pred, dim=-1) label = label.squeeze(dim=0) # set the name of saved images to 'idx_correct/wrong_label_pred.jpg' if pred == label: correct += 1 save_name = "%04d_correct_%s_%s.jpg" % ( n, letters[int(label)], letters[int(pred)], ) else: save_name = "%04d_wrong_%s_%s.jpg" % ( n, letters[int(label)], letters[int(pred)], ) if save_results: img = ( 255 * (input * 0.5 + 0.5).squeeze(0).permute(1, 2, 0).detach().cpu().numpy() ) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) cv2.imwrite(os.path.join(save_dir, save_name), img) n += 1 # calculate accuracy accuracy = float(correct) / float(len(test_loader)) print("accuracy on the test set: %.3f" % accuracy) if save_results: print("results saved to %s" % save_dir) if __name__ == "__main__": parser = argparse.ArgumentParser() # set configurations of the testing process parser.add_argument("--path", type=str, default="data", help="path to data file") parser.add_argument( "--epoch", type=int, default=15, help="epoch of checkpoint you want to load" ) parser.add_argument( "--ckpt_path", type=str, default="ckpt", help="path to load checkpoints" ) parser.add_argument( "--save", action="store_true", default=False, help="whether to save results" ) parser.add_argument("--device", type=str, help="cpu or cuda") opt = parser.parse_args() if opt.device is None: opt.device = "cuda" if torch.cuda.is_available() else "cpu" # run the testing procedure test( data_root=opt.path, ckpt_path=opt.ckpt_path, epoch=opt.epoch, save_results=opt.save, device=opt.device, )