SFTTrainer loss function and formatting_func

Hmm… Overfitting?


What you are seeing now (loss drops fast, model starts outputting almost always the same tokens) is very typical of:

  1. A tiny dataset (≈700 examples)
  2. Very aggressive training (250 epochs)
  3. A short-answer classification task ((A)/(B))
  4. A large pretrained LM (Tx Gemma) with strong prior biases

Even if the masking is now correct and the loss is computed only on the completion, this setup is enough to cause severe overfitting and catastrophic forgetting, where the model memorizes a narrow pattern and “forgets” how to behave like a general model. This looks like “the model is destroyed” and “always produces the same tokens”, even though the training loss looks great. (arXiv)

Below I’ll split things into:

  1. What should happen with a prompt–completion dataset (masking)
  2. Likely causes of your collapse (even with correct masks)
  3. How to verify the masks concretely
  4. Concrete fixes for your setting (tiny hERG dataset + Tx Gemma)
  5. If it still outputs the same answer after fixing hyperparameters

1. What should happen with prompt–completion + SFTTrainer

In current TRL, if you:

  • provide a prompt–completion dataset (fields like {"prompt": "...", "completion": "..."}), and
  • set (or keep default) completion_only_loss=True in SFTConfig,

then the docs are very explicit:

“To train on completion only, use a prompt-completion dataset. By default, the trainer computes the loss on the completion tokens only, ignoring the prompt tokens. If you want to train on the full sequence, set completion_only_loss=False.” (Hugging Face)

So if:

  • your dataset really has prompt and completion fields,
  • you are not passing a formatting_func anymore, and
  • you have not set completion_only_loss=False,

then the mask should be:

  • labels = -100 (ignored) for all prompt tokens
  • labels = token_id for completion tokens only

This is the behaviour you originally wanted.

There are two common configuration pitfalls:

  1. Using formatting_func and prompt–completion together.
    In modern TRL this is considered incompatible with completion_only_loss=True: using a formatter converts the dataset into a pure LM type. Some stacks even raise the exact error “A formatting function was provided while completion_only_loss=True, which is incompatible…” (GitHub)

  2. Accidentally setting completion_only_loss=False.
    Then you do train on both prompt and completion, even with prompt–completion data.

If you’ve removed formatting_func and explicitly set completion_only_loss=True, you are probably masked correctly now. The fact that the model collapsed is then much more about overfitting and forgetting than about the prompt mask.


2. Likely causes of “model destroyed” after 250 epochs on 700 examples

Assuming the masking is now correct, there are three big causes to focus on.

2.1 Extreme overfitting and catastrophic forgetting

Fine-tuning LLMs on very small datasets with many epochs is exactly the setting where catastrophic forgetting shows up: the model rapidly adapts to the tiny dataset and overwrites useful general behaviours. Recent work on catastrophic forgetting in foundation models notes that overfitting to small fine-tuning sets is a primary cause, and that simply tuning longer or harder on small data pushes the model to forget its original capabilities. (arXiv)

For 700 examples:

  • 250 epochs means each example is seen 250 times.
  • If your effective batch size is small (say 4–16), this is tens of thousands of gradient updates on the same 700 samples.
  • With a typical LR for LoRA (e.g. 5e-5–2e-4), that is more than enough to drive the adapter to essentially memorize a narrow behaviour such as “when asked this hERG-style question, always answer X”.

Because your completions are extremely short (e.g. (A) or (B)), the model can reduce the loss significantly by:

  • Pushing logits for one of the tokens (say (B)) very high in general, so that the answer is essentially always (B).

The training loss goes down, but the model is no longer a good general language model and no longer sensitive to the input SMILES in a meaningful way.

This is overfitting + forgetting, not necessarily a masking bug.

2.2 Dataset structure: short labels, possible imbalance, strong prior bias

Your task is:

  • Binary (A/B)
  • Labels very short (a few tokens)
  • Prompt long and almost constant across examples
  • SMILES are varied but quite opaque to the LM a priori

Even with correct completion-only loss, if the dataset is:

  • Imbalanced (e.g. 75% “(B) inhibits hERG”, 25% “(A) does not”), or
  • Very small (700 examples) and noisy,

