Skip to content

Commit 84e94d7

Browse files
committed
match model forward api
1 parent 8d1a17c commit 84e94d7

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/diffusers/models/unet_rl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,15 +130,15 @@ def __init__(
130130
nn.Conv1d(dim, transition_dim, 1),
131131
)
132132

133-
def forward(self, sample, timesteps):
133+
def forward(self, sample, timestep):
134134
"""
135135
x : [ batch x horizon x transition ]
136136
"""
137137
x = sample
138138

139139
x = x.permute(0, 2, 1)
140140

141-
t = self.time_mlp(timesteps)
141+
t = self.time_mlp(timestep)
142142
h = []
143143

144144
for resnet, resnet2, downsample in self.downs:

tests/test_modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ def dummy_input(self):
585585
noise = floats_tensor((batch_size, seq_len, num_features)).to(torch_device)
586586
time_step = torch.tensor([10] * batch_size).to(torch_device)
587587

588-
return {"sample": noise, "timesteps": time_step}
588+
return {"sample": noise, "timestep": time_step}
589589

590590
@property
591591
def input_shape(self):

0 commit comments

Comments
 (0)