diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/scripts/tests/test_build_model.py b/scripts/tests/test_build_model.py index 278cb1fc..b9476cd3 100644 --- a/scripts/tests/test_build_model.py +++ b/scripts/tests/test_build_model.py @@ -21,7 +21,7 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import build_model # type: ignore # noqa (module hack) +from scripts import build_model # noqa (module hack) class TestAggregateScores(unittest.TestCase): diff --git a/scripts/tests/test_encode_data.py b/scripts/tests/test_encode_data.py index 25161ded..2e713664 100644 --- a/scripts/tests/test_encode_data.py +++ b/scripts/tests/test_encode_data.py @@ -24,7 +24,7 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import encode_data # type: ignore # noqa (module hack) +from scripts import encode_data # noqa (module hack) class TestGetFeature(unittest.TestCase): diff --git a/scripts/tests/test_prepare_knbc.py b/scripts/tests/test_prepare_knbc.py index c468b0d7..cdfbc38c 100644 --- a/scripts/tests/test_prepare_knbc.py +++ b/scripts/tests/test_prepare_knbc.py @@ -21,7 +21,7 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import prepare_knbc # type: ignore # noqa (module hack) +from scripts import prepare_knbc # noqa (module hack) class TestBreakBeforeSequence(unittest.TestCase): diff --git a/scripts/tests/test_train.py b/scripts/tests/test_train.py index 8046f62c..15b8c071 100644 --- a/scripts/tests/test_train.py +++ b/scripts/tests/test_train.py @@ -27,7 +27,7 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import train # type: ignore # noqa (module hack) +from scripts import train # noqa (module hack) class TestArgParse(unittest.TestCase): diff --git a/scripts/tests/test_translate_model.py b/scripts/tests/test_translate_model.py index 4c7a368b..1c35f5f3 100644 --- a/scripts/tests/test_translate_model.py +++ b/scripts/tests/test_translate_model.py @@ -21,7 +21,7 @@ LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..') sys.path.insert(0, os.path.abspath(LIB_PATH)) -from scripts import translate_model # type: ignore # noqa (module hack) +from scripts import translate_model # noqa (module hack) class TestNormalize(unittest.TestCase): diff --git a/scripts/train.py b/scripts/train.py index 3c548400..e2fca1b6 100644 --- a/scripts/train.py +++ b/scripts/train.py @@ -157,10 +157,10 @@ def get_metrics(pred: npt.NDArray[np.bool_], Returns: result (Result): A result. """ - tp = jnp.sum(jnp.logical_and(pred == 1, actual == 1)) - tn = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) - fp = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) - fn = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) + tp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 1)) # type: ignore + tn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) # type: ignore + fp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) # type: ignore + fn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) # type: ignore accuracy = (tp + tn) / (tp + tn + fp + fn) precision = tp / (tp + fp) recall = tp / (tp + fn) @@ -208,7 +208,7 @@ def update( best_feature_index: int = err.argmin() positivity: bool = res.at[best_feature_index].get() < 0.5 err_min = err.at[best_feature_index].get() - amount: float = jnp.log((1 - err_min) / (err_min + EPS)) + amount: float = jnp.log((1 - err_min) / (err_min + EPS)) # type: ignore # This is equivalent to X_best = X[:, best_feature_index] X_best = jnp.zeros( diff --git a/setup.cfg b/setup.cfg index c0c0a7e0..bf047b88 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,7 +28,7 @@ dev = flake8 isort numpy - mypy==0.971 + mypy pytest toml twine