-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathserver.py
111 lines (99 loc) · 3.69 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import logging
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)
import argparse
import json
from os.path import join
import numpy as np
from components.nw import Nw
from components.stt import Stt
from components.llm_server import Llm
from components.tts_server import Tts
from components.utils import remove_emojis
from components.utils import remove_nonverbal_cues
from components.utils import remove_multiple_dots
from components.utils import remove_code_blocks
from components.utils import check_delete_messages
from components.utils import check_skip_message
# import scipy.io.wavfile as wf
def load_config(config_file):
with open(config_file, "r") as file:
json_data = json.load(file)
return json_data
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Aria.")
parser.add_argument(
"--config",
default="default.json",
help="Path to JSON config file in the configs folder",
)
args = parser.parse_args()
config_path = join("configs", args.config)
config = load_config(config_path)
nw_params = config.get("Nw", {}).get("params", {})
stt_params = config.get("Stt", {}).get("params", {})
llm_params = config.get("Llm", {}).get("params", {})
tts_params = config.get("Tts", {}).get("params", {})
mic_params = config.get("Mic", {}).get("params", {})
ap_params = config.get("Ap", {}).get("params", {})
print("Loading...")
nw = Nw(params=nw_params)
stt = Stt(params=stt_params)
llm = Llm(params=llm_params)
tts = Tts(params=tts_params)
nw.server_init()
if nw.audio_compression:
nw.init_audio_encoder(
ap_params.get("samplerate"),
ap_params.get("channels"),
ap_params.get("buffer_size"),
)
nw.init_audio_decoder(
mic_params.get("samplerate"),
mic_params.get("channels"),
mic_params.get("buffer_size"),
)
client_address, username = nw.server_listening()
while True:
try:
client_data = nw.receive_msg()
except:
client_data = False
# print(client_data)
if not client_data:
print("Client disconnected...")
client_address, username = nw.server_listening()
if client_data == "stt_transcribe":
nw.send_ack()
mic_recording = nw.receive_audio_recording()
# wf.write(
# "test.wav",
# mic_params.get("samplerate", None),
# np.frombuffer(mic_recording, np.int16).flatten(),
# )
stt_data = stt.transcribe_translate(
np.frombuffer(mic_recording, np.int16)
.flatten()
.astype(np.float32, order="C")
/ 32768.0
)
if check_delete_messages(stt_data):
llm.user_aware_messages.pop(username, None)
stt_data = "d"
elif check_skip_message(stt_data):
stt_data = "s"
nw.send_msg(stt_data)
elif client_data == "llm_get_answer":
llm_data = llm.get_answer(nw, tts, stt_data, username)
if not llm.streaming_output:
nw.send_msg(llm_data)
if tts.tts_type == "coqui":
tts.text_splitting = True
# TODO handle emphasis
# TODO add remove_nonverbal_cues when not streaming llm
txt_for_tts = remove_emojis(
remove_multiple_dots(remove_code_blocks(llm_data))
)
tts.run_tts(nw, txt_for_tts)
elif client_data == "fixed_answer":
tts.run_tts(nw, "Did you say something?")