Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Update GuidanceLogitsProcessor to use llguidance bitmask functions #5

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions vllm/model_executor/guided_decoding/guidance_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@

import llguidance # type: ignore[import-untyped]
import llguidance.hf
import numpy as np
import llguidance.torch
import torch
from llguidance.gbnf_to_lark import any_to_lark # type: ignore[import-untyped]
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.model_executor.guided_decoding.guidance_utils import (
LLInterpreterResponse)


class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""
Expand Down Expand Up @@ -47,15 +45,11 @@ def __init__(
self.tokenizer_name = tokenizer.name_or_path
self.whitespace_pattern = whitespace_pattern

self.is_stopped = False
self.pending_ff_tokens: list[int] = []
self.new_sampling = False
self.initialized = False

def _initialize(self):
if self.initialized:
return

def _get_serialized_grammar(self):
if self.mode.lower() == "json":
if isinstance(self.guide, dict):
schema = json.dumps(self.guide)
Expand All @@ -72,21 +66,28 @@ def _initialize(self):
"whitespace_flexible", False)
compiler = llguidance.JsonCompiler(
whitespace_flexible=whitespace_flexible)
self.serialized_grammar = compiler.compile(schema)
return compiler.compile(schema)
elif self.mode.lower() in ["regex", "choice"]:
compiler = llguidance.RegexCompiler()
self.serialized_grammar = compiler.compile(regex=self.guide)
return compiler.compile(regex=self.guide)
elif self.mode.lower() == "grammar":
serialized_grammar = self.guide
if isinstance(self.guide, dict):
serialized_grammar = json.dumps(self.guide)
self.serialized_grammar = serialized_grammar
# grammar can be in EBNF or LARK syntax
compiler = llguidance.LarkCompiler()
return compiler.compile(any_to_lark(self.guide))

raise ValueError(f"Invalid mode: {self.mode}")

def _initialize(self):
if self.initialized:
return

self.serialized_grammar = self._get_serialized_grammar()
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer

self.ll_tokenizer = ll_tokenizer
self.ll_interpreter = llguidance.LLInterpreter(
self.ll_tokenizer,
Expand All @@ -96,6 +97,10 @@ def _initialize(self):
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)

# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)

self.initialized = True

def __call__(
Expand All @@ -107,7 +112,11 @@ def __call__(
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()

if self.is_stopped:
if self.ll_interpreter.has_pending_stop():
if self.ll_tokenizer.eos_token is not None:
scores.add_(-scores)
scores[self.ll_tokenizer.eos_token] = 200.0

return scores

if self.new_sampling and len(input_ids) > 0:
Expand All @@ -127,30 +136,10 @@ def __call__(
scores[ff_token] = 200.0
return scores

mask, resp = self.ll_interpreter.compute_mask()
r = LLInterpreterResponse.model_validate_json(resp)

if r.stop:
mask = np.zeros(scores.shape[-1], dtype=np.uint8)
if self.ll_tokenizer.eos_token is not None:
mask[self.ll_tokenizer.eos_token] = 200
self.is_stopped = True
elif mask is None:
# NOTE: mask should not be None unless r.stop is True
# However, we are handling this case just in case
# llguidance allows free-style generation
mask = np.zeros(scores.shape[-1], dtype=np.uint8)
else:
mask = np.frombuffer(mask, dtype=np.uint8)

# Force all invalid tokens to have 0 value
scores.add_(-torch.min(scores))
zero_indices = np.where(mask == 0)[0]
scores[zero_indices] = 0.0
non_zero_indices = np.nonzero(mask)[0]
scores[non_zero_indices] += 200.0
# set special tokens not in vocab to 0
scores[mask.shape[0]:] = 0.0
llguidance.torch.fill_next_token_bitmask(self.ll_interpreter,
self.bitmask, 0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))
self.new_sampling = True

return scores