Skip to content

Commit cd072ef

Browse files
mgoinMatthias Vogler
authored and
Matthias Vogler
committed
[Bugfix] Fix CFGGuide and use outlines for grammars that can't convert to GBNF (vllm-project#11389)
Signed-off-by: mgoin <michael@neuralmagic.com>
1 parent ca4890c commit cd072ef

File tree

5 files changed

+103
-86
lines changed

5 files changed

+103
-86
lines changed

tests/entrypoints/llm/test_guided_generate.py

-5
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,6 @@ def test_guided_choice_completion(sample_guided_choice, llm,
174174
@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS)
175175
def test_guided_grammar(sample_sql_statements, llm,
176176
guided_decoding_backend: str):
177-
if guided_decoding_backend == "outlines":
178-
pytest.skip("Outlines backend fails in this test case with:\n"
179-
"AttributeError: Error in model execution: 'ParserConf' "
180-
"object has no attribute 'deterministic'")
181-
182177
sampling_params = SamplingParams(temperature=0.8,
183178
top_p=0.95,
184179
max_tokens=1000,

vllm/model_executor/guided_decoding/__init__.py

+17-70
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from typing import TYPE_CHECKING
44

55
from vllm.logger import init_logger
6+
from vllm.model_executor.guided_decoding.utils import (
7+
convert_lark_to_gbnf, grammar_is_likely_lark,
8+
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
69
from vllm.platforms import CpuArchEnum, current_platform
710

811
if TYPE_CHECKING:
@@ -15,76 +18,6 @@
1518
logger = init_logger(__name__)
1619

1720

18-
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
19-
"""Check if JSON schema contains features unsupported by xgrammar."""
20-
21-
def check_object(obj: dict) -> bool:
22-
if not isinstance(obj, dict):
23-
return False
24-
25-
# Check for pattern restrictions
26-
if "pattern" in obj:
27-
return True
28-
29-
# Check for numeric ranges
30-
if obj.get("type") in ("integer", "number") and any(
31-
key in obj for key in [
32-
"minimum", "maximum", "exclusiveMinimum",
33-
"exclusiveMaximum", "multipleOf"
34-
]):
35-
return True
36-
37-
# Recursively check all nested objects and arrays
38-
for value in obj.values():
39-
if isinstance(value, dict):
40-
if check_object(value):
41-
return True
42-
elif isinstance(value, list):
43-
for item in value:
44-
if isinstance(item, dict) and check_object(item):
45-
return True
46-
47-
return False
48-
49-
return check_object(schema)
50-
51-
52-
def has_lmf_unsupported_json_features(schema: dict) -> bool:
53-
"""
54-
Check if JSON schema contains features unsupported
55-
by lm_format_enforcer.
56-
57-
Known issues:
58-
- Regex patterns:
59-
"grade": {
60-
"type": "string",
61-
"pattern": "^[A-D]$" # Regex pattern
62-
},
63-
"""
64-
65-
def check_object(obj: dict) -> bool:
66-
if not isinstance(obj, dict):
67-
return False
68-
69-
# Check for pattern restrictions
70-
if "pattern" in obj:
71-
return True
72-
73-
# Recursively check all nested objects and arrays
74-
for value in obj.values():
75-
if isinstance(value, dict):
76-
if check_object(value):
77-
return True
78-
elif isinstance(value, list):
79-
for item in value:
80-
if isinstance(item, dict) and check_object(item):
81-
return True
82-
83-
return False
84-
85-
return check_object(schema)
86-
87-
8821
def maybe_backend_fallback(
8922
guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
9023
# lm-format-enforce doesn't support grammar, fallback to xgrammar
@@ -127,6 +60,20 @@ def maybe_backend_fallback(
12760
"Falling back to use outlines instead.")
12861
guided_params.backend = "outlines"
12962

63+
# xgrammar only supports GBNF grammars, so we must convert Lark.
64+
# We must check if the grammar is likely Lark and if that
65+
# grammar is convertible to GBNF
66+
elif (guided_params.grammar is not None
67+
and grammar_is_likely_lark(guided_params.grammar)):
68+
try:
69+
convert_lark_to_gbnf(guided_params.grammar)
70+
except Exception:
71+
logger.warning(
72+
"xgrammar does not support Lark grammars and the "
73+
"grammar failed to convert to GBNF. "
74+
"Falling back to use outlines instead.")
75+
guided_params.backend = "outlines"
76+
13077
if (guided_params.backend == "outlines"
13178
and guided_params.json_object is not None):
13279
# outlines doesn't support json_object, fallback to xgrammar

vllm/model_executor/guided_decoding/outlines_logits_processors.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
import numpy as np
2323
import torch
24-
from lark import Lark
2524
from outlines import grammars
2625
from outlines.caching import cache
27-
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
26+
from outlines.fsm.guide import (CFGGuide, CFGState, Generate, Guide,
27+
RegexGuide, Write)
28+
from outlines.fsm.parsing import PartialLark
2829
from outlines_core.fsm.json_schema import build_regex_from_schema
2930
from pydantic import BaseModel
3031
from transformers import PreTrainedTokenizerBase
@@ -34,7 +35,9 @@ class BaseLogitsProcessor:
3435

3536
def __init__(self, guide: Guide):
3637
self._guide: Guide = guide
37-
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
38+
# CFGState is used for the FSM state for CFGGuide
39+
self._fsm_state: DefaultDict[int, Union[int,
40+
CFGState]] = defaultdict(int)
3841

3942
def __call__(self, input_ids: List[int],
4043
scores: torch.Tensor) -> torch.Tensor:
@@ -54,15 +57,13 @@ def __call__(self, input_ids: List[int],
5457
# On the first time this is called, we simply re-create
5558
# the Lark object.
5659
if isinstance(self._guide, CFGGuide):
57-
self._guide.parser = Lark(
60+
self._guide.parser = PartialLark(
5861
self._guide.cfg_string,
5962
parser="lalr",
60-
lexer="contextual",
61-
propagate_positions=False,
62-
maybe_placeholders=False,
63-
regex=True,
6463
import_paths=[grammars.GRAMMAR_PATH],
6564
)
65+
self._fsm_state[seq_id] = CFGState(
66+
parser_state=self._guide.parser.parse(""), prev_token=None)
6667

6768
instruction = self._guide.get_next_instruction(
6869
state=self._fsm_state[seq_id])
@@ -200,7 +201,8 @@ def convert_token_to_string(token: str) -> str:
200201
string = tokenizer.convert_tokens_to_string([token])
201202

202203
# A hack to handle missing spaces to HF's Llama tokenizers
203-
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
204+
if (type(token) is str and token.startswith(SPIECE_UNDERLINE)
205+
or token == "<0x20>"):
204206
return " " + string
205207

206208
return string
@@ -211,6 +213,9 @@ def change_decoder(
211213
"""Sync vLLM's decoder with the outlines by returning list."""
212214

213215
def new_decoder(inp_tokens: List[int]) -> List[str]:
216+
if (isinstance(inp_tokens, list) and len(inp_tokens) == 1
217+
and isinstance(inp_tokens[0], list)):
218+
inp_tokens = inp_tokens[0]
214219
return [decoder(inp_tokens)]
215220

216221
return new_decoder

vllm/model_executor/guided_decoding/xgrammar_utils.py vllm/model_executor/guided_decoding/utils.py

+70
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,76 @@
11
import re
22

33

4+
def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
5+
"""Check if JSON schema contains features unsupported by xgrammar."""
6+
7+
def check_object(obj: dict) -> bool:
8+
if not isinstance(obj, dict):
9+
return False
10+
11+
# Check for pattern restrictions
12+
if "pattern" in obj:
13+
return True
14+
15+
# Check for numeric ranges
16+
if obj.get("type") in ("integer", "number") and any(
17+
key in obj for key in [
18+
"minimum", "maximum", "exclusiveMinimum",
19+
"exclusiveMaximum", "multipleOf"
20+
]):
21+
return True
22+
23+
# Recursively check all nested objects and arrays
24+
for value in obj.values():
25+
if isinstance(value, dict):
26+
if check_object(value):
27+
return True
28+
elif isinstance(value, list):
29+
for item in value:
30+
if isinstance(item, dict) and check_object(item):
31+
return True
32+
33+
return False
34+
35+
return check_object(schema)
36+
37+
38+
def has_lmf_unsupported_json_features(schema: dict) -> bool:
39+
"""
40+
Check if JSON schema contains features unsupported
41+
by lm_format_enforcer.
42+
43+
Known issues:
44+
- Regex patterns:
45+
"grade": {
46+
"type": "string",
47+
"pattern": "^[A-D]$" # Regex pattern
48+
},
49+
"""
50+
51+
def check_object(obj: dict) -> bool:
52+
if not isinstance(obj, dict):
53+
return False
54+
55+
# Check for pattern restrictions
56+
if "pattern" in obj:
57+
return True
58+
59+
# Recursively check all nested objects and arrays
60+
for value in obj.values():
61+
if isinstance(value, dict):
62+
if check_object(value):
63+
return True
64+
elif isinstance(value, list):
65+
for item in value:
66+
if isinstance(item, dict) and check_object(item):
67+
return True
68+
69+
return False
70+
71+
return check_object(schema)
72+
73+
474
def grammar_is_likely_lark(grammar_str: str) -> bool:
575
"""
676
Check if grammar appears to use Lark syntax.

vllm/model_executor/guided_decoding/xgrammar_decoding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
except ImportError:
1515
pass
1616

17-
from vllm.model_executor.guided_decoding.xgrammar_utils import (
18-
convert_lark_to_gbnf, grammar_is_likely_lark)
17+
from vllm.model_executor.guided_decoding.utils import (convert_lark_to_gbnf,
18+
grammar_is_likely_lark)
1919
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
2020

2121
if TYPE_CHECKING:

0 commit comments

Comments
 (0)