Skip to content

Commit 0896454

Browse files
authored
Merge pull request #403 from NexaAI/perry/server-dev
Perry/server dev
2 parents 2c759da + 8d6f41f commit 0896454

File tree

3 files changed

+58
-6
lines changed

3 files changed

+58
-6
lines changed

nexa/constants.py

+12
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,15 @@ class ModelType(Enum):
244244
NEXA_RUN_T5XXL_MAP = {
245245
"flux": "FLUX.1-schnell:t5xxl-q4_0",
246246
"FLUX.1-schnell:q4_0": "FLUX.1-schnell:t5xxl-q4_0",
247+
"FLUX.1-schnell:flux1-schnell-q4_0": "FLUX.1-schnell:t5xxl-q4_0",
247248
"FLUX.1-schnell:q5_0": "FLUX.1-schnell:t5xxl-q5_0",
249+
"FLUX.1-schnell:flux1-schnell-q5_0": "FLUX.1-schnell:t5xxl-q5_0",
248250
"FLUX.1-schnell:q5_1": "FLUX.1-schnell:t5xxl-q5_1",
251+
"FLUX.1-schnell:flux1-schnell-q5_1": "FLUX.1-schnell:t5xxl-q5_1",
249252
"FLUX.1-schnell:q8_0": "FLUX.1-schnell:t5xxl-q8_0",
253+
"FLUX.1-schnell:flux1-schnell-q8_0": "FLUX.1-schnell:t5xxl-q8_0",
250254
"FLUX.1-schnell:fp16": "FLUX.1-schnell:t5xxl-fp16",
255+
"FLUX.1-schnell:flux1-schnell-fp16": "FLUX.1-schnell:t5xxl-fp16",
251256
}
252257

253258
NEXA_RUN_MODEL_MAP_IMAGE = {
@@ -546,3 +551,10 @@ class ModelType(Enum):
546551
"all-MiniLM-L6-v2": ModelType.TEXT_EMBEDDING,
547552
"all-MiniLM-L12-v2": ModelType.TEXT_EMBEDDING,
548553
}
554+
555+
NEXA_LIST_FILTERED_MODEL_PREFIXES = [
556+
'projector',
557+
't5xxl-',
558+
'ae-',
559+
'clip_l-'
560+
]

nexa/general.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import concurrent.futures
88
import time
99
import os
10+
import re
1011
from tqdm import tqdm
1112
import platform
1213
import tempfile
@@ -23,6 +24,7 @@
2324
NEXA_RUN_MODEL_MAP,
2425
NEXA_TOKEN_PATH,
2526
NEXA_OFFICIAL_MODELS_TYPE,
27+
NEXA_LIST_FILTERED_MODEL_PREFIXES
2628
)
2729
from nexa.constants import ModelType
2830

@@ -679,17 +681,13 @@ def add_model_to_list(model_name, model_location, model_type, run_type):
679681
if tag_name.startswith("model-"):
680682
tag_name = tag_name[6:]
681683
model_name = f"{model_name.split(':')[0]}:{tag_name}"
682-
else:
683-
return
684684

685685
# For Computer Vision Flux model, should remove the "flux1-schnell-" prefix from the tag name
686686
if run_type == "Computer Vision":
687687
tag_name = model_name.split(":")[1]
688688
if tag_name.startswith("flux1-schnell-"):
689689
tag_name = tag_name[14:]
690690
model_name = f"{model_name.split(':')[0]}:{tag_name}"
691-
else:
692-
return
693691

694692
model_list[model_name] = {
695693
"type": model_type,
@@ -737,7 +735,8 @@ def list_models():
737735
filtered_list = {
738736
model_name: model_info
739737
for model_name, model_info in model_list.items()
740-
if ':' not in model_name or not model_name.split(':')[1].startswith('projector')
738+
if ':' not in model_name or
739+
not any(model_name.split(':')[1].startswith(prefix) for prefix in NEXA_LIST_FILTERED_MODEL_PREFIXES)
741740
}
742741

743742
table = [
@@ -812,7 +811,7 @@ def remove_model(model_path):
812811
else:
813812
print(f"Warning: Model location not found: {model_path}")
814813

815-
# Delete projectors only if model was successfully deleted
814+
# Delete projectors or flux related files only if model was successfully deleted
816815
if model_deleted:
817816
parent_dir = model_path.parent
818817
gguf_files = list(parent_dir.glob("*.gguf"))
@@ -834,6 +833,46 @@ def remove_model(model_path):
834833
shutil.rmtree(projector_location)
835834
print(f"Deleted projector: {projector_location}")
836835

836+
# Check if the model path contains "flux"
837+
if 'flux' in str(model_path).lower():
838+
model_path_parts = str(model_path).split(":")
839+
tag_name = None
840+
841+
for part in model_path_parts:
842+
match = re.search(r'q\d_|fp16', part)
843+
if match:
844+
tag_name = part[match.start():]
845+
break
846+
else:
847+
raise ValueError(
848+
"Invalid model path. Expected a tag name in the model path.")
849+
850+
if tag_name:
851+
# First delete files matching tag_name
852+
for item in parent_dir.glob(f"*{tag_name}*"):
853+
if item.exists():
854+
if item.is_file():
855+
item.unlink()
856+
else:
857+
shutil.rmtree(item)
858+
print(f"Deleted flux-related file: {item}")
859+
860+
# Check remaining files: ae- and clip_l- files
861+
remaining_files = list(parent_dir.glob("*"))
862+
if len(remaining_files) == 2:
863+
file_names = [f.name.lower() for f in remaining_files]
864+
has_ae = any(name.startswith("ae-") for name in file_names)
865+
has_clip = any(name.startswith("clip_l-") for name in file_names)
866+
867+
if has_ae and has_clip:
868+
for item in remaining_files:
869+
if item.exists():
870+
if item.is_file():
871+
item.unlink()
872+
else:
873+
shutil.rmtree(item)
874+
print(f"Deleted additional file: {item}")
875+
837876
# Update the model list file
838877
with open(NEXA_MODEL_LIST_PATH, "w") as f:
839878
json.dump(model_list, f, indent=2)

nexa/gguf/nexa_inference_image.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def __init__(self, model_path: str = None, local_path: str = None, **kwargs):
9999
if self.clip_l_path:
100100
self.clip_l_downloaded_path, _ = pull_model(
101101
self.clip_l_path, **kwargs)
102+
102103
if "lcm-dreamshaper" in self.model_path:
103104
# print('Loading lcm default arguments')
104105
self.params = DEFAULT_IMG_GEN_PARAMS_LCM.copy()

0 commit comments

Comments
 (0)