Files
MediaNCognition/hw2/code/datasets.py
2024-04-11 14:20:28 +08:00

78 lines
3.5 KiB
Python

# ========================================================
# Media and Cognition
# Homework 2 Convolutional Neural Network
# datasets.py - Define the data loader for the traffic sign classification dataset
# Student ID: 2022010639
# Name: Gao Yixuan
# Tsinghua University
# (C) Copyright 2024
# ========================================================
import os
import numpy as np
import torchvision.transforms.v2 as transforms
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
def get_data_loader(
data_root, mode, image_size, batch_size, num_workers=0, augment=False
):
"""
Get the data loader for the specified dataset and mode.
:param data_root: the root directory of the whole dataset
:param mode: the mode of the dataset, which can be 'train', 'val', or 'test'
:param image_size: the target image size for resizing
:param batch_size: the batch size
:param num_workers: the number of workers for loading data in multiple processes
:param augment: whether to use data augmentation
:return: a data loader
"""
# >>> TODO 1.1: Define the data transform.
# You should use the `transforms` module from `torchvision`.
# Docs: https://pytorch.org/vision/stable/transforms.html
# You can create a list of shared transforms among the training, validation, and testing datasets, and
# modify the list according to the mode and the `augment` parameter later.
# The shared transforms include:
# (1) resize the images to `image_size`,
# (2) convert the images to PyTorch tensors
# (3) normalize the pixel values to [-1, 1]
data_transforms = [
# transforms.ToImage(),
transforms.Resize(image_size),
transforms.ToImage(),
transforms.ToDtype(torch.float32, scale=True),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
# You should insert some data augmentation techniques to `data_transforms` when `augment` is True
# for the training dataset.
# Consider what is an appropriate data augmentation technique for traffic sign classification.
if mode == "train" and augment:
# pass # TODO
# data_transforms.append(transforms.AutoAugment())
data_transforms.append(transforms.RandomAffine(degrees=30,shear=10))
data_transforms.append(transforms.RandomAutocontrast())
# Else, the `data_transforms` should be left unchanged
# <<< TODO 1.1
# Use `transforms.Compose` to compose the list of transforms into a single transform
data_transforms = transforms.Compose(data_transforms)
# >>> TODO 1.2: Define the dataset.
# You should build the path to the selected dataset according to the `mode` parameter,
# and use the `ImageFolder` class from `torchvision.datasets` to load the datasets.
# Docs: https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html
# The `ImageFolder` class takes in the path to the dataset and the transform to apply to the images.
# The `ImageFolder` class will automatically load the images and labels for you.
dataset = ImageFolder(root=data_root + "/" + mode, transform=data_transforms)
# <<< TODO 1.2
# >>> TODO 1.3: Define the data loader.
# You should set the `shuffle` parameter to `True` when `mode=='train'`, and `False` otherwise.
loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=(mode=='train'))
# <<< TODO 1.3
return loader