Skip to content

Commit 785de0b

Browse files
authored
Adding to categorical function
Normally the PyTorch labels for classification are in integer numbers associated with classes. For the MixMatch algorithm, the labels need to be in one hot format. The to_categorical function was taken from Keras code, is able to handle batches. num_clases=10 for future CIFAR test. Inside of the algorithm y is converted to one hot encode format. At the end of the MixMatch phase, the labels p and q are converted back to the natural classes format in PyTorch
1 parent dfe5c3e commit 785de0b

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

layer.py

+26
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,31 @@ def label_guessing(model, ub, K):
4444

4545
return (sum/K).cpu().detach().numpy()
4646

47+
def to_categorical(y, num_classes=None, dtype='float32'):
48+
"""Converts a class vector (integers) to binary class matrix.
49+
E.g. for use with categorical_crossentropy.
50+
Taken from Keras code
51+
https://github.com/keras-team/keras/blob/master/keras/utils/np_utils.py#L9
52+
"""
53+
54+
y = np.array(y, dtype='int')
55+
input_shape = y.shape
56+
if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:
57+
input_shape = tuple(input_shape[:-1])
58+
y = y.ravel()
59+
if not num_classes:
60+
num_classes = np.max(y) + 1
61+
n = y.shape[0]
62+
categorical = np.zeros((n, num_classes), dtype=dtype)
63+
categorical[np.arange(n), y] = 1
64+
output_shape = input_shape + (num_classes,)
65+
categorical = np.reshape(categorical, output_shape)
66+
return categorical
67+
4768

4869
def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
4970
xb = augment_fn(x)
71+
y = to_categorical(y, num_classes=10) # Converting to one hot encode, num_clases=10 for future CIFAR test
5072
ub = [augment_fn(u) for _ in range(K)]
5173
avg_probs = label_guessing(model, ub, K)
5274
qb = sharpen(avg_probs, T)
@@ -57,6 +79,10 @@ def mixmatch(x, y, u, model, augment_fn, T=0.5, K=2, alpha=0.75):
5779
Wy = np.concatenate([qb, y], axis=0)[indices]
5880
X, p = mixup_mod(xb, Wx[:len(xb)], y, Wy[:len(xb)], alpha)
5981
U, q = mixup_mod(Ux, Wx[len(xb):], Uy, Wy[len(xb):], alpha)
82+
83+
# One hot decode for PyTorch labels compability
84+
p = p.argmax(axis=1)
85+
q = q.argmax(axis=1)
6086
return X, p, U, q
6187

6288

0 commit comments

Comments
 (0)