Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

AssertionError assert(len(tmp_ids) == 1) when using RoBERTa #5

Open
steveazzolin opened this issue Jun 9, 2023 · 0 comments
Open

Comments

@steveazzolin
Copy link

Hi,

I was trying to run scripts/run_optiprompt.pt with the model roberta-base when I encountered the following problem:

Traceback (most recent call last):
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/run_optiprompt.py", line 176, in
best_result, result_rel = evaluate(model, valid_samples_batches, valid_sentences_batches, filter_indices, index_list)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/utils.py", line 110, in evaluate
log_probs, cor_b, tot_b, pred_b, topk_preds, loss, common_vocab_loss = model.run_batch(sentences_b, samples_b, training=False, filter_indices=filter_indices, index_list=index_list, vocab_to_common_vocab=vocab_to_common_vocab)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 345, in run_batch
tokens_tensor, segments_tensor, attention_mask_tensor, masked_indices_list, tokenized_text_list, mlm_labels_tensor, mlm_label_ids = self._get_input_tensors_batch_train(sentences_list, samples_list)
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 151, in _get_input_tensors_batch_train
tokens_tensor, segments_tensor, masked_indices, tokenized_text, mlm_labels_tensor, mlm_label_id = self.__get_input_tensors(sentences, mlm_label=samples['obj_label'])
File "/nfs/data_chaos/sazzolin/OptiPrompt/code/models.py", line 298, in __get_input_tensors
assert(len(tmp_ids) == 1)
AssertionError

Nonetheless, the code works just fine when using bert-base-cased.

This is the complete python call:

python code/run_optiprompt.py \
            --relation_profile relation_metainfo/LAMA_relations.jsonl \
            --relation ${REL} \
            --common_vocab_filename common_vocabs/common_vocab_cased.txt \
            --model_name roberta_base \
            --do_train \
            --train_data data/autoprompt_data/${REL}/train.jsonl \
            --dev_data data/autoprompt_data/${REL}/dev.jsonl \
            --do_eval \
            --test_data data/LAMA-TREx/${REL}.jsonl \
            --output_dir ${DIR} \
            --random_init none \
            --seed ${SEED} \
            --output_predictions 
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant