Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation
Paper
โข
2108.12409
โข
Published
โข
5
MPTK-1B๋ ํ๊ตญ์ด/์์ด์ฝ๋ ๋ฐ์ดํฐ์ ์์ ํ์ต๋ 1.3B ํ๋ผ๋ฏธํฐ์ decoder-only transformer ์ธ์ด๋ชจ๋ธ์ ๋๋ค.
์ด ๋ชจ๋ธ์ ๊ตฌ๊ธ์ TPU Research Cloud(TRC)๋ฅผ ํตํด ์ง์๋ฐ์ Cloud TPU๋ก ํ์ต๋์์ต๋๋ค.
๋ค๋ฅธ decoder-only transformer์์ ์ผ๋ถ ์์ ๋ ์ํคํ ์ฒ์ธ MPT๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํฉ๋๋ค.
| Hyperparameter | Value |
|---|---|
| n_parameters | 1.3B |
| n_layers | 24 |
| n_heads | 16 |
| d_model | 2048 |
| vocab size | 50432 |
| sequence length | 2048 |
fp16์ผ๋ก ์คํ ์ NaN์ด ๋ฐ์ํ ์ ์์ต๋๋ค. ๋ฐ๋ผ์ fp32 ํน์ bf16๋ก ์คํํ๊ธฐ๋ฅผ ๊ถ์ฅํฉ๋๋ค.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
tokenizer = AutoTokenizer.from_pretrained("team-lucid/mptk-1b")
model = AutoModelForCausalLM.from_pretrained("team-lucid/mptk-1b")
pipe = pipeline('text-generation', model=model, tokenizer=tokenizer, device='cuda:0')
with torch.autocast('cuda', dtype=torch.bfloat16):
print(
pipe(
'๋ํ๋ฏผ๊ตญ์ ์๋๋',
max_new_tokens=100,
do_sample=True,
)
)
OSCAR, mC4, wikipedia, namuwiki ๋ฑ ํ๊ตญ์ด ๋ฐ์ดํฐ์ RefinedWeb, The Stack ์์ ์ผ๋ถ๋ฅผ ์ถ๊ฐํด ํ์ตํ์์ต๋๋ค.
| Hyperparameter | Value |
|---|---|
| Precision | bfloat16 |
| Optimizer | Lion |
| Learning rate | 2e-4 |
| Batch size | 1024 |