qwen3-norm-ablation

Normalization ablation on Qwen3-0.6B (~752M parameters). Eight architectural variants Γ— two training stages (base = randomly-initialized weights with the modified architecture; finetuned = same arch trained on a downstream corpus). Studies which normalization layers (QK-norm, layer-norm, all norms, alternative norms) are load-bearing.

Variants

Each variant has both a base (*/) and a finetuned (*_finetuned/) checkpoint:

Variant What changed
original Stock Qwen3-0.6B (RMSNorm everywhere) β€” control
no_qk_norm QK-norm removed
no_layer_norm Pre/post layer RMSNorm removed
no_all_norm All RMSNorm layers removed
replace_l2norm RMSNorm β†’ L2Norm
replace_layernorm RMSNorm β†’ LayerNorm
replace_scalenorm RMSNorm β†’ ScaleNorm

The original model has 113 Qwen3RMSNorm modules β€” ablations selectively remove or substitute these.

Files

  • <variant>/model.pt β€” base (modified-arch, untrained) state dict, ~1.5 GB
  • <variant>/metadata.json β€” variant name, base model id, norm counts, num params
  • <variant>_finetuned/model.pt β€” finetuned state dict, ~9.9 GB (full precision)
  • <variant>_finetuned/checkpoint-9500/, checkpoint-9753/ β€” intermediate Trainer checkpoints
  • <variant>_finetuned/train_metrics.json β€” per-step training metrics
  • <variant>_finetuned/{tokenizer.json,tokenizer_config.json,chat_template.jinja} β€” tokenizer (shared across variants, also in tokenizer/)
  • all_train_metrics.json β€” aggregated final metrics across all variants

Headline numbers

From all_train_metrics.json (3 epochs each, identical data):

Variant Final train loss
original 1.331
no_qk_norm 6.567
no_layer_norm 0.000 (collapsed)
no_all_norm 0.000 (collapsed)

(Variants reporting train_loss = 0.0 collapsed during training β€” useful as negative results showing those norms are required.)

Usage

import torch
from transformers import AutoModelForCausalLM

# Load the modified architecture (you need the matching arch code from the project)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B")
state = torch.load("model.pt", map_location="cpu")
model.load_state_dict(state, strict=False)

Architectural changes are not in stock transformers β€” you need to apply the variant's structural patch (drop/replace norm layers) before loading. See nanochat_adapter.py and the project's ablation harness for reference.

Notes

This is research scaffolding, not a polished release. Checkpoints are full-precision FP32 (β‰ˆ 9.9 GB each finetuned) β€” quantize before inference.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for AlexWortega/qwen3-norm-ablation

Finetuned
Qwen/Qwen3-0.6B
Finetuned
(897)
this model