Skip to content

Commit

Permalink
[Model] Enable quantization support for transformers backend (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
Isotr0py authored and kerthcet committed Feb 21, 2025
1 parent b6f13ce commit 2fd84a0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 23 deletions.
10 changes: 7 additions & 3 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project

### Transformers fallback

After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!
`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!

To check if the backend is `transformers`, you can simply do this:

Expand All @@ -56,9 +56,13 @@ If it is `TransformersModel` then it means it's based on `transformers`!

#### Supported features

##### LORA and quantization
##### Quantization

Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!
Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm.

##### LoRA

LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team!

Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.

Expand Down
54 changes: 49 additions & 5 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ def check_implementation(
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("openai-community/gpt2", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
]) # trust_remote_code=True by default
def test_models(hf_runner, vllm_runner, example_prompts, model,
model_impl) -> None:
def test_models(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
example_prompts: list[str],
model: str,
model_impl: str,
) -> None:

maybe_raises = nullcontext()
if model == "openai-community/gpt2" and model_impl == "transformers":
Expand All @@ -67,10 +71,50 @@ def test_models(hf_runner, vllm_runner, example_prompts, model,

@multi_gpu_test(num_gpus=2)
def test_distributed(
hf_runner,
vllm_runner,
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
example_prompts,
):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)


@pytest.mark.parametrize("model, quantization_kwargs", [
(
"meta-llama/Llama-3.2-1B-Instruct",
{
"quantization": "bitsandbytes",
"load_format": "bitsandbytes",
},
),
])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
def test_quantization(
vllm_runner: Type[VllmRunner],
example_prompts: list[str],
model: str,
quantization_kwargs: dict[str, str],
max_tokens: int,
num_logprobs: int,
) -> None:
with vllm_runner(
model, model_impl="auto", enforce_eager=True,
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)

with vllm_runner(
model,
model_impl="transformers",
enforce_eager=True,
**quantization_kwargs) as vllm_model: # type: ignore[arg-type]
transformers_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs)
check_logprobs_close(
outputs_0_lst=transformers_outputs,
outputs_1_lst=vllm_outputs,
name_0="transformers",
name_1="vllm",
)
25 changes: 10 additions & 15 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from vllm.distributed.utils import divide
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
Expand All @@ -37,6 +38,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsQuant
from .utils import maybe_prefix

logger = init_logger(__name__)
Expand All @@ -50,10 +52,10 @@ def vllm_flash_attention_forward(
value: torch.Tensor,
attention_mask: torch.Tensor,
# Transformers kwargs
scaling: float = None,
scaling: Optional[float] = None,
# vLLM kwargs
attn_metadata: AttentionMetadata = None,
attention_instances: list[Attention] = None,
attn_metadata: Optional[AttentionMetadata] = None,
attention_instances: Optional[list[Attention]] = None,
**kwargs):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
Expand Down Expand Up @@ -99,13 +101,7 @@ def replace_linear_class(
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
}.get(style)

if vllm_linear_cls is None:
logger.warning(
"Unsupported parallel style value: %s. "
"This layer will not be tensor parallelized.", style)
return linear
}.get(style, ReplicatedLinear)

class HFCompatibleLinear(vllm_linear_cls):
"""
Expand All @@ -119,10 +115,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
)


class TransformersModel(nn.Module):
class TransformersModel(nn.Module, SupportsQuant):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
Expand All @@ -133,10 +130,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config

self.config = config
self.quant_config = quant_config
self.vocab_size = config.vocab_size
self.unpadded_vocab_size = config.vocab_size

Expand All @@ -162,7 +157,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
scale=config.head_dim**-0.5,
num_kv_heads=divide(config.num_key_value_heads, tp_size),
cache_config=cache_config,
quant_config=None,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
]

Expand All @@ -172,7 +167,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
# ForCausalLM modifications
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=None,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
Expand Down

0 comments on commit 2fd84a0

Please # to comment.