Skip to content

Commit da95a28

Browse files
authored
[Diffusion DPO] apply fixes from #6547 (#6668)
apply fixes from #6547
1 parent d66d554 commit da95a28

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

examples/research_projects/diffusion_dpo/train_diffusion_dpo_sdxl.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,10 @@ def preprocess_train(examples):
740740
# Resize.
741741
combined_im = train_resize(combined_im)
742742

743+
# Flipping.
744+
if not args.no_flip and random.random() < 0.5:
745+
combined_im = train_flip(combined_im)
746+
743747
# Cropping.
744748
if not args.random_crop:
745749
y1 = max(0, int(round((combined_im.shape[1] - args.resolution) / 2.0)))
@@ -749,11 +753,6 @@ def preprocess_train(examples):
749753
y1, x1, h, w = train_crop.get_params(combined_im, (args.resolution, args.resolution))
750754
combined_im = crop(combined_im, y1, x1, h, w)
751755

752-
# Flipping.
753-
if random.random() < 0.5:
754-
x1 = combined_im.shape[2] - x1
755-
combined_im = train_flip(combined_im)
756-
757756
crop_top_left = (y1, x1)
758757
crop_top_lefts.append(crop_top_left)
759758
combined_im = normalize(combined_im)

0 commit comments

Comments
 (0)