55
55
from diffusers .utils .torch_utils import is_compiled_module
56
56
57
57
58
+ if is_wandb_available ():
59
+ import wandb
60
+
58
61
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59
62
check_min_version ("0.26.0.dev0" )
60
63
67
70
TORCH_DTYPE_MAPPING = {"fp32" : torch .float32 , "fp16" : torch .float16 , "bf16" : torch .bfloat16 }
68
71
69
72
73
+ def log_validation (
74
+ pipeline ,
75
+ args ,
76
+ accelerator ,
77
+ generator ,
78
+ global_step ,
79
+ is_final_validation = False ,
80
+ ):
81
+ logger .info (
82
+ f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
83
+ f" { args .validation_prompt } ."
84
+ )
85
+
86
+ pipeline = pipeline .to (accelerator .device )
87
+ pipeline .set_progress_bar_config (disable = True )
88
+
89
+ if not is_final_validation :
90
+ val_save_dir = os .path .join (args .output_dir , "validation_images" )
91
+ if not os .path .exists (val_save_dir ):
92
+ os .makedirs (val_save_dir )
93
+
94
+ original_image = (
95
+ lambda image_url_or_path : load_image (image_url_or_path )
96
+ if urlparse (image_url_or_path ).scheme
97
+ else Image .open (image_url_or_path ).convert ("RGB" )
98
+ )(args .val_image_url_or_path )
99
+
100
+ with torch .autocast (str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16" ):
101
+ edited_images = []
102
+ # Run inference
103
+ for val_img_idx in range (args .num_validation_images ):
104
+ a_val_img = pipeline (
105
+ args .validation_prompt ,
106
+ image = original_image ,
107
+ num_inference_steps = 20 ,
108
+ image_guidance_scale = 1.5 ,
109
+ guidance_scale = 7 ,
110
+ generator = generator ,
111
+ ).images [0 ]
112
+ edited_images .append (a_val_img )
113
+ # Save validation images
114
+ if not is_final_validation :
115
+ a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
116
+
117
+ for tracker in accelerator .trackers :
118
+ if tracker .name == "wandb" :
119
+ wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
120
+ for edited_image in edited_images :
121
+ wandb_table .add_data (wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt )
122
+ logger_name = "test" if is_final_validation else "validation"
123
+ tracker .log ({logger_name : wandb_table })
124
+
125
+
70
126
def import_model_class_from_model_name_or_path (
71
127
pretrained_model_name_or_path : str , revision : str , subfolder : str = "text_encoder"
72
128
):
@@ -447,11 +503,6 @@ def main():
447
503
448
504
generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed )
449
505
450
- if args .report_to == "wandb" :
451
- if not is_wandb_available ():
452
- raise ImportError ("Make sure to install wandb if you want to use it for logging during training." )
453
- import wandb
454
-
455
506
# Make one log on every process with the configuration for debugging.
456
507
logging .basicConfig (
457
508
format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
@@ -1111,11 +1162,6 @@ def collate_fn(examples):
1111
1162
### BEGIN: Perform validation every `validation_epochs` steps
1112
1163
if global_step % args .validation_steps == 0 :
1113
1164
if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1114
- logger .info (
1115
- f"Running validation... \n Generating { args .num_validation_images } images with prompt:"
1116
- f" { args .validation_prompt } ."
1117
- )
1118
-
1119
1165
# create pipeline
1120
1166
if args .use_ema :
1121
1167
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
@@ -1135,44 +1181,16 @@ def collate_fn(examples):
1135
1181
variant = args .variant ,
1136
1182
torch_dtype = weight_dtype ,
1137
1183
)
1138
- pipeline = pipeline .to (accelerator .device )
1139
- pipeline .set_progress_bar_config (disable = True )
1140
-
1141
- # run inference
1142
- # Save validation images
1143
- val_save_dir = os .path .join (args .output_dir , "validation_images" )
1144
- if not os .path .exists (val_save_dir ):
1145
- os .makedirs (val_save_dir )
1146
-
1147
- original_image = (
1148
- lambda image_url_or_path : load_image (image_url_or_path )
1149
- if urlparse (image_url_or_path ).scheme
1150
- else Image .open (image_url_or_path ).convert ("RGB" )
1151
- )(args .val_image_url_or_path )
1152
- with torch .autocast (
1153
- str (accelerator .device ).replace (":0" , "" ), enabled = accelerator .mixed_precision == "fp16"
1154
- ):
1155
- edited_images = []
1156
- for val_img_idx in range (args .num_validation_images ):
1157
- a_val_img = pipeline (
1158
- args .validation_prompt ,
1159
- image = original_image ,
1160
- num_inference_steps = 20 ,
1161
- image_guidance_scale = 1.5 ,
1162
- guidance_scale = 7 ,
1163
- generator = generator ,
1164
- ).images [0 ]
1165
- edited_images .append (a_val_img )
1166
- a_val_img .save (os .path .join (val_save_dir , f"step_{ global_step } _val_img_{ val_img_idx } .png" ))
1167
-
1168
- for tracker in accelerator .trackers :
1169
- if tracker .name == "wandb" :
1170
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1171
- for edited_image in edited_images :
1172
- wandb_table .add_data (
1173
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1174
- )
1175
- tracker .log ({"validation" : wandb_table })
1184
+
1185
+ log_validation (
1186
+ pipeline ,
1187
+ args ,
1188
+ accelerator ,
1189
+ generator ,
1190
+ global_step ,
1191
+ is_final_validation = False ,
1192
+ )
1193
+
1176
1194
if args .use_ema :
1177
1195
# Switch back to the original UNet parameters.
1178
1196
ema_unet .restore (unet .parameters ())
@@ -1187,7 +1205,6 @@ def collate_fn(examples):
1187
1205
# Create the pipeline using the trained modules and save it.
1188
1206
accelerator .wait_for_everyone ()
1189
1207
if accelerator .is_main_process :
1190
- unet = unwrap_model (unet )
1191
1208
if args .use_ema :
1192
1209
ema_unet .copy_to (unet .parameters ())
1193
1210
@@ -1198,10 +1215,11 @@ def collate_fn(examples):
1198
1215
tokenizer = tokenizer_1 ,
1199
1216
tokenizer_2 = tokenizer_2 ,
1200
1217
vae = vae ,
1201
- unet = unet ,
1218
+ unet = unwrap_model ( unet ) ,
1202
1219
revision = args .revision ,
1203
1220
variant = args .variant ,
1204
1221
)
1222
+
1205
1223
pipeline .save_pretrained (args .output_dir )
1206
1224
1207
1225
if args .push_to_hub :
@@ -1212,30 +1230,15 @@ def collate_fn(examples):
1212
1230
ignore_patterns = ["step_*" , "epoch_*" ],
1213
1231
)
1214
1232
1215
- if args .validation_prompt is not None :
1216
- edited_images = []
1217
- pipeline = pipeline .to (accelerator .device )
1218
- with torch .autocast (str (accelerator .device ).replace (":0" , "" )):
1219
- for _ in range (args .num_validation_images ):
1220
- edited_images .append (
1221
- pipeline (
1222
- args .validation_prompt ,
1223
- image = original_image ,
1224
- num_inference_steps = 20 ,
1225
- image_guidance_scale = 1.5 ,
1226
- guidance_scale = 7 ,
1227
- generator = generator ,
1228
- ).images [0 ]
1229
- )
1230
-
1231
- for tracker in accelerator .trackers :
1232
- if tracker .name == "wandb" :
1233
- wandb_table = wandb .Table (columns = WANDB_TABLE_COL_NAMES )
1234
- for edited_image in edited_images :
1235
- wandb_table .add_data (
1236
- wandb .Image (original_image ), wandb .Image (edited_image ), args .validation_prompt
1237
- )
1238
- tracker .log ({"test" : wandb_table })
1233
+ if (args .val_image_url_or_path is not None ) and (args .validation_prompt is not None ):
1234
+ log_validation (
1235
+ pipeline ,
1236
+ args ,
1237
+ accelerator ,
1238
+ generator ,
1239
+ global_step = None ,
1240
+ is_final_validation = True ,
1241
+ )
1239
1242
1240
1243
accelerator .end_training ()
1241
1244
0 commit comments