Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Adds guidance extension #2554

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions extensions/guidance/guidance_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Loads model into the guidance library (https://github.com/microsoft/guidance).
It aims to reduce the entry barrier of using the guidance library with quantized models

The easiest way to get started with this extension is by using the Python client wrapper:

https://github.com/ChuloAI/andromeda-chain

Example:

```
from andromeda_chain import AndromedaChain, AndromedaPrompt, AndromedaResponse
chain = AndromedaChain("http://0.0.0.0:9000/guidance_api/v1/generate")

prompt = AndromedaPrompt(
name="hello",
prompt_template="Howdy: {{gen 'expert_names' temperature=0 max_tokens=300}}",
input_vars=[],
output_vars=["expert_names"]
)

response: AndromedaResponse = chain.run_guidance_prompt(prompt)
```

"""
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread

import guidance
from modules import shared

# This extension depends on having the model already fully loaded, including LoRA

guidance_model = None

def load_guidance_model_singleton():
global guidance_model

if guidance_model:
return guidance_model
try:
import guidance
except ImportError:
raise ImportError("Please run 'pip install guidance' before using the guidance extension.")

if not shared.model or not shared.tokenizer:
raise ValueError("Cannot use guidance extension without a pre-initialized model!")

# For now only supports HF Transformers
# As far as I know, this includes:
# - 8 and 4 bits quantizations loaded through bitsandbytes
# - GPTQ variants
# - Models with LoRAs

guidance_model = guidance.llms.Transformers(
model=shared.model,
tokenizer=shared.tokenizer,
device=shared.args.guidance_device
)
guidance.llm = guidance_model


class Handler(BaseHTTPRequestHandler):
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))

if self.path == '/guidance_api/v1/generate':
# TODO: add request validation
# For now disabled to avoid an extra dependency on validation libraries, like Pydantic

prompt_template = body["prompt_template"]
input_vars = body["input_vars"]
kwargs = body["guidance_kwargs"]
output_vars = body["output_vars"]

llm = load_guidance_model_singleton()
guidance_program = guidance(prompt_template)
program_result = guidance_program(
**kwargs,
stream=False,
async_mode=False,
caching=False,
**input_vars,
llm=llm,
)
output = {"__main__": str(program_result)}
for output_var in output_vars:
output[output_var] = program_result[output_var]


response = json.dumps(output)
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
self.wfile.write(response.encode('utf-8'))


def _run_server(port: int):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'

server = ThreadingHTTPServer((address, port), Handler)
print(f'Starting Guidance API at http://{address}:{port}/guidance_api')

server.serve_forever()


def start_server(port: int):
Thread(target=_run_server, args=[port], daemon=True).start()
1 change: 1 addition & 0 deletions extensions/guidance/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
guidance
5 changes: 5 additions & 0 deletions extensions/guidance/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from extensions.guidance import guidance_server
from modules import shared

def setup():
guidance_server.start_server(shared.args.guidance_port)
11 changes: 10 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
tokenizer = None
model_name = "None"
model_type = None
guidance_model = None
lora_names = []

# Chat variables
Expand Down Expand Up @@ -174,6 +175,12 @@ def str2bool(v):
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')

# Guidance Server
parser.add_argument('--guidance', action='store_true', help='Enable the guidance API extension.')
parser.add_argument('--guidance-port', type=int, default=9000, help='The listening port for the guidance API.')
parser.add_argument('--guidance-device', type=str, default='cuda', help='The device where the model is loaded on.')


# Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')

Expand Down Expand Up @@ -206,11 +213,12 @@ def add_extension(name):
if args.multimodal_pipeline is not None:
add_extension('multimodal')

if args.guidance:
add_extension("guidance")

def is_chat():
return args.chat


# Loading model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists():
Expand All @@ -229,3 +237,4 @@ def is_chat():
model_config[k] = user_config[k]

model_config = OrderedDict(model_config)

2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39
llama-cpp-python==0.1.57; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.57/llama_cpp_python-0.1.57-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux"
1 change: 1 addition & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,7 @@ def create_interface():
'instruction_template': shared.settings['instruction_template']
})


shared.generation_lock = Lock()
# Launch the web UI
create_interface()
Expand Down