It doesn’t seem to be due to parameters or anything like that…
This is a bf16 forward-path failure in the attention stack, not a parameter or LR issue. It shows at step 0 because fused SDPA/Flash softmax overflows before any update. FP32 “fixes” it by avoiding that brittle path. Keep adapters in fp32, keep master weights in fp32, and pick a safe attention backend.
Causes
-
BF16 SDPA/Flash fused kernels overflow on Hopper.
SDPA auto-selects a fused backend. In bf16 this path can yield NaNs or 0.0 loss on the first batch. Switching to attn_implementation="eager" removes it; users report instant stability with the same data and model. PyTorch also documents backend-dependent numerics and that the math backend keeps bf16 intermediates in float32. (GitHub)
-
Backend regressions in cuDNN SDPA.
cuDNN release notes list SDPA fixes on Ampere/Hopper/Blackwell. Real-world runs stop exploding when disabling cuDNN SDPA. This matches your “step-0 NaN” fingerprint. (NVIDIA Docs)
-
Hard-casting the whole model to bf16 reduces headroom.
Proper mixed precision keeps parameters in fp32 and uses AMP for bf16 compute. Hard-casting everything to bf16 tightens QKT and softmax numerics and increases step-0 failure risk. (PyTorch Docs)
-
PEFT adapters are intentionally fp32.
PEFT stores adapter weights in fp32 and upcasts inputs for stability, even if the base is bf16/fp16. Disabling this makes instability worse. Your swap to official PEFT is not the cause. (Hugging Face)
-
device_map="auto" is inference-only.
It is not supported for training. It can interfere with gradients or masking and complicate debugging. Use single-device, DDP, FSDP, or ZeRO instead. (Hugging Face Forums)
-
Llama-3.x has no PAD by default.
Token id 0 is a real token (“!”) in common Llama-3 tokenizers. Using 0 as PAD can poison loss or produce odd generations. Add a dedicated PAD and resize embeddings once before wrapping LoRA/DoRA. (GitHub)
-
torch.compile can amplify SDPA+bf16 issues.
There are open reports of bf16 SDPA failing under torch.compile. Stabilize first, then re-enable compile. (GitHub)
Fixes that work
Apply top-down. Stop when training is stable.
A) Force a safe attention path and use fp32 masters
# URLs inside comments for traceability:
# - attn_implementation docs: https://huggingface.co/docs/transformers/en/attention_interface
# - sdpa math vs fused numerics: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# - mixed-precision policy rationale: see link above
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.float32, # fp32 master weights for stability
attn_implementation="eager", # safer backend; try FA2 after stack upgrade
device_map=None, # no inference-style sharding for training
trust_remote_code=True,
)
config = LoraConfig(
r=32, lora_alpha=64, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","up_proj","down_proj"],
bias="none", task_type="CAUSAL_LM", init_lora_weights=True,
)
model = get_peft_model(model, config) # PEFT keeps adapters fp32 by default
args = TrainingArguments(
learning_rate=3e-5,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
warmup_ratio=0.06,
bf16=True, fp16=False, # bf16 compute via AMP, not by casting params
optim="adamw_torch",
adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-8,
max_grad_norm=1.0,
num_train_epochs=3,
lr_scheduler_type="cosine",
weight_decay=0.0,
logging_steps=2,
ddp_find_unused_parameters=False,
report_to="wandb",
)
model.config.use_cache = False # training-time best practice
trainer = Trainer(model=model, args=args, train_dataset=your_dataset)
# torch.compile later, after stability is confirmed
trainer.train()
Why this helps: eager avoids fused softmax in bf16; fp32 masters keep projections stable; AMP still uses bf16 GEMMs. (Hugging Face)
B) If you must stay on SDPA, fence kernels
# Backend toggles:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
import torch
torch.backends.cuda.enable_cudnn_sdp(False) # avoid cuDNN SDPA regressions
torch.backends.cuda.enable_flash_sdp(False) # avoid Flash SDPA if needed
torch.backends.cuda.enable_mem_efficient_sdp(True)
Real users stop loss explosions after disabling cuDNN SDPA. (GitHub)
Or per-region:
# Context manager to select math/efficient kernels:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel([SDPBackend.MATH]):
# run one train/eval step here
pass
Math backend keeps bf16 intermediates in float32. (PyTorch Docs)
C) Add a dedicated PAD and resize before PEFT
# Llama-3.x has no pad_token_id; id=0 is "!" in common tokenizer builds.
# https://github.com/turboderp/exllamav2/issues/415
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
if tok.pad_token_id is None or tok.pad_token_id == 0:
tok.add_special_tokens({"pad_token": "<|pad|>"})
model.resize_token_embeddings(len(tok))
model.config.pad_token_id = tok.pad_token_id
Avoid training on real tokens as PAD. (GitHub)
D) Remove device_map="auto" from training
It is an inference feature. Use single-device, DDP, FSDP, or ZeRO instead. (Hugging Face Forums)
E) Upgrade path for FA2
If you want FA2 speed on GH200, upgrade to a stack that includes recent SDPA fixes (driver/cuDNN/PyTorch). Re-enable attn_implementation="flash_attention_2" and retest. If NaNs return, fall back to eager and keep the backend fences. (NVIDIA Docs)
Quick diagnostics
- Single-batch sanity under math/eager. If the loss is finite, your failure was the fused attention path. The sdpa math backend explicitly raises intermediate precision. (PyTorch Docs)
- Disable cuDNN SDPA only. If training stabilizes, you’ve isolated it to cuDNN’s SDPA implementation. (GitHub)
- Confirm adapters are fp32. This is the PEFT default; keep it. (Hugging Face)
Background and context
- Backend sensitivity is expected. PyTorch states SDPA outputs differ by backend due to fused floating-point ops. The math backend is conservative and often the most stable for bf16. (PyTorch Docs)
- Community repros match your logs. Several issues document “first step OK, then
grad_norm: nan, then loss: 0.0” under FA/SDPA, solved by attn_implementation="eager" or disabling cuDNN SDPA. (GitHub)
Curated references
Switching attention backends
- Transformers attention interface and how to set
eager/flash_attention_2. (Hugging Face)
- PyTorch
sdpa_kernel and backend selection. (PyTorch Docs)
- SDPA numerics and math backend precision rules. (PyTorch Docs)
Instability reports and fixes
- Training fails with FA2, succeeds with
eager. (GitHub)
- cuDNN SDPA loss explosions; fix by disabling cuDNN SDPA. (GitHub)
- Flash-attention NaNs during training; disable or swap backend. (Hugging Face)
Tokenizer and padding
- Llama-3 tokenizer:
pad_token_id defaults to 0 which is “!”. Add a dedicated PAD and resize. (GitHub)
PEFT dtype behavior
- Adapters kept in fp32 and upcast on-the-fly for stability. (Hugging Face)
Training architecture
device_map="auto" is for big-model inference; not supported for training. (Hugging Face Forums)