Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Label-wise metrics (Accuracy etc.) for multi-label problems #513

Open
jphdotam opened this issue May 2, 2019 · 4 comments · May be fixed by #542
Open

Label-wise metrics (Accuracy etc.) for multi-label problems #513

jphdotam opened this issue May 2, 2019 · 4 comments · May be fixed by #542

Comments

@jphdotam
Copy link

jphdotam commented May 2, 2019

Hi,

I've made a multi-label classifier using BCEWithLogitsLoss. In summary a data sample can be one of 3 binary classes, which aren't mutually eclusive, so y_pred and y can look something like [0, 1, 1].

My metrics include Accuracy(output_transform=thresholded_output_transform, is_multilabel=True) and Precision(output_transform=thresholded_output_transform, is_multilabel=True, average=True)}.

However, I'm interesting in having label-specific metrics (i.e. having 3 accuracies etc.). This is important because it allows me to see what labels are compromising my overall accuracy the most (a 70% accuracy be a 30% error in a single label, or a more modest error scattered across 3 labels).

There is no option to disable averaging for Accuracy() as with the others, and setting average=False for Precision() does not do what I expected (it yields a binary result per datum, not per label, so I end up with a tensor of size 500, not 3, if my dataset n=500).

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

@jphdotam jphdotam changed the title Class-wise metrics (Accuracy etc.) for multi-label problems Label-wise metrics (Accuracy etc.) for multi-label problems May 2, 2019
@vfdev-5
Copy link
Collaborator

vfdev-5 commented May 2, 2019

@jphdotam thanks for the feedback! You are correct, multi-label case is always averaged for now for Accuracy, Precision, Recall.

Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?

There is an issue with a similar requirement #467
For instance we have not much bandwidth to work on that. If you can send a PR for that, we'll be awesome.

P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?

You can find a link for that here : https://pytorch.org/resources

@jphdotam
Copy link
Author

jphdotam commented May 2, 2019

Many thanks, I've made a pull request here: #516

I'm quite new to working on large projects so apologies if I have gone about this inappropriately.

@jphdotam
Copy link
Author

jphdotam commented May 3, 2019

In the mean time whilst the core team decide how best to implement this, this is a custom class I've made for the task which inherits from Accuracy:

class LabelwiseAccuracy(Accuracy):
    def __init__(self, output_transform=lambda x: x):
        self._num_correct = None
        self._num_examples = None
        super(LabelwiseAccuracy, self).__init__(output_transform=output_transform)

    def reset(self):
        self._num_correct = None
        self._num_examples = 0
        super(LabelwiseAccuracy, self).reset()

    def update(self, output):

        y_pred, y = self._check_shape(output)
        self._check_type((y_pred, y))

        num_classes = y_pred.size(1)
        last_dim = y_pred.ndimension()
        y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
        y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
        correct_exact = torch.all(y == y_pred.type_as(y), dim=-1)  # Sample-wise
        correct_elementwise = torch.sum(y == y_pred.type_as(y), dim=0)

        if self._num_correct is not None:
            self._num_correct = torch.add(self._num_correct,
                                                    correct_elementwise)
        else:
            self._num_correct = correct_elementwise
        self._num_examples += correct_exact.shape[0]

    def compute(self):
        if self._num_examples == 0:
            raise NotComputableError('Accuracy must have at least one example before it can be computed.')
        return self._num_correct.type(torch.float) / self._num_examples

@crypdick
Copy link

crypdick commented Mar 7, 2020

For anyone trying to use @jphdotam code in #513 (comment) ,

y_pred, y = self._check_shape(output)

throws an exception because that function now returns nothing. Instead, use

self._check_shape(output)
y_pred, y = output

However, there's something wrong with it because I'm getting 'labelwise_accuracy': [0.9070000648498535, 0.8530000448226929, 0.8370000123977661, 0.7450000643730164, 0.8720000386238098, 0.7570000290870667, 0.9860000610351562, 0.9190000295639038, 0.8740000128746033] when 'avg_accuracy': 0.285

Edit: nvm, I stepped thru the code and it was fine. The bug was on my end. Cheers!

@vfdev-5 vfdev-5 added PyDataGlobal PyData Global 2020 Sprint and removed Hacktoberfest labels Oct 31, 2020
@vfdev-5 vfdev-5 removed the PyDataGlobal PyData Global 2020 Sprint label Dec 14, 2020
@vfdev-5 vfdev-5 added the module: metrics Metrics module label Jan 18, 2021
# for free to join this conversation on GitHub. Already have an account? # to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants