55
55
from diffusers .utils .torch_utils import is_compiled_module
56
56
57
57
58
+ if is_wandb_available ():
59
+ import wandb
60
+
58
61
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
62
check_min_version ("0.26.0.dev0" )
60
63
67
70
TORCH_DTYPE_MAPPING = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
68
71
69
72
73
+ def log_validation (
74
+ pipeline ,
75
+ args ,
76
+ accelerator ,
77
+ generator ,
78
+ global_step ,
79
+ is_final_validation = False ,
80
+ ):
81
+ logger .info (
82
+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
83
+ f" { args .validation_prompt } ."
84
+ )
85
+
86
+ pipeline = pipeline .to (accelerator .device )
87
+ pipeline .set_progress_bar_config (disable = True )
88
+
89
+ if not is_final_validation :
90
+ val_save_dir = os .path .join (args .output_dir , "validation_images" )
91
+ if not os .path .exists (val_save_dir ):
92
+ os .makedirs (val_save_dir )
93
+
94
+ original_image = (
95
+ lambda image_url_or_path : load_image (image_url_or_path )
96
+ if urlparse (image_url_or_path ).scheme
97
+ else Image .open (image_url_or_path ).convert ("RGB" )
98
+ )(args .val_image_url_or_path )
99
+
100
+ with torch .autocast (str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16" ):
101
+ edited_images = []
102
+ # Run inference
103
+ for val_img_idx in range (args .num_validation_images ):
104
+ a_val_img = pipeline (
105
+ args .validation_prompt ,
106
+ image = original_image ,
107
+ num_inference_steps = 20 ,
108
+ image_guidance_scale = 1.5 ,
109
+ guidance_scale = 7 ,
110
+ generator = generator ,
111
+ ).images [0 ]
112
+ edited_images .append (a_val_img )
113
+ # Save validation images
114
+ if not is_final_validation :
115
+ a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
116
+
117
+ for tracker in accelerator .trackers :
118
+ if tracker .name == "wandb" :
119
+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
120
+ for edited_image in edited_images :
121
+ wandb_table .add_data (
122
+ wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
123
+ )
124
+ logger_name = "test" if is_final_validation else "validation"
125
+ tracker .log ({logger_name : wandb_table })
126
+
127
+
70
128
def import_model_class_from_model_name_or_path (
71
129
pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
72
130
):
@@ -447,11 +505,6 @@ def main():
447
505
448
506
generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
449
507
450
- if args .report_to == "wandb" :
451
- if not is_wandb_available ():
452
- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
453
- import wandb
454
-
455
508
# Make one log on every process with the configuration for debugging.
456
509
logging .basicConfig (
457
510
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1111,11 +1164,6 @@ def collate_fn(examples):
1111
1164
### BEGIN: Perform validation every `validation_epochs` steps
1112
1165
if global_step % args .validation_steps == 0 :
1113
1166
if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1114
- logger .info (
1115
- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1116
- f" { args .validation_prompt } ."
1117
- )
1118
-
1119
1167
# create pipeline
1120
1168
if args .use_ema :
1121
1169
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1183,16 @@ def collate_fn(examples):
1135
1183
variant = args .variant ,
1136
1184
torch_dtype = weight_dtype ,
1137
1185
)
1138
- pipeline = pipeline .to (accelerator .device )
1139
- pipeline .set_progress_bar_config (disable = True )
1140
-
1141
- # run inference
1142
- # Save validation images
1143
- val_save_dir = os .path .join (args .output_dir , "validation_images" )
1144
- if not os .path .exists (val_save_dir ):
1145
- os .makedirs (val_save_dir )
1146
-
1147
- original_image = (
1148
- lambda image_url_or_path : load_image (image_url_or_path )
1149
- if urlparse (image_url_or_path ).scheme
1150
- else Image .open (image_url_or_path ).convert ("RGB" )
1151
- )(args .val_image_url_or_path )
1152
- with torch .autocast (
1153
- str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
1154
- ):
1155
- edited_images = []
1156
- for val_img_idx in range (args .num_validation_images ):
1157
- a_val_img = pipeline (
1158
- args .validation_prompt ,
1159
- image = original_image ,
1160
- num_inference_steps = 20 ,
1161
- image_guidance_scale = 1.5 ,
1162
- guidance_scale = 7 ,
1163
- generator = generator ,
1164
- ).images [0 ]
1165
- edited_images .append (a_val_img )
1166
- a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
1167
-
1168
- for tracker in accelerator .trackers :
1169
- if tracker .name == "wandb" :
1170
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1171
- for edited_image in edited_images :
1172
- wandb_table .add_data (
1173
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1174
- )
1175
- tracker .log ({"validation" : wandb_table })
1186
+
1187
+ log_validation (
1188
+ pipeline ,
1189
+ args ,
1190
+ accelerator ,
1191
+ generator ,
1192
+ global_step ,
1193
+ is_final_validation = False ,
1194
+ )
1195
+
1176
1196
if args .use_ema :
1177
1197
# Switch back to the original UNet parameters.
1178
1198
ema_unet .restore (unet .parameters ())
@@ -1187,7 +1207,6 @@ def collate_fn(examples):
1187
1207
# Create the pipeline using the trained modules and save it.
1188
1208
accelerator .wait_for_everyone ()
1189
1209
if accelerator .is_main_process :
1190
- unet = unwrap_model (unet )
1191
1210
if args .use_ema :
1192
1211
ema_unet .copy_to (unet .parameters ())
1193
1212
@@ -1198,10 +1217,11 @@ def collate_fn(examples):
1198
1217
tokenizer = tokenizer_1 ,
1199
1218
tokenizer_2 = tokenizer_2 ,
1200
1219
vae = vae ,
1201
- unet = unet ,
1220
+ unet = unwrap_model ( unet ) ,
1202
1221
revision = args .revision ,
1203
1222
variant = args .variant ,
1204
1223
)
1224
+
1205
1225
pipeline .save_pretrained (args .output_dir )
1206
1226
1207
1227
if args .push_to_hub :
@@ -1212,30 +1232,15 @@ def collate_fn(examples):
1212
1232
ignore_patterns = ["step_*" , "epoch_*" ],
1213
1233
)
1214
1234
1215
- if args .validation_prompt is not None :
1216
- edited_images = []
1217
- pipeline = pipeline .to (accelerator .device )
1218
- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1219
- for _ in range (args .num_validation_images ):
1220
- edited_images .append (
1221
- pipeline (
1222
- args .validation_prompt ,
1223
- image = original_image ,
1224
- num_inference_steps = 20 ,
1225
- image_guidance_scale = 1.5 ,
1226
- guidance_scale = 7 ,
1227
- generator = generator ,
1228
- ).images [0 ]
1229
- )
1230
-
1231
- for tracker in accelerator .trackers :
1232
- if tracker .name == "wandb" :
1233
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1234
- for edited_image in edited_images :
1235
- wandb_table .add_data (
1236
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1237
- )
1238
- tracker .log ({"test" : wandb_table })
1235
+ if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1236
+ log_validation (
1237
+ pipeline ,
1238
+ args ,
1239
+ accelerator ,
1240
+ generator ,
1241
+ global_step = None ,
1242
+ is_final_validation = True ,
1243
+ )
1239
1244
1240
1245
accelerator .end_training ()
1241
1246
0 commit comments