-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
159 lines (129 loc) · 6.56 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from flask import Flask, render_template, redirect, request, url_for
import psycopg2
import requests
from dotenv import dotenv_values
import pandas as pd
from tabulate import tabulate
env_vars = dotenv_values('../.env')
API_TOKEN = env_vars.get("API_TOKEN")
USER = env_vars.get("DB_USER")
PASSWORD = env_vars.get("DB_PASSWORD")
HOST = env_vars.get("HOST_IP")
if None in (API_TOKEN, USER, PASSWORD, HOST):
print("Error: One or more environment variables are not set.")
exit(1)
API_URL = "https://api-inference.huggingface.co/models/barunparua/flant5-nltosql-final-model"
headers = {"Authorization": f"Bearer {API_TOKEN}"}
def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
isDBselected = False
db_id = -1
schema = ""
conn = None
preset_schemas = [
{
"id": 0,
"db_name": "21CS10014",
"name": "Fest Management System",
"schema": "ADMIN: USERNAME (PRIMARY KEY) (text); PASS (text)//STUDENT: FEST_ID (PRIMARY KEY) (numeric); NAME (text); ROLL (text); DEPT (text); PASS (text)//EVENT: EVENT_ID (PRIMARY KEY) (numeric); EVENT_NAME (text); EVENT_DATE (date); EVENT_TIME (time); EVENT_VENUE (text); EVENT_TYPE (text); EVENT_DESCRIPTION (text); EVENT_WINNER (numeric)//ACCOMODATION: ACC_ID (numeric) (PRIMARY KEY); NAME (text); CAPACITY (numeric)//EXT_PARTICIPANT: FEST_ID (numeric) (PRIMARY KEY); NAME (text); COLLEGE (text); ACC_ID (numeric); PASS (text)//ORGANISING: FEST_ID (numeric); EVENT_ID (numeric); PRIMARY KEY (FEST_ID, EVENT_ID)//VOLUNTEERING: FEST_ID (numeric); EVENT_ID (numeric); PRIMARY KEY (FEST_ID, EVENT_ID)//PARTICIPATING_EXT: FEST_ID (numeric); EVENT_ID (numeric); PRIMARY KEY (FEST_ID, EVENT_ID)//PARTICIPATING_INT: FEST_ID (numeric); EVENT_ID (numeric); PRIMARY KEY (FEST_ID, EVENT_ID)",
"df" : [
{'Table': 'ADMIN', 'Attributes': ['USERNAME (PRIMARY KEY) (text)', 'PASS (text)']},
{'Table': 'STUDENT', 'Attributes': ['FEST_ID (PRIMARY KEY) (numeric)', 'NAME (text)', 'ROLL (text)', 'DEPT (text)', 'PASS (text)']},
{'Table': 'EVENT', 'Attributes': ['EVENT_ID (PRIMARY KEY) (numeric)', 'EVENT_NAME (text)', 'EVENT_DATE (date)', 'EVENT_TIME (time)', 'EVENT_VENUE (text)', 'EVENT_TYPE (text)', 'EVENT_DESCRIPTION (text)', 'EVENT_WINNER (numeric)']},
{'Table': 'ACCOMODATION', 'Attributes': ['ACC_ID (numeric) (PRIMARY KEY)', 'NAME (text)', 'CAPACITY (numeric)']},
{'Table': 'EXT_PARTICIPANT', 'Attributes': ['FEST_ID (numeric) (PRIMARY KEY)', 'NAME (text)', 'COLLEGE (text)', 'ACC_ID (numeric)', 'PASS (text)']},
{'Table': 'ORGANISING', 'Attributes': ['FEST_ID (numeric)', 'EVENT_ID (numeric)', 'PRIMARY KEY (FEST_ID, EVENT_ID)']},
{'Table': 'VOLUNTEERING', 'Attributes': ['FEST_ID (numeric)', 'EVENT_ID (numeric)', 'PRIMARY KEY (FEST_ID, EVENT_ID)']},
{'Table': 'PARTICIPATING_EXT', 'Attributes': ['FEST_ID (numeric)', 'EVENT_ID (numeric)', 'PRIMARY KEY (FEST_ID, EVENT_ID)']},
{'Table': 'PARTICIPATING_INT', 'Attributes': ['FEST_ID (numeric)', 'EVENT_ID (numeric)', 'PRIMARY KEY (FEST_ID, EVENT_ID)']}
]
},
]
def connect_to_db(id):
global isDBselected, schema, conn, preset_schemas, schema_struct
try:
conn = psycopg2.connect(
dbname=preset_schemas[id]["db_name"],
user=USER,
password=PASSWORD,
host=HOST
)
isDBselected = True
schema = preset_schemas[id]["schema"]
x = preset_schemas[id]["df"]
schema_struct = tabulate(x, headers='keys', tablefmt='html')
print("Connected to the database. Schema selected.")
print(f"Database: {preset_schemas[id]['db_name']}")
print(f"Name: {preset_schemas[id]['name']}")
print(f"Schema: {schema[:50]}...")
except Exception as e:
print(f"Error: Unable to connect to the database. {e}")
return None
def close_connection():
global isDBselected, schema, conn, schema_struct
try:
conn.close()
isDBselected = False
schema = ""
schema_struct = "<p>Please provide your own schema in prompt to get desired query.</p>"
print("Connection to the database closed. Schema deselected.")
except Exception as e:
print(f"Error: Unable to close the connection. {e}")
"""
Flask app to serve the user interface
Consists of basic options for user to
interact with the LLM SQL model.
"""
app = Flask(__name__)
app.secret_key = 'my_secret_key'
chat_history = []
schema_struct = "<p>Please provide your own schema in prompt to get relevant SQL query.</p>"
@app.route('/', methods=['GET', 'POST'])
def home():
global chat_history, schema_struct
if request.method == 'POST':
chat_entry = {'user': "", 'bot': "", 'result': ""}
# get data from form
nl_query = request.form['nl_query']
chat_entry['user'] = nl_query
# get response from model
prompt = f"{schema} {nl_query}"
output = query({"inputs": prompt, "parameters": {"max_new_tokens": 200},"options": {"wait_for_model": True}})
try:
sql = output[0]['generated_text']
except Exception as e:
sql = f"Error: Unable to generate SQL query. {e}\n\nResponse: {output}"
chat_entry['bot'] = sql
# get result from dbms, if db is selected
result = "<p></p>"
if isDBselected:
try:
df = pd.read_sql_query(sql, conn)
result = tabulate(df, headers='keys', tablefmt='html')
except Exception as e:
result = f"<p>Error: Unable to execute the query in database. {e}</p>"
chat_entry['result'] = result
# add entry to chat history
chat_history.append(chat_entry)
if len(chat_history) > 1:
chat_history.pop(0)
# pass results to template
return redirect(url_for('home', chat_history=chat_history, databases=[{"id":db["id"], "name":db["name"]} for db in preset_schemas], db_id=db_id, schema = schema_struct))
else:
return render_template('home.html', chat_history=chat_history, databases=[{"id":db["id"], "name":db["name"]} for db in preset_schemas], db_id=db_id, schema = schema_struct)
@app.route('/update_db_id', methods=['POST'])
def update_db_id():
global db_id
data = request.json
new_db_id = int(data['id'])
# Update the global db_id variable
db_id = new_db_id
# Connect to the database
if db_id != -1:
connect_to_db(db_id)
else:
close_connection()
return redirect(url_for('home'))
if __name__ == '__main__':
app.run(debug=True)