diff --git a/deepvariant/modeling.py b/deepvariant/modeling.py index 4473dafb..e69f1d71 100644 --- a/deepvariant/modeling.py +++ b/deepvariant/modeling.py @@ -117,18 +117,19 @@ class UnsupportedImageDimensionsError(Exception): def binarize(labels, target_class): """Binarize labels and predictions. - The labels that are not equal to target_class parameter are set to zero. + The labels that are equal to target_class parameter are set to 0, else + set to 1. Args: labels: the ground-truth labels for the examples. - target_class: index of the class that is left as non-zero. + target_class: index of the class that is left as zero. Returns: Tensor of the same shape as labels. """ labels_binary = tf.compat.v1.where( tf.equal(labels, tf.constant(target_class, dtype=tf.int64)), - tf.zeros_like(labels), labels) + tf.zeros_like(labels), tf.ones_like(labels)) return labels_binary diff --git a/deepvariant/modeling_test.py b/deepvariant/modeling_test.py index df83dbce..c8bac046 100644 --- a/deepvariant/modeling_test.py +++ b/deepvariant/modeling_test.py @@ -98,6 +98,17 @@ def _run(tensor_to_run): modeling.is_encoded_variant_type( tensor, tf_utils.EncodedVariantType.INDEL)), [False, True] * 4) + @parameterized.parameters( + dict(labels=[0, 2, 1, 0], target_class=0, expected=[0, 1, 1, 0]), + dict(labels=[0, 2, 1, 0], target_class=1, expected=[1, 1, 0, 1]), + dict(labels=[0, 2, 1, 0], target_class=2, expected=[1, 0, 1, 1]), + ) + def test_binarize(self, labels, target_class, expected): + with self.test_session() as sess: + result = sess.run( + modeling.binarize(np.array(labels), np.array(target_class))) + self.assertListEqual(result.tolist(), expected) + @parameterized.parameters([True, False]) def test_eval_metric_fn(self, include_variant_types): labels = tf.constant([1, 0], dtype=tf.int64)