Files
MediaNCognition/hw4/code/attnvis.ipynb
2024-05-22 20:22:47 +08:00

105 lines
3.1 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"import os\n",
"from contextlib import nullcontext\n",
"import torch\n",
"from model import GPTConfig, GPT\n",
"from bertviz import head_view\n",
"from dataset import Converter, LMDataset\n",
"\n",
"# set random seed for reproducibility\n",
"seed = 2024\n",
"torch.manual_seed(seed)\n",
"torch.cuda.manual_seed(seed)\n",
"torch.cuda.manual_seed_all(seed)\n",
"torch.backends.cudnn.deterministic = True\n",
"torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
"torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
"\n",
"#################################################\n",
"# \n",
"model_name = 'mygpt'\n",
"ckpt_path = 'workdirs/quansongci'\n",
"data_root = 'data/quansongci'\n",
"vis_text_path = 'data/vis/vis_1.txt'\n",
"#################################################\n",
"\n",
"device = 'cpu'\n",
"\n",
"dataset = LMDataset(data_root, 'train')\n",
"converter = Converter(dataset.stoi, dataset.itos)\n",
"\n",
"\n",
"with open(vis_text_path, 'r', encoding='utf-8') as f:\n",
" start = f.read()\n",
"start_ids = converter.single_encode(start)\n",
"start_texts = [c for c in start]\n",
"x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])\n",
"print(f\"Input texts: {start}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c0792738",
"metadata": {},
"outputs": [],
"source": [
"# model\n",
"dtype = 'float16' # 'float32' or 'bfloat16' or 'float16'\n",
"ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n",
"ctx = nullcontext() if device == 'cpu' else torch.autocast(device_type=device, dtype=ptdtype)\n",
"# init from a model saved in a specific directory\n",
"ckpt_path = os.path.join(ckpt_path, 'best.pth')\n",
"print(\"loading model params from %s\"%ckpt_path)\n",
"checkpoint = torch.load(ckpt_path, map_location=device)\n",
"gptconf = GPTConfig[model_name]\n",
"if 'model_args' in checkpoint:\n",
" gptconf = checkpoint['model_args']\n",
"model = GPT(**gptconf)\n",
"state_dict = checkpoint['state_dict']\n",
"model.load_state_dict(state_dict)\n",
"\n",
"model.eval()\n",
"model.to(device)\n",
"\n",
"# run generation\n",
"with torch.no_grad():\n",
" with ctx:\n",
" _, attn_weights = model(x)\n",
"\n",
"head_view(attn_weights, start_texts)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}