Skip to content

Commit 8e7d6c0

Browse files
authored
[chore] fix: retain memory utility. (#9543)
* fix: retain memory utility. * fix * quality * free_memory.
1 parent b28675c commit 8e7d6c0

6 files changed

+33
-35
lines changed

examples/cogvideo/train_cogvideox_lora.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@
3838
from diffusers.models.embeddings import get_3d_rotary_pos_embed
3939
from diffusers.optimization import get_scheduler
4040
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
41-
from diffusers.training_utils import (
42-
cast_training_params,
43-
clear_objs_and_retain_memory,
44-
)
41+
from diffusers.training_utils import cast_training_params, free_memory
4542
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
4643
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
4744
from diffusers.utils.torch_utils import is_compiled_module
@@ -726,7 +723,8 @@ def log_validation(
726723
}
727724
)
728725

729-
clear_objs_and_retain_memory([pipe])
726+
del pipe
727+
free_memory()
730728

731729
return videos
732730

examples/controlnet/train_controlnet_flux.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from diffusers.models.controlnet_flux import FluxControlNetModel
5555
from diffusers.optimization import get_scheduler
5656
from diffusers.pipelines.flux.pipeline_flux_controlnet import FluxControlNetPipeline
57-
from diffusers.training_utils import clear_objs_and_retain_memory, compute_density_for_timestep_sampling
57+
from diffusers.training_utils import compute_density_for_timestep_sampling, free_memory
5858
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
5959
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
6060
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
@@ -193,7 +193,8 @@ def log_validation(
193193
else:
194194
logger.warning(f"image logging not implemented for {tracker.name}")
195195

196-
clear_objs_and_retain_memory([pipeline])
196+
del pipeline
197+
free_memory()
197198
return image_logs
198199

199200

@@ -1103,7 +1104,8 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
11031104
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint, batch_size=50
11041105
)
11051106

1106-
clear_objs_and_retain_memory([text_encoders, tokenizers])
1107+
del text_encoders, tokenizers, text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
1108+
free_memory()
11071109

11081110
# Then get the training dataset ready to be passed to the dataloader.
11091111
train_dataset = prepare_train_dataset(train_dataset, accelerator)

examples/controlnet/train_controlnet_sd3.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,7 @@
4949
StableDiffusion3ControlNetPipeline,
5050
)
5151
from diffusers.optimization import get_scheduler
52-
from diffusers.training_utils import (
53-
clear_objs_and_retain_memory,
54-
compute_density_for_timestep_sampling,
55-
compute_loss_weighting_for_sd3,
56-
)
52+
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3, free_memory
5753
from diffusers.utils import check_min_version, is_wandb_available
5854
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5955
from diffusers.utils.torch_utils import is_compiled_module
@@ -174,7 +170,8 @@ def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_v
174170
else:
175171
logger.warning(f"image logging not implemented for {tracker.name}")
176172

177-
clear_objs_and_retain_memory(pipeline)
173+
del pipeline
174+
free_memory()
178175

179176
if not is_final_validation:
180177
controlnet.to(accelerator.device)
@@ -1131,7 +1128,9 @@ def compute_text_embeddings(batch, text_encoders, tokenizers):
11311128
new_fingerprint = Hasher.hash(args)
11321129
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
11331130

1134-
clear_objs_and_retain_memory(text_encoders + tokenizers)
1131+
del text_encoder_one, text_encoder_two, text_encoder_three
1132+
del tokenizer_one, tokenizer_two, tokenizer_three
1133+
free_memory()
11351134

11361135
train_dataloader = torch.utils.data.DataLoader(
11371136
train_dataset,

examples/dreambooth/train_dreambooth_lora_flux.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
from diffusers.training_utils import (
5656
_set_state_dict_into_text_encoder,
5757
cast_training_params,
58-
clear_objs_and_retain_memory,
5958
compute_density_for_timestep_sampling,
6059
compute_loss_weighting_for_sd3,
60+
free_memory,
6161
)
6262
from diffusers.utils import (
6363
check_min_version,
@@ -1437,7 +1437,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14371437

14381438
# Clear the memory here
14391439
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
1440-
clear_objs_and_retain_memory([tokenizers, text_encoders, text_encoder_one, text_encoder_two])
1440+
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
1441+
free_memory()
14411442

14421443
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14431444
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1480,7 +1481,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14801481
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
14811482

14821483
if args.validation_prompt is None:
1483-
clear_objs_and_retain_memory([vae])
1484+
del vae
1485+
free_memory()
14841486

14851487
# Scheduler and math around the number of training steps.
14861488
overrode_max_train_steps = False
@@ -1817,7 +1819,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18171819
torch_dtype=weight_dtype,
18181820
)
18191821
if not args.train_text_encoder:
1820-
clear_objs_and_retain_memory([text_encoder_one, text_encoder_two])
1822+
del text_encoder_one, text_encoder_two
1823+
free_memory()
18211824

18221825
# Save the lora layers
18231826
accelerator.wait_for_everyone()

examples/dreambooth/train_dreambooth_lora_sd3.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,9 @@
5555
from diffusers.training_utils import (
5656
_set_state_dict_into_text_encoder,
5757
cast_training_params,
58-
clear_objs_and_retain_memory,
5958
compute_density_for_timestep_sampling,
6059
compute_loss_weighting_for_sd3,
60+
free_memory,
6161
)
6262
from diffusers.utils import (
6363
check_min_version,
@@ -211,7 +211,8 @@ def log_validation(
211211
}
212212
)
213213

214-
clear_objs_and_retain_memory(objs=[pipeline])
214+
del pipeline
215+
free_memory()
215216

216217
return images
217218

@@ -1106,7 +1107,8 @@ def main(args):
11061107
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
11071108
image.save(image_filename)
11081109

1109-
clear_objs_and_retain_memory(objs=[pipeline])
1110+
del pipeline
1111+
free_memory()
11101112

11111113
# Handle the repository creation
11121114
if accelerator.is_main_process:
@@ -1453,9 +1455,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
14531455
# Clear the memory here
14541456
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
14551457
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
1456-
clear_objs_and_retain_memory(
1457-
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
1458-
)
1458+
del tokenizers, text_encoders
1459+
del text_encoder_one, text_encoder_two, text_encoder_three
1460+
free_memory()
14591461

14601462
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
14611463
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1791,11 +1793,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17911793
epoch=epoch,
17921794
torch_dtype=weight_dtype,
17931795
)
1794-
objs = []
1795-
if not args.train_text_encoder:
1796-
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])
17971796

1798-
clear_objs_and_retain_memory(objs=objs)
1797+
del text_encoder_one, text_encoder_two, text_encoder_three
1798+
free_memory()
17991799

18001800
# Save the lora layers
18011801
accelerator.wait_for_everyone()

src/diffusers/training_utils.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,8 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
260260
return weighting
261261

262262

263-
def clear_objs_and_retain_memory(objs: List[Any]):
264-
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
265-
if len(objs) >= 1:
266-
for obj in objs:
267-
del obj
268-
263+
def free_memory():
264+
"""Runs garbage collection. Then clears the cache of the available accelerator."""
269265
gc.collect()
270266

271267
if torch.cuda.is_available():

0 commit comments

Comments
 (0)