Skip to content

Commit 837ba90

Browse files
committedJan 17, 2024
Modularize InstructPix2Pix SDXL inferencing during and after training in examples
1 parent 619e3ab commit 837ba90

File tree

1 file changed

+77
-74
lines changed

1 file changed

+77
-74
lines changed
 

Diff for: ‎examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

+77-74
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from diffusers.utils.torch_utils import is_compiled_module
5656

5757

58+
if is_wandb_available():
59+
import wandb
60+
5861
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962
check_min_version("0.26.0.dev0")
6063

@@ -67,6 +70,59 @@
6770
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
6871

6972

73+
def log_validation(
74+
pipeline,
75+
args,
76+
accelerator,
77+
generator,
78+
global_step,
79+
is_final_validation=False,
80+
):
81+
logger.info(
82+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
83+
f" {args.validation_prompt}."
84+
)
85+
86+
pipeline = pipeline.to(accelerator.device)
87+
pipeline.set_progress_bar_config(disable=True)
88+
89+
if not is_final_validation:
90+
val_save_dir = os.path.join(args.output_dir, "validation_images")
91+
if not os.path.exists(val_save_dir):
92+
os.makedirs(val_save_dir)
93+
94+
original_image = (
95+
lambda image_url_or_path: load_image(image_url_or_path)
96+
if urlparse(image_url_or_path).scheme
97+
else Image.open(image_url_or_path).convert("RGB")
98+
)(args.val_image_url_or_path)
99+
100+
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
101+
edited_images = []
102+
# Run inference
103+
for val_img_idx in range(args.num_validation_images):
104+
a_val_img = pipeline(
105+
args.validation_prompt,
106+
image=original_image,
107+
num_inference_steps=20,
108+
image_guidance_scale=1.5,
109+
guidance_scale=7,
110+
generator=generator,
111+
).images[0]
112+
edited_images.append(a_val_img)
113+
# Save validation images
114+
if not is_final_validation:
115+
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
116+
117+
for tracker in accelerator.trackers:
118+
if tracker.name == "wandb":
119+
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
120+
for edited_image in edited_images:
121+
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
122+
logger_name = "test" if is_final_validation else "validation"
123+
tracker.log({logger_name: wandb_table})
124+
125+
70126
def import_model_class_from_model_name_or_path(
71127
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
72128
):
@@ -447,11 +503,6 @@ def main():
447503

448504
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
449505

450-
if args.report_to == "wandb":
451-
if not is_wandb_available():
452-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
453-
import wandb
454-
455506
# Make one log on every process with the configuration for debugging.
456507
logging.basicConfig(
457508
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -1111,11 +1162,6 @@ def collate_fn(examples):
11111162
### BEGIN: Perform validation every `validation_epochs` steps
11121163
if global_step % args.validation_steps == 0:
11131164
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1114-
logger.info(
1115-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1116-
f" {args.validation_prompt}."
1117-
)
1118-
11191165
# create pipeline
11201166
if args.use_ema:
11211167
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1181,16 @@ def collate_fn(examples):
11351181
variant=args.variant,
11361182
torch_dtype=weight_dtype,
11371183
)
1138-
pipeline = pipeline.to(accelerator.device)
1139-
pipeline.set_progress_bar_config(disable=True)
1140-
1141-
# run inference
1142-
# Save validation images
1143-
val_save_dir = os.path.join(args.output_dir, "validation_images")
1144-
if not os.path.exists(val_save_dir):
1145-
os.makedirs(val_save_dir)
1146-
1147-
original_image = (
1148-
lambda image_url_or_path: load_image(image_url_or_path)
1149-
if urlparse(image_url_or_path).scheme
1150-
else Image.open(image_url_or_path).convert("RGB")
1151-
)(args.val_image_url_or_path)
1152-
with torch.autocast(
1153-
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
1154-
):
1155-
edited_images = []
1156-
for val_img_idx in range(args.num_validation_images):
1157-
a_val_img = pipeline(
1158-
args.validation_prompt,
1159-
image=original_image,
1160-
num_inference_steps=20,
1161-
image_guidance_scale=1.5,
1162-
guidance_scale=7,
1163-
generator=generator,
1164-
).images[0]
1165-
edited_images.append(a_val_img)
1166-
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
1167-
1168-
for tracker in accelerator.trackers:
1169-
if tracker.name == "wandb":
1170-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1171-
for edited_image in edited_images:
1172-
wandb_table.add_data(
1173-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1174-
)
1175-
tracker.log({"validation": wandb_table})
1184+
1185+
log_validation(
1186+
pipeline,
1187+
args,
1188+
accelerator,
1189+
generator,
1190+
global_step,
1191+
is_final_validation=False,
1192+
)
1193+
11761194
if args.use_ema:
11771195
# Switch back to the original UNet parameters.
11781196
ema_unet.restore(unet.parameters())
@@ -1187,7 +1205,6 @@ def collate_fn(examples):
11871205
# Create the pipeline using the trained modules and save it.
11881206
accelerator.wait_for_everyone()
11891207
if accelerator.is_main_process:
1190-
unet = unwrap_model(unet)
11911208
if args.use_ema:
11921209
ema_unet.copy_to(unet.parameters())
11931210

@@ -1198,10 +1215,11 @@ def collate_fn(examples):
11981215
tokenizer=tokenizer_1,
11991216
tokenizer_2=tokenizer_2,
12001217
vae=vae,
1201-
unet=unet,
1218+
unet=unwrap_model(unet),
12021219
revision=args.revision,
12031220
variant=args.variant,
12041221
)
1222+
12051223
pipeline.save_pretrained(args.output_dir)
12061224

12071225
if args.push_to_hub:
@@ -1212,30 +1230,15 @@ def collate_fn(examples):
12121230
ignore_patterns=["step_*", "epoch_*"],
12131231
)
12141232

1215-
if args.validation_prompt is not None:
1216-
edited_images = []
1217-
pipeline = pipeline.to(accelerator.device)
1218-
with torch.autocast(str(accelerator.device).replace(":0", "")):
1219-
for _ in range(args.num_validation_images):
1220-
edited_images.append(
1221-
pipeline(
1222-
args.validation_prompt,
1223-
image=original_image,
1224-
num_inference_steps=20,
1225-
image_guidance_scale=1.5,
1226-
guidance_scale=7,
1227-
generator=generator,
1228-
).images[0]
1229-
)
1230-
1231-
for tracker in accelerator.trackers:
1232-
if tracker.name == "wandb":
1233-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1234-
for edited_image in edited_images:
1235-
wandb_table.add_data(
1236-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1237-
)
1238-
tracker.log({"test": wandb_table})
1233+
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1234+
log_validation(
1235+
pipeline,
1236+
args,
1237+
accelerator,
1238+
generator,
1239+
global_step=None,
1240+
is_final_validation=True,
1241+
)
12391242

12401243
accelerator.end_training()
12411244

0 commit comments

Comments
 (0)