Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Enable mypy generic-related checks #338

Merged
merged 6 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,11 @@ files = [
"smartsim"
]
plugins = []
ignore_errors=false
ignore_errors = false

# Dynamic typing
disallow_any_generics = true
warn_return_any = true

# Strict fn defs
disallow_untyped_calls = true
Expand All @@ -84,7 +88,9 @@ disallow_untyped_decorators = true

# Safety/Upgrading Mypy
warn_unused_ignores = true
# warn_redundant_casts = true # not a per-module setting?
warn_redundant_casts = true
warn_unused_configs = true
show_error_codes = true

[[tool.mypy.overrides]]
# Ignore packages that are not used or not typed
Expand Down
10 changes: 5 additions & 5 deletions smartsim/_core/_install/buildenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,31 +105,31 @@ def patch(self) -> str:

def __gt__(self, cmp: t.Any) -> bool:
try:
return Version(self).__gt__(self._convert_to_version(cmp))
return bool(Version(self).__gt__(self._convert_to_version(cmp)))
except InvalidVersion:
return super().__gt__(cmp)

def __lt__(self, cmp: t.Any) -> bool:
try:
return Version(self).__lt__(self._convert_to_version(cmp))
return bool(Version(self).__lt__(self._convert_to_version(cmp)))
except InvalidVersion:
return super().__lt__(cmp)

def __eq__(self, cmp: t.Any) -> bool:
try:
return Version(self).__eq__(self._convert_to_version(cmp))
return bool(Version(self).__eq__(self._convert_to_version(cmp)))
except InvalidVersion:
return super().__eq__(cmp)

def __ge__(self, cmp: t.Any) -> bool:
try:
return Version(self).__ge__(self._convert_to_version(cmp))
return bool(Version(self).__ge__(self._convert_to_version(cmp)))
except InvalidVersion:
return super().__ge__(cmp)

def __le__(self, cmp: t.Any) -> bool:
try:
return Version(self).__le__(self._convert_to_version(cmp))
return bool(Version(self).__le__(self._convert_to_version(cmp)))
except InvalidVersion:
return super().__le__(cmp)

Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/control/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def reload_saved_db(self, checkpoint_file: str) -> Orchestrator:
raise SmartSimError(
err_message + "Could not find database job objects."
)
orc = db_config["db"]
orc: Orchestrator = db_config["db"]

# TODO check that each db_object is running

Expand Down
13 changes: 7 additions & 6 deletions smartsim/_core/entrypoints/colocated.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
if args.outputs:
outputs = list(args.outputs)

# devices_per_node being greater than one only applies
# to GPU devices
name = str(args.name)

# devices_per_node being greater than one only applies to GPU devices
if args.devices_per_node > 1 and args.device.lower() == "gpu":
client.set_model_from_file_multigpu(
args.name,
name,
args.file,
args.backend,
0,
Expand All @@ -111,7 +112,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
)
else:
client.set_model_from_file(
args.name,
name,
args.file,
args.backend,
args.device,
Expand All @@ -122,7 +123,7 @@ def launch_db_model(client: Client, db_model: t.List[str]) -> str:
outputs,
)

return args.name
return name


def launch_db_script(client: Client, db_script: t.List[str]) -> str:
Expand Down Expand Up @@ -163,7 +164,7 @@ def launch_db_script(client: Client, db_script: t.List[str]) -> str:
else:
raise ValueError("No file or func provided.")

return args.name
return str(args.name)


def main(
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/launcher/pbs/pbsParser.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def parse_step_id_from_qstat(output: str, step_name: str) -> t.Optional[str]:
:return: the step_id
:rtype: str
"""
step_id = None
step_id: t.Optional[str] = None
out_json = load_and_clean_json(output)

if "Jobs" not in out_json:
Expand Down
9 changes: 7 additions & 2 deletions smartsim/_core/launcher/taskManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,10 @@ def check_status(self) -> t.Optional[int]:
:rtype: int
"""
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.poll()
poll_result = self.process.poll()
if poll_result is not None:
return int(poll_result)
return None
# we can't manage Processed we don't own
# have to rely on .kill() to stop.
return self.returncode
Expand Down Expand Up @@ -363,7 +366,9 @@ def wait(self) -> None:
@property
def returncode(self) -> t.Optional[int]:
if self.owned and isinstance(self.process, psutil.Popen):
return self.process.returncode
if self.process.returncode is not None:
return int(self.process.returncode)
return None
if self.is_alive:
return None
return 0
Expand Down
2 changes: 1 addition & 1 deletion smartsim/_core/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_base_36_repr(positive_int: int) -> str:
def init_default(
default: t.Any,
init_value: t.Any,
expected_type: t.Optional[t.Union[t.Type, t.Tuple]] = None,
expected_type: t.Union[t.Type[t.Any], t.Tuple[t.Type[t.Any], ...], None] = None,
) -> t.Any:
if init_value is None:
return default
Expand Down
4 changes: 3 additions & 1 deletion smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def check_cluster_status(
# wait for cluster to spin up
time.sleep(5)
try:
redis_tester: RedisCluster = RedisCluster(startup_nodes=cluster_nodes)
redis_tester: "RedisCluster[t.Any]" = RedisCluster(
startup_nodes=cluster_nodes
)
redis_tester.set("__test__", "__test__")
redis_tester.delete("__test__") # type: ignore
logger.debug("Cluster status verified")
Expand Down
2 changes: 1 addition & 1 deletion smartsim/database/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def _build_run_settings_lsf(
return erf_rs

def _initialize_entities(self, **kwargs: t.Any) -> None:
self.db_nodes = kwargs.get("db_nodes", 1)
self.db_nodes = int(kwargs.get("db_nodes", 1))
single_cmd = kwargs.get("single_cmd", True)

if int(self.db_nodes) == 2:
Expand Down
21 changes: 14 additions & 7 deletions smartsim/ml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..error import SSInternalError
from ..log import get_logger


logger = get_logger(__name__)


Expand Down Expand Up @@ -205,7 +206,9 @@ def publish_info(self) -> None:
self._info.publish(self.client)

def put_batch(
self, samples: np.ndarray, targets: t.Optional[np.ndarray] = None
self,
samples: np.ndarray, # type: ignore[type-arg]
targets: t.Optional[np.ndarray] = None, # type: ignore[type-arg]
) -> None:
batch_ds_name = form_name("training_samples", self.rank, self.batch_idx)
batch_ds = Dataset(batch_ds_name)
Expand Down Expand Up @@ -381,10 +384,12 @@ def __len__(self) -> int:
length = int(np.floor(self.num_samples / self.batch_size))
return length

def _calc_indices(self, index: int) -> np.ndarray:
def _calc_indices(self, index: int) -> np.ndarray: # type: ignore[type-arg]
return self.indices[index * self.batch_size : (index + 1) * self.batch_size]

def __iter__(self) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]:
def __iter__(
self,
) -> t.Iterator[t.Tuple[np.ndarray, np.ndarray]]: # type: ignore[type-arg]
self.update_data()
# Generate data
if len(self) < 1:
Expand Down Expand Up @@ -426,11 +431,11 @@ def init_samples(self, init_trials: int = -1) -> None:

def _data_exists(self, batch_name: str, target_name: str) -> bool:
if self.need_targets:
return self.client.tensor_exists(batch_name) and self.client.tensor_exists(
target_name
return all(
self.client.tensor_exists(datum) for datum in [batch_name, target_name]
)

return self.client.tensor_exists(batch_name)
return bool(self.client.tensor_exists(batch_name))

def _add_samples(self, indices: t.List[int]) -> None:
datasets: t.List[Dataset] = []
Expand Down Expand Up @@ -491,7 +496,9 @@ def update_data(self) -> None:
if self.shuffle:
np.random.shuffle(self.indices)

def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]:
def _data_generation(
self, indices: np.ndarray # type: ignore[type-arg]
) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
# Initialization
if self.samples is None:
raise ValueError("Samples have not been initialized")
Expand Down
6 changes: 4 additions & 2 deletions smartsim/ml/tf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@


class _TFDataGenerationCommon(DataDownloader, keras.utils.Sequence):
def __getitem__(self, index: int) -> t.Tuple[np.ndarray, np.ndarray]:
def __getitem__(
self, index: int
) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
if len(self) < 1:
raise ValueError(
"Not enough samples in generator for one batch. Please "
Expand All @@ -57,7 +59,7 @@ def on_epoch_end(self) -> None:
if self.shuffle:
np.random.shuffle(self.indices)

def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]:
def _data_generation(self, indices: np.ndarray) -> t.Tuple[np.ndarray, np.ndarray]: # type: ignore[type-arg]
# Initialization
if self.samples is None:
raise ValueError("No samples loaded for data generation")
Expand Down
4 changes: 2 additions & 2 deletions smartsim/settings/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def create_batch_settings(
:raises SmartSimError: if batch creation fails
"""
# all supported batch class implementations
by_launcher = {
by_launcher: t.Dict[str, t.Callable[..., base.BatchSettings]] = {
"cobalt": CobaltBatchSettings,
"pbs": QsubBatchSettings,
"slurm": SbatchSettings,
Expand Down Expand Up @@ -144,7 +144,7 @@ def create_run_settings(
:raises SmartSimError: if run_command=="auto" and detection fails
"""
# all supported RunSettings child classes
supported = {
supported: t.Dict[str, t.Callable[..., RunSettings]] = {
"aprun": AprunSettings,
"srun": SrunSettings,
"mpirun": MpirunSettings,
Expand Down