-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
80 lines (55 loc) · 1.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from fastapi import FastAPI
import time
from pydantic import BaseModel, BaseSettings
# Settings
class Settings(BaseSettings):
OPENAI_API_KEY: str
DB_URL: str = "sqlite:///./database.db"
INDEX_PATH: str = "./faiss.index"
MODEL: str = "gpt-3.5-turbo" # Can also be set to "gpt-4"
settings = Settings()
# Models
class PromptQuery(BaseModel):
prompt: str
# OpenAI API utilities
from gptcache import cache, adapter, manager, embedding
from gptcache.adapter import openai
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
def response_text(openai_resp):
return openai_resp['choices'][0]['message']['content']
# Cache initialization
print("Cache loading.....")
onnx = embedding.Onnx()
cache.init(
embedding_func=onnx.to_embeddings,
data_manager=manager.get_data_manager(manager.CacheBase("sqlite", sql_url=settings.DB_URL),
manager.VectorBase("faiss", dimension=onnx.dimension,
index_path=settings.INDEX_PATH)),
similarity_evaluation=SearchDistanceEvaluation(),
)
cache.set_openai_key()
# Main function to get response from OpenAI
def ask(prompt: str):
start_time = time.time()
response = openai.ChatCompletion.create(
model=settings.MODEL,
messages=[
{
'role': 'user',
'content': prompt
}
],
)
print("Time consuming: {:.2f}s".format(time.time() - start_time))
return response_text(response)
# FastAPI app
app = FastAPI()
@app.get("/")
def read_root():
return "Hello World"
@app.get("/query")
def get_prompt(prompt: str):
return ask(prompt)
@app.post("/query")
def post_prompt(query: PromptQuery):
return ask(query.prompt)