Skip to content

Commit c538dea

Browse files
chiral-carbona-r-r-o-w
authored andcommitted
[Model Card] standardize advanced diffusion training sd15 lora (#7613)
* modelcard generation edit * add missed tag * fix param name * fix var * change str to dict * add use_dora check * use correct tags for lora * make style && make quality --------- Co-authored-by: Aryan <aryan@huggingface.co>
1 parent 1ef46d9 commit c538dea

File tree

1 file changed

+38
-37
lines changed

1 file changed

+38
-37
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_sd15_advanced.py

+38-37
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
convert_state_dict_to_kohya,
6868
is_wandb_available,
6969
)
70+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
7071
from diffusers.utils.import_utils import is_xformers_available
7172

7273

@@ -79,30 +80,27 @@
7980
def save_model_card(
8081
repo_id: str,
8182
use_dora: bool,
82-
images=None,
83-
base_model=str,
83+
images: list = None,
84+
base_model: str = None,
8485
train_text_encoder=False,
8586
train_text_encoder_ti=False,
8687
token_abstraction_dict=None,
87-
instance_prompt=str,
88-
validation_prompt=str,
88+
instance_prompt=None,
89+
validation_prompt=None,
8990
repo_folder=None,
9091
vae_path=None,
9192
):
92-
img_str = "widget:\n"
9393
lora = "lora" if not use_dora else "dora"
94-
for i, image in enumerate(images):
95-
image.save(os.path.join(repo_folder, f"image_{i}.png"))
96-
img_str += f"""
97-
- text: '{validation_prompt if validation_prompt else ' ' }'
98-
output:
99-
url:
100-
"image_{i}.png"
101-
"""
102-
if not images:
103-
img_str += f"""
104-
- text: '{instance_prompt}'
105-
"""
94+
95+
widget_dict = []
96+
if images is not None:
97+
for i, image in enumerate(images):
98+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
99+
widget_dict.append(
100+
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
101+
)
102+
else:
103+
widget_dict.append({"text": instance_prompt})
106104
embeddings_filename = f"{repo_folder}_emb"
107105
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
108106
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
@@ -137,24 +135,7 @@ def save_model_card(
137135
trigger_str += f"""
138136
to trigger concept `{key}` → use `{tokens}` in your prompt \n
139137
"""
140-
141-
yaml = f"""---
142-
tags:
143-
- stable-diffusion
144-
- stable-diffusion-diffusers
145-
- diffusers-training
146-
- text-to-image
147-
- diffusers
148-
- {lora}
149-
- template:sd-lora
150-
{img_str}
151-
base_model: {base_model}
152-
instance_prompt: {instance_prompt}
153-
license: openrail++
154-
---
155-
"""
156-
157-
model_card = f"""
138+
model_description = f"""
158139
# SD1.5 LoRA DreamBooth - {repo_id}
159140
160141
<Gallery />
@@ -202,8 +183,28 @@ def save_model_card(
202183
Special VAE used for training: {vae_path}.
203184
204185
"""
205-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
206-
f.write(yaml + model_card)
186+
model_card = load_or_create_model_card(
187+
repo_id_or_path=repo_id,
188+
from_training=True,
189+
license="openrail++",
190+
base_model=base_model,
191+
prompt=instance_prompt,
192+
model_description=model_description,
193+
inference=True,
194+
widget=widget_dict,
195+
)
196+
197+
tags = [
198+
"text-to-image",
199+
"diffusers",
200+
"diffusers-training",
201+
lora,
202+
"template:sd-lora" "stable-diffusion",
203+
"stable-diffusion-diffusers",
204+
]
205+
model_card = populate_model_card(model_card, tags=tags)
206+
207+
model_card.save(os.path.join(repo_folder, "README.md"))
207208

208209

209210
def import_model_class_from_model_name_or_path(

0 commit comments

Comments
 (0)