Skip to content

Commit 1835510

Browse files
authored
Remove torch_dtype in to() to end deprecation (#6886)
* remove torch_dtype from to() * remove torch_dtype from usage scripts. * remove old lora backend * Revert "remove old lora backend" This reverts commit adcddf6.
1 parent 4a3d528 commit 1835510

12 files changed

+15
-37
lines changed

scripts/convert_gligen_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,6 @@ def convert_gligen_to_diffusers(
576576
)
577577

578578
if args.half:
579-
pipe.to(torch_dtype=torch.float16)
579+
pipe.to(dtype=torch.float16)
580580

581581
pipe.save_pretrained(args.dump_path)

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@
179179
)
180180

181181
if args.half:
182-
pipe.to(torch_dtype=torch.float16)
182+
pipe.to(dtype=torch.float16)
183183

184184
if args.controlnet:
185185
# only save the controlnet model

scripts/convert_zero123_to_diffusers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,6 @@ def convert_from_original_zero123_ckpt(checkpoint_path, original_config_file, ex
801801
)
802802

803803
if args.half:
804-
pipe.to(torch_dtype=torch.float16)
804+
pipe.to(dtype=torch.float16)
805805

806806
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -775,32 +775,10 @@ def to(self, *args, **kwargs):
775775
Returns:
776776
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
777777
"""
778-
779-
torch_dtype = kwargs.pop("torch_dtype", None)
780-
if torch_dtype is not None:
781-
deprecate("torch_dtype", "0.27.0", "")
782-
torch_device = kwargs.pop("torch_device", None)
783-
if torch_device is not None:
784-
deprecate("torch_device", "0.27.0", "")
785-
786-
dtype_kwarg = kwargs.pop("dtype", None)
787-
device_kwarg = kwargs.pop("device", None)
778+
dtype = kwargs.pop("dtype", None)
779+
device = kwargs.pop("device", None)
788780
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
789781

790-
if torch_dtype is not None and dtype_kwarg is not None:
791-
raise ValueError(
792-
"You have passed both `torch_dtype` and `dtype` as a keyword argument. Please make sure to only pass `dtype`."
793-
)
794-
795-
dtype = torch_dtype or dtype_kwarg
796-
797-
if torch_device is not None and device_kwarg is not None:
798-
raise ValueError(
799-
"You have passed both `torch_device` and `device` as a keyword argument. Please make sure to only pass `device`."
800-
)
801-
802-
device = torch_device or device_kwarg
803-
804782
dtype_arg = None
805783
device_arg = None
806784
if len(args) == 1:
@@ -873,12 +851,12 @@ def module_is_offloaded(module):
873851

874852
if is_loaded_in_8bit and dtype is not None:
875853
logger.warning(
876-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
854+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
877855
)
878856

879857
if is_loaded_in_8bit and device is not None:
880858
logger.warning(
881-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
859+
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
882860
)
883861
else:
884862
module.to(device, dtype)

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def test_to_dtype(self):
218218
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
219219
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
220220

221-
pipe.to(torch_dtype=torch.float16)
221+
pipe.to(dtype=torch.float16)
222222
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
223223
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
224224

tests/pipelines/animatediff/test_animatediff_video2video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ def test_to_dtype(self):
224224
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
225225
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
226226

227-
pipe.to(torch_dtype=torch.float16)
227+
pipe.to(dtype=torch.float16)
228228
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
229229
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
230230

tests/pipelines/audioldm2/test_audioldm2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def test_to_dtype(self):
483483
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
484484

485485
# Once we send to fp16, all params are in half-precision, including the logit scale
486-
pipe.to(torch_dtype=torch.float16)
486+
pipe.to(dtype=torch.float16)
487487
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
488488
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
489489

tests/pipelines/musicldm/test_musicldm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def test_to_dtype(self):
400400
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
401401

402402
# Once we send to fp16, all params are in half-precision, including the logit scale
403-
pipe.to(torch_dtype=torch.float16)
403+
pipe.to(dtype=torch.float16)
404404
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
405405
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
406406

tests/pipelines/pia/test_pia.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test_to_dtype(self):
231231
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
232232
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
233233

234-
pipe.to(torch_dtype=torch.float16)
234+
pipe.to(dtype=torch.float16)
235235
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
236236
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
237237

tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def test_to_dtype(self):
396396
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
397397
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
398398

399-
pipe.to(torch_dtype=torch.float16)
399+
pipe.to(dtype=torch.float16)
400400
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
401401
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
402402

tests/pipelines/test_pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1623,7 +1623,7 @@ def test_pipe_to(self):
16231623
sd1 = sd.to(torch.float16)
16241624
sd2 = sd.to(None, torch.float16)
16251625
sd3 = sd.to(dtype=torch.float16)
1626-
sd4 = sd.to(torch_dtype=torch.float16)
1626+
sd4 = sd.to(dtype=torch.float16)
16271627
sd5 = sd.to(None, dtype=torch.float16)
16281628
sd6 = sd.to(None, torch_dtype=torch.float16)
16291629

tests/pipelines/test_pipelines_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -716,7 +716,7 @@ def test_to_dtype(self):
716716
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
717717
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
718718

719-
pipe.to(torch_dtype=torch.float16)
719+
pipe.to(dtype=torch.float16)
720720
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
721721
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
722722

0 commit comments

Comments
 (0)