diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py new file mode 100644 index 0000000000000..812d411bbad8e --- /dev/null +++ b/tests/models/test_granite.py @@ -0,0 +1,43 @@ +"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. + +Run `pytest tests/models/test_granite.py`. +""" +import pytest + +from .utils import check_logprobs_close + +MODELS = [ + "mayank-mishra/granite-3b-mup", +] + + +@pytest.mark.parametrize("model", MODELS) +# @pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + # print("hf_outputs ", hf_outputs) + # print("vllm_outputs", vllm_outputs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + )