Skip to content

Commit

Permalink
Avoid out-of-bound position ids with left padded generation input
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko committed Dec 2, 2024
1 parent b30ad82 commit 8c489cb
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion lib/bumblebee/text/generation.ex
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,11 @@ defmodule Bumblebee.Text.Generation do
position_ids =
attention_mask
|> Nx.cumulative_sum(axis: 1)
|> Nx.subtract(1)
# Position ids are zero-indexed, so we want to subtract 1.
# However, attention mask may have zeros on the left, in which
# case cumulative sum leaves zeros there as well and we don't
# want to subtract from these.
|> Nx.subtract(Nx.select(attention_mask, 1, 0))

inputs =
inputs
Expand Down

0 comments on commit 8c489cb

Please # to comment.