|
67 | 67 | convert_state_dict_to_kohya,
|
68 | 68 | is_wandb_available,
|
69 | 69 | )
|
| 70 | +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card |
70 | 71 | from diffusers.utils.import_utils import is_xformers_available
|
71 | 72 |
|
72 | 73 |
|
|
79 | 80 | def save_model_card(
|
80 | 81 | repo_id: str,
|
81 | 82 | use_dora: bool,
|
82 |
| - images=None, |
83 |
| - base_model=str, |
| 83 | + images: list = None, |
| 84 | + base_model: str = None, |
84 | 85 | train_text_encoder=False,
|
85 | 86 | train_text_encoder_ti=False,
|
86 | 87 | token_abstraction_dict=None,
|
87 |
| - instance_prompt=str, |
88 |
| - validation_prompt=str, |
| 88 | + instance_prompt=None, |
| 89 | + validation_prompt=None, |
89 | 90 | repo_folder=None,
|
90 | 91 | vae_path=None,
|
91 | 92 | ):
|
92 |
| - img_str = "widget:\n" |
93 | 93 | lora = "lora" if not use_dora else "dora"
|
94 |
| - for i, image in enumerate(images): |
95 |
| - image.save(os.path.join(repo_folder, f"image_{i}.png")) |
96 |
| - img_str += f""" |
97 |
| - - text: '{validation_prompt if validation_prompt else ' ' }' |
98 |
| - output: |
99 |
| - url: |
100 |
| - "image_{i}.png" |
101 |
| - """ |
102 |
| - if not images: |
103 |
| - img_str += f""" |
104 |
| - - text: '{instance_prompt}' |
105 |
| - """ |
| 94 | + |
| 95 | + widget_dict = [] |
| 96 | + if images is not None: |
| 97 | + for i, image in enumerate(images): |
| 98 | + image.save(os.path.join(repo_folder, f"image_{i}.png")) |
| 99 | + widget_dict.append( |
| 100 | + {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}} |
| 101 | + ) |
| 102 | + else: |
| 103 | + widget_dict.append({"text": instance_prompt}) |
106 | 104 | embeddings_filename = f"{repo_folder}_emb"
|
107 | 105 | instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
|
108 | 106 | ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
|
@@ -137,24 +135,7 @@ def save_model_card(
|
137 | 135 | trigger_str += f"""
|
138 | 136 | to trigger concept `{key}` → use `{tokens}` in your prompt \n
|
139 | 137 | """
|
140 |
| - |
141 |
| - yaml = f"""--- |
142 |
| -tags: |
143 |
| -- stable-diffusion |
144 |
| -- stable-diffusion-diffusers |
145 |
| -- diffusers-training |
146 |
| -- text-to-image |
147 |
| -- diffusers |
148 |
| -- {lora} |
149 |
| -- template:sd-lora |
150 |
| -{img_str} |
151 |
| -base_model: {base_model} |
152 |
| -instance_prompt: {instance_prompt} |
153 |
| -license: openrail++ |
154 |
| ---- |
155 |
| -""" |
156 |
| - |
157 |
| - model_card = f""" |
| 138 | + model_description = f""" |
158 | 139 | # SD1.5 LoRA DreamBooth - {repo_id}
|
159 | 140 |
|
160 | 141 | <Gallery />
|
@@ -202,8 +183,28 @@ def save_model_card(
|
202 | 183 | Special VAE used for training: {vae_path}.
|
203 | 184 |
|
204 | 185 | """
|
205 |
| - with open(os.path.join(repo_folder, "README.md"), "w") as f: |
206 |
| - f.write(yaml + model_card) |
| 186 | + model_card = load_or_create_model_card( |
| 187 | + repo_id_or_path=repo_id, |
| 188 | + from_training=True, |
| 189 | + license="openrail++", |
| 190 | + base_model=base_model, |
| 191 | + prompt=instance_prompt, |
| 192 | + model_description=model_description, |
| 193 | + inference=True, |
| 194 | + widget=widget_dict, |
| 195 | + ) |
| 196 | + |
| 197 | + tags = [ |
| 198 | + "text-to-image", |
| 199 | + "diffusers", |
| 200 | + "diffusers-training", |
| 201 | + lora, |
| 202 | + "template:sd-lora" "stable-diffusion", |
| 203 | + "stable-diffusion-diffusers", |
| 204 | + ] |
| 205 | + model_card = populate_model_card(model_card, tags=tags) |
| 206 | + |
| 207 | + model_card.save(os.path.join(repo_folder, "README.md")) |
207 | 208 |
|
208 | 209 |
|
209 | 210 | def import_model_class_from_model_name_or_path(
|
|
0 commit comments