the cross-entropy optimum for a generative model might be something trivial like “always predict the majority label”. That behaviour gives very low loss on the training set and fits exactly what you see: “answers always the same tokens”.

On top of this, Tx Gemma itself has a documented positional bias in multiple-choice hERG prompts: tests have shown that it tends to favor the first option in (A)/(B)/0/1 style prompts almost regardless of semantics. (arXiv)

If your fine-tuning doesn’t provide a strong, diverse signal, the model can easily:

  • Keep or amplify that bias (e.g. “always A” or “always B”),
  • While still reducing loss, because the dataset does not strongly contradict that behaviour.

2.3 Hyperparameters (LR, LoRA rank, etc.)

With a tiny dataset, common hyperparameters that are harmless on larger datasets can be destructive:

  • Learning rate too high for LoRA (e.g. 1e-4–2e-4 or more)
  • LoRA rank too large, or applied to too many layers (so too many degrees of freedom)
  • No early stopping and training for fixed 250 epochs
  • No regularization (weight decay, dropout, etc.)

This combination means each small batch can substantially change the adapter weights, so the model quickly converges to a narrow, degenerate solution.

Even parameter-efficient methods like LoRA can still suffer from forgetting under such small-data regimes; analyses of PEFT methods explicitly point out that dataset size is often more critical than the exact adapter mechanism. (Obsidian)


3. How to verify definitively whether loss is only on the completion

You’re already planning to inspect the masks; that’s the right move. Do it once and you’ll remove all doubt.

3.1 Inspect a training batch

After instantiating your trainer (with prompt–completion, and completion_only_loss=True):

batch = next(iter(trainer.get_train_dataloader()))

for k, v in batch.items():
    print(k, v.shape)

print("input_ids:", batch["input_ids"][0])
print("labels:", batch["labels"][0])

Then:

  1. Decode the first sequence to see the text:

    print(tokenizer.decode(batch["input_ids"][0], skip_special_tokens=False))
    
  2. Manually locate the boundary between prompt and completion in that decoded text.

  3. Look at labels[0]:

    • All positions corresponding to the prompt should be -100.
    • All positions corresponding to the completion should be ≥ 0 (true token IDs).

If this is true, then masking is correct and the issue is not “loss on the entire prompt”, but overfitting / forgetting / dataset structure.

If you see labels ≥ 0 for the prompt tokens as well, then:

  • Check that your dataset really uses {"prompt", "completion"} fields.
  • Check that you are not passing a formatting_func anymore.
  • Check that completion_only_loss=True in SFTConfig. (Hugging Face)

4. Concrete solutions for your exact setup

Assuming you confirm the masks are correct, here is what I would change for Tx Gemma + hERG + 700 examples.

4.1 Drastically reduce training intensity

For 700 examples, something like:

  • Epochs: start with 1–3 epochs, not 250.
  • Batch size: as large as fits in memory (e.g. 16–64 effective batch via gradient accumulation).
  • Learning rate (LoRA): small, e.g. 5e-5 or even 1e-5.
  • Warmup ratio: 0.03–0.1, to avoid large updates early.
  • Max steps: consider capping total steps explicitly instead of epochs.

This is consistent with general recommendations for preventing catastrophic forgetting: use a smaller learning rate, some regularization, and avoid extensive over-training on tiny data. (Hugging Face Forums)

4.2 Make LoRA truly “small” and localized

If you’re not already doing so, constrain LoRA:

  • Use a small LoRA rank (e.g. r=4–8 instead of 16–64).
  • Apply LoRA only to a few later layers or attention projections (q_proj, v_proj, o_proj) rather than everything.
  • Keep the base model frozen.

This limits how much the adapter can distort the model’s behaviour and reduces forgetting. Parameter-efficient methods are exactly about updating a small subset to keep the base model stable. (SuperAnnotate)

4.3 Add some regularization and early stopping

  • Use weight decay (e.g. 0.01).
  • Enable early stopping based on validation loss (or accuracy on a held-out subset).
  • Consider a small amount of dropout in the adapter layers if available.

This won’t fix an absurd 250-epoch regime, but it helps once you bring the epoch count down.

4.4 Stabilize the classification formulation

For such a tiny dataset and very short labels, consider:

  1. Always use one fixed label surface form.
    E.g. always (A) for non-inhibitor, (B) for inhibitor, never switch them. This fights the known positional bias of Tx Gemma in multiple-choice prompts. (arXiv)

  2. Use log-prob scoring instead of free generation during evaluation.
    At inference time, for a given SMILES prompt:

    • Construct the same prompt you used for training, ending in "Answer:".
    • Compute log P("(A)" | prompt) and log P("(B)" | prompt) using the model (no sampling).
    • Predict the label with higher log-probability.

    This avoids generation randomness and focuses on the key logits for A/B.

  3. Check label balance.
    If your dataset is heavily skewed (e.g. 80% “(B)”), then “always (B)” is actually a strong local optimum for cross-entropy. You may need to:

    • Downsample the majority class, or
    • Use a class-weighted loss (not trivial with SFTTrainer, but possible if you write a custom loss wrapper), or
    • Accept that with 700 examples you might be mostly learning a prior.

4.5 Consider mixing in some “generic” text (optional)

Research on catastrophic forgetting in foundation models suggests that mixing some of the original pretraining or instruction data during fine-tuning helps preserve general capability. (arXiv)

Practically:

  • Build a mixed dataset: e.g. 70–80% generic Tx Gemma–style prompts, 20–30% your hERG classification prompts.
  • Train on this mixture with appropriate sampling.

This way, the model is gently nudged toward the hERG task without being forced to adapt solely to a tiny, narrow dataset.

4.6 If fine-tuning continues to be unstable, use Tx Gemma as a frozen feature extractor

Given only 700 samples, an alternative is:

  1. Use Tx Gemma to encode your full prompt (instructions + SMILES but no answer).
  2. Extract a fixed representation (e.g. mean-pooled last-layer hidden state).
  3. Train a small separate classifier (logistic regression, MLP) on top of these embeddings, instead of fine-tuning the LM at all.

This completely avoids catastrophic forgetting in the LM (it stays frozen), and the classifier is easy to train robustly on 700 examples.


5. If the model still outputs the same tokens after these changes

If, after:

  • verifying that masks are correct,
  • reducing epochs,
  • lowering LR,
  • tightening LoRA,

you still see “always the same answer”:

  1. Check your evaluation code:

    • Are you always sending exactly the same prompt by mistake?
    • Are you truncating the SMILES so they are identical?
    • Are you decoding only the first token (which might be (A) or (B) because of priors)?
  2. Check generation settings:

    • Use greedy decoding: do_sample=False, temperature=0, top_p ignored.
    • Limit max_new_tokens to something small (e.g. 4–8).
    • Ensure you start generation from the same "Answer:" ending as during training.
  3. Drop the adapter and compare:

    • Save the base model separately.
    • Remove/disable the LoRA adapter.
    • Compare base vs fine-tuned outputs on the same SMILES.

If the base model behaves reasonably (or at least differently per SMILES) and the fine-tuned one is constant, you have confirmed that the collapse is purely a fine-tuning effect, not a bug in your eval pipeline.


Short summary

  • With a proper prompt–completion dataset and completion_only_loss=True, SFTTrainer does compute loss only on the completion tokens, not the prompt. (Hugging Face)
  • Your symptom (loss decreases, model outputs always the same tokens) is classic overfitting + catastrophic forgetting on a tiny dataset with very many epochs and short binary labels, amplified by Tx Gemma’s known positional bias on multiple-choice prompts. (arXiv)
  • First, verify the masks by printing a batch and checking that prompt tokens have labels -100.
  • Then, fix the training regime: drastically reduce epochs (1–3), lower LR, use small LoRA, add early stopping and slight regularization.
  • Stabilize the classification formulation (fixed label format, log-prob scoring, check label balance) and, if needed, mix in some generic data or instead train a small classifier on top of frozen Tx Gemma embeddings.