From 63214279c90402711ece819a6eb8bc1751ea44a7 Mon Sep 17 00:00:00 2001 From: marvinfriede <51965259+marvinfriede@users.noreply.github.com> Date: Sun, 15 Dec 2024 10:20:49 -0600 Subject: [PATCH] Fix typing --- .pre-commit-config.yaml | 8 ++++++-- src/tad_dftd3/data/radii.py | 4 ++-- src/tad_dftd3/reference.py | 6 ++++-- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 690f372..323d74b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,7 +38,7 @@ repos: --min-py-version, "3.8", --max-py-version, - "3.11", + "3.12", ] - repo: https://github.com/asottile/pyupgrade @@ -60,10 +60,14 @@ repos: - id: black stages: [pre-commit] + - repo: https://github.com/woodruffw/zizmor-pre-commit + rev: v0.9.2 + hooks: + - id: zizmor + - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.13.0 hooks: - id: mypy - additional_dependencies: [types-all] pass_filenames: false args: [--config-file=pyproject.toml, --ignore-missing-imports, src] diff --git a/src/tad_dftd3/data/radii.py b/src/tad_dftd3/data/radii.py index 0766e07..ac729c3 100644 --- a/src/tad_dftd3/data/radii.py +++ b/src/tad_dftd3/data/radii.py @@ -29,7 +29,7 @@ from tad_mctc._version import __tversion__ from tad_mctc.data.radii import COV_D3 -from ..typing import Tensor +from ..typing import Any, Tensor __all__ = ["COV_D3", "VDW_D3"] @@ -52,7 +52,7 @@ def _load_vdw_rad_d3( Tensor VDW radii. """ - kwargs: dict = {"map_location": device} + kwargs: dict[str, Any] = {"map_location": device} if __tversion__ > (1, 12, 1): # pragma: no cover kwargs["weights_only"] = True diff --git a/src/tad_dftd3/reference.py b/src/tad_dftd3/reference.py index 56bcc3c..bf1fb2d 100644 --- a/src/tad_dftd3/reference.py +++ b/src/tad_dftd3/reference.py @@ -20,7 +20,7 @@ C6 dispersion coefficients. """ import os.path as op -from typing import Optional +from typing import Optional, Union import torch from tad_mctc._version import __tversion__ @@ -48,6 +48,7 @@ def _load_cn( Tensor Reference coordination numbers. """ + # fmt: off return torch.tensor( [ [-1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000, -1.0000], # None @@ -158,6 +159,7 @@ def _load_cn( device=device, dtype=dtype, ) + # fmt: on def _load_c6( @@ -178,7 +180,7 @@ def _load_c6( Tensor Reference C6 coefficients. """ - kwargs: dict = {"map_location": device} + kwargs: dict[str, Any] = {"map_location": device} if __tversion__ > (1, 12, 1): # pragma: no cover kwargs["weights_only"] = True