Skip to content

Commit 0b7748e

Browse files
committed
Modularize InstructPix2Pix SDXL inferencing during and after training in examples
1 parent 619e3ab commit 0b7748e

File tree

1 file changed

+79
-74
lines changed

1 file changed

+79
-74
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py

+79-74
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@
5555
from diffusers.utils.torch_utils import is_compiled_module
5656

5757

58+
if is_wandb_available():
59+
import wandb
60+
5861
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
5962
check_min_version("0.26.0.dev0")
6063

@@ -67,6 +70,61 @@
6770
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
6871

6972

73+
def log_validation(
74+
pipeline,
75+
args,
76+
accelerator,
77+
generator,
78+
global_step,
79+
is_final_validation=False,
80+
):
81+
logger.info(
82+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
83+
f" {args.validation_prompt}."
84+
)
85+
86+
pipeline = pipeline.to(accelerator.device)
87+
pipeline.set_progress_bar_config(disable=True)
88+
89+
if not is_final_validation:
90+
val_save_dir = os.path.join(args.output_dir, "validation_images")
91+
if not os.path.exists(val_save_dir):
92+
os.makedirs(val_save_dir)
93+
94+
original_image = (
95+
lambda image_url_or_path: load_image(image_url_or_path)
96+
if urlparse(image_url_or_path).scheme
97+
else Image.open(image_url_or_path).convert("RGB")
98+
)(args.val_image_url_or_path)
99+
100+
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
101+
edited_images = []
102+
# Run inference
103+
for val_img_idx in range(args.num_validation_images):
104+
a_val_img = pipeline(
105+
args.validation_prompt,
106+
image=original_image,
107+
num_inference_steps=20,
108+
image_guidance_scale=1.5,
109+
guidance_scale=7,
110+
generator=generator,
111+
).images[0]
112+
edited_images.append(a_val_img)
113+
# Save validation images
114+
if not is_final_validation:
115+
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
116+
117+
for tracker in accelerator.trackers:
118+
if tracker.name == "wandb":
119+
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
120+
for edited_image in edited_images:
121+
wandb_table.add_data(
122+
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
123+
)
124+
logger_name = "test" if is_final_validation else "validation"
125+
tracker.log({logger_name: wandb_table})
126+
127+
70128
def import_model_class_from_model_name_or_path(
71129
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
72130
):
@@ -447,11 +505,6 @@ def main():
447505

448506
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
449507

450-
if args.report_to == "wandb":
451-
if not is_wandb_available():
452-
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
453-
import wandb
454-
455508
# Make one log on every process with the configuration for debugging.
456509
logging.basicConfig(
457510
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -1111,11 +1164,6 @@ def collate_fn(examples):
11111164
### BEGIN: Perform validation every `validation_epochs` steps
11121165
if global_step % args.validation_steps == 0:
11131166
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1114-
logger.info(
1115-
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1116-
f" {args.validation_prompt}."
1117-
)
1118-
11191167
# create pipeline
11201168
if args.use_ema:
11211169
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1183,16 @@ def collate_fn(examples):
11351183
variant=args.variant,
11361184
torch_dtype=weight_dtype,
11371185
)
1138-
pipeline = pipeline.to(accelerator.device)
1139-
pipeline.set_progress_bar_config(disable=True)
1140-
1141-
# run inference
1142-
# Save validation images
1143-
val_save_dir = os.path.join(args.output_dir, "validation_images")
1144-
if not os.path.exists(val_save_dir):
1145-
os.makedirs(val_save_dir)
1146-
1147-
original_image = (
1148-
lambda image_url_or_path: load_image(image_url_or_path)
1149-
if urlparse(image_url_or_path).scheme
1150-
else Image.open(image_url_or_path).convert("RGB")
1151-
)(args.val_image_url_or_path)
1152-
with torch.autocast(
1153-
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
1154-
):
1155-
edited_images = []
1156-
for val_img_idx in range(args.num_validation_images):
1157-
a_val_img = pipeline(
1158-
args.validation_prompt,
1159-
image=original_image,
1160-
num_inference_steps=20,
1161-
image_guidance_scale=1.5,
1162-
guidance_scale=7,
1163-
generator=generator,
1164-
).images[0]
1165-
edited_images.append(a_val_img)
1166-
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
1167-
1168-
for tracker in accelerator.trackers:
1169-
if tracker.name == "wandb":
1170-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1171-
for edited_image in edited_images:
1172-
wandb_table.add_data(
1173-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1174-
)
1175-
tracker.log({"validation": wandb_table})
1186+
1187+
log_validation(
1188+
pipeline,
1189+
args,
1190+
accelerator,
1191+
generator,
1192+
global_step,
1193+
is_final_validation=False,
1194+
)
1195+
11761196
if args.use_ema:
11771197
# Switch back to the original UNet parameters.
11781198
ema_unet.restore(unet.parameters())
@@ -1187,7 +1207,6 @@ def collate_fn(examples):
11871207
# Create the pipeline using the trained modules and save it.
11881208
accelerator.wait_for_everyone()
11891209
if accelerator.is_main_process:
1190-
unet = unwrap_model(unet)
11911210
if args.use_ema:
11921211
ema_unet.copy_to(unet.parameters())
11931212

@@ -1198,10 +1217,11 @@ def collate_fn(examples):
11981217
tokenizer=tokenizer_1,
11991218
tokenizer_2=tokenizer_2,
12001219
vae=vae,
1201-
unet=unet,
1220+
unet=unwrap_model(unet),
12021221
revision=args.revision,
12031222
variant=args.variant,
12041223
)
1224+
12051225
pipeline.save_pretrained(args.output_dir)
12061226

12071227
if args.push_to_hub:
@@ -1212,30 +1232,15 @@ def collate_fn(examples):
12121232
ignore_patterns=["step_*", "epoch_*"],
12131233
)
12141234

1215-
if args.validation_prompt is not None:
1216-
edited_images = []
1217-
pipeline = pipeline.to(accelerator.device)
1218-
with torch.autocast(str(accelerator.device).replace(":0", "")):
1219-
for _ in range(args.num_validation_images):
1220-
edited_images.append(
1221-
pipeline(
1222-
args.validation_prompt,
1223-
image=original_image,
1224-
num_inference_steps=20,
1225-
image_guidance_scale=1.5,
1226-
guidance_scale=7,
1227-
generator=generator,
1228-
).images[0]
1229-
)
1230-
1231-
for tracker in accelerator.trackers:
1232-
if tracker.name == "wandb":
1233-
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
1234-
for edited_image in edited_images:
1235-
wandb_table.add_data(
1236-
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
1237-
)
1238-
tracker.log({"test": wandb_table})
1235+
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
1236+
log_validation(
1237+
pipeline,
1238+
args,
1239+
accelerator,
1240+
generator,
1241+
global_step=None,
1242+
is_final_validation=True,
1243+
)
12391244

12401245
accelerator.end_training()
12411246

0 commit comments

Comments
 (0)