Files
MediaNCognition/hw1/HW1-code/losses.py
unlockable 8b657be441 Mac Sync
2024-05-15 20:05:18 +08:00

118 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#========================================================
# Media and Cognition
# Homework 1 Neural network basics
# losses.py - loss functions
# Student ID: 2022010639
# Name: Gao Yixuan
# Tsinghua University
# (C) Copyright 2024
#========================================================
import torch
import torch.nn.functional as F
'''
In this script we will implement our MSE and Cross Entropy loss functions, including both the forward and backward processes.
More details about customizing a backward process can be found in:
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
'''
# here is the sample code of MSELoss
# you can use this as reference to implement the CrossEntropyLoss
class MSELoss(torch.autograd.Function):
'''
MSE loss function
loss = (label - pred) ** 2
'''
@staticmethod
def forward(ctx, pred, label):
"""
:param pred: prediction with shape [batch_size, *], where means additional dimensions
:param label: groundtruth, same shape as the predition
:return: MSE loss, averaged by batch_size
"""
# step 1: here we compute the summation of loss for each element and save both pred and label in ctx
loss = torch.sum((pred - label) ** 2)
ctx.save_for_backward(pred, label)
return loss
@staticmethod
def backward(ctx, grad_output):
"""
:param grad_output: for loss function, grad_output will be 1
"""
# step 2: get pred and label from ctx and calculate the derivative of loss w.r.t. pred (dL/dpred)
pred, label = ctx.saved_tensors
grad_input = grad_output * 2 * (pred - label)
# return None for gradient of label since we do not need to compute dL/dlabel
return grad_input, None
#TODO 1: Complete the CrossEntropyLoss loss function
class CrossEntropyLoss(torch.autograd.Function):
'''
Cross entropy loss function:
loss = - log q_i
where
q_i = softmax(z_i) = exp(z_i) / (exp(z_0) + exp(z_1) + ...)
However, when z_i has a lager value, exp(z_i) might become infinity.
So we use stable softmax:
softmax(z_i) = A exp(z_i) / A (exp(z_0) + exp(z_1) + ...)
where
A = exp(-z_max) = exp(-max{z_0, z_1, ...})
therefore we have
softmax(z_i) = softmax(z_i - z_max)
'''
@staticmethod
def forward(ctx, logits, label):
"""
:param logits: logits with shape [batch_size, n_classes], denoted by "z" in the above formula
:param label: groundtruth with shape [batch_size], where 0 <= label[i] < n_classes - 1
:return: cross entropy loss, averaged by batch_size
"""
# step 1: calculate softmax(z) using stable softmax method
# hint: you can use torch.exp(x) to calculate exp(x), and remember to convert label into one-hot version
#e.g., if label = [0, 2] and n_classes=4, then the one-hot version is [[1,0,0,0], [0,0,1,0]]
# calculate z_max
z_max = torch.max(logits, 1, keepdim=True).values # of size [batch_size]
# calculate exps = exp(z - z_max)
exps = torch.exp(logits - z_max) # of size [batch_size, n_classes]
# calculate q = softmax(y - y_max)
sums = torch.sum(exps, 1) # of size [batch_size]
# print(exps.size(), sums.size())
# print(sums.reshape(-1, 1))
q = exps / sums.reshape(-1, 1)
# step 2: convert label into one-hot version
# e.g., if label = [0, 2] and n_classes=4, then the one-hot version is [[1,0,0,0], [0,0,1,0]]
# the converted label has shape [batch_size, n_classes]
# tips: you can use torch.nn.functional.one_hot() to convert label into one-hot vector with dimension n_classes
one_hot_label = torch.nn.functional.one_hot(label, logits.size()[1])
# step 3: calculate cross entropy loss = - log q_i, and averaged by batch
# save result of softmax and one-hot label in ctx for gradient computation
cross_entropy = -torch.sum(torch.log(torch.sum(q * one_hot_label, 1))) / label.size()[0]
ctx.save_for_backward(q, one_hot_label)
return cross_entropy
@staticmethod
def backward(ctx, grad_output):
# step 4: get q and label from ctx and calculate the derivative of loss w.r.t. pred (dL/dz)
q, label = ctx.saved_tensors
grad_input = grad_output * (q - label)
# return the pred (dL/dz) and None for dL/dlabel since we do not need to compute dL/dlabel
return grad_input, None