-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
250 lines (205 loc) · 7.56 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import streamlit as st
import boto3
import io
import json
from botocore.exceptions import ClientError
import base64
from PIL import Image
import time
# Initialize the Bedrock client
client = boto3.client(service_name='bedrock-runtime')
__TableName__ = "promptCache"
boto3.client('dynamodb')
dynamoClient = boto3.client(service_name="dynamodb", region_name="us-west-2")
db = boto3.resource('dynamodb')
promptTable = db.Table(__TableName__)
# Initialize the S3 client
s3_client = boto3.client('s3')
bucket_name = 'cichackathon2024'
def generate_image(prompt):
# Define the request payload
payload = json.dumps({
"taskType": "TEXT_IMAGE",
"textToImageParams": {
"text": prompt
},
"imageGenerationConfig": {
"numberOfImages": 1,
"height": 512,
"width": 512,
"cfgScale": 8.0,
"seed": 0
}
})
# Initialize the Bedrock client
client = boto3.client(service_name='bedrock-runtime')
MODEL_ID = "amazon.titan-image-generator-v2:0"
# Invoke the model
response = client.invoke_model(
body=payload,
modelId=MODEL_ID,
contentType='application/json',
accept="application/json"
)
# Parse the response
result = json.loads(response.get("body").read())
base64_image = result.get("images")[0]
base64_bytes = base64_image.encode('ascii')
image_bytes = base64.b64decode(base64_bytes)
return image_bytes
def upload_image_to_s3(image_bytes, prompt):
s3_key = f'images/{prompt}.png' # Unique S3 key for each prompt
s3_client.put_object(Bucket=bucket_name, Key=s3_key, Body=image_bytes, ContentType='image/png')
return s3_key
def generate_image_and_store(prompt):
image_bytes = generate_image(prompt) # Assuming generate_image returns binary data
# Upload image to S3
image_key = upload_image_to_s3(image_bytes, prompt)
# Store prompt and S3 URL in DynamoDB
promptTable.put_item(
Item={
'prompt': prompt,
'image_url': image_key # Storing S3 URL instead of Base64 data
}
)
return image_key # Return the URL for display in the app
def getOrGenerate(prompt):
# Contact llm with this prompt
existingPrompt = findExistingPrompt(prompt)
existingPrompt = existingPrompt.strip()
if ("None" != existingPrompt):
entry = promptTable.get_item(
Key={
'prompt': existingPrompt
}
)
image_key = entry["Item"]["image_url"]
return image_key
else:
new_image_key = generate_image_and_store(prompt)
return new_image_key
def extractPrompts(item):
return item["prompt"]
def findExistingPrompt(prompt):
listPrompts = promptTable.scan()['Items']
prompts = map(extractPrompts, listPrompts)
cached_prompts_array = list(prompts) #get keys from db
cached_prompts = ""
# turn cached prompt array file into one string separated by '|'
for existingPrompt in cached_prompts_array:
cached_prompts = cached_prompts + " | " + existingPrompt
# Adding one final "| and adding the rest of the prompt for LLM"
cached_prompts = cached_prompts + " | Is there anything in this list of prompts (separated by the character '|') that is sufficiently and semantically the same to the prompt '" + prompt + "'? If there is, return it exactly without saying anything else. If not, if there's even a bit of a difference, return 'None'."
return contactLLM(cached_prompts)
def contactLLM(prompt):
# Create a Bedrock Runtime client in the AWS Region of your choice.
client = boto3.client("bedrock-runtime", region_name="us-west-2")
# Set the model ID, e.g., Llama 3 70b Instruct.
model_id = "meta.llama3-70b-instruct-v1:0"
# Embed the prompt in Llama 3's instruction format.
formatted_prompt = f"""
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{prompt}
<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
"""
# Format the request payload using the model's native structure.
native_request = {
"prompt": formatted_prompt,
"max_gen_len": 512,
"temperature": 0.5,
}
# Convert the native request to JSON.
request = json.dumps(native_request)
out = ""
try:
# Invoke the model with the request.
streaming_response = client.invoke_model_with_response_stream(
modelId=model_id, body=request
)
# Extract and print the response text in real-time.
for event in streaming_response["body"]:
chunk = json.loads(event["chunk"]["bytes"])
if "generation" in chunk:
# print(chunk["generation"], end="")
out = out + chunk["generation"]
except (ClientError, Exception) as e:
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
exit(1)
return out
#---------------------------------------------------------------------
# STREAMLIT APP
col1, col2, col3 = st.columns([1, 9, 1])
REMOVE_PADDING_FROM_SIDES="""
<style>
.block-container {
padding-top: 3rem;
padding-bottom: 0rem;
}
</style>
"""
st.markdown(REMOVE_PADDING_FROM_SIDES, unsafe_allow_html=True)
hide_img_fs = '''
<style>
button[title="View fullscreen"]{
visibility: hidden;}
</style>
'''
picture_rounded = """
<style>
.container1 {
border-radius: 8px;
}
.container2 {
/* Add styles for Container 2 if needed */
}
</style>
"""
st.markdown(hide_img_fs, unsafe_allow_html=True)
st.markdown(hide_img_fs, unsafe_allow_html=True)
with col2:
st.image("./media/logodark.svg")#557)
with col2:
with st.container():
col2_1, col2_2 = st.columns([3, 1])
# Centering
st.markdown("""
<style>
div {
text-align:center;
align-items: center;
justify-content: center;
}
</style>""", unsafe_allow_html=True)
# Button Width
st.markdown(
"""
<style>
div.stButton > button {
width: 150px;
}
</style>
""",
unsafe_allow_html=True,
)
with col2_1:
prompt = st.text_input(label="",placeholder="Enter your text prompt")
with col2_2:
st.markdown("<div style='margin-top: 27px;'></div>", unsafe_allow_html=True)
gen_button_clicked = st.button("Generate Image")
if (gen_button_clicked & (prompt != "")):
with col2:
with st.spinner("Generating image..."):
start_time = time.time()
image_key = getOrGenerate(prompt)
print(image_key)
image_response = s3_client.get_object(Bucket=bucket_name, Key=image_key)
image_base64 = image_response['Body'].read()
# Convert image to base64 to embed in HTML
image_data = base64.b64encode(image_base64).decode('utf-8')
image_html = f'<img src="data:image/png;base64,{image_data}" style="border-radius: 10px; width: 557px;"/>'
# Use markdown to render the image with rounded corners
st.markdown(image_html, unsafe_allow_html=True)
end_time = time.time()
elapsed_time = end_time - start_time
st.write(f"This image took {elapsed_time:.2f} seconds to generate!")