diff --git a/lib/bumblebee/text/generation.ex b/lib/bumblebee/text/generation.ex index 274b56ad..42d50e62 100644 --- a/lib/bumblebee/text/generation.ex +++ b/lib/bumblebee/text/generation.ex @@ -290,7 +290,11 @@ defmodule Bumblebee.Text.Generation do position_ids = attention_mask |> Nx.cumulative_sum(axis: 1) - |> Nx.subtract(1) + # Position ids are zero-indexed, so we want to subtract 1. + # However, attention mask may have zeros on the left, in which + # case cumulative sum leaves zeros there as well and we don't + # want to subtract from these. + |> Nx.subtract(Nx.select(attention_mask, 1, 0)) inputs = inputs