-
-
Notifications
You must be signed in to change notification settings - Fork 620
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Label-wise metrics (Accuracy etc.) for multi-label problems #513
Comments
@jphdotam thanks for the feedback! You are correct, multi-label case is always averaged for now for Accuracy, Precision, Recall.
There is an issue with a similar requirement #467
You can find a link for that here : https://pytorch.org/resources |
Many thanks, I've made a pull request here: #516 I'm quite new to working on large projects so apologies if I have gone about this inappropriately. |
In the mean time whilst the core team decide how best to implement this, this is a custom class I've made for the task which inherits from Accuracy: class LabelwiseAccuracy(Accuracy):
def __init__(self, output_transform=lambda x: x):
self._num_correct = None
self._num_examples = None
super(LabelwiseAccuracy, self).__init__(output_transform=output_transform)
def reset(self):
self._num_correct = None
self._num_examples = 0
super(LabelwiseAccuracy, self).reset()
def update(self, output):
y_pred, y = self._check_shape(output)
self._check_type((y_pred, y))
num_classes = y_pred.size(1)
last_dim = y_pred.ndimension()
y_pred = torch.transpose(y_pred, 1, last_dim - 1).reshape(-1, num_classes)
y = torch.transpose(y, 1, last_dim - 1).reshape(-1, num_classes)
correct_exact = torch.all(y == y_pred.type_as(y), dim=-1) # Sample-wise
correct_elementwise = torch.sum(y == y_pred.type_as(y), dim=0)
if self._num_correct is not None:
self._num_correct = torch.add(self._num_correct,
correct_elementwise)
else:
self._num_correct = correct_elementwise
self._num_examples += correct_exact.shape[0]
def compute(self):
if self._num_examples == 0:
raise NotComputableError('Accuracy must have at least one example before it can be computed.')
return self._num_correct.type(torch.float) / self._num_examples |
For anyone trying to use @jphdotam code in #513 (comment) ,
throws an exception because that function now returns nothing. Instead, use
However, there's something wrong with it because I'm getting Edit: nvm, I stepped thru the code and it was fine. The bug was on my end. Cheers! |
Hi,
I've made a multi-label classifier using
BCEWithLogitsLoss
. In summary a data sample can be one of 3 binary classes, which aren't mutually eclusive, so y_pred and y can look something like [0, 1, 1].My metrics include
Accuracy(output_transform=thresholded_output_transform, is_multilabel=True)
andPrecision(output_transform=thresholded_output_transform, is_multilabel=True, average=True)}
.However, I'm interesting in having label-specific metrics (i.e. having 3 accuracies etc.). This is important because it allows me to see what labels are compromising my overall accuracy the most (a 70% accuracy be a 30% error in a single label, or a more modest error scattered across 3 labels).
There is no option to disable averaging for
Accuracy()
as with the others, and settingaverage=False
forPrecision()
does not do what I expected (it yields a binary result per datum, not per label, so I end up with a tensor of size 500, not 3, if my dataset n=500).Is there a way to get label-wise metrics in mutlilabel problems? Or a plan to introduce it?
P.S. I'd love to get an invite to the slack workspace if possible? How do I go about doing that?
The text was updated successfully, but these errors were encountered: