Skip to content

Commit

Permalink
🐛 fix device discrepance in evaluator
Browse files Browse the repository at this point in the history
  • Loading branch information
ferzcam committed Oct 16, 2024
1 parent 4c3a31f commit 3bb953b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions mowl/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(self, dataset, device="cpu", batch_size=16):
eval_heads, eval_tails = self.dataset.evaluation_classes

print(f"Number of evaluation classes: {len(eval_heads)}")
self.evaluation_heads = th.tensor([self.class_to_id[c] for c in eval_heads.as_str], dtype=th.long)
self.evaluation_tails = th.tensor([self.class_to_id[c] for c in eval_tails.as_str], dtype=th.long)
self.evaluation_heads = th.tensor([self.class_to_id[c] for c in eval_heads.as_str], dtype=th.long).to(self.device)
self.evaluation_tails = th.tensor([self.class_to_id[c] for c in eval_tails.as_str], dtype=th.long).to(self.device)


@property
Expand All @@ -71,6 +71,7 @@ def evaluate_base(self, model, eval_tuples, mode="test",
filter_deductive_closure=False,
**kwargs):

model = model.to(self.device)
num_heads, num_tails = len(self.evaluation_heads), len(self.evaluation_tails)
model.eval()
if not mode in ["valid", "test"]:
Expand Down

0 comments on commit 3bb953b

Please # to comment.