Skip to content

Commit

Permalink
Merge pull request #404 from NexaAI/perry/server-dev
Browse files Browse the repository at this point in the history
further fixed some problems related to flux support and the output pa…
  • Loading branch information
zhycheng614 authored Feb 25, 2025
2 parents 0896454 + 3e42ebd commit 31abc0e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
29 changes: 28 additions & 1 deletion nexa/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,7 @@ def list_models():

def remove_model(model_path):
model_path = NEXA_RUN_MODEL_MAP.get(model_path, model_path)
model_name = model_path.split(":")[0] if ":" in model_path else model_path

if not NEXA_MODEL_LIST_PATH.exists():
print("No models found.")
Expand Down Expand Up @@ -841,7 +842,13 @@ def remove_model(model_path):
for part in model_path_parts:
match = re.search(r'q\d_|fp16', part)
if match:
tag_name = part[match.start():]
# Get the substring from match start
tag_substring = part[match.start():]
# Remove .gguf if present at the end
if tag_substring.endswith('.gguf'):
tag_name = tag_substring[:-5] # remove '.gguf'
else:
tag_name = tag_substring
break
else:
raise ValueError(
Expand All @@ -857,6 +864,16 @@ def remove_model(model_path):
shutil.rmtree(item)
print(f"Deleted flux-related file: {item}")

# Remove the t5xxl entry from model_list
t5xxl_key = f"{model_name}:t5xxl-{tag_name}"
print(f"Removing t5xxl entry: {t5xxl_key}")
removed_t5xxl = model_list.pop(t5xxl_key, None)
if removed_t5xxl:
print(f"Removed from model_list: {t5xxl_key}")
else:
raise ValueError(
"Failed to remove the t5xxl entry from model_list.")

# Check remaining files: ae- and clip_l- files
remaining_files = list(parent_dir.glob("*"))
if len(remaining_files) == 2:
Expand All @@ -872,6 +889,16 @@ def remove_model(model_path):
else:
shutil.rmtree(item)
print(f"Deleted additional file: {item}")

# Remove corresponding entries from model_list
prefix = item.name.split('-')[0].lower()
key_to_remove = f"{model_name}:{prefix}-fp16"
removed_item = model_list.pop(key_to_remove, None)
if removed_item:
print(f"Removed from model_list: {key_to_remove}")
else:
raise ValueError(
f"Failed to remove the {prefix} entry from model_list.")

# Update the model list file
with open(NEXA_MODEL_LIST_PATH, "w") as f:
Expand Down
26 changes: 17 additions & 9 deletions nexa/gguf/server/nexa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class TextToSpeechRequest(BaseModel):
seed: int = 42
sampling_rate: int = 24000
language: Optional[str] = "en" # Only for 'outetts'
output_dir: Optional[str] = "nexa_server_output"


class FunctionCallRequest(BaseModel):
Expand Down Expand Up @@ -1541,6 +1542,7 @@ async def txt2img(
height: Optional[int] = Form(512),
sample_steps: Optional[int] = Form(20, description="set to 4 when using Flux for optimal results"),
seed: Optional[int] = Form(42),
output_dir: Optional[str] = Form("nexa_server_output", description="Directory to save generated images"),
):
try:
if model_type != "Computer Vision":
Expand All @@ -1554,12 +1556,13 @@ async def txt2img(

resp = {"created": time.time(), "data": []}

# Create output directory if it doesn't exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)

for image in generated_images:
id = int(time.time())
if not os.path.exists("nexa_server_output"):
os.makedirs("nexa_server_output")
image_path = os.path.join(
"nexa_server_output", f"txt2img_{id}.png")
image_path = os.path.join(output_dir, f"txt2img_{id}.png")
image.save(image_path)
img = ImageResponse(base64=base64_encode_image(
image_path), url=os.path.abspath(image_path))
Expand All @@ -1583,6 +1586,7 @@ async def img2img(
height: Optional[int] = Form(512),
sample_steps: Optional[int] = Form(20, description="set to 4 when using Flux for optimal results"),
seed: Optional[int] = Form(42),
output_dir: Optional[str] = Form("nexa_server_output", description="Directory to save generated images"),
):
try:
if model_type != "Computer Vision":
Expand All @@ -1603,12 +1607,13 @@ async def img2img(

resp = {"created": time.time(), "data": []}

# Create output directory if it doesn't exist
if not os.path.exists(output_dir):
os.makedirs(output_dir)

for image in generated_images:
id = int(time.time())
if not os.path.exists("nexa_server_output"):
os.makedirs("nexa_server_output")
image_path = os.path.join(
"nexa_server_output", f"img2img_{id}.png")
image_path = os.path.join(output_dir, f"img2img_{id}.png")
image.save(image_path)
img = ImageResponse(base64=base64_encode_image(
image_path), url=os.path.abspath(image_path))
Expand Down Expand Up @@ -1650,8 +1655,11 @@ async def txt2speech(request: TextToSpeechRequest):
)

audio_data = model.audio_generation(request.text)
output_dir = "nexa_server_output"

# Create output directory if it doesn't exist
output_dir = request.output_dir if hasattr(request, 'output_dir') else "nexa_server_output"
os.makedirs(output_dir, exist_ok=True)

file_path = model._save_audio(
audio_data, request.sampling_rate, output_dir)

Expand Down

0 comments on commit 31abc0e

Please # to comment.