@@ -44,6 +44,7 @@ def __call__(
44
44
noise : torch .Tensor ,
45
45
timesteps : torch .Tensor ,
46
46
condition : torch .Tensor | None = None ,
47
+ mode : str = "crossattn" ,
47
48
) -> torch .Tensor :
48
49
"""
49
50
Implements the forward pass for a supervised training iteration.
@@ -54,8 +55,15 @@ def __call__(
54
55
noise: random noise, of the same shape as the input.
55
56
timesteps: random timesteps.
56
57
condition: Conditioning for network input.
58
+ mode: Conditioning mode for the network.
57
59
"""
60
+ if mode not in ["crossattn" , "concat" ]:
61
+ raise NotImplementedError (f"{ mode } condition is not supported" )
62
+
58
63
noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
64
+ if mode == "concat" :
65
+ noisy_image = torch .cat ([noisy_image , condition ], dim = 1 )
66
+ condition = None
59
67
prediction = diffusion_model (x = noisy_image , timesteps = timesteps , context = condition )
60
68
61
69
return prediction
@@ -69,6 +77,7 @@ def sample(
69
77
save_intermediates : bool | None = False ,
70
78
intermediate_steps : int | None = 100 ,
71
79
conditioning : torch .Tensor | None = None ,
80
+ mode : str = "crossattn" ,
72
81
verbose : bool = True ,
73
82
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
74
83
"""
@@ -79,8 +88,12 @@ def sample(
79
88
save_intermediates: whether to return intermediates along the sampling change
80
89
intermediate_steps: if save_intermediates is True, saves every n steps
81
90
conditioning: Conditioning for network input.
91
+ mode: Conditioning mode for the network.
82
92
verbose: if true, prints the progression bar of the sampling process.
83
93
"""
94
+ if mode not in ["crossattn" , "concat" ]:
95
+ raise NotImplementedError (f"{ mode } condition is not supported" )
96
+
84
97
if not scheduler :
85
98
scheduler = self .scheduler
86
99
image = input_noise
@@ -91,9 +104,15 @@ def sample(
91
104
intermediates = []
92
105
for t in progress_bar :
93
106
# 1. predict noise model_output
94
- model_output = diffusion_model (
95
- image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = conditioning
96
- )
107
+ if mode == "concat" :
108
+ model_input = torch .cat ([image , conditioning ], dim = 1 )
109
+ model_output = diffusion_model (
110
+ model_input , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = None
111
+ )
112
+ else :
113
+ model_output = diffusion_model (
114
+ image , timesteps = torch .Tensor ((t ,)).to (input_noise .device ), context = conditioning
115
+ )
97
116
98
117
# 2. compute previous image: x_t -> x_t-1
99
118
image , _ = scheduler .step (model_output , t , image )
@@ -112,6 +131,7 @@ def get_likelihood(
112
131
scheduler : Callable [..., torch .Tensor ] | None = None ,
113
132
save_intermediates : bool | None = False ,
114
133
conditioning : torch .Tensor | None = None ,
134
+ mode : str = "crossattn" ,
115
135
original_input_range : tuple | None = (0 , 255 ),
116
136
scaled_input_range : tuple | None = (0 , 1 ),
117
137
verbose : bool = True ,
@@ -125,6 +145,7 @@ def get_likelihood(
125
145
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
126
146
save_intermediates: save the intermediate spatial KL maps
127
147
conditioning: Conditioning for network input.
148
+ mode: Conditioning mode for the network.
128
149
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
129
150
scaled_input_range: the [min,max] intensity range of the input data after scaling.
130
151
verbose: if true, prints the progression bar of the sampling process.
@@ -137,6 +158,8 @@ def get_likelihood(
137
158
f"Likelihood computation is only compatible with DDPMScheduler,"
138
159
f" you are using { scheduler ._get_name ()} "
139
160
)
161
+ if mode not in ["crossattn" , "concat" ]:
162
+ raise NotImplementedError (f"{ mode } condition is not supported" )
140
163
if verbose and has_tqdm :
141
164
progress_bar = tqdm (scheduler .timesteps )
142
165
else :
@@ -147,7 +170,11 @@ def get_likelihood(
147
170
for t in progress_bar :
148
171
timesteps = torch .full (inputs .shape [:1 ], t , device = inputs .device ).long ()
149
172
noisy_image = self .scheduler .add_noise (original_samples = inputs , noise = noise , timesteps = timesteps )
150
- model_output = diffusion_model (x = noisy_image , timesteps = timesteps , context = conditioning )
173
+ if mode == "concat" :
174
+ noisy_image = torch .cat ([noisy_image , conditioning ], dim = 1 )
175
+ model_output = diffusion_model (noisy_image , timesteps = timesteps , context = None )
176
+ else :
177
+ model_output = diffusion_model (x = noisy_image , timesteps = timesteps , context = conditioning )
151
178
# get the model's predicted mean, and variance if it is predicted
152
179
if model_output .shape [1 ] == inputs .shape [1 ] * 2 and scheduler .variance_type in ["learned" , "learned_range" ]:
153
180
model_output , predicted_variance = torch .split (model_output , inputs .shape [1 ], dim = 1 )
@@ -290,6 +317,7 @@ def __call__(
290
317
noise : torch .Tensor ,
291
318
timesteps : torch .Tensor ,
292
319
condition : torch .Tensor | None = None ,
320
+ mode : str = "crossattn" ,
293
321
) -> torch .Tensor :
294
322
"""
295
323
Implements the forward pass for a supervised training iteration.
@@ -301,12 +329,18 @@ def __call__(
301
329
noise: random noise, of the same shape as the latent representation.
302
330
timesteps: random timesteps.
303
331
condition: conditioning for network input.
332
+ mode: Conditioning mode for the network.
304
333
"""
305
334
with torch .no_grad ():
306
335
latent = autoencoder_model .encode_stage_2_inputs (inputs ) * self .scale_factor
307
336
308
337
prediction = super ().__call__ (
309
- inputs = latent , diffusion_model = diffusion_model , noise = noise , timesteps = timesteps , condition = condition
338
+ inputs = latent ,
339
+ diffusion_model = diffusion_model ,
340
+ noise = noise ,
341
+ timesteps = timesteps ,
342
+ condition = condition ,
343
+ mode = mode ,
310
344
)
311
345
312
346
return prediction
@@ -321,6 +355,7 @@ def sample(
321
355
save_intermediates : bool | None = False ,
322
356
intermediate_steps : int | None = 100 ,
323
357
conditioning : torch .Tensor | None = None ,
358
+ mode : str = "crossattn" ,
324
359
verbose : bool = True ,
325
360
) -> torch .Tensor | tuple [torch .Tensor , list [torch .Tensor ]]:
326
361
"""
@@ -332,6 +367,7 @@ def sample(
332
367
save_intermediates: whether to return intermediates along the sampling change
333
368
intermediate_steps: if save_intermediates is True, saves every n steps
334
369
conditioning: Conditioning for network input.
370
+ mode: Conditioning mode for the network.
335
371
verbose: if true, prints the progression bar of the sampling process.
336
372
"""
337
373
outputs = super ().sample (
@@ -341,6 +377,7 @@ def sample(
341
377
save_intermediates = save_intermediates ,
342
378
intermediate_steps = intermediate_steps ,
343
379
conditioning = conditioning ,
380
+ mode = mode ,
344
381
verbose = verbose ,
345
382
)
346
383
@@ -369,6 +406,7 @@ def get_likelihood(
369
406
scheduler : Callable [..., torch .Tensor ] | None = None ,
370
407
save_intermediates : bool | None = False ,
371
408
conditioning : torch .Tensor | None = None ,
409
+ mode : str = "crossattn" ,
372
410
original_input_range : tuple | None = (0 , 255 ),
373
411
scaled_input_range : tuple | None = (0 , 1 ),
374
412
verbose : bool = True ,
@@ -385,6 +423,7 @@ def get_likelihood(
385
423
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
386
424
save_intermediates: save the intermediate spatial KL maps
387
425
conditioning: Conditioning for network input.
426
+ mode: Conditioning mode for the network.
388
427
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
389
428
scaled_input_range: the [min,max] intensity range of the input data after scaling.
390
429
verbose: if true, prints the progression bar of the sampling process.
@@ -404,6 +443,7 @@ def get_likelihood(
404
443
scheduler = scheduler ,
405
444
save_intermediates = save_intermediates ,
406
445
conditioning = conditioning ,
446
+ mode = mode ,
407
447
verbose = verbose ,
408
448
)
409
449
if save_intermediates and resample_latent_likelihoods :
0 commit comments