Skip to content

Commit 7fada49

Browse files
DN6yiyixuxu
authored andcommitted
Expand Single File support in SD3 Pipeline (#8517)
* update * update
1 parent 46418bd commit 7fada49

File tree

4 files changed

+66
-17
lines changed

4 files changed

+66
-17
lines changed

docs/source/en/api/loaders/single_file.md

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
3535
- [`StableDiffusionXLInstructPix2PixPipeline`]
3636
- [`StableDiffusionXLControlNetPipeline`]
3737
- [`StableDiffusionXLKDiffusionPipeline`]
38+
- [`StableDiffusion3Pipeline`]
3839
- [`LatentConsistencyModelPipeline`]
3940
- [`LatentConsistencyModelImg2ImgPipeline`]
4041
- [`StableDiffusionControlNetXSPipeline`]
@@ -49,6 +50,7 @@ The [`~loaders.FromSingleFileMixin.from_single_file`] method allows you to load:
4950
- [`StableCascadeUNet`]
5051
- [`AutoencoderKL`]
5152
- [`ControlNetModel`]
53+
- [`SD3Transformer2DModel`]
5254

5355
## FromSingleFileMixin
5456

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_3.md

+29-8
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ The abstract from the paper is:
2121

2222
## Usage Example
2323

24-
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
24+
_As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to login so that your system knows you’ve accepted the gate._
2525

26-
Use the command below to log in:
26+
Use the command below to log in:
2727

2828
```bash
2929
huggingface-cli login
@@ -211,17 +211,38 @@ model = SD3Transformer2DModel.from_single_file("https://huggingface.co/stability
211211

212212
## Loading the single checkpoint for the `StableDiffusion3Pipeline`
213213

214+
### Loading the single file checkpoint without T5
215+
214216
```python
217+
import torch
215218
from diffusers import StableDiffusion3Pipeline
216-
from transformers import T5EncoderModel
217219

218-
text_encoder_3 = T5EncoderModel.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", subfolder="text_encoder_3", torch_dtype=torch.float16)
219-
pipe = StableDiffusion3Pipeline.from_single_file("https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors", torch_dtype=torch.float16, text_encoder_3=text_encoder_3)
220+
pipe = StableDiffusion3Pipeline.from_single_file(
221+
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips.safetensors",
222+
torch_dtype=torch.float16,
223+
text_encoder_3=None
224+
)
225+
pipe.enable_model_cpu_offload()
226+
227+
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
228+
image.save('sd3-single-file.png')
220229
```
221230

222-
<Tip>
223-
`from_single_file` support for the `fp8` version of the checkpoints is coming soon. Watch this space.
224-
</Tip>
231+
### Loading the single file checkpoint without T5
232+
233+
```python
234+
import torch
235+
from diffusers import StableDiffusion3Pipeline
236+
237+
pipe = StableDiffusion3Pipeline.from_single_file(
238+
"https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/sd3_medium_incl_clips_t5xxlfp8.safetensors",
239+
torch_dtype=torch.float16,
240+
)
241+
pipe.enable_model_cpu_offload()
242+
243+
image = pipe("a picture of a cat holding a sign that says hello world").images[0]
244+
image.save('sd3-single-file-t5-fp8.png')
245+
```
225246

226247
## StableDiffusion3Pipeline
227248

src/diffusers/loaders/single_file.py

+12
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
_legacy_load_safety_checker,
2929
_legacy_load_scheduler,
3030
create_diffusers_clip_model_from_ldm,
31+
create_diffusers_t5_model_from_checkpoint,
3132
fetch_diffusers_config,
3233
fetch_original_config,
3334
is_clip_model_in_single_file,
35+
is_t5_in_single_file,
3436
load_single_file_checkpoint,
3537
)
3638

@@ -118,6 +120,16 @@ def load_single_file_sub_model(
118120
is_legacy_loading=is_legacy_loading,
119121
)
120122

123+
elif is_transformers_model and is_t5_in_single_file(checkpoint):
124+
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125+
class_obj,
126+
checkpoint=checkpoint,
127+
config=cached_model_config_path,
128+
subfolder=name,
129+
torch_dtype=torch_dtype,
130+
local_files_only=local_files_only,
131+
)
132+
121133
elif is_tokenizer and is_legacy_loading:
122134
loaded_sub_model = _legacy_load_clip_tokenizer(
123135
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only

src/diffusers/loaders/single_file_utils.py

+23-9
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,6 @@
252252
LDM_CLIP_PREFIX_TO_REMOVE = [
253253
"cond_stage_model.transformer.",
254254
"conditioner.embedders.0.transformer.",
255-
"text_encoders.clip_l.transformer.",
256255
]
257256
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
258257
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
399398

400399

401400
def is_open_clip_sd3_model(checkpoint):
402-
is_open_clip_sdxl_refiner_model(checkpoint)
401+
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
402+
return True
403+
404+
return False
403405

404406

405407
def is_open_clip_sdxl_refiner_model(checkpoint):
406-
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
408+
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
407409
return True
408410

409411
return False
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
12331235
return new_checkpoint
12341236

12351237

1236-
def convert_ldm_clip_checkpoint(checkpoint):
1238+
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
12371239
keys = list(checkpoint.keys())
12381240
text_model_dict = {}
12391241

1240-
remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
1242+
remove_prefixes = []
1243+
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1244+
if remove_prefix:
1245+
remove_prefixes.append(remove_prefix)
12411246

12421247
for key in keys:
12431248
for prefix in remove_prefixes:
@@ -1376,6 +1381,13 @@ def create_diffusers_clip_model_from_ldm(
13761381
):
13771382
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
13781383

1384+
elif (
1385+
is_clip_sd3_model(checkpoint)
1386+
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1387+
):
1388+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1389+
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1390+
13791391
elif is_open_clip_model(checkpoint):
13801392
prefix = "cond_stage_model.model."
13811393
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1391,9 +1403,11 @@ def create_diffusers_clip_model_from_ldm(
13911403
prefix = "conditioner.embedders.0.model."
13921404
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
13931405

1394-
elif is_open_clip_sd3_model(checkpoint):
1395-
prefix = "text_encoders.clip_g.transformer."
1396-
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1406+
elif (
1407+
is_open_clip_sd3_model(checkpoint)
1408+
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1409+
):
1410+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
13971411

13981412
else:
13991413
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
@@ -1755,7 +1769,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
17551769
keys = list(checkpoint.keys())
17561770
text_model_dict = {}
17571771

1758-
remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
1772+
remove_prefixes = ["text_encoders.t5xxl.transformer."]
17591773

17601774
for key in keys:
17611775
for prefix in remove_prefixes:

0 commit comments

Comments
 (0)