|
3 | 3 | from typing import TYPE_CHECKING
|
4 | 4 |
|
5 | 5 | from vllm.logger import init_logger
|
| 6 | +from vllm.model_executor.guided_decoding.utils import ( |
| 7 | + convert_lark_to_gbnf, grammar_is_likely_lark, |
| 8 | + has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) |
6 | 9 | from vllm.platforms import CpuArchEnum, current_platform
|
7 | 10 |
|
8 | 11 | if TYPE_CHECKING:
|
|
15 | 18 | logger = init_logger(__name__)
|
16 | 19 |
|
17 | 20 |
|
18 |
| -def has_xgrammar_unsupported_json_features(schema: dict) -> bool: |
19 |
| - """Check if JSON schema contains features unsupported by xgrammar.""" |
20 |
| - |
21 |
| - def check_object(obj: dict) -> bool: |
22 |
| - if not isinstance(obj, dict): |
23 |
| - return False |
24 |
| - |
25 |
| - # Check for pattern restrictions |
26 |
| - if "pattern" in obj: |
27 |
| - return True |
28 |
| - |
29 |
| - # Check for numeric ranges |
30 |
| - if obj.get("type") in ("integer", "number") and any( |
31 |
| - key in obj for key in [ |
32 |
| - "minimum", "maximum", "exclusiveMinimum", |
33 |
| - "exclusiveMaximum", "multipleOf" |
34 |
| - ]): |
35 |
| - return True |
36 |
| - |
37 |
| - # Recursively check all nested objects and arrays |
38 |
| - for value in obj.values(): |
39 |
| - if isinstance(value, dict): |
40 |
| - if check_object(value): |
41 |
| - return True |
42 |
| - elif isinstance(value, list): |
43 |
| - for item in value: |
44 |
| - if isinstance(item, dict) and check_object(item): |
45 |
| - return True |
46 |
| - |
47 |
| - return False |
48 |
| - |
49 |
| - return check_object(schema) |
50 |
| - |
51 |
| - |
52 |
| -def has_lmf_unsupported_json_features(schema: dict) -> bool: |
53 |
| - """ |
54 |
| - Check if JSON schema contains features unsupported |
55 |
| - by lm_format_enforcer. |
56 |
| -
|
57 |
| - Known issues: |
58 |
| - - Regex patterns: |
59 |
| - "grade": { |
60 |
| - "type": "string", |
61 |
| - "pattern": "^[A-D]$" # Regex pattern |
62 |
| - }, |
63 |
| - """ |
64 |
| - |
65 |
| - def check_object(obj: dict) -> bool: |
66 |
| - if not isinstance(obj, dict): |
67 |
| - return False |
68 |
| - |
69 |
| - # Check for pattern restrictions |
70 |
| - if "pattern" in obj: |
71 |
| - return True |
72 |
| - |
73 |
| - # Recursively check all nested objects and arrays |
74 |
| - for value in obj.values(): |
75 |
| - if isinstance(value, dict): |
76 |
| - if check_object(value): |
77 |
| - return True |
78 |
| - elif isinstance(value, list): |
79 |
| - for item in value: |
80 |
| - if isinstance(item, dict) and check_object(item): |
81 |
| - return True |
82 |
| - |
83 |
| - return False |
84 |
| - |
85 |
| - return check_object(schema) |
86 |
| - |
87 |
| - |
88 | 21 | def maybe_backend_fallback(
|
89 | 22 | guided_params: GuidedDecodingParams) -> GuidedDecodingParams:
|
90 | 23 | # lm-format-enforce doesn't support grammar, fallback to xgrammar
|
@@ -127,6 +60,20 @@ def maybe_backend_fallback(
|
127 | 60 | "Falling back to use outlines instead.")
|
128 | 61 | guided_params.backend = "outlines"
|
129 | 62 |
|
| 63 | + # xgrammar only supports GBNF grammars, so we must convert Lark. |
| 64 | + # We must check if the grammar is likely Lark and if that |
| 65 | + # grammar is convertible to GBNF |
| 66 | + elif (guided_params.grammar is not None |
| 67 | + and grammar_is_likely_lark(guided_params.grammar)): |
| 68 | + try: |
| 69 | + convert_lark_to_gbnf(guided_params.grammar) |
| 70 | + except Exception: |
| 71 | + logger.warning( |
| 72 | + "xgrammar does not support Lark grammars and the " |
| 73 | + "grammar failed to convert to GBNF. " |
| 74 | + "Falling back to use outlines instead.") |
| 75 | + guided_params.backend = "outlines" |
| 76 | + |
130 | 77 | if (guided_params.backend == "outlines"
|
131 | 78 | and guided_params.json_object is not None):
|
132 | 79 | # outlines doesn't support json_object, fallback to xgrammar
|
|
0 commit comments