Skip to content

Increased CUDA reserved memory in Transformers 5.x under int4 quantization leads to OOM #44387

@tangefly

Description

@tangefly

System Info

  • transformers version: 5.2.0
  • Platform: Linux-6.12.57+deb13-amd64-x86_64-with-glibc2.41
  • Python version: 3.12.12
  • Huggingface_hub version: 0.36.2
  • Safetensors version: 0.7.0
  • Accelerate version: 1.11.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (accelerator?): 2.10.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: No
  • Using GPU in script?: No
  • GPU type: NVIDIA GeForce RTX 5090

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Problem Description

I encountered a CUDA OOM error when using LlamaFactory to load the Qwen3.5-35B-A3B model with int4 quantization for inference.

My hardware setup consists of 2 × RTX 5090 32GB GPUs. Under normal circumstances, this configuration should be sufficient to load the model without OOM when using int4 quantization.

To further investigate the issue, I used a minimal reproduction script and observed that under Transformers 5.2.0, the CUDA reserved memory appears to be allocated as if the model were not quantized. This results in excessive reserved memory usage, leading to memory waste and potential OOM, even though the actual allocated memory remains consistent with int4 quantization expectations.

This behavior was not observed in earlier versions, where the model could be loaded successfully under the same hardware conditions.

import torch
import gc
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

def print_mem(tag):
    torch.cuda.synchronize()
    allocated = torch.cuda.memory_allocated() / 1024**3
    reserved = torch.cuda.memory_reserved() / 1024**3
    print(f"{tag}")
    print(f"  allocated: {allocated:.2f} GB")
    print(f"  reserved : {reserved:.2f} GB")


# Clean up GPU memory
gc.collect()
torch.cuda.empty_cache()

print_mem("Before loading")

model_name = "Qwen2.5-7B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
)

print_mem("After loading")

Output Result

  • Transformers 5.2.0
Before loading
  allocated: 0.00 GB
  reserved : 0.00 GB
Loading weights: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 339/339 [00:02<00:00, 155.08it/s, Materializing param=model.norm.weight]
After loading
  allocated: 5.46 GB
  reserved : 13.83 GB
  • Transformers 4.57.5
Before loading
  allocated: 0.00 GB
  reserved : 0.00 GB
Loading checkpoint shards: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:05<00:00,  1.43s/it]
After loading
  allocated: 5.46 GB
  reserved : 6.70 GB

Expected behavior

  • In an environment with 2 × RTX 5090 32GB GPUs, loading the Qwen3.5-35B-A3B model with int4 quantization using LlamaFactory should not result in an OOM error.
  • The CUDA reserved memory usage during model loading in Transformers 5.2.0 should be consistent with that of Transformers 4.57.5.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions