qwen3-norm-ablation
Normalization ablation on Qwen3-0.6B (~752M parameters). Eight architectural variants Γ two training stages (base = randomly-initialized weights with the modified architecture; finetuned = same arch trained on a downstream corpus). Studies which normalization layers (QK-norm, layer-norm, all norms, alternative norms) are load-bearing.
Variants
Each variant has both a base (*/) and a finetuned (*_finetuned/) checkpoint:
| Variant | What changed |
|---|---|
original |
Stock Qwen3-0.6B (RMSNorm everywhere) β control |
no_qk_norm |
QK-norm removed |
no_layer_norm |
Pre/post layer RMSNorm removed |
no_all_norm |
All RMSNorm layers removed |
replace_l2norm |
RMSNorm β L2Norm |
replace_layernorm |
RMSNorm β LayerNorm |
replace_scalenorm |
RMSNorm β ScaleNorm |
The original model has 113 Qwen3RMSNorm modules β ablations selectively remove or substitute these.
Files
<variant>/model.ptβ base (modified-arch, untrained) state dict, ~1.5 GB<variant>/metadata.jsonβ variant name, base model id, norm counts, num params<variant>_finetuned/model.ptβ finetuned state dict, ~9.9 GB (full precision)<variant>_finetuned/checkpoint-9500/,checkpoint-9753/β intermediate Trainer checkpoints<variant>_finetuned/train_metrics.jsonβ per-step training metrics<variant>_finetuned/{tokenizer.json,tokenizer_config.json,chat_template.jinja}β tokenizer (shared across variants, also intokenizer/)all_train_metrics.jsonβ aggregated final metrics across all variants
Headline numbers
From all_train_metrics.json (3 epochs each, identical data):
| Variant | Final train loss |
|---|---|
| original | 1.331 |
| no_qk_norm | 6.567 |
| no_layer_norm | 0.000 (collapsed) |
| no_all_norm | 0.000 (collapsed) |
(Variants reporting train_loss = 0.0 collapsed during training β useful as negative results showing those norms are required.)
Usage
import torch
from transformers import AutoModelForCausalLM
# Load the modified architecture (you need the matching arch code from the project)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
state = torch.load("model.pt", map_location="cpu")
model.load_state_dict(state, strict=False)
Architectural changes are not in stock transformers β you need to apply the variant's structural patch (drop/replace norm layers) before loading. See nanochat_adapter.py and the project's ablation harness for reference.
Notes
This is research scaffolding, not a polished release. Checkpoints are full-precision FP32 (β 9.9 GB each finetuned) β quantize before inference.