-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathmanager.py
150 lines (115 loc) · 3.98 KB
/
manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# coding: utf-8
import sys
import asyncio
from pprint import pprint
from dotenv import load_dotenv
from law_ai.callback import OutCallbackHandler
from law_ai.loader import LawLoader
from law_ai.splitter import LawSplitter
from law_ai.utils import law_index, clear_vectorstore, get_record_manager
from law_ai.chain import get_law_chain, get_check_law_chain
from config import config
load_dotenv()
# import langchain
# from langchain.cache import SQLiteCache
# from langchain.globals import set_llm_cache
# set_llm_cache(SQLiteCache(database_path=".langchain.db"))
# langchain.debug = True
def init_vectorstore() -> None:
record_manager = get_record_manager("law")
record_manager.create_schema()
clear_vectorstore("law")
text_splitter = LawSplitter.from_tiktoken_encoder(
chunk_size=config.LAW_BOOK_CHUNK_SIZE, chunk_overlap=config.LAW_BOOK_CHUNK_OVERLAP
)
docs = LawLoader(config.LAW_BOOK_PATH).load_and_split(text_splitter=text_splitter)
info = law_index(docs)
pprint(info)
async def run_shell() -> None:
check_law_chain = get_check_law_chain(config)
out_callback = OutCallbackHandler()
chain = get_law_chain(config, out_callback=out_callback)
while True:
question = input("\n用户:")
if question.strip() == "stop":
break
print("\n法律小助手:", end="")
is_law = check_law_chain.invoke({"question": question})
if not is_law:
print("不好意思,我是法律AI助手,请提问和法律有关的问题。")
continue
task = asyncio.create_task(
chain.ainvoke({"question": question}))
async for new_token in out_callback.aiter():
print(new_token, end="", flush=True)
res = await task
print(res["law_context"] + "\n" + res["web_context"])
out_callback.done.clear()
def run_web() -> None:
import gradio as gr
check_law_chain = get_check_law_chain(config)
chain = get_law_chain(config, out_callback=None)
async def chat(message, history):
out_callback = OutCallbackHandler()
is_law = check_law_chain.invoke({"question": message})
if not is_law:
yield "不好意思,我是法律AI助手,请提问和法律有关的问题。"
return
task = asyncio.create_task(
chain.ainvoke({"question": message}, config={"callbacks": [out_callback]}))
async for new_token in out_callback.aiter():
pass
out_callback.done.clear()
response = ""
async for new_token in out_callback.aiter():
response += new_token
yield response
res = await task
for new_token in ["\n\n", res["law_context"], "\n", res["web_context"]]:
response += new_token
yield response
demo = gr.ChatInterface(
fn=chat, examples=["故意杀了一个人,会判几年?", "杀人自首会减刑吗?"], title="法律AI小助手")
demo.queue()
demo.launch(
server_name=config.WEB_HOST, server_port=config.WEB_PORT,
auth=(config.WEB_USERNAME, config.WEB_PASSWORD),
auth_message="默认用户名密码: username / password")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(
description="please specify only one operate method once time.")
parser.add_argument(
"-i",
"--init",
action="store_true",
help=('''
init vectorstore
''')
)
parser.add_argument(
"-s",
"--shell",
action="store_true",
help=('''
run shell
''')
)
parser.add_argument(
"-w",
"--web",
action="store_true",
help=('''
run web
''')
)
if len(sys.argv) <= 1:
parser.print_help()
exit()
args = parser.parse_args()
if args.init:
init_vectorstore()
if args.shell:
asyncio.run(run_shell())
if args.web:
run_web()