@@ -3152,86 +3152,37 @@ def forward(self, x):
3152
3152
)
3153
3153
3154
3154
3155
- class TestLogicalAnd (TorchBaseTest ):
3155
+ class TestBitWiseLogical (TorchBaseTest ):
3156
3156
@pytest .mark .parametrize (
3157
- "backend, x_y" ,
3157
+ "backend, x_y, op_string " ,
3158
3158
itertools .product (
3159
3159
backends ,
3160
3160
[
3161
3161
([True , False , True , False ], [True , True , False , False ]),
3162
3162
([[True , False ], [True , False ]], [[True , True ], [False , False ]]),
3163
+ ([[True , False ], [True , False ]], [[1 , 0 ], [2 , 1 ]]),
3163
3164
([- 1.5 , 0.0 , 1.0 , 0.0 ], [0.1 , 2.5 , 0.0 , 0.0 ]),
3164
3165
([2 , 0 , - 1 , 0 , 5 ], [1 , 1 , 0 , 0 , - 5 ]),
3165
3166
],
3166
- ),
3167
- )
3168
- def test_logical_and (self , backend , x_y ):
3169
- class TestNet (nn .Module ):
3170
- def __init__ (self ):
3171
- super (TestNet , self ).__init__ ()
3172
-
3173
- def forward (self , x , y ):
3174
- return torch .logical_and (x , y )
3175
-
3176
- model = TestNet ()
3177
- x = torch .tensor (x_y [0 ])
3178
- y = torch .tensor (x_y [1 ])
3179
- self .run_compare_torch ([x , y ], model , backend = backend , input_as_shape = False )
3180
-
3181
-
3182
- class TestLogicalOr (TorchBaseTest ):
3183
- @pytest .mark .parametrize (
3184
- "backend, x_y" ,
3185
- itertools .product (
3186
- backends ,
3187
3167
[
3188
- ([True , False , True , False ], [True , True , False , False ]),
3189
- ([[True , False ], [True , False ]], [[True , True ], [False , False ]]),
3190
- ([- 1.5 , 0.0 , 1.0 , 0.0 ], [0.1 , 2.5 , 0.0 , 0.0 ]),
3191
- ([2 , 0 , - 1 , 0 , 5 ], [1 , 1 , 0 , 0 , - 5 ]),
3168
+ "eq" ,
3169
+ "ne" ,
3170
+ "logical_and" ,
3171
+ "logical_or" ,
3172
+ "logical_xor" ,
3192
3173
],
3193
3174
),
3194
3175
)
3195
- def test_logical_or (self , backend , x_y ):
3196
- class TestNet (nn .Module ):
3197
- def __init__ (self ):
3198
- super (TestNet , self ).__init__ ()
3199
-
3200
- def forward (self , x , y ):
3201
- return torch .logical_or (x , y )
3202
-
3203
- model = TestNet ()
3176
+ def test_bitwise_logical (self , backend , x_y , op_string ):
3177
+ if not contains_op (torch , op_string ):
3178
+ return
3179
+ op_func = getattr (torch , op_string )
3180
+ model = ModuleWrapper (function = op_func )
3204
3181
x = torch .tensor (x_y [0 ])
3205
3182
y = torch .tensor (x_y [1 ])
3206
3183
self .run_compare_torch ([x , y ], model , backend = backend , input_as_shape = False )
3207
3184
3208
3185
3209
- class TestLogicalXor (TorchBaseTest ):
3210
- @pytest .mark .parametrize (
3211
- "backend, x_y" ,
3212
- itertools .product (
3213
- backends ,
3214
- [
3215
- ([True , False , True , False ], [True , True , False , False ]),
3216
- ([[True , False ], [True , False ]], [[True , True ], [False , False ]]),
3217
- ([- 1.5 , 0.0 , 1.0 , 0.0 ], [0.1 , 2.5 , 0.0 , 0.0 ]),
3218
- ([2 , 0 , - 1 , 0 , 5 ], [1 , 1 , 0 , 0 , - 5 ]),
3219
- ],
3220
- ),
3221
- )
3222
- def test_logical_xor (self , backend , x_y ):
3223
- class TestNet (nn .Module ):
3224
- def __init__ (self ):
3225
- super (TestNet , self ).__init__ ()
3226
-
3227
- def forward (self , x , y ):
3228
- return torch .logical_xor (x , y )
3229
-
3230
- model = TestNet ()
3231
- x = torch .tensor (x_y [0 ])
3232
- y = torch .tensor (x_y [1 ])
3233
- self .run_compare_torch ([x , y ], model , backend = backend , input_as_shape = False )
3234
-
3235
3186
3236
3187
class TestWhere (TorchBaseTest ):
3237
3188
@pytest .mark .parametrize (
0 commit comments