From 1252a7c30cfe8a69b8a9d7561256e9ede47e56fb Mon Sep 17 00:00:00 2001 From: Daniel Johnson Date: Wed, 20 Nov 2024 15:46:13 -0500 Subject: [PATCH] Add autoformatting check for test files. --- .github/workflows/unittests.yml | 2 +- .../v1/toolshed/model_rewiring_test.py | 33 ++++++++++--------- tests/toolshed/model_rewiring_test.py | 33 ++++++++++--------- 3 files changed, 37 insertions(+), 31 deletions(-) diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 0c8bc5c..2e065ef 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -42,7 +42,7 @@ jobs: # Check formatting - name: Check pyink formatting - run: uv run pyink penzai --check + run: uv run pyink penzai tests --check - name: Run pylint run: uv run pylint penzai diff --git a/tests/deprecated/v1/toolshed/model_rewiring_test.py b/tests/deprecated/v1/toolshed/model_rewiring_test.py index 94593ec..2707333 100644 --- a/tests/deprecated/v1/toolshed/model_rewiring_test.py +++ b/tests/deprecated/v1/toolshed/model_rewiring_test.py @@ -141,23 +141,26 @@ def target(stuff): result["b"].canonicalize(), ( pz.nx.wrap( - jnp.array([ - [ - 1**3 + 3 * 1**2 * (4 - 1), - 1**3 + 3 * 1**2 * (5 - 1), - 1**3 + 3 * 1**2 * (6 - 1), - ], - [ - 2**3 + 3 * 2**2 * (4 - 2), - 2**3 + 3 * 2**2 * (5 - 2), - 2**3 + 3 * 2**2 * (6 - 2), - ], + jnp.array( [ - 3**3 + 3 * 3**2 * (4 - 3), - 3**3 + 3 * 3**2 * (5 - 3), - 3**3 + 3 * 3**2 * (6 - 3), + [ + 1**3 + 3 * 1**2 * (4 - 1), + 1**3 + 3 * 1**2 * (5 - 1), + 1**3 + 3 * 1**2 * (6 - 1), + ], + [ + 2**3 + 3 * 2**2 * (4 - 2), + 2**3 + 3 * 2**2 * (5 - 2), + 2**3 + 3 * 2**2 * (6 - 2), + ], + [ + 3**3 + 3 * 3**2 * (4 - 3), + 3**3 + 3 * 3**2 * (5 - 3), + 3**3 + 3 * 3**2 * (6 - 3), + ], ], - ], jnp.float32) + jnp.float32, + ) ) .tag("foo", "bar") .canonicalize() diff --git a/tests/toolshed/model_rewiring_test.py b/tests/toolshed/model_rewiring_test.py index 35a0346..45ef9cd 100644 --- a/tests/toolshed/model_rewiring_test.py +++ b/tests/toolshed/model_rewiring_test.py @@ -141,23 +141,26 @@ def target(stuff): result["b"].canonicalize(), ( pz.nx.wrap( - jnp.array([ - [ - 1**3 + 3 * 1**2 * (4 - 1), - 1**3 + 3 * 1**2 * (5 - 1), - 1**3 + 3 * 1**2 * (6 - 1), - ], - [ - 2**3 + 3 * 2**2 * (4 - 2), - 2**3 + 3 * 2**2 * (5 - 2), - 2**3 + 3 * 2**2 * (6 - 2), - ], + jnp.array( [ - 3**3 + 3 * 3**2 * (4 - 3), - 3**3 + 3 * 3**2 * (5 - 3), - 3**3 + 3 * 3**2 * (6 - 3), + [ + 1**3 + 3 * 1**2 * (4 - 1), + 1**3 + 3 * 1**2 * (5 - 1), + 1**3 + 3 * 1**2 * (6 - 1), + ], + [ + 2**3 + 3 * 2**2 * (4 - 2), + 2**3 + 3 * 2**2 * (5 - 2), + 2**3 + 3 * 2**2 * (6 - 2), + ], + [ + 3**3 + 3 * 3**2 * (4 - 3), + 3**3 + 3 * 3**2 * (5 - 3), + 3**3 + 3 * 3**2 * (6 - 3), + ], ], - ], jnp.float32) + jnp.float32, + ) ) .tag("foo", "bar") .canonicalize()