diff --git a/lib/bumblebee/text/m2m100.ex b/lib/bumblebee/text/m2m100.ex index 98503abf..5f064741 100644 --- a/lib/bumblebee/text/m2m100.ex +++ b/lib/bumblebee/text/m2m100.ex @@ -437,6 +437,8 @@ defmodule Bumblebee.Text.M2m100 do end defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do + position_ids = Nx.vectorize(position_ids, :batch) + size = opts[:size] half_size = div(size, 2) @@ -444,7 +446,9 @@ defmodule Bumblebee.Text.M2m100 do range = Nx.iota({half_size}) / (half_size - 1) inv_frequency = 1 / Nx.pow(base, range) angle = Nx.outer(position_ids, inv_frequency) - Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) + sin_cos = Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) + + Nx.devectorize(sin_cos) end defp decoder(