Skip to content

Commit 05bdaf7

Browse files
committed
Add tests for numpy expression compatibility
1 parent 345917e commit 05bdaf7

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# ___________________________________________________________________________
2+
#
3+
# Pyomo: Python Optimization Modeling Objects
4+
# Copyright (c) 2008-2024
5+
# National Technology and Engineering Solutions of Sandia, LLC
6+
# Under the terms of Contract DE-NA0003525 with National Technology and
7+
# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain
8+
# rights in this software.
9+
# This software is distributed under the 3-clause BSD License.
10+
# ___________________________________________________________________________
11+
12+
import pyomo.common.unittest as unittest
13+
14+
from pyomo.common.dependencies import numpy as np, numpy_available
15+
from pyomo.environ import ConcreteModel, Var, Constraint
16+
17+
18+
@unittest.skipUnless(numpy_available, "tests require numpy")
19+
class TestNumpyExpr(unittest.TestCase):
20+
def test_scalar_operations(self):
21+
m = ConcreteModel()
22+
m.x = Var()
23+
24+
a = np.array(m.x)
25+
self.assertEqual(a.shape, ())
26+
27+
self.assertExpressionsEqual(5 * a, 5 * m.x)
28+
self.assertExpressionsEqual(np.array([2, 3]) * a, [2 * m.x, 3 * m.x])
29+
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
30+
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])
31+
32+
a = np.array([m.x])
33+
self.assertEqual(a.shape, (1,))
34+
35+
self.assertExpressionsEqual(5 * a, [5 * m.x])
36+
self.assertExpressionsEqual(np.array([2, 3]) * a, [2 * m.x, 3 * m.x])
37+
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
38+
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])
39+
40+
def test_vector_operations(self):
41+
m = ConcreteModel()
42+
m.x = Var()
43+
m.y = Var([0, 1, 2])
44+
45+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
46+
# TODO: when we finally support a true matrix expression
47+
# system, this test should work
48+
self.assertExpressionsEqual(5 * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
49+
50+
a = np.array(5)
51+
self.assertExpressionsEqual(a * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
52+
self.assertExpressionsEqual(m.y * a, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
53+
a = np.array([5])
54+
self.assertExpressionsEqual(a * m.y, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
55+
self.assertExpressionsEqual(m.y * a, [5 * m.y[0], 5 * m.y[1], 5 * m.y[2]])
56+
57+
a = np.array(5)
58+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
59+
# TODO: when we finally support a true matrix expression
60+
# system, this test should work
61+
self.assertExpressionsEqual(
62+
a * m.x * m.y, [5 * m.x * m.y[0], 5 * m.x * m.y[1], 5 * m.x * m.y[2]]
63+
)
64+
self.assertExpressionsEqual(
65+
a * m.y * m.x, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
66+
)
67+
self.assertExpressionsEqual(
68+
a * m.y * m.y,
69+
[5 * m.y[0] * m.y[0], 5 * m.y[1] * m.y[1], 5 * m.y[2] * m.y[2]],
70+
)
71+
self.assertExpressionsEqual(
72+
m.y * a * m.x, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
73+
)
74+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
75+
# TODO: when we finally support a true matrix expression
76+
# system, this test should work
77+
self.assertExpressionsEqual(
78+
m.y * m.x * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
79+
)
80+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
81+
# TODO: when we finally support a true matrix expression
82+
# system, this test should work
83+
self.assertExpressionsEqual(
84+
m.x * a * m.y, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
85+
)
86+
with self.assertRaisesRegex(TypeError, "unsupported operand"):
87+
# TODO: when we finally support a true matrix expression
88+
# system, this test should work
89+
self.assertExpressionsEqual(
90+
m.x * m.y * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
91+
)
92+
93+
94+
if __name__ == "__main__":
95+
unittest.main()

0 commit comments

Comments
 (0)