Skip to content

Commit

Permalink
breaking: Add compatibility for OpenAI API
Browse files Browse the repository at this point in the history
  • Loading branch information
clemlesne committed Jun 15, 2024
1 parent e5f1d28 commit d0bf213
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 47 deletions.
26 changes: 16 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,17 +230,23 @@ Place a file called `config.yaml` in the root of the project with the following
# config.yaml
llm:
fast:
api_key: xxx
context: 16385
deployment: gpt-35-turbo-0125
endpoint: https://xxx.openai.azure.com
model: gpt-35-turbo
mode: azure_openai
azure_openai:
api_key: xxx
context: 16385
deployment: gpt-35-turbo-0125
endpoint: https://xxx.openai.azure.com
model: gpt-35-turbo
streaming: true
slow:
api_key: xxx
context: 128000
deployment: gpt-4o-2024-05-13
endpoint: https://xxx.openai.azure.com
model: gpt-4o
mode: azure_openai
azure_openai:
api_key: xxx
context: 128000
deployment: gpt-4o-2024-05-13
endpoint: https://xxx.openai.azure.com
model: gpt-4o
streaming: true

destination:
mode: ai_search
Expand Down
20 changes: 10 additions & 10 deletions function_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ async def extract_to_chunck(input: BlobClientTrigger) -> None:
# Free up memory
del content
# Prepare chunks for LLM
llm_client = CONFIG.llm.instance(
llm_client = CONFIG.llm.selected(
is_fast=False, # We will use the slow model next step
)
).instance()
chuncks = llm_client.chunck(text=extracted_model.document_content)
logger.info(f"Splited to {len(chuncks)} chuncks ({blob_name})")
# Store
Expand Down Expand Up @@ -250,9 +250,9 @@ def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[str]]:
if len(req) < 10: # Arbitrary minimum length
return False, "Response too short", None
return True, None, req
llm_client = CONFIG.llm.instance(
llm_client = CONFIG.llm.selected(
is_fast=False, # We want high quality summaries because they are used to avoid hallucinations in the next steps
)
).instance()
synthesis_str = await llm_client.generate(
max_tokens=500, # 500 tokens ~= 375 words
res_object=str,
Expand Down Expand Up @@ -338,9 +338,9 @@ async def synthesis_to_page(input: BlobClientTrigger) -> None:
# Free up memory
del content
# Prepare chunks for LLM
llm_client = CONFIG.llm.instance(
llm_client = CONFIG.llm.selected(
is_fast=True, # We will use the fast model
)
).instance()
pages = llm_client.chunck(
max_tokens=CONFIG.features.page_split_size,
text=synthesis_model.chunk_content,
Expand Down Expand Up @@ -399,9 +399,9 @@ async def page_to_fact(input: BlobClientTrigger) -> None:
# Free up memory
del content
# LLM does its magic
llm_client = CONFIG.llm.instance(
llm_client = CONFIG.llm.selected(
is_fast=True, # We will use the fast model
)
).instance()
facts: list[FactModel] = []
for _ in range(CONFIG.features.fact_iterations): # We will generate facts 10 times
def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[FactedLlmModel]]:
Expand Down Expand Up @@ -516,9 +516,9 @@ def _validate(req: Optional[str]) -> tuple[bool, Optional[str], Optional[float]]
if group:
return True, None, float(group.group())
return False, "Score not detected", None
llm_client = CONFIG.llm.instance(
llm_client = CONFIG.llm.selected(
is_fast=False, # We want high quality to avoid using human validation which is even more costly and slower
)
).instance()
fact_scores = await asyncio.gather(
*[
llm_client.generate(
Expand Down
81 changes: 61 additions & 20 deletions helpers/config_models/llm.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
from enum import Enum
from functools import cache
from typing import Optional, Union
from persistence.illm import ILlm
from pydantic import Field, SecretStr, BaseModel
from pydantic import Field, SecretStr, BaseModel, field_validator, ValidationInfo


class TypeEnum(Enum):
FAST = "fast"
SLOW = "slow"
class ModeEnum(str, Enum):
AZURE_OPENAI = "azure_openai"
OPENAI = "openai"


class BackendModel(BaseModel, frozen=True):
api_key: SecretStr
class AbstractPlatformModel(BaseModel, frozen=True):
context: int
deployment: str
endpoint: str
model: str
type: TypeEnum
streaming: bool
validation_retry_max: int = Field(
default=3,
ge=0,
Expand All @@ -25,26 +23,69 @@ class BackendModel(BaseModel, frozen=True):
ge=0,
)


class AzureOpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_key: Optional[SecretStr] = None
deployment: str
endpoint: str

@cache
def instance(self) -> ILlm:
from persistence.azure_openai import AzureOpenaiLlm
from persistence.openai import AzureOpenaiLlm

return AzureOpenaiLlm(self)


class FastModel(BackendModel):
type: TypeEnum = TypeEnum.FAST
class OpenaiPlatformModel(AbstractPlatformModel, frozen=True):
api_key: SecretStr
endpoint: str

@cache
def instance(self) -> ILlm:
from persistence.openai import OpenaiLlm

return OpenaiLlm(self)


class SelectedPlatformModel(BaseModel):
azure_openai: Optional[AzureOpenaiPlatformModel] = None
mode: ModeEnum
openai: Optional[OpenaiPlatformModel] = None

@field_validator("azure_openai")
def _validate_azure_openai(
cls,
azure_openai: Optional[AzureOpenaiPlatformModel],
info: ValidationInfo,
) -> Optional[AzureOpenaiPlatformModel]:
if not azure_openai and info.data.get("mode", None) == ModeEnum.AZURE_OPENAI:
raise ValueError("Azure OpenAI config required")
return azure_openai

@field_validator("openai")
def _validate_openai(
cls,
openai: Optional[OpenaiPlatformModel],
info: ValidationInfo,
) -> Optional[OpenaiPlatformModel]:
if not openai and info.data.get("mode", None) == ModeEnum.OPENAI:
raise ValueError("OpenAI config required")
return openai

class SlowModel(BackendModel):
type: TypeEnum = TypeEnum.SLOW
def selected(self) -> Union[AzureOpenaiPlatformModel, OpenaiPlatformModel]:
platform = (
self.azure_openai if self.mode == ModeEnum.AZURE_OPENAI else self.openai
)
assert platform
return platform


class LlmModel(BaseModel):
fast: FastModel
slow: SlowModel
fast: SelectedPlatformModel
slow: SelectedPlatformModel

def instance(self, is_fast: bool) -> ILlm:
if is_fast:
return self.fast.instance()
return self.slow.instance()
def selected(
self, is_fast: bool
) -> Union[AzureOpenaiPlatformModel, OpenaiPlatformModel]:
platform = self.fast if is_fast else self.slow
return platform.selected()
49 changes: 42 additions & 7 deletions persistence/azure_openai.py → persistence/openai.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
from helpers.config_models.llm import BackendModel as LlmBackendModel
from abc import abstractmethod
from azure.identity import ManagedIdentityCredential, get_bearer_token_provider
from helpers.config_models.llm import AzureOpenaiPlatformModel, OpenaiPlatformModel
from helpers.logging import logger
from openai import AsyncAzureOpenAI
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai.types.chat import ChatCompletionSystemMessageParam
from persistence.illm import ILlm
from typing import Optional, TypeVar, Callable
import math
from typing import Optional, TypeVar, Callable, Union
import tiktoken


T = TypeVar("T")


class AzureOpenaiLlm(ILlm):
class AbstractOpenaiLlm(ILlm):
_client: Optional[AsyncAzureOpenAI] = None
_config: LlmBackendModel
_config: Union[AzureOpenaiPlatformModel, OpenaiPlatformModel]

def __init__(self, config: LlmBackendModel):
def __init__(self, config: Union[AzureOpenaiPlatformModel, OpenaiPlatformModel]):
self._config = config

async def generate(
Expand Down Expand Up @@ -189,14 +190,48 @@ def chunck(
# Return the chunks
return contents

@abstractmethod
def _use_client(self) -> AsyncOpenAI:
pass


class AzureOpenaiLlm(AbstractOpenaiLlm):
def __init__(self, config: AzureOpenaiPlatformModel):
super().__init__(config)

def _use_client(self) -> AsyncAzureOpenAI:
if not self._client:
api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
token_func = (
get_bearer_token_provider(
ManagedIdentityCredential(),
"https://cognitiveservices.azure.com/.default",
)
if not self._config.api_key
else None
)
self._client = AsyncAzureOpenAI(
# Deployment
api_version="2023-12-01-preview",
azure_deployment=self._config.deployment,
azure_endpoint=self._config.endpoint,
# Reliability
max_retries=30, # We are patient, this is a background job :)
timeout=180, # 3 minutes
# Authentication
api_key=api_key,
azure_ad_token_provider=token_func,
)
return self._client


class OpenaiLlm(AbstractOpenaiLlm):
def __init__(self, config: OpenaiPlatformModel):
super().__init__(config)

def _use_client(self) -> AsyncOpenAI:
if not self._client:
self._client = AsyncOpenAI(
# Reliability
max_retries=30, # We are patient, this is a background job :)
timeout=180, # 3 minutes
Expand Down

0 comments on commit d0bf213

Please # to comment.