Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Tokenizers setter of ids of special tokens don't work #16660

Closed
2 of 4 tasks
davidleonfdez opened this issue Apr 8, 2022 · 1 comment · Fixed by #16661
Closed
2 of 4 tasks

Tokenizers setter of ids of special tokens don't work #16660

davidleonfdez opened this issue Apr 8, 2022 · 1 comment · Fixed by #16661

Comments

@davidleonfdez
Copy link
Contributor

Environment info

  • transformers version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?):
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help

Information

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Create an instance of a pretrained tokenizer
  2. Try to set the pad_token_id

For instance:

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token_id = tokenizer.eos_token_id

Output:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_33/1516894257.py in <module>
      1 tokenizer = AutoTokenizer.from_pretrained('gpt2')
----> 2 tokenizer.pad_token_id = tokenizer.eos_token_id

/opt/conda/lib/python3.7/site-packages/transformers/tokenization_utils_base.py in pad_token_id(self, value)
   1173     @pad_token_id.setter
   1174     def pad_token_id(self, value):
-> 1175         self._pad_token = self.convert_tokens_to_ids(value)
   1176 
   1177     @cls_token_id.setter

/opt/conda/lib/python3.7/site-packages/transformers/tokenization_utils_fast.py in convert_tokens_to_ids(self, tokens)
    248 
    249         ids = []
--> 250         for token in tokens:
    251             ids.append(self._convert_token_to_id_with_added_voc(token))
    252         return ids

TypeError: 'int' object is not iterable

Expected behavior

Set the pad_token appropriately.

I've fixed this in a branch and I'm submitting a PR.

@SaulLu
Copy link
Contributor

SaulLu commented Apr 8, 2022

Thank you very much for sharing this problem (and a solution)! 🤗
You are right, this behaviour is not desirable

sgugger pushed a commit that referenced this issue Apr 13, 2022
* Fix setters of *_token_id properties of SpecialTokensMixin

* Test setters of common tokens ids

* Move to a separate test checks of setters of tokens ids

* Add independent test for ByT5

* Add Canine test

* Test speech to text
elusenji pushed a commit to elusenji/transformers that referenced this issue Jun 12, 2022
…uggingface#16661)

* Fix setters of *_token_id properties of SpecialTokensMixin

* Test setters of common tokens ids

* Move to a separate test checks of setters of tokens ids

* Add independent test for ByT5

* Add Canine test

* Test speech to text
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants