diff --git a/extensions/guidance/guidance_server.py b/extensions/guidance/guidance_server.py new file mode 100644 index 0000000000..0f10082fd4 --- /dev/null +++ b/extensions/guidance/guidance_server.py @@ -0,0 +1,109 @@ +"""Loads model into the guidance library (https://github.com/microsoft/guidance). +It aims to reduce the entry barrier of using the guidance library with quantized models + +The easiest way to get started with this extension is by using the Python client wrapper: + +https://github.com/ChuloAI/andromeda-chain + +Example: + +``` +from andromeda_chain import AndromedaChain, AndromedaPrompt, AndromedaResponse +chain = AndromedaChain("http://0.0.0.0:9000/guidance_api/v1/generate") + +prompt = AndromedaPrompt( + name="hello", + prompt_template="Howdy: {{gen 'expert_names' temperature=0 max_tokens=300}}", + input_vars=[], + output_vars=["expert_names"] +) + +response: AndromedaResponse = chain.run_guidance_prompt(prompt) +``` + +""" +import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from threading import Thread + +import guidance +from modules import shared + +# This extension depends on having the model already fully loaded, including LoRA + +guidance_model = None + +def load_guidance_model_singleton(): + global guidance_model + + if guidance_model: + return guidance_model + try: + import guidance + except ImportError: + raise ImportError("Please run 'pip install guidance' before using the guidance extension.") + + if not shared.model or not shared.tokenizer: + raise ValueError("Cannot use guidance extension without a pre-initialized model!") + + # For now only supports HF Transformers + # As far as I know, this includes: + # - 8 and 4 bits quantizations loaded through bitsandbytes + # - GPTQ variants + # - Models with LoRAs + + guidance_model = guidance.llms.Transformers( + model=shared.model, + tokenizer=shared.tokenizer, + device=shared.args.guidance_device + ) + guidance.llm = guidance_model + + +class Handler(BaseHTTPRequestHandler): + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = json.loads(self.rfile.read(content_length).decode('utf-8')) + + if self.path == '/guidance_api/v1/generate': + # TODO: add request validation + # For now disabled to avoid an extra dependency on validation libraries, like Pydantic + + prompt_template = body["prompt_template"] + input_vars = body["input_vars"] + kwargs = body["guidance_kwargs"] + output_vars = body["output_vars"] + + llm = load_guidance_model_singleton() + guidance_program = guidance(prompt_template) + program_result = guidance_program( + **kwargs, + stream=False, + async_mode=False, + caching=False, + **input_vars, + llm=llm, + ) + output = {"__main__": str(program_result)} + for output_var in output_vars: + output[output_var] = program_result[output_var] + + + response = json.dumps(output) + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + self.wfile.write(response.encode('utf-8')) + + +def _run_server(port: int): + address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + + server = ThreadingHTTPServer((address, port), Handler) + print(f'Starting Guidance API at http://{address}:{port}/guidance_api') + + server.serve_forever() + + +def start_server(port: int): + Thread(target=_run_server, args=[port], daemon=True).start() diff --git a/extensions/guidance/requirements.txt b/extensions/guidance/requirements.txt new file mode 100644 index 0000000000..07d0c54f45 --- /dev/null +++ b/extensions/guidance/requirements.txt @@ -0,0 +1 @@ +guidance \ No newline at end of file diff --git a/extensions/guidance/script.py b/extensions/guidance/script.py new file mode 100644 index 0000000000..404d01da4b --- /dev/null +++ b/extensions/guidance/script.py @@ -0,0 +1,5 @@ +from extensions.guidance import guidance_server +from modules import shared + +def setup(): + guidance_server.start_server(shared.args.guidance_port) diff --git a/modules/shared.py b/modules/shared.py index 9f4f720c68..10b69d41f3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ tokenizer = None model_name = "None" model_type = None +guidance_model = None lora_names = [] # Chat variables @@ -174,6 +175,12 @@ def str2bool(v): parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.') parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') +# Guidance Server +parser.add_argument('--guidance', action='store_true', help='Enable the guidance API extension.') +parser.add_argument('--guidance-port', type=int, default=9000, help='The listening port for the guidance API.') +parser.add_argument('--guidance-device', type=str, default='cuda', help='The device where the model is loaded on.') + + # Multimodal parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.') @@ -206,11 +213,12 @@ def add_extension(name): if args.multimodal_pipeline is not None: add_extension('multimodal') +if args.guidance: + add_extension("guidance") def is_chat(): return args.chat - # Loading model-specific settings with Path(f'{args.model_dir}/config.yaml') as p: if p.exists(): @@ -229,3 +237,4 @@ def is_chat(): model_config[k] = user_config[k] model_config = OrderedDict(model_config) + diff --git a/requirements.txt b/requirements.txt index 0a5adce4de..b62ae96cfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,4 @@ https://github.com/jllllll/bitsandbytes-windows-webui/raw/main/bitsandbytes-0.39 llama-cpp-python==0.1.57; platform_system != "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.57/llama_cpp_python-0.1.57-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows" -https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" +https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" \ No newline at end of file diff --git a/server.py b/server.py index 25dbae8e9a..be10d48800 100644 --- a/server.py +++ b/server.py @@ -1088,6 +1088,7 @@ def create_interface(): 'instruction_template': shared.settings['instruction_template'] }) + shared.generation_lock = Lock() # Launch the web UI create_interface()