Files
MediaNCognition/hw3/code/check.py
2024-05-01 17:13:51 +08:00

41 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# ========================================================
# Media and Cognition
# Homework 3 Support Vector Machine
# check.py - Check your implementation of several modules
# Tsinghua University
# (C) Copyright 2024
# ========================================================
from svm_hw import SVM_HINGE, LinearFunction, Hinge
import torch
from torch.autograd import gradcheck
def run():
model = SVM_HINGE(2, C=1.0).double()
x = torch.randn(50, 2, requires_grad=False).double()
W = torch.randn(1, 2, requires_grad=True).double()
b = torch.zeros(1, requires_grad=True).double()
test = gradcheck(LinearFunction.apply, (x, W, b), eps=1e-6, atol=1e-4)
if test:
print('Linear successully tested!')
output = torch.randn(50, 1, requires_grad=True).double()
W = torch.randn(1, 2, requires_grad=True).double()
labels = torch.ones(1, requires_grad=False).double()
C = torch.tensor([[1.0]], requires_grad=False).double()
test = gradcheck(Hinge.apply, (output, W, labels, C), eps=1e-6, atol=1e-5)
if test:
print('Hinge successfully tested')
x = torch.randn(50, 2, requires_grad=False).double()
labels = torch.ones(50, requires_grad=False).double()
try:
output, loss = model(x, labels)
assert model.W.requires_grad is True
assert model.b.requires_grad is True
print('SVM_HINGE successfully tested')
except:
raise Exception('Failed testing SVM_HINGE!')
if __name__ == '__main__':
run()