{ "cells": [ { "cell_type": "code", "execution_count": 8, "id": "0fc5df88", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 75.79it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "TOPTransformerForCausalLM(\n", " (model): TOPTransformerModel(\n", " (embeddings): Embedding(32000, 2048, padding_idx=2)\n", " (layers): ModuleList(\n", " (0-31): 32 x TOPTransformerBlock(\n", " (attn_norm): RMSNorm(2048, eps=1e-06)\n", " (attn): Attention(\n", " (q_proj): Linear(in_features=2048, out_features=2048, bias=False)\n", " (k_proj): Linear(in_features=2048, out_features=2048, bias=False)\n", " (v_proj): Linear(in_features=2048, out_features=2048, bias=False)\n", " (o_proj): Linear(in_features=2048, out_features=2048, bias=False)\n", " (rotary): RotaryEmbedding(dim=64, base=10000.0, interleaved=False, pos_idx_in_fp32=True)\n", " )\n", " (mlp_norm): RMSNorm(2048, eps=1e-06)\n", " (mlp): GatedMLP(\n", " (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)\n", " (up_proj): Linear(in_features=2048, out_features=5632, bias=False)\n", " (down_proj): Linear(in_features=5632, out_features=2048, bias=False)\n", " (swiglu_linear): SwiGLULinear()\n", " )\n", " )\n", " )\n", " (norm): RMSNorm(2048, eps=1e-06)\n", " )\n", " (lm_head): Linear(in_features=2048, out_features=32000, bias=False)\n", " (top_head): Linear(in_features=2048, out_features=32000, bias=False)\n", " (top_criterion): FusedLinearListNetLoss()\n", ")\n" ] } ], "source": [ "import os\n", "import fla\n", "from torch import float16\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "# from fla import TransformerForCausalLM\n", "name = 'zaydzuhri/top-code-1.8B-4096-model'\n", "tokenizer = AutoTokenizer.from_pretrained(name)\n", "model = AutoModelForCausalLM.from_pretrained(name).cuda().half()\n", "print(model)" ] }, { "cell_type": "code", "execution_count": 9, "id": "90237353", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "len input ids 103\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "You are an expert Python programmer, and here is your task: Write a function to find sequences of lowercase letters joined with an underscore using regex. Your code should pass these tests:\n", "\n", "assert text_match(\"aab_cbbbc\") == 'Found a match!'\n", "assert text_match(\"aab_Abbbc\") == 'Not matched!'\n", "assert text_match(\"Aaab_abbbc\") == 'Not matched!'\n", "[BEGIN]\n", "def text_match(text):\n", " return 'Found a match!' if re.search(r'^[a-z]+_[a-z]+$', text) else 'Not matched!'\n", "[END]\n", "\n", "#!/bin/python3\n", "\n", "import math\n", "import os\n", "import random\n", "import re\n", "import sys\n", "\n", "# Complete the text_match function below.\n", "def text_match(text):\n", " return 'Found a match!' if re.search(r'^[a-z]+_[a-z]+$', text) else 'Not matched!'\n", "\n", "if __name__ == '__main__':\n", " fptr = open(os.environ['OUTPUT_PATH'], 'w')\n", "\n", " text = input()\n", "\n", " result = text_match(text)\n", "\n", " fptr.write(result + '\\n')\n", "\n", " fptr.close()\n", " import numpy\n" ] } ], "source": [ "# input_prompt = \"According to all known laws of aviation, there is no way a bee should be able to\"\n", "# input_prompt = \"def fibonacci(n):\"\n", "# input_prompt = \"And I summon pot of greed, which allows me to\"\n", "# input_prompt = \"from typing import List\\n\\n\\ndef mean_absolute_deviation(numbers: List[float]) -> float:\\n \\\"\\\"\\\" For a given list of input numbers, calculate Mean Absolute Deviation\\n around the mean of this dataset.\\n Mean Absolute Deviation is the average absolute difference between each\\n element and a centerpoint (mean in this case):\\n MAD = average | x - x_mean |\\n >>> mean_absolute_deviation([1.0, 2.0, 3.0, 4.0])\\n 1.0\\n \\\"\\\"\\\"\\n\"\n", "# input_prompt = \"from typing import List, Tuple\\n\\n\\ndef sum_product(numbers: List[int]) -> Tuple[int, int]:\\n \\\"\\\"\\\" For a given list of integers, return a tuple consisting of a sum and a product of all the integers in a list.\\n Empty sum should be equal to 0 and empty product should be equal to 1.\\n >>> sum_product([])\\n (0, 1)\\n >>> sum_product([1, 2, 3, 4])\\n (10, 24)\\n \\\"\\\"\\\"\\n\"\n", "input_prompt = \"You are an expert Python programmer, and here is your task: Write a function to find sequences of lowercase letters joined with an underscore using regex. Your code should pass these tests:\\n\\nassert text_match(\\\"aab_cbbbc\\\") == 'Found a match!'\\nassert text_match(\\\"aab_Abbbc\\\") == 'Not matched!'\\nassert text_match(\\\"Aaab_abbbc\\\") == 'Not matched!'\\n[BEGIN]\\n\"\n", "inputs = tokenizer(input_prompt, return_tensors=\"pt\")\n", "input_ids = inputs.input_ids.cuda()\n", "print(\"len input ids\", len(input_ids[0]))\n", "# attention_mask = inputs.attention_mask.cuda()\n", "attention_mask = None\n", "outputs = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=200, do_sample=True, temperature=0.2, use_cache=True)\n", "print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])" ] } ], "metadata": { "kernelspec": { "display_name": "flame-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }