Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Sorting requirement for __include_gt__ #21

Open
Nifury opened this issue Oct 13, 2023 · 0 comments
Open

Sorting requirement for __include_gt__ #21

Nifury opened this issue Oct 13, 2023 · 0 comments

Comments

@Nifury
Copy link

Nifury commented Oct 13, 2023

Hello, I notice that if the ground truth y[0] is not sorted, __include_gt__ does not behave properly.
It might be worth mentioning this in the documentation.

Code to reproduce:

def test(self):
    h_s = torch.randn(1, 10, 20)
    h_t = torch.randn(1, 10, 20)
    s_mask = torch.ones(1, 10, dtype=torch.bool)
    y = torch.as_tensor([[2, 0, 1], [3, 4, 5]])
    # make sure top k doesn't include ground truth
    h_s[0, y[0]] = 100
    h_t[0, y[1]] = -100
    self.k = 1
    S_idx = self.__top_k__(h_s, h_t)
    S_rnd_idx = torch.zeros(1, 10, 1, dtype=torch.long)
    S_idx = torch.cat([S_idx, S_rnd_idx], dim=-1)
    S_idx = self.__include_gt__(S_idx, s_mask, y)
    mask = S_idx[0, y[0]] == y[1].view(-1, 1)
    print(mask.any(dim=-1))
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant