-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathbot.py
229 lines (193 loc) · 8.06 KB
/
bot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import json
from werobot import WeRoBot
import config as cfg
import openai
import itertools
#import re
#import prettytable
#from transformers import GPT2TokenizerFast
import asyncio
import time
import aiohttp
mybot = WeRoBot(token=cfg.token)
mybot.config["APP_ID"] = cfg.appid
mybot.config['ENCODING_AES_KEY'] = cfg.aeskey
openai.api_base = cfg.api_base
openai.api_key = cfg.azure_openai_key #cfg.openai_key
deployment_name= cfg.deployment_name
#only azure api need below two lines
openai.api_type = cfg.api_type
openai.api_version = cfg.api_version
@mybot.image
def image_repeat(message,session):
return message.img
@mybot.subscribe
def intro(message):
return "欢迎加入系统之美,ChatGPT上线为您服务"
#@mybot.text
def echo(message,session): #echo back userinput for tests
# Return message content
return message.content
@mybot.text
def text_response(message,session):
userinput = message.content.strip().lower()
sessionState = []
if 'state' in session:
sessionState = session.get('state',[])
sessionState_str = sessionState.__str__()
print("sessionState:", sessionState_str)
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
encoded_tokens = tokenizer.encode(sessionState_str+userinput)
# Calculate the number of tokens
num_tokens = len(encoded_tokens)
# Print the number of tokens
print("Number of tokens:", num_tokens)
#This model's maximum context length is 4097 tokens
#GPT-2 compression ratio of around 30% to 40%
#num_tokens = len(sessionState_str)
if num_tokens>=2048:
answer = '不好意思,我的短期记忆不够用了。您重新提示前文一下吧?'
session.pop('state',None)
return answer
#decoded_text = tokenizer.decode(encoded_tokens)
else:
with open('fewshot.json', 'r', encoding='utf-8') as f:
# Load the JSON data into a Python object for few-shot greeting pairs training
sessionState = json.load(f)
s = list(itertools.chain(*sessionState))
s.append(userinput+'\n') #add a linebreak to inexplicitly end the user prompt
prompt = ' '.join(s)
prompt = 'extract the intention and object from the message and answer based on it. '+ prompt
print ('prompt: '+ userinput)
answer=''
try:
output = openai_create(prompt)
print('raw response: '+ output)
#cleansing for json load
output = output.lstrip('\n').replace(' .', '.').strip().replace('\n', '\\n')
if not output.lower().startswith('{"i":'):
output = '{"i":"na","a":"' + output + '"}'
outputj = json.loads(output)
intention = outputj['i']
answer = outputj['a']
if intention =='greeting':
answer=answer or 'Hi'
elif intention == 'reset':
answer=answer or 'Ok, 那聊啥呢?'
sessionState = []
session.pop('state',None)
elif intention == 'archive':
answer = "您查询的'往期文章'功能正在建设中🚧预计明天上线"
elif intention =='relevant':
answer = "您查询的'相关文章'功能正在建设中🚧预计明天上线"
elif answer =='':
answer = '抱歉,这个我不会,试试别的话题。'
return answer
else:
answer = answer.replace('\\n', '')
sessionState.append([userinput, answer])
#print("sessionState1:" + sessionState.__str__())
session['state'] = sessionState
#convert answer to ascii table if it contains a markdown table, tencent doesn't allow html
if answer.count('|') >= 2:
answer = markdown2ascii_table(answer)
else:
print("no markdown table")
except Exception as e:
# handle the exception
print(f"Opps: {e}")
print ('answer: ', answer)
return answer
#conver markdown table to ascii
def markdown2ascii_table(markdown_str:str):
# markdown_str = "| 国家 | GDP | 人均GDP |\n| :---: | :---: | :---: |\n| 美国 | 21.4万亿美元 | 62,794美元 |\n| 中国 | 14.6万亿美元 | 10,223美元 |\n| 日本 | 5.2万亿美元 | 43,521美元 |"
print('markdown: ' + markdown_str)
#Split the Markdown string into rows and columns
rows = markdown_str.split("\n")[0:]
header_row_number = None
for i, row in enumerate(rows):
if '|' in row:
header_row_number = i
break
# print(rows)
header_row = rows[header_row_number].strip().split("|")[1:-1]
print('header_row', header_row)
# Remove any unnecessary whitespace characters from the header row
headers = [h.strip() for h in header_row]
# print(headers)
# Create a new table with the headers
table = prettytable.PrettyTable(headers)
alignment_row_number=header_row_number+1
data_start_row_number = alignment_row_number #assume no alignment row first
# Check if the alignment_row string has an alignment row
alignment_row_str= rows[alignment_row_number].strip()
if '|' in alignment_row_str and '-' in alignment_row_str:
# Create a list of alignment strings based on the Markdown alignment row
alignment_row = alignment_row_str.split("|")[1:-1]
print ('alignment_row', alignment_row)
alignments = [
"l" if alignment.startswith(":") and alignment.endswith("-") else
"r" if alignment.startswith("-") and alignment.endswith(":") else
"c"
for alignment in alignment_row
]
print ('alignments', alignments)
data_start_row_number=alignment_row_number+1
#table.align = alignments
# Get the number of columns in the table
num_columns = len(table.field_names)
# Add the rows to the table
#print (f"row count: {rows.count}")
for row in rows[data_start_row_number:]:
new_data_row = [c.strip() for c in row.split("|")[1:-1]]
try:
# Add empty cells to the new row
if len(new_data_row)<num_columns:
while len(new_data_row) < num_columns:
new_data_row.append("")
elif len(new_data_row)>num_columns:
new_data_row = new_data_row[:num_columns]
table.add_row(new_data_row)
except Exception as e:
# handle the exception
print(f"adding data row: {e}")
# Set the table style
table.set_style(prettytable.SINGLE_BORDER)
table_string = table.get_string()
print('table_string', table_string)
return table_string
#defining the conversation function
def openai_create(prompt):
response = openai.Completion.create(
engine = cfg.deployment_name,
prompt = prompt,
#lower value means that the generated text will have a low level of randomness and creativity
temperature = 0.3,
max_tokens = 350,
# Set the top_p parameter to 0.9 to sample the next token based on the top 90% of likelihoods
top_p = 0.9,
# Set the frequency penalty to 0.5 to reduce the relevance score of documents that contain the search terms too frequently
frequency_penalty = 0.3,
# Set the presence penalty to 0.5 to reduce the relevance score of documents that do not contain the search terms at all
presence_penalty = 0.5,
#stop = '\n' #this will result in missing reply when leading with '\n'
n=1,
)
return response.choices[0].text#.replace('\n', '').replace(' .', '.').strip()
async def call_api_with_timeout():
# Start a timer for 5 seconds
start_time = time.monotonic()
timeout = 5
# Make async call to API
try:
result = await asyncio.wait_for(make_api_call(), timeout=timeout)
return result
except asyncio.TimeoutError:
# Return 'success' if API call times out
return 'success'
async def make_api_call():
# Make async call to API at http://www.abc.com
async with aiohttp.ClientSession() as session:
async with session.get('http://www.abc.com') as response:
result = await response.text()
return result