Skip to content

Commit dbfb8f1

Browse files
committed
update
1 parent 0746cf9 commit dbfb8f1

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

src/diffusers/loaders/single_file.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,13 @@ def build_sub_model_components(
4343
checkpoint,
4444
local_files_only=False,
4545
load_safety_checker=False,
46-
**kwargs,
46+
model_type=None,
47+
image_size=None,
48+
**kwargs
4749
):
4850
if component_name in pipeline_components:
4951
return {}
5052

51-
model_type = kwargs.pop("model_type", None)
52-
image_size = kwargs.pop("image_size", None)
53-
5453
if component_name == "unet":
5554
num_in_channels = kwargs.pop("num_in_channels", None)
5655
unet_components = create_diffusers_unet_model_from_ldm(
@@ -112,10 +111,9 @@ def build_sub_model_components(
112111
def set_additional_components(
113112
pipeline_class_name,
114113
original_config,
115-
**kwargs,
114+
model_type=None,
116115
):
117116
components = {}
118-
model_type = kwargs.get("model_type", None)
119117
if pipeline_class_name in REFINER_PIPELINES:
120118
model_type = infer_model_type(original_config, model_type=model_type)
121119
is_refiner = model_type == "SDXL-Refiner"
@@ -235,6 +233,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
235233
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
236234
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
237235

236+
model_type = kwargs.pop("model_type", None)
237+
image_size = kwargs.pop("image_size", None)
238+
238239
init_kwargs = {}
239240
for name in expected_modules:
240241
if name in passed_class_obj:
@@ -247,13 +248,15 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
247248
original_config,
248249
checkpoint,
249250
pretrained_model_link_or_path,
251+
model_type=model_type,
252+
image_size=image_size,
250253
**kwargs,
251254
)
252255
if not components:
253256
continue
254257
init_kwargs.update(components)
255258

256-
additional_components = set_additional_components(class_name, original_config, **kwargs)
259+
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
257260
if additional_components:
258261
init_kwargs.update(additional_components)
259262

src/diffusers/loaders/single_file_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,14 @@
191191
]
192192

193193

194-
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
194+
VALID_HF_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
195195

196196

197197
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
198198
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
199199
weights_name = None
200200
repo_id = (None,)
201-
for prefix in VALID_URL_PREFIXES:
201+
for prefix in VALID_HF_URL_PREFIXES:
202202
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
203203
match = re.match(pattern, pretrained_model_name_or_path)
204204
if not match:

src/diffusers/models/unet_3d_condition.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def forward(
533533
534534
Args:
535535
sample (`torch.FloatTensor`):
536-
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
536+
The noisy input tensor with the following shape `(batch, channel, num_frames, height, width)`.
537537
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
538538
encoder_hidden_states (`torch.FloatTensor`):
539539
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -753,14 +753,14 @@ def test_download_local(self):
753753
def test_download_ckpt_diff_format_is_same(self):
754754
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
755755

756-
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
757-
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
758-
pipe.unet.set_attn_processor(AttnProcessor())
759-
pipe.to("cuda")
756+
sf_pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
757+
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
758+
sf_pipe.unet.set_attn_processor(AttnProcessor())
759+
sf_pipe.to("cuda")
760760

761761
inputs = self.get_inputs(torch_device)
762762
inputs["num_inference_steps"] = 5
763-
image_ckpt = pipe(**inputs).images[0]
763+
image_ckpt = sf_pipe(**inputs).images[0]
764764

765765
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
766766
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

0 commit comments

Comments
 (0)