From 8c489cb82d7bd236a8603965fa043cb64e972de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Mon, 2 Dec 2024 15:17:39 +0800 Subject: [PATCH] Avoid out-of-bound position ids with left padded generation input --- lib/bumblebee/text/generation.ex | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 274b56ad..42d50e62 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -290,7 +290,11 @@ defmodule Bumblebee.Text.Generation do position_ids = attention_mask |> Nx.cumulative_sum(axis: 1) - |> Nx.subtract(1) + # Position ids are zero-indexed, so we want to subtract 1. + # However, attention mask may have zeros on the left, in which + # case cumulative sum leaves zeros there as well and we don't + # want to subtract from these. + |> Nx.subtract(Nx.select(attention_mask, 1, 0)) inputs = inputs