Skip to content

Commit 61d96c3

Browse files
authored
refactor rotary embedding 3: so it is not on cpu (#9307)
change get_1d_rotary to accept pos as torch tensors
1 parent 4f495b0 commit 61d96c3

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/diffusers/models/embeddings.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,14 @@ def get_1d_rotary_pos_embed(
545545
assert dim % 2 == 0
546546

547547
if isinstance(pos, int):
548-
pos = np.arange(pos)
548+
pos = torch.arange(pos)
549+
if isinstance(pos, np.ndarray):
550+
pos = torch.from_numpy(pos) # type: ignore # [S]
551+
549552
theta = theta * ntk_factor
550553
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2]
551-
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
552-
freqs = torch.outer(t, freqs) # type: ignore # [S, D/2]
554+
freqs = freqs.to(pos.device)
555+
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
553556
if use_real and repeat_interleave_real:
554557
# flux, hunyuan-dit, cogvideox
555558
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
@@ -626,7 +629,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
626629
n_axes = ids.shape[-1]
627630
cos_out = []
628631
sin_out = []
629-
pos = ids.squeeze().float().cpu().numpy()
632+
pos = ids.squeeze().float()
630633
is_mps = ids.device.type == "mps"
631634
freqs_dtype = torch.float32 if is_mps else torch.float64
632635
for i in range(n_axes):

0 commit comments

Comments
 (0)