105 lines
3.1 KiB
Plaintext
105 lines
3.1 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"collapsed": false
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import os\n",
|
|
"from contextlib import nullcontext\n",
|
|
"import torch\n",
|
|
"from model import GPTConfig, GPT\n",
|
|
"from bertviz import head_view\n",
|
|
"from dataset import Converter, LMDataset\n",
|
|
"\n",
|
|
"# set random seed for reproducibility\n",
|
|
"seed = 2024\n",
|
|
"torch.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed(seed)\n",
|
|
"torch.cuda.manual_seed_all(seed)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul\n",
|
|
"torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn\n",
|
|
"\n",
|
|
"#################################################\n",
|
|
"# \n",
|
|
"model_name = 'mygpt'\n",
|
|
"ckpt_path = 'workdirs/quansongci'\n",
|
|
"data_root = 'data/quansongci'\n",
|
|
"vis_text_path = 'data/vis/vis_1.txt'\n",
|
|
"#################################################\n",
|
|
"\n",
|
|
"device = 'cpu'\n",
|
|
"\n",
|
|
"dataset = LMDataset(data_root, 'train')\n",
|
|
"converter = Converter(dataset.stoi, dataset.itos)\n",
|
|
"\n",
|
|
"\n",
|
|
"with open(vis_text_path, 'r', encoding='utf-8') as f:\n",
|
|
" start = f.read()\n",
|
|
"start_ids = converter.single_encode(start)\n",
|
|
"start_texts = [c for c in start]\n",
|
|
"x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])\n",
|
|
"print(f\"Input texts: {start}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "c0792738",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# model\n",
|
|
"dtype = 'float16' # 'float32' or 'bfloat16' or 'float16'\n",
|
|
"ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]\n",
|
|
"ctx = nullcontext() if device == 'cpu' else torch.autocast(device_type=device, dtype=ptdtype)\n",
|
|
"# init from a model saved in a specific directory\n",
|
|
"ckpt_path = os.path.join(ckpt_path, 'best.pth')\n",
|
|
"print(\"loading model params from %s\"%ckpt_path)\n",
|
|
"checkpoint = torch.load(ckpt_path, map_location=device)\n",
|
|
"gptconf = GPTConfig[model_name]\n",
|
|
"if 'model_args' in checkpoint:\n",
|
|
" gptconf = checkpoint['model_args']\n",
|
|
"model = GPT(**gptconf)\n",
|
|
"state_dict = checkpoint['state_dict']\n",
|
|
"model.load_state_dict(state_dict)\n",
|
|
"\n",
|
|
"model.eval()\n",
|
|
"model.to(device)\n",
|
|
"\n",
|
|
"# run generation\n",
|
|
"with torch.no_grad():\n",
|
|
" with ctx:\n",
|
|
" _, attn_weights = model(x)\n",
|
|
"\n",
|
|
"head_view(attn_weights, start_texts)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|