Skip to content

Commit

Permalink
Compatibility with numpy 2
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 2, 2025
1 parent 0558ab3 commit 231b4f7
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 29 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Release
uses: patrick-kidger/action_update_python_project@v2
uses: patrick-kidger/action_update_python_project@v6
with:
python-version: "3.11"
test-script: |
Expand All @@ -20,7 +20,3 @@ jobs:
pypi-token: ${{ secrets.pypi_token }}
github-user: patrick-kidger
github-token: ${{ github.token }}
email-user: ${{ secrets.email_user }}
email-token: ${{ secrets.email_token }}
email-server: ${{ secrets.email_server }}
email-target: ${{ secrets.email_target }}
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ 3.9 ]
python-version: [ "3.10", "3.12" ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
29 changes: 10 additions & 19 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
args: ["--fix"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.315
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: ["equinox", "jax", "sympy"]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
additional_dependencies: [ipython==8.12, black]
- id: nbqa-ruff
args: ["--ignore=I001"]
additional_dependencies: [ipython==8.12, ruff]
additional_dependencies: [equinox, jax, sympy]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "sympy2jax"
version = "0.0.5"
version = "0.0.6"
description = "Turn SymPy expressions into trainable JAX expressions."
readme = "README.md"
requires-python ="~=3.9"
Expand Down
2 changes: 1 addition & 1 deletion sympy2jax/sympy_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class _Symbol(_AbstractNode):
_name: str

def __init__(self, expr: sympy.Expr):
self._name = expr.name # pyright: ignore
self._name = str(expr.name) # pyright: ignore

def __call__(self, memodict: dict):
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_symbolic_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def assert_sympy_allclose(x, y):

def test_example():
x_sym = sympy.symbols("x_sym")
cosx = 1.0 * sympy.cos(x_sym)
sinx = 2.0 * sympy.sin(x_sym)
cosx = 1.0 * sympy.cos(x_sym) # pyright: ignore[reportOperatorIssue]
sinx = 2.0 * sympy.sin(x_sym) # pyright: ignore[reportOperatorIssue]
mod = sympy2jax.SymbolicModule([cosx, sinx])

x = jax.numpy.zeros(3)
Expand Down

0 comments on commit 231b4f7

Please # to comment.