From 3747678e61453966d7392c15b6a4d4e6847343b5 Mon Sep 17 00:00:00 2001 From: unlockable Date: Tue, 9 Apr 2024 21:42:26 +0800 Subject: [PATCH] Correct image normalization --- .gitignore | 3 ++- hw2/code/datasets.py | 4 +--- hw2/code/networks.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index b14374f..13d9ea4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,5 @@ __pycache__/ *.synctex.gz(buzy) *.out *.pdf -.DS_Store \ No newline at end of file +.DS_Store +hw2/code/checkpoints/ \ No newline at end of file diff --git a/hw2/code/datasets.py b/hw2/code/datasets.py index cd25bcc..21153cb 100644 --- a/hw2/code/datasets.py +++ b/hw2/code/datasets.py @@ -44,7 +44,7 @@ def get_data_loader( transforms.Resize(image_size), transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), - transforms.Normalize(mean=[-127.0, -127.0, -127.0], std=[128.0, 128.0, 128.0]) + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ] # You should insert some data augmentation techniques to `data_transforms` when `augment` is True @@ -58,8 +58,6 @@ def get_data_loader( # Use `transforms.Compose` to compose the list of transforms into a single transform data_transforms = transforms.Compose(data_transforms) - print(type(data_transforms)) - # >>> TODO 1.2: Define the dataset. # You should build the path to the selected dataset according to the `mode` parameter, # and use the `ImageFolder` class from `torchvision.datasets` to load the datasets. diff --git a/hw2/code/networks.py b/hw2/code/networks.py index a4b2099..138f756 100644 --- a/hw2/code/networks.py +++ b/hw2/code/networks.py @@ -187,7 +187,7 @@ class Classifier(nn.Module): # Step 3: use `Tensor.view()` to flatten the tensor to match the size of the input of the # fully connected layers. - x = x.view(-1, 2048) + x = x.view(x.shape[0], -1) # Step 4: forward process for the fully connected network out = self.fc_net(x) @@ -241,8 +241,8 @@ class STN(nn.Module): # Suggested structure: 2 linear layers with one BN and ReLU. self.localization_fc = nn.Sequential( nn.Linear(16, 256), - nn.Linear(256, 360), - nn.BatchNorm1d(360), + nn.Linear(256, 6), + nn.BatchNorm1d(6), nn.ReLU() ) # <<< TODO 4.1