diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index b046ccfd15551..a1a28986b8a90 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -42,7 +42,7 @@ Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project ### Transformers fallback -After the merge of , `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! +`vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! To check if the backend is `transformers`, you can simply do this: @@ -56,9 +56,13 @@ If it is `TransformersModel` then it means it's based on `transformers`! #### Supported features -##### LORA and quantization +##### Quantization -Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team! +Transformers fallback has supported most of available quantization in vLLM (except GGUF). See [Quantization page](#quantization-index) for more information about supported quantization in vllm. + +##### LoRA + +LoRA hasn't supported on transformers fallback yet! Make sure to open an issue and we'll work on this together with the `transformers` team! Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly. diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py index 1d5d9729df85b..31e3c1f7b987f 100644 --- a/tests/models/test_transformers.py +++ b/tests/models/test_transformers.py @@ -45,10 +45,14 @@ def check_implementation( ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), ("openai-community/gpt2", "transformers"), ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE - ("meta-llama/Llama-3.2-1B-Instruct", "auto"), ]) # trust_remote_code=True by default -def test_models(hf_runner, vllm_runner, example_prompts, model, - model_impl) -> None: +def test_models( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + example_prompts: list[str], + model: str, + model_impl: str, +) -> None: maybe_raises = nullcontext() if model == "openai-community/gpt2" and model_impl == "transformers": @@ -67,10 +71,50 @@ def test_models(hf_runner, vllm_runner, example_prompts, model, @multi_gpu_test(num_gpus=2) def test_distributed( - hf_runner, - vllm_runner, + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], example_prompts, ): kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} check_implementation(hf_runner, vllm_runner, example_prompts, "meta-llama/Llama-3.2-1B-Instruct", **kwargs) + + +@pytest.mark.parametrize("model, quantization_kwargs", [ + ( + "meta-llama/Llama-3.2-1B-Instruct", + { + "quantization": "bitsandbytes", + "load_format": "bitsandbytes", + }, + ), +]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_quantization( + vllm_runner: Type[VllmRunner], + example_prompts: list[str], + model: str, + quantization_kwargs: dict[str, str], + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner( + model, model_impl="auto", enforce_eager=True, + **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + + with vllm_runner( + model, + model_impl="transformers", + enforce_eager=True, + **quantization_kwargs) as vllm_model: # type: ignore[arg-type] + transformers_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens=max_tokens, num_logprobs=num_logprobs) + check_logprobs_close( + outputs_0_lst=transformers_outputs, + outputs_1_lst=vllm_outputs, + name_0="transformers", + name_1="vllm", + ) diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index 1605467bc3dd6..9b456b2489525 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -28,6 +28,7 @@ from vllm.distributed.utils import divide from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ColumnParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler @@ -37,6 +38,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsQuant from .utils import maybe_prefix logger = init_logger(__name__) @@ -50,10 +52,10 @@ def vllm_flash_attention_forward( value: torch.Tensor, attention_mask: torch.Tensor, # Transformers kwargs - scaling: float = None, + scaling: Optional[float] = None, # vLLM kwargs - attn_metadata: AttentionMetadata = None, - attention_instances: list[Attention] = None, + attn_metadata: Optional[AttentionMetadata] = None, + attention_instances: Optional[list[Attention]] = None, **kwargs): self_attn = attention_instances[module.layer_idx] if scaling is not None: @@ -99,13 +101,7 @@ def replace_linear_class( vllm_linear_cls = { "colwise": ColumnParallelLinear, "rowwise": RowParallelLinear, - }.get(style) - - if vllm_linear_cls is None: - logger.warning( - "Unsupported parallel style value: %s. " - "This layer will not be tensor parallelized.", style) - return linear + }.get(style, ReplicatedLinear) class HFCompatibleLinear(vllm_linear_cls): """ @@ -119,10 +115,11 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: input_size=linear.in_features, output_size=linear.out_features, bias=linear.bias is not None, + quant_config=quant_config, ) -class TransformersModel(nn.Module): +class TransformersModel(nn.Module, SupportsQuant): embedding_padding_modules = ["lm_head"] embedding_modules = ["embed_tokens" ] # TODO transformers will have a util to get it @@ -133,10 +130,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.config = config - self.quant_config = quant_config self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size @@ -162,7 +157,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: scale=config.head_dim**-0.5, num_kv_heads=divide(config.num_key_value_heads, tp_size), cache_config=cache_config, - quant_config=None, + quant_config=self.quant_config, prefix=f"{i}.attn") for i in range(config.num_hidden_layers) ] @@ -172,7 +167,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # ForCausalLM modifications self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=None, + quant_config=self.quant_config, prefix=maybe_prefix(prefix, "lm_head")) if config.tie_word_embeddings: self.lm_head.weight = self.model.get_input_embeddings().weight