Skip to content

Commit 2ff8e80

Browse files
committed
Torch eq and ne ops supports bool type.
1 parent 6adf95b commit 2ff8e80

File tree

2 files changed

+27
-64
lines changed

2 files changed

+27
-64
lines changed

coremltools/converters/mil/frontend/torch/ops.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,14 +472,26 @@ def listconstruct(context, node):
472472
@register_torch_op
473473
def eq(context, node):
474474
inputs = _get_inputs(context, node, expected=2)
475-
equal_to = mb.equal(x=inputs[0], y=inputs[1], name=node.name)
475+
x = inputs[0]
476+
y = inputs[1]
477+
if is_bool(x.dtype):
478+
x = mb.cast(x=x, dtype='int32')
479+
if is_bool(y.dtype):
480+
y = mb.cast(x=y, dtype='int32')
481+
equal_to = mb.equal(x=x, y=y, name=node.name)
476482
context.add(equal_to)
477483

478484

479485
@register_torch_op
480486
def ne(context, node):
481487
inputs = _get_inputs(context, node, expected=2)
482-
equal_to = mb.not_equal(x=inputs[0], y=inputs[1], name=node.name)
488+
x = inputs[0]
489+
y = inputs[1]
490+
if is_bool(x.dtype):
491+
x = mb.cast(x=x, dtype='int32')
492+
if is_bool(y.dtype):
493+
y = mb.cast(x=y, dtype='int32')
494+
equal_to = mb.not_equal(x=x, y=y, name=node.name)
483495
context.add(equal_to)
484496

485497

coremltools/converters/mil/frontend/torch/test/test_torch_ops.py

Lines changed: 13 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3152,86 +3152,37 @@ def forward(self, x):
31523152
)
31533153

31543154

3155-
class TestLogicalAnd(TorchBaseTest):
3155+
class TestBitWiseLogical(TorchBaseTest):
31563156
@pytest.mark.parametrize(
3157-
"backend, x_y",
3157+
"backend, x_y, op_string",
31583158
itertools.product(
31593159
backends,
31603160
[
31613161
([True, False, True, False], [True, True, False, False]),
31623162
([[True, False], [True, False]], [[True, True], [False, False]]),
3163+
([[True, False], [True, False]], [[1, 0], [2, 1]]),
31633164
([-1.5, 0.0, 1.0, 0.0], [0.1, 2.5, 0.0, 0.0]),
31643165
([2, 0, -1, 0, 5], [1, 1, 0, 0, -5]),
31653166
],
3166-
),
3167-
)
3168-
def test_logical_and(self, backend, x_y):
3169-
class TestNet(nn.Module):
3170-
def __init__(self):
3171-
super(TestNet, self).__init__()
3172-
3173-
def forward(self, x, y):
3174-
return torch.logical_and(x, y)
3175-
3176-
model = TestNet()
3177-
x = torch.tensor(x_y[0])
3178-
y = torch.tensor(x_y[1])
3179-
self.run_compare_torch([x, y], model, backend=backend, input_as_shape=False)
3180-
3181-
3182-
class TestLogicalOr(TorchBaseTest):
3183-
@pytest.mark.parametrize(
3184-
"backend, x_y",
3185-
itertools.product(
3186-
backends,
31873167
[
3188-
([True, False, True, False], [True, True, False, False]),
3189-
([[True, False], [True, False]], [[True, True], [False, False]]),
3190-
([-1.5, 0.0, 1.0, 0.0], [0.1, 2.5, 0.0, 0.0]),
3191-
([2, 0, -1, 0, 5], [1, 1, 0, 0, -5]),
3168+
"eq",
3169+
"ne",
3170+
"logical_and",
3171+
"logical_or",
3172+
"logical_xor",
31923173
],
31933174
),
31943175
)
3195-
def test_logical_or(self, backend, x_y):
3196-
class TestNet(nn.Module):
3197-
def __init__(self):
3198-
super(TestNet, self).__init__()
3199-
3200-
def forward(self, x, y):
3201-
return torch.logical_or(x, y)
3202-
3203-
model = TestNet()
3176+
def test_bitwise_logical(self, backend, x_y, op_string):
3177+
if not contains_op(torch, op_string):
3178+
return
3179+
op_func = getattr(torch, op_string)
3180+
model = ModuleWrapper(function=op_func)
32043181
x = torch.tensor(x_y[0])
32053182
y = torch.tensor(x_y[1])
32063183
self.run_compare_torch([x, y], model, backend=backend, input_as_shape=False)
32073184

32083185

3209-
class TestLogicalXor(TorchBaseTest):
3210-
@pytest.mark.parametrize(
3211-
"backend, x_y",
3212-
itertools.product(
3213-
backends,
3214-
[
3215-
([True, False, True, False], [True, True, False, False]),
3216-
([[True, False], [True, False]], [[True, True], [False, False]]),
3217-
([-1.5, 0.0, 1.0, 0.0], [0.1, 2.5, 0.0, 0.0]),
3218-
([2, 0, -1, 0, 5], [1, 1, 0, 0, -5]),
3219-
],
3220-
),
3221-
)
3222-
def test_logical_xor(self, backend, x_y):
3223-
class TestNet(nn.Module):
3224-
def __init__(self):
3225-
super(TestNet, self).__init__()
3226-
3227-
def forward(self, x, y):
3228-
return torch.logical_xor(x, y)
3229-
3230-
model = TestNet()
3231-
x = torch.tensor(x_y[0])
3232-
y = torch.tensor(x_y[1])
3233-
self.run_compare_torch([x, y], model, backend=backend, input_as_shape=False)
3234-
32353186

32363187
class TestWhere(TorchBaseTest):
32373188
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)