Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix mypy issue #135

Merged
merged 1 commit into from
Apr 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added scripts/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion scripts/tests/test_build_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import build_model # type: ignore # noqa (module hack)
from scripts import build_model # noqa (module hack)


class TestAggregateScores(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/test_encode_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import encode_data # type: ignore # noqa (module hack)
from scripts import encode_data # noqa (module hack)


class TestGetFeature(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/test_prepare_knbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import prepare_knbc # type: ignore # noqa (module hack)
from scripts import prepare_knbc # noqa (module hack)


class TestBreakBeforeSequence(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import train # type: ignore # noqa (module hack)
from scripts import train # noqa (module hack)


class TestArgParse(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion scripts/tests/test_translate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
LIB_PATH = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, os.path.abspath(LIB_PATH))

from scripts import translate_model # type: ignore # noqa (module hack)
from scripts import translate_model # noqa (module hack)


class TestNormalize(unittest.TestCase):
Expand Down
10 changes: 5 additions & 5 deletions scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,10 @@ def get_metrics(pred: npt.NDArray[np.bool_],
Returns:
result (Result): A result.
"""
tp = jnp.sum(jnp.logical_and(pred == 1, actual == 1))
tn = jnp.sum(jnp.logical_and(pred == 0, actual == 0))
fp = jnp.sum(jnp.logical_and(pred == 1, actual == 0))
fn = jnp.sum(jnp.logical_and(pred == 0, actual == 1))
tp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 1)) # type: ignore
tn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 0)) # type: ignore
fp: int = jnp.sum(jnp.logical_and(pred == 1, actual == 0)) # type: ignore
fn: int = jnp.sum(jnp.logical_and(pred == 0, actual == 1)) # type: ignore
accuracy = (tp + tn) / (tp + tn + fp + fn)
precision = tp / (tp + fp)
recall = tp / (tp + fn)
Expand Down Expand Up @@ -208,7 +208,7 @@ def update(
best_feature_index: int = err.argmin()
positivity: bool = res.at[best_feature_index].get() < 0.5
err_min = err.at[best_feature_index].get()
amount: float = jnp.log((1 - err_min) / (err_min + EPS))
amount: float = jnp.log((1 - err_min) / (err_min + EPS)) # type: ignore

# This is equivalent to X_best = X[:, best_feature_index]
X_best = jnp.zeros(
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dev =
flake8
isort
numpy
mypy==0.971
mypy
pytest
toml
twine
Expand Down