{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "import torchvision.transforms as transforms" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class ConvBlock(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " kernel_size,\n", " stride,\n", " padding,\n", " use_batch_norm=False,\n", " use_residual=False,\n", " ):\n", " \"\"\"\n", " Convolutional block with batch normalization and ReLU activation\n", " ----------------------\n", " :param in_channels: channel number of input image\n", " :param out_channels: channel number of output image\n", " :param kernel_size: size of convolutional kernel\n", " :param stride: stride of convolutional operation\n", " :param padding: padding of convolutional operation\n", " :param use_batch_norm: whether to use batch normalization in convolutional layers\n", " :param use_residual: whether to use residual connection\n", " \"\"\"\n", " super().__init__()\n", "\n", " if use_batch_norm:\n", " bn2d = nn.BatchNorm2d\n", " else:\n", " # use identity function to replace batch normalization\n", " bn2d = nn.Identity\n", "\n", " self.use_residual = use_residual\n", "\n", " # >>> TODO 2.1: complete a convolutional block with batch normalization and ReLU activation\n", " # Hint: use the `bn2d` defined above for batch normalization to adapt to the input parameter `use_batch_norm`\n", " # Network structure:\n", " # conv -> batchnorm -> relu\n", " self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)\n", " self.bn = bn2d(out_channels)\n", " self.relu = nn.ReLU()\n", " # <<< TODO 2.1\n", "\n", " def forward(self, x):\n", " # >>> TODO 2.2: forward process\n", " # Hint: apply residual connection if `self.use_residual` is True\n", " out = self.relu(self.bn(self.conv(x)))\n", " if self.use_residual:\n", " out += x\n", "\n", " # <<< TODO 2.2\n", " return out\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "in_channels = 3\n", "dropout_prob = 0.5\n", "conv_net = nn.Sequential(\n", " ConvBlock(\n", " in_channels=in_channels, out_channels=32, kernel_size=5, stride=1, padding=2\n", " ),\n", " ConvBlock(in_channels=32, out_channels=64, kernel_size=5, stride=2, padding=2),\n", " nn.MaxPool2d(kernel_size=2, stride=2, padding=0),\n", " ConvBlock(\n", " in_channels=64,\n", " out_channels=64,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1,\n", " use_residual=True,\n", " ),\n", " ConvBlock(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),\n", " nn.MaxPool2d(kernel_size=2, stride=2, padding=0),\n", " ConvBlock(\n", " in_channels=128,\n", " out_channels=128,\n", " kernel_size=3,\n", " stride=1,\n", " padding=1,\n", " use_residual=True,\n", " ),\n", " nn.Dropout2d(p=dropout_prob),\n", ")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 128, 4, 4])\n", "ConvBlock(\n", " (conv): Conv2d(32, 64, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))\n", " (bn): Identity()\n", " (relu): ReLU()\n", ")\n" ] } ], "source": [ "a = torch.randn(10, 3, 32, 32)\n", "print(conv_net(a).size())\n", "print(conv_net[1])" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10, 8, 16, 16])\n", "torch.Size([10, 16, 8, 8])\n" ] } ], "source": [ "conv_1 = ConvBlock(in_channels=3, out_channels=8, kernel_size=9, stride=2, padding=4, use_batch_norm=True)\n", "conv_2 = ConvBlock(in_channels=8, out_channels=16, kernel_size=5, stride=2, padding=2, use_batch_norm=True)\n", "\n", "print(conv_1(a).size())\n", "print(conv_2(conv_1(a)).size())\n" ] } ], "metadata": { "kernelspec": { "display_name": "media_cognition", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }