Skip to content

Commit

Permalink
Add doc about attention_mask on gpt2 (#16829)
Browse files Browse the repository at this point in the history
* Add doc about `attention_mask` on gpt2

Add a simple sentence describing how `attention_mask` needs to be constructed when ``past_key_values` is used.

* Add doc about attention_mask on gpt2_tf

* clean up style

* remove empty line white spaces

* remove whitespace in empty line
  • Loading branch information
wiio12 authored Apr 19, 2022
1 parent b96e82c commit 7481457
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,10 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
`past_key_values`. In other words, the `attention_mask` always has to have the length:
`len(past_key_values) + len(input_ids)`
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/gpt2/modeling_tf_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,10 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
`past_key_values`. In other words, the `attention_mask` always has to have the length:
`len(past_key_values) + len(input_ids)`
[What are attention masks?](../glossary#attention-mask)
token_type_ids (`tf.Tensor` or `Numpy array` of shape `(batch_size, sequence_length)`, *optional*):
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
Expand Down

0 comments on commit 7481457

Please # to comment.