diff --git a/tokenlearn/pretrain.py b/tokenlearn/pretrain.py index 4d70d12..0871aaf 100644 --- a/tokenlearn/pretrain.py +++ b/tokenlearn/pretrain.py @@ -179,10 +179,11 @@ def train_supervised( # noqa: C901 train_dataloader = train_dataset.to_dataloader(shuffle=True, batch_size=batch_size) # Initialize the model + pad_id = model.tokenizer.token_to_id("[PAD]") or model.tokenizer.token_to_id("") trainable_model = StaticModelFineTuner( torch.from_numpy(model.embedding), out_dim=train_dataset.targets.shape[1], - pad_id=model.tokenizer.token_to_id("[PAD]"), + pad_id=pad_id, ) trainable_model.to(device)