Bf16 Training Instability with Llama-3.1-8B + LoRA/DoRA (PEFT)

When fine-tuning Llama-3.1-8B with LoRA or DoRA adapters using torch.bfloat16, the training immediately produces grad_norm: nan and loss: inf or loss: 0.0 on the first few steps — even before the learning rate warmup phase. and switching to fp32 eliminates the issue.

Code base: NVlabs/DoRA

Modification: replaced the original PEFT implementation with the official Hugging Face peft package.

Below is a minimal example extracted from our full script.

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "up_proj", "down_proj"],
    bias="none",
    task_type="CAUSAL_LM",
    init_lora_weights=True,
)
model = get_peft_model(model, config)

trainer = transformers.Trainer(
    model=model,
    args=transformers.TrainingArguments(
        learning_rate=3e-5,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,
        warmup_ratio=0.06,
        bf16=True,
        fp16=False,
        optim="adamw_torch",
        adam_beta1=0.9,
        adam_beta2=0.999,
        adam_epsilon=1e-8,
        max_grad_norm=1.0,
        num_train_epochs=3,
        lr_scheduler_type="cosine",
        weight_decay=0.0,
        logging_steps=2,
        ddp_find_unused_parameters=False,
        report_to="wandb",
    ),
)
model = torch.compile(model)
trainer.train()

Observed Behavior

Training logs (first few steps):

{‘loss’: 41.5734, ‘grad_norm’: nan, ‘learning_rate’: 1.56e-08, ‘epoch’: 0.0}

{‘loss’: 0.0, ‘grad_norm’: nan, ‘learning_rate’: 4.69e-08, ‘epoch’: 0.0}

Symptoms:
• Occurs immediately (step 0 – 1) before LR warmup could affect stability.
• Happens on both single-GPU (GH200) and multi-GPU settings.
• Reproducible regardless of gradient_checkpointing, torch.compile, or bf16_full_eval.
• No issue when switching torch_dtype=torch.float32.

Enviroment:

software:

python 3.12

torch 2.9.0+cu126

transformers 4.57.1

peft 0.17.1

hardware : GH200 on aarch64 linux

Is there a setting error I made? Please help me solve this problem, thank you.

1 Like

It doesn’t seem to be due to parameters or anything like that…


This is a bf16 forward-path failure in the attention stack, not a parameter or LR issue. It shows at step 0 because fused SDPA/Flash softmax overflows before any update. FP32 “fixes” it by avoiding that brittle path. Keep adapters in fp32, keep master weights in fp32, and pick a safe attention backend.

Causes

  1. BF16 SDPA/Flash fused kernels overflow on Hopper.
    SDPA auto-selects a fused backend. In bf16 this path can yield NaNs or 0.0 loss on the first batch. Switching to attn_implementation="eager" removes it; users report instant stability with the same data and model. PyTorch also documents backend-dependent numerics and that the math backend keeps bf16 intermediates in float32. (GitHub)

  2. Backend regressions in cuDNN SDPA.
    cuDNN release notes list SDPA fixes on Ampere/Hopper/Blackwell. Real-world runs stop exploding when disabling cuDNN SDPA. This matches your “step-0 NaN” fingerprint. (NVIDIA Docs)

  3. Hard-casting the whole model to bf16 reduces headroom.
    Proper mixed precision keeps parameters in fp32 and uses AMP for bf16 compute. Hard-casting everything to bf16 tightens QKT and softmax numerics and increases step-0 failure risk. (PyTorch Docs)

  4. PEFT adapters are intentionally fp32.
    PEFT stores adapter weights in fp32 and upcasts inputs for stability, even if the base is bf16/fp16. Disabling this makes instability worse. Your swap to official PEFT is not the cause. (Hugging Face)

  5. device_map="auto" is inference-only.
    It is not supported for training. It can interfere with gradients or masking and complicate debugging. Use single-device, DDP, FSDP, or ZeRO instead. (Hugging Face Forums)

  6. Llama-3.x has no PAD by default.
    Token id 0 is a real token (“!”) in common Llama-3 tokenizers. Using 0 as PAD can poison loss or produce odd generations. Add a dedicated PAD and resize embeddings once before wrapping LoRA/DoRA. (GitHub)

  7. torch.compile can amplify SDPA+bf16 issues.
    There are open reports of bf16 SDPA failing under torch.compile. Stabilize first, then re-enable compile. (GitHub)

Fixes that work

Apply top-down. Stop when training is stable.

A) Force a safe attention path and use fp32 masters

# URLs inside comments for traceability:
# - attn_implementation docs: https://huggingface.co/docs/transformers/en/attention_interface
# - sdpa math vs fused numerics: https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
# - mixed-precision policy rationale: see link above

from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B",
    torch_dtype=torch.float32,        # fp32 master weights for stability
    attn_implementation="eager",      # safer backend; try FA2 after stack upgrade
    device_map=None,                  # no inference-style sharding for training
    trust_remote_code=True,
)

config = LoraConfig(
    r=32, lora_alpha=64, lora_dropout=0.05,
    target_modules=["q_proj","k_proj","v_proj","up_proj","down_proj"],
    bias="none", task_type="CAUSAL_LM", init_lora_weights=True,
)
model = get_peft_model(model, config)  # PEFT keeps adapters fp32 by default

args = TrainingArguments(
    learning_rate=3e-5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_ratio=0.06,
    bf16=True, fp16=False,            # bf16 compute via AMP, not by casting params
    optim="adamw_torch",
    adam_beta1=0.9, adam_beta2=0.999, adam_epsilon=1e-8,
    max_grad_norm=1.0,
    num_train_epochs=3,
    lr_scheduler_type="cosine",
    weight_decay=0.0,
    logging_steps=2,
    ddp_find_unused_parameters=False,
    report_to="wandb",
)
model.config.use_cache = False        # training-time best practice

trainer = Trainer(model=model, args=args, train_dataset=your_dataset)
# torch.compile later, after stability is confirmed
trainer.train()

Why this helps: eager avoids fused softmax in bf16; fp32 masters keep projections stable; AMP still uses bf16 GEMMs. (Hugging Face)

B) If you must stay on SDPA, fence kernels

# Backend toggles:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
import torch
torch.backends.cuda.enable_cudnn_sdp(False)   # avoid cuDNN SDPA regressions
torch.backends.cuda.enable_flash_sdp(False)   # avoid Flash SDPA if needed
torch.backends.cuda.enable_mem_efficient_sdp(True)

Real users stop loss explosions after disabling cuDNN SDPA. (GitHub)

Or per-region:

# Context manager to select math/efficient kernels:
# https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html
from torch.nn.attention import sdpa_kernel, SDPBackend
with sdpa_kernel([SDPBackend.MATH]):
    # run one train/eval step here
    pass

Math backend keeps bf16 intermediates in float32. (PyTorch Docs)

C) Add a dedicated PAD and resize before PEFT

# Llama-3.x has no pad_token_id; id=0 is "!" in common tokenizer builds.
# https://github.com/turboderp/exllamav2/issues/415
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")
if tok.pad_token_id is None or tok.pad_token_id == 0:
    tok.add_special_tokens({"pad_token": "<|pad|>"})
    model.resize_token_embeddings(len(tok))
model.config.pad_token_id = tok.pad_token_id

Avoid training on real tokens as PAD. (GitHub)

D) Remove device_map="auto" from training

It is an inference feature. Use single-device, DDP, FSDP, or ZeRO instead. (Hugging Face Forums)

E) Upgrade path for FA2

If you want FA2 speed on GH200, upgrade to a stack that includes recent SDPA fixes (driver/cuDNN/PyTorch). Re-enable attn_implementation="flash_attention_2" and retest. If NaNs return, fall back to eager and keep the backend fences. (NVIDIA Docs)

Quick diagnostics

  • Single-batch sanity under math/eager. If the loss is finite, your failure was the fused attention path. The sdpa math backend explicitly raises intermediate precision. (PyTorch Docs)
  • Disable cuDNN SDPA only. If training stabilizes, you’ve isolated it to cuDNN’s SDPA implementation. (GitHub)
  • Confirm adapters are fp32. This is the PEFT default; keep it. (Hugging Face)

Background and context

  • Backend sensitivity is expected. PyTorch states SDPA outputs differ by backend due to fused floating-point ops. The math backend is conservative and often the most stable for bf16. (PyTorch Docs)
  • Community repros match your logs. Several issues document “first step OK, then grad_norm: nan, then loss: 0.0” under FA/SDPA, solved by attn_implementation="eager" or disabling cuDNN SDPA. (GitHub)

Curated references

Switching attention backends

  • Transformers attention interface and how to set eager/flash_attention_2. (Hugging Face)
  • PyTorch sdpa_kernel and backend selection. (PyTorch Docs)
  • SDPA numerics and math backend precision rules. (PyTorch Docs)

Instability reports and fixes

  • Training fails with FA2, succeeds with eager. (GitHub)
  • cuDNN SDPA loss explosions; fix by disabling cuDNN SDPA. (GitHub)
  • Flash-attention NaNs during training; disable or swap backend. (Hugging Face)

Tokenizer and padding

  • Llama-3 tokenizer: pad_token_id defaults to 0 which is “!”. Add a dedicated PAD and resize. (GitHub)

PEFT dtype behavior

  • Adapters kept in fp32 and upcast on-the-fly for stability. (Hugging Face)

Training architecture

  • device_map="auto" is for big-model inference; not supported for training. (Hugging Face Forums)
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.