Dequantize 4bit B&B model to prepare for merging

Hello!

I trained a LoRA adapter for a 13B model, I was training in a quantized setup - so on top of a 4bit quantized model, I trained a 16bit LoRA adapter.

Now, after training, I would like to get my model back together, i.e., merge LoRA weights.

Unfortunately, if I do merge_and_unload right away, the model’s outputs become complete garbage. I suppose this is because the LoRA weights get converted to 4 bits and then added to the base weights.

Therefore, I thought it would be smarter to first dequantize my model, and only then merge, but when I call .dequantize() on my 4bit base + 16 bit LoRA model, I quickly get OOMs. At the same time, dequantizing on a CPU is not possible, as this is only implemented for 8 bit quant.

Is there any way out of this stalemate? To sum up:

  1. I can’t merge_and_unload → garbage output
  2. I can’t dequantize and merge_and_unload → OOM on dequantize
  3. I can’t dequantize on CPU → not supported for 4 bit

I have tried loading the base model in fp16 and then applying my trained LoRA weights to it, but when done this way the LoRA doesn’t seem to affect the output at all - seems like the scale of the weights is not compatible if the original model is loaded in fp16.

Is there any way to still salvage my model?

1 Like

Yeah. It should be certain that merging while keeping NF4 is impossible.

I have tried loading the base model in fp16 and then applying my trained LoRA weights to it

That method should be the best. For example, even if the target is a different weight on the same architecture, the effect is usually achieved. If the impact becomes zero, the cause lies elsewhere.

from transformers import AutoModelForCausalLM
from peft import PeftModel
import torch

base_id    = "your/base-model"
adapter_id = "path-or-hub-id-of-your-lora"

base = AutoModelForCausalLM.from_pretrained(base_id, torch_dtype=torch.float16, device_map="auto")
peft = PeftModel.from_pretrained(base, adapter_id)
merged = peft.merge_and_unload(safe_merge=True)
merged.save_pretrained("merged-fp16", safe_serialization=True)

That method should be the best.

The issue with this method is that I also lose the benefit of fine-tuning if I do this, i.e., my model’s output is the same as if it wasn’t tuned at all.

I cannot use the peft classes to load the weights automatically. I train with deepspeed zero 3, so after float32 conversion all I have are the .safetensors files.

I set up my model like this:

def load_model_and_wrap(model_id: str, r=8, alpha=8):

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_storage=torch.uint8,
        bnb_4bit_use_double_quant=True,
    )

    processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
    tokenizer = processor.tokenizer
    model = LlavaNextForConditionalGeneration.from_pretrained(
       model_id,
       torch_dtype=torch.float16,
       quantization_config=quantization_config,
       low_cpu_mem_usage=True,
    )

    model = prepare_model_for_kbit_training(model)
    target_modules = find_all_linear_names(model)

    lora_config = LoraConfig(
        r=r,
        lora_alpha=alpha,
        lora_dropout=0.1,
        target_modules=target_modules,
        init_lora_weights="gaussian",
    )
    model = get_peft_model(model, lora_config)

    return model, processor

But then, to load the checkpoint safetensors after training, I need to do this manually:

def load_sharded_safetensors(dir_path):
    state_dict = {}
    for filename in os.listdir(dir_path):
        if filename.endswith(".safetensors"):
            filepath = os.path.join(dir_path, filename)
            shard = load_file(filepath)
            state_dict.update(shard)
    return state_dict

def load_weights(model, state_dict):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith("model"):
            new_key = key[len("model.") :]
            new_state_dict[new_key] = value 
        else:
            new_state_dict[key] = value 

    missing, unexpected = model.load_state_dict(new_state_dict, strict=False)
    return model

Finally I do:

model_base, processor = load_model_and_wrap(model_id)
checkpoint = load_sharded_safetensors(os.path.join(weights_path, "checkpoint_fp32"))
model = load_weights(model_base, checkpoint)

What I find strange is that the model works perfectly fine if used directly with unmerged LoRA, but as soon as I try to load my LoRA on top of a fp16 base model, all effects of fine-tuning are instantly lost. Could this be a result of how I restore the LoRA weights?

1 Like

What I find strange is that the model works perfectly fine if used directly with unmerged LoRA, but as soon as I try to load my LoRA on top of a fp16 base model, all effects of fine-tuning are instantly lost. Could this be a result of how I restore the LoRA weights?

I think PEFT implicitly corrects the target module name when loading LoRA. If this correction is missing and strict=False is set, some of LoRA’s effects may be lost. A mismatch occurs between the LoRA key and the base model key, and when strict=False, the tensor for the mismatched key is discarded…

# pip install -U peft transformers
import os, re
from safetensors.torch import load_file
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from transformers import LlavaNextForConditionalGeneration, AutoProcessor

_ADAPTER = "default"

# 1) read all shards
def load_sharded_safetensors(dir_path: str):
    sd = {}
    for fn in sorted(os.listdir(dir_path)):
        if fn.endswith(".safetensors"):
            sd.update(load_file(os.path.join(dir_path, fn)))
    return sd

# 2) normalize + keep: LoRA + modules_to_save; drop: base_layer
_LORA_PAT = re.compile(
    r"(lora_(?:A|B|up|down|embedding_A|embedding_B)\.weight$|lora_magnitude_vector$)"
)
def _needs_prefix(k: str) -> bool:
    return not (k.startswith("base_model.") or k.startswith("peft_prefix."))

def build_adapter_state_dict(raw_sd: dict, adapter_name: str = _ADAPTER):
    keep = {}
    for k, v in raw_sd.items():
        kk = k[6:] if k.startswith("model.") else k  # optional engine prefix

        # keep trained copies: modules_to_save.<adapter>.*
        if ".modules_to_save." in kk:
            if _needs_prefix(kk):
                kk = "base_model.model." + kk
            keep[kk] = v
            continue

        # keep LoRA tensors
        if _LORA_PAT.search(kk):
            if _needs_prefix(kk):
                kk = "base_model.model." + kk
            kk = re.sub(r"(lora_(?:A|B|up|down|embedding_A|embedding_B))\.weight$",
                        rf"\1.{adapter_name}.weight", kk)
            kk = re.sub(r"(lora_magnitude_vector)$", rf"\1.{adapter_name}", kk)
            keep[kk] = v
            continue

        # drop frozen mirrors of the base module
        if ".base_layer." in kk:
            continue

    return keep

# 3) base model in fp16 (no quantization, no k-bit prep)
def load_base_fp16(model_id: str):
    processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
    model = LlavaNextForConditionalGeneration.from_pretrained(
        model_id, torch_dtype="auto", device_map="auto"
    )
    return model, processor

# 4) inject from state_dict, then load weights
def load_lora_with_inject(model_id: str, ckpt_dir: str,
                          r: int = 8, alpha: int = 8, dropout: float = 0.1,
                          adapter_name: str = _ADAPTER):
    base, processor = load_base_fp16(model_id)
    raw = load_sharded_safetensors(ckpt_dir)
    adp_sd = build_adapter_state_dict(raw, adapter_name)
    if not adp_sd:
        raise ValueError("No LoRA/modules_to_save tensors found in checkpoint.")

    conf = LoraConfig(r=r, lora_alpha=alpha, lora_dropout=dropout, init_lora_weights=False)

    # create adapter modules, using the checkpoint to infer targets
    model = inject_adapter_in_model(conf, base, state_dict=adp_sd,
                                    adapter_name=adapter_name, low_cpu_mem_usage=True)

    # populate tensors; check mapping
    outcome = set_peft_model_state_dict(model, adp_sd, adapter_name=adapter_name,
                                        low_cpu_mem_usage=True)
    if outcome.missing_keys or outcome.unexpected_keys:
        raise RuntimeError(f"Adapter load mismatch. "
                           f"missing={outcome.missing_keys[:10]} "
                           f"unexpected={outcome.unexpected_keys[:10]}")

    model.eval()
    return model, processor

# usage:
# ckpt_dir = os.path.join(weights_path, "checkpoint_fp32")
# model, processor = load_lora_with_inject(model_id, ckpt_dir, r=8, alpha=8, dropout=0.1)

Edit:
Fixed to handle modules_to_save.

Thank you for your explanation John, I was aware of this strict=False issue, that’s why I verified each time that all weights were indeed loaded by printing the number of missing lengths (which were 0).

I think I managed to get a bit deeper into understanding the problem, it seems that I need to include the “*.base_layer.weight” weights from the LoRA checkpoint in my state_dict as well, if I don’t do that, the model becomes unusable.

To put it simply:

def load_weights(model, state_dict):
    dct = dict(model.named_modules())
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith("model"):
            new_key = key[len("model.") :]
            new_state_dict[new_key] = value 
        else:
            new_state_dict[key] = value 
    # proceed with load_state_dict [...]

Works correctly and produces a valid model (of course merge_and_unload still doesn’t work).

But, changing the above to:

def load_weights(model, state_dict):
    dct = dict(model.named_modules())
    new_state_dict = {}
    for key, value in state_dict.items():
        if "base_layer.weight" in key:
            compare_tensors(value, model.state_dict()[new_key].detach().clone(), key)
            new_key = key[len("model.") :]
            new_state_dict[new_key] = model.state_dict()[new_key].detach().clone()
        elif key.startswith("model"):
            new_key = key[len("model.") :]
            new_state_dict[new_key] = value 
        else:
            new_state_dict[key] = value 
    # proceed with load_state_dict [...]

Makes the model output faulty. compare_tensors never prints anything, so the tensors are the same:

def compare_tensors(t1, t2, name):
    if t1.shape != t2.shape:
        print(f"{name}: Shape mismatch {t1.shape} - {t2.shape}")
    if type(t1) != type(t2):
        print(f"{name}: Type mismatch {type(t1)} - {type(t2)}")
    if t1.dtype != t2.dtype:
        print(f"{name}: Dtype mismatch {t1.dtype} - {t2.dtype}")

    nan_mask_1 = torch.isnan(t1)
    nan_mask_2 = torch.isnan(t2)
    if (nan_mask_1 != nan_mask_2).any():
        n1, n2 = nan_mask_1.sum().item(), nan_mask_2.sum().item()
        print(f"{name}: Different NaN counts -> {n1} vs {n2}")
    mask = ~(nan_mask_1 | nan_mask_2)
    # Compare non-NaN elements
    if mask.any():
        if not torch.allclose(t1[mask], t2[mask], atol=0, rtol=0):
            diff_count = (t1[mask] != t2[mask]).sum().item()
            max_diff = (t1[mask] - t2[mask]).abs().max().item()
            print(f"{name}: {diff_count} mismatched elements, max diff={max_diff}")
    else:
        print(f"{name}: Only NaNs (no numeric values)") 

I do not understand now, why sourcing the base layer weight from the model itself produces a nonfunctional model, while sourcing it from the checkpoint works - in the end they are exactly the same weights.

Anyway, I will keep on investigating this issue, if I get to the bottom of this I will update this thread.

1 Like