104 lines
2.0 KiB
Plaintext
104 lines
2.0 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[1, 2, 3],\n",
|
|
" [4, 5, 6]])\n",
|
|
"a_max tensor([[3],\n",
|
|
" [6]])\n",
|
|
"tensor([[-2, -1, 0],\n",
|
|
" [-2, -1, 0]])\n",
|
|
"tensor([ 6, 15])\n",
|
|
"torch.Size([2, 3])\n",
|
|
"3\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"a = torch.tensor([[1, 2, 3], [4, 5, 6]])\n",
|
|
"print(a)\n",
|
|
"a_max = torch.max(a, 1, keepdim=True).values\n",
|
|
"print(\"a_max\", a_max)\n",
|
|
"a_max = torch.reshape(a_max, (2, 1))\n",
|
|
"print(a - a_max)\n",
|
|
"b = torch.sum(a, 1)\n",
|
|
"print(b)\n",
|
|
"print(a.size())\n",
|
|
"print(a.size()[1])\n",
|
|
"print(a.reshape(-1,))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"tensor([[0, 1, 0],\n",
|
|
" [0, 0, 1]])\n",
|
|
"tensor([[0, 2, 0],\n",
|
|
" [0, 0, 6]])\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"label = torch.tensor([1, 2])\n",
|
|
"one_hot = torch.nn.functional.one_hot(label, 3)\n",
|
|
"print(one_hot)\n",
|
|
"print(one_hot * a)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"a = torch.tensor([[1, 2], [3, 4]])\n",
|
|
"b = torch.tensor([5, 6])\n",
|
|
"print(torch.matmul(a, b))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "media_cognition",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.2"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|