Skip to content

Commit 793664a

Browse files
committed
[test] Add tests for get_output_properties
1 parent 91bcd37 commit 793664a

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
3+
import pytest
4+
5+
from autoPyTorch.datasets.base_dataset import _get_output_properties
6+
7+
8+
@pytest.mark.parametrize(
9+
"target_labels,dim,task_type", (
10+
(np.arange(5), 5, "multiclass"),
11+
(np.linspace(0, 1, 3), 1, "continuous"),
12+
(np.linspace(0, 1, 3)[:, np.newaxis], 1, "continuous")
13+
)
14+
)
15+
def test_get_output_properties(target_labels, dim, task_type):
16+
train_tensors = np.vstack([np.empty_like(target_labels), target_labels])
17+
output_dim, output_type = _get_output_properties(train_tensors)
18+
assert output_dim == dim
19+
assert output_type == task_type

0 commit comments

Comments
 (0)