From 29bdfa48cfa6911ae24bbeaeeb6e9f892eb01851 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Tue, 19 Nov 2024 17:15:03 +0800 Subject: [PATCH] Fix M2M100 with batched input --- lib/bumblebee/text/m2m100.ex | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/bumblebee/text/m2m100.ex b/lib/bumblebee/text/m2m100.ex index 98503abf..5f064741 100644 --- a/lib/bumblebee/text/m2m100.ex +++ b/lib/bumblebee/text/m2m100.ex @@ -437,6 +437,8 @@ defmodule Bumblebee.Text.M2m100 do end defnp sinusoidal_position_embedding_impl(position_ids, opts \\ []) do + position_ids = Nx.vectorize(position_ids, :batch) + size = opts[:size] half_size = div(size, 2) @@ -444,7 +446,9 @@ defmodule Bumblebee.Text.M2m100 do range = Nx.iota({half_size}) / (half_size - 1) inv_frequency = 1 / Nx.pow(base, range) angle = Nx.outer(position_ids, inv_frequency) - Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) + sin_cos = Nx.concatenate([Nx.sin(angle), Nx.cos(angle)], axis: -1) + + Nx.devectorize(sin_cos) end defp decoder(