Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit c33be6e

Browse files
authored
Support concat-based conditioning in infererer (#416)
* add concat condition support * add concat condition support in get_likelihood * fix format * fix format * add unittests and fix bug
1 parent 72f0e5d commit c33be6e

File tree

3 files changed

+181
-5
lines changed

3 files changed

+181
-5
lines changed

generative/inferers/inferer.py

+45-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __call__(
4444
noise: torch.Tensor,
4545
timesteps: torch.Tensor,
4646
condition: torch.Tensor | None = None,
47+
mode: str = "crossattn",
4748
) -> torch.Tensor:
4849
"""
4950
Implements the forward pass for a supervised training iteration.
@@ -54,8 +55,15 @@ def __call__(
5455
noise: random noise, of the same shape as the input.
5556
timesteps: random timesteps.
5657
condition: Conditioning for network input.
58+
mode: Conditioning mode for the network.
5759
"""
60+
if mode not in ["crossattn", "concat"]:
61+
raise NotImplementedError(f"{mode} condition is not supported")
62+
5863
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
64+
if mode == "concat":
65+
noisy_image = torch.cat([noisy_image, condition], dim=1)
66+
condition = None
5967
prediction = diffusion_model(x=noisy_image, timesteps=timesteps, context=condition)
6068

6169
return prediction
@@ -69,6 +77,7 @@ def sample(
6977
save_intermediates: bool | None = False,
7078
intermediate_steps: int | None = 100,
7179
conditioning: torch.Tensor | None = None,
80+
mode: str = "crossattn",
7281
verbose: bool = True,
7382
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
7483
"""
@@ -79,8 +88,12 @@ def sample(
7988
save_intermediates: whether to return intermediates along the sampling change
8089
intermediate_steps: if save_intermediates is True, saves every n steps
8190
conditioning: Conditioning for network input.
91+
mode: Conditioning mode for the network.
8292
verbose: if true, prints the progression bar of the sampling process.
8393
"""
94+
if mode not in ["crossattn", "concat"]:
95+
raise NotImplementedError(f"{mode} condition is not supported")
96+
8497
if not scheduler:
8598
scheduler = self.scheduler
8699
image = input_noise
@@ -91,9 +104,15 @@ def sample(
91104
intermediates = []
92105
for t in progress_bar:
93106
# 1. predict noise model_output
94-
model_output = diffusion_model(
95-
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
96-
)
107+
if mode == "concat":
108+
model_input = torch.cat([image, conditioning], dim=1)
109+
model_output = diffusion_model(
110+
model_input, timesteps=torch.Tensor((t,)).to(input_noise.device), context=None
111+
)
112+
else:
113+
model_output = diffusion_model(
114+
image, timesteps=torch.Tensor((t,)).to(input_noise.device), context=conditioning
115+
)
97116

98117
# 2. compute previous image: x_t -> x_t-1
99118
image, _ = scheduler.step(model_output, t, image)
@@ -112,6 +131,7 @@ def get_likelihood(
112131
scheduler: Callable[..., torch.Tensor] | None = None,
113132
save_intermediates: bool | None = False,
114133
conditioning: torch.Tensor | None = None,
134+
mode: str = "crossattn",
115135
original_input_range: tuple | None = (0, 255),
116136
scaled_input_range: tuple | None = (0, 1),
117137
verbose: bool = True,
@@ -125,6 +145,7 @@ def get_likelihood(
125145
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
126146
save_intermediates: save the intermediate spatial KL maps
127147
conditioning: Conditioning for network input.
148+
mode: Conditioning mode for the network.
128149
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
129150
scaled_input_range: the [min,max] intensity range of the input data after scaling.
130151
verbose: if true, prints the progression bar of the sampling process.
@@ -137,6 +158,8 @@ def get_likelihood(
137158
f"Likelihood computation is only compatible with DDPMScheduler,"
138159
f" you are using {scheduler._get_name()}"
139160
)
161+
if mode not in ["crossattn", "concat"]:
162+
raise NotImplementedError(f"{mode} condition is not supported")
140163
if verbose and has_tqdm:
141164
progress_bar = tqdm(scheduler.timesteps)
142165
else:
@@ -147,7 +170,11 @@ def get_likelihood(
147170
for t in progress_bar:
148171
timesteps = torch.full(inputs.shape[:1], t, device=inputs.device).long()
149172
noisy_image = self.scheduler.add_noise(original_samples=inputs, noise=noise, timesteps=timesteps)
150-
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
173+
if mode == "concat":
174+
noisy_image = torch.cat([noisy_image, conditioning], dim=1)
175+
model_output = diffusion_model(noisy_image, timesteps=timesteps, context=None)
176+
else:
177+
model_output = diffusion_model(x=noisy_image, timesteps=timesteps, context=conditioning)
151178
# get the model's predicted mean, and variance if it is predicted
152179
if model_output.shape[1] == inputs.shape[1] * 2 and scheduler.variance_type in ["learned", "learned_range"]:
153180
model_output, predicted_variance = torch.split(model_output, inputs.shape[1], dim=1)
@@ -290,6 +317,7 @@ def __call__(
290317
noise: torch.Tensor,
291318
timesteps: torch.Tensor,
292319
condition: torch.Tensor | None = None,
320+
mode: str = "crossattn",
293321
) -> torch.Tensor:
294322
"""
295323
Implements the forward pass for a supervised training iteration.
@@ -301,12 +329,18 @@ def __call__(
301329
noise: random noise, of the same shape as the latent representation.
302330
timesteps: random timesteps.
303331
condition: conditioning for network input.
332+
mode: Conditioning mode for the network.
304333
"""
305334
with torch.no_grad():
306335
latent = autoencoder_model.encode_stage_2_inputs(inputs) * self.scale_factor
307336

308337
prediction = super().__call__(
309-
inputs=latent, diffusion_model=diffusion_model, noise=noise, timesteps=timesteps, condition=condition
338+
inputs=latent,
339+
diffusion_model=diffusion_model,
340+
noise=noise,
341+
timesteps=timesteps,
342+
condition=condition,
343+
mode=mode,
310344
)
311345

312346
return prediction
@@ -321,6 +355,7 @@ def sample(
321355
save_intermediates: bool | None = False,
322356
intermediate_steps: int | None = 100,
323357
conditioning: torch.Tensor | None = None,
358+
mode: str = "crossattn",
324359
verbose: bool = True,
325360
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
326361
"""
@@ -332,6 +367,7 @@ def sample(
332367
save_intermediates: whether to return intermediates along the sampling change
333368
intermediate_steps: if save_intermediates is True, saves every n steps
334369
conditioning: Conditioning for network input.
370+
mode: Conditioning mode for the network.
335371
verbose: if true, prints the progression bar of the sampling process.
336372
"""
337373
outputs = super().sample(
@@ -341,6 +377,7 @@ def sample(
341377
save_intermediates=save_intermediates,
342378
intermediate_steps=intermediate_steps,
343379
conditioning=conditioning,
380+
mode=mode,
344381
verbose=verbose,
345382
)
346383

@@ -369,6 +406,7 @@ def get_likelihood(
369406
scheduler: Callable[..., torch.Tensor] | None = None,
370407
save_intermediates: bool | None = False,
371408
conditioning: torch.Tensor | None = None,
409+
mode: str = "crossattn",
372410
original_input_range: tuple | None = (0, 255),
373411
scaled_input_range: tuple | None = (0, 1),
374412
verbose: bool = True,
@@ -385,6 +423,7 @@ def get_likelihood(
385423
scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
386424
save_intermediates: save the intermediate spatial KL maps
387425
conditioning: Conditioning for network input.
426+
mode: Conditioning mode for the network.
388427
original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
389428
scaled_input_range: the [min,max] intensity range of the input data after scaling.
390429
verbose: if true, prints the progression bar of the sampling process.
@@ -404,6 +443,7 @@ def get_likelihood(
404443
scheduler=scheduler,
405444
save_intermediates=save_intermediates,
406445
conditioning=conditioning,
446+
mode=mode,
407447
verbose=verbose,
408448
)
409449
if save_intermediates and resample_latent_likelihoods:

tests/test_diffusion_inferer.py

+56
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,62 @@ def test_normal_cdf(self):
161161
cdf_true = norm.cdf(x)
162162
torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)
163163

164+
@parameterized.expand(TEST_CASES)
165+
def test_sampler_conditioned_concat(self, model_params, input_shape):
166+
# copy the model_params dict to prevent from modifying test cases
167+
model_params = model_params.copy()
168+
n_concat_channel = 2
169+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
170+
model_params["cross_attention_dim"] = None
171+
model_params["with_conditioning"] = False
172+
model = DiffusionModelUNet(**model_params)
173+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
174+
model.to(device)
175+
model.eval()
176+
noise = torch.randn(input_shape).to(device)
177+
conditioning_shape = list(input_shape)
178+
conditioning_shape[1] = n_concat_channel
179+
conditioning = torch.randn(conditioning_shape).to(device)
180+
scheduler = DDIMScheduler(num_train_timesteps=1000)
181+
inferer = DiffusionInferer(scheduler=scheduler)
182+
scheduler.set_timesteps(num_inference_steps=10)
183+
sample, intermediates = inferer.sample(
184+
input_noise=noise,
185+
diffusion_model=model,
186+
scheduler=scheduler,
187+
save_intermediates=True,
188+
intermediate_steps=1,
189+
conditioning=conditioning,
190+
mode="concat",
191+
)
192+
self.assertEqual(len(intermediates), 10)
193+
194+
@parameterized.expand(TEST_CASES)
195+
def test_call_conditioned_concat(self, model_params, input_shape):
196+
# copy the model_params dict to prevent from modifying test cases
197+
model_params = model_params.copy()
198+
n_concat_channel = 2
199+
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
200+
model_params["cross_attention_dim"] = None
201+
model_params["with_conditioning"] = False
202+
model = DiffusionModelUNet(**model_params)
203+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
204+
model.to(device)
205+
model.eval()
206+
input = torch.randn(input_shape).to(device)
207+
noise = torch.randn(input_shape).to(device)
208+
conditioning_shape = list(input_shape)
209+
conditioning_shape[1] = n_concat_channel
210+
conditioning = torch.randn(conditioning_shape).to(device)
211+
scheduler = DDPMScheduler(num_train_timesteps=10)
212+
inferer = DiffusionInferer(scheduler=scheduler)
213+
scheduler.set_timesteps(num_inference_steps=10)
214+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
215+
sample = inferer(
216+
inputs=input, noise=noise, diffusion_model=model, timesteps=timesteps, condition=conditioning, mode="concat"
217+
)
218+
self.assertEqual(sample.shape, input_shape)
219+
164220

165221
if __name__ == "__main__":
166222
unittest.main()

tests/test_latent_diffusion_inferer.py

+80
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,86 @@ def test_resample_likelihoods(self, model_type, autoencoder_params, stage_2_para
245245
self.assertEqual(len(intermediates), 10)
246246
self.assertEqual(intermediates[0].shape[2:], input_shape[2:])
247247

248+
@parameterized.expand(TEST_CASES)
249+
def test_prediction_shape_conditioned_concat(
250+
self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape
251+
):
252+
if model_type == "AutoencoderKL":
253+
stage_1 = AutoencoderKL(**autoencoder_params)
254+
if model_type == "VQVAE":
255+
stage_1 = VQVAE(**autoencoder_params)
256+
257+
stage_2_params = stage_2_params.copy()
258+
n_concat_channel = 3
259+
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
260+
stage_2 = DiffusionModelUNet(**stage_2_params)
261+
262+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
263+
stage_1.to(device)
264+
stage_2.to(device)
265+
stage_1.eval()
266+
stage_2.eval()
267+
268+
input = torch.randn(input_shape).to(device)
269+
noise = torch.randn(latent_shape).to(device)
270+
conditioning_shape = list(latent_shape)
271+
conditioning_shape[1] = n_concat_channel
272+
conditioning = torch.randn(conditioning_shape).to(device)
273+
274+
scheduler = DDPMScheduler(num_train_timesteps=10)
275+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
276+
scheduler.set_timesteps(num_inference_steps=10)
277+
278+
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
279+
prediction = inferer(
280+
inputs=input,
281+
autoencoder_model=stage_1,
282+
diffusion_model=stage_2,
283+
noise=noise,
284+
timesteps=timesteps,
285+
condition=conditioning,
286+
mode="concat",
287+
)
288+
self.assertEqual(prediction.shape, latent_shape)
289+
290+
@parameterized.expand(TEST_CASES)
291+
def test_sample_shape_conditioned_concat(
292+
self, model_type, autoencoder_params, stage_2_params, input_shape, latent_shape
293+
):
294+
if model_type == "AutoencoderKL":
295+
stage_1 = AutoencoderKL(**autoencoder_params)
296+
if model_type == "VQVAE":
297+
stage_1 = VQVAE(**autoencoder_params)
298+
stage_2_params = stage_2_params.copy()
299+
n_concat_channel = 3
300+
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
301+
stage_2 = DiffusionModelUNet(**stage_2_params)
302+
303+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
304+
stage_1.to(device)
305+
stage_2.to(device)
306+
stage_1.eval()
307+
stage_2.eval()
308+
309+
noise = torch.randn(latent_shape).to(device)
310+
conditioning_shape = list(latent_shape)
311+
conditioning_shape[1] = n_concat_channel
312+
conditioning = torch.randn(conditioning_shape).to(device)
313+
314+
scheduler = DDPMScheduler(num_train_timesteps=10)
315+
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
316+
scheduler.set_timesteps(num_inference_steps=10)
317+
318+
sample = inferer.sample(
319+
input_noise=noise,
320+
autoencoder_model=stage_1,
321+
diffusion_model=stage_2,
322+
scheduler=scheduler,
323+
conditioning=conditioning,
324+
mode="concat",
325+
)
326+
self.assertEqual(sample.shape, input_shape)
327+
248328

249329
if __name__ == "__main__":
250330
unittest.main()

0 commit comments

Comments
 (0)