| --- |
| license: mit |
| tags: |
| - gpt |
| - hypernetwork |
| - lora |
| - adaptive-attention |
| - diffusion |
| - language-model |
| - from-scratch |
| language: |
| - en |
| --- |
| |
| # Adaptive GPTs |
|
|
| Five small GPT variants trained from scratch for 12k steps on a ~300M token mixed corpus (TinyStories, Cosmopedia WikiHow, WikiText-2, Python code). The core experiment: can a hypernetwork that reads the residual stream generate dynamic LoRA-style updates to the value heads at inference time? |
|
|
| **Blog post:** [Adaptive Attention at Inference Time: Does It Actually Work?](https://teendifferent.substack.com/p/adaptive-attention-at-inference-time) |
|
|
| **Code:** [REDDITARUN/a-gpt](https://github.com/REDDITARUN/a-gpt) |
|
|
| ## Checkpoints |
|
|
| | File | Model | Params | Description | |
| |---|---|---|---| |
| | `base_best.pt` | Base GPT | 28.9M | Vanilla causal GPT. 4 layers, 4 heads, 256 dim. | |
| | `matched_best.pt` | Matched GPT | 30.5M | 6 layers, 4 heads, 256 dim. Parameter-matched to adaptive. | |
| | `adaptive_best.pt` | Adaptive GPT | 30.5M | Base + per-head TinyHeadTransformer hypernetwork on V. | |
| | `diffusion_best.pt` | Diffusion GPT | 28.9M | Bidirectional denoising (discrete diffusion). | |
| | `adaptive_diffusion_best.pt` | Adaptive Diffusion GPT | 30.5M | Diffusion + per-head hypernetwork on V. | |
|
|
| ## Training |
|
|
| All models share: AdamW, lr 3e-4 with cosine decay, batch size 8, 4 gradient accumulation steps, 256 context window, GPT-2 tokenizer, RoPE. |
|
|
| Data mixture: Cosmopedia WikiHow (40%), Python code (30%), WikiText-2 (20%), TinyStories (10%). |
|
|
| ## Loading |
|
|
| ```python |
| import torch |
| from models.base_gpt import GPT_Base, GPT_Base_Config |
| from models.a_gpt import GPT_Custom, GPT_Custom_Config, HyperConfig |
| |
| # Example: load adaptive checkpoint |
| config = GPT_Custom_Config() |
| hyper_config = HyperConfig() |
| model = GPT_Custom(config, hyper_config=hyper_config) |
| |
| ckpt = torch.load("adaptive_best.pt", map_location="cpu") |
| model.load_state_dict(ckpt["model"]) |
| |
| |