diff --git a/Makefile b/Makefile index 500fc2a..8deea92 100644 --- a/Makefile +++ b/Makefile @@ -13,4 +13,4 @@ test: check: poetry run pre-commit run --all-files mypy: - poetry run mypy . --config-file pyproject.toml + poetry run mypy ml_orchestrator --config-file pyproject.toml diff --git a/dummy_components/__init__.py b/dummy_components/__init__.py index aed3ad7..b73cd85 100644 --- a/dummy_components/__init__.py +++ b/dummy_components/__init__.py @@ -6,6 +6,7 @@ ) valid_classes_meta = [ + dcomps.MetaComponentTest, dcomps.ComponentTestB, dcomps.ComponentTestA, dcomps.ComponentTestC, @@ -15,6 +16,7 @@ dcomps.ComponentTestD2, ] valid_classes_meta_v2 = [ + dv2comps.MetaComponentTest, dv2comps.ComponentTestB, dv2comps.ComponentTestA, dv2comps.ComponentTestC, diff --git a/ml_orchestrator/comp_parser.py b/ml_orchestrator/comp_parser.py index fe641fb..d888169 100644 --- a/ml_orchestrator/comp_parser.py +++ b/ml_orchestrator/comp_parser.py @@ -1,7 +1,6 @@ import dataclasses from typing import List -from ml_orchestrator.comp_protocol.comp_protocol import ComponentProtocol from ml_orchestrator.comp_protocol.func_parser import FunctionParser from ml_orchestrator.env_params import EnvironmentParams from ml_orchestrator.meta_comp import MetaComponent, _MetaComponent @@ -49,10 +48,6 @@ def create_kfp_str(self, component: _MetaComponent) -> str: # type: ignore kfp_component_str = kfp_component_str.replace("\t", " ") return kfp_component_str + "\n" - def parse_components_to_file(self, components: List[ComponentProtocol], filename: str) -> None: - kfp_str = self.create_kfp_file_str(components) - self.write_to_file(filename, kfp_str) - def write_to_file(self, filename: str, file_content: str) -> None: for imp in self.add_imports: file_content = f"{imp}\n{file_content}" diff --git a/ml_orchestrator/comp_protocol/func_parser.py b/ml_orchestrator/comp_protocol/func_parser.py index 7b9f49c..80322e6 100644 --- a/ml_orchestrator/comp_protocol/func_parser.py +++ b/ml_orchestrator/comp_protocol/func_parser.py @@ -28,8 +28,10 @@ def create_function(self, component: ComponentProtocol) -> str: def get_function_parts(self, comp_class: Type[ComponentProtocol]) -> Tuple[str, str]: component_variables = self.comp_vars(comp_class) kfp_func_name = comp_class.__name__.lower() - func_scope = "(\n\t" + ",\n\t".join(self.get_func_params(component_variables)) + ",\n)" - comp_scope = "(\n\t\t" + ",\n\t\t".join(self.get_comp_params(component_variables)) + ",\n\t)" + func_params = self.get_func_params(component_variables) + comp_params = self.get_comp_params(component_variables) + func_scope = "(\n\t" + ",\n\t".join(func_params) + (",\n)" if func_params else ")") + comp_scope = "(\n\t\t" + ",\n\t\t".join(comp_params) + (",\n\t)" if comp_params else ")") return_type = "" if self.exe_return(comp_class) is not None: return_type = f" -> {self.exe_return(comp_class).__name__}" @@ -107,6 +109,10 @@ def create_kfp_file_str( file_content = f"{IMPORT_COMPOUND}\n\n\n{kfp_str}" return file_content + def parse_components_to_file(self, components: List[ComponentProtocol], filename: str) -> None: + kfp_str = self.create_kfp_file_str(components) + self.write_to_file(filename, kfp_str) + def write_to_file(self, filename: str, file_content: str) -> None: file_content = f"# flake8: noqa: F403, F405, B006\n{file_content}" with open(filename, "w", encoding="utf-8") as f: diff --git a/ml_orchestrator/comp_protocol/t_suites.py b/ml_orchestrator/comp_protocol/t_suites.py index fe8db65..70c23a9 100644 --- a/ml_orchestrator/comp_protocol/t_suites.py +++ b/ml_orchestrator/comp_protocol/t_suites.py @@ -18,8 +18,12 @@ def test_flows_protocol(self, comp_fixture: ComponentProtocol) -> None: def test_comp_protocol_attrs(self, comp_fixture: ComponentProtocol) -> None: fp = FunctionParser() comp_vars = fp.comp_vars(comp_fixture) # type: ignore - assert comp_vars - assert fp.get_func_params(comp_vars) + num_init_params = len(comp_fixture.__init__.__code__.co_varnames) - 1 # type: ignore + func_params = fp.get_func_params(comp_vars) + + assert comp_vars if num_init_params > 0 else not comp_vars + assert len(comp_vars) == num_init_params + assert len(func_params) == num_init_params def test_comp_protocol_e2e(self, comp_fixture: ComponentProtocol) -> None: fp = FunctionParser() diff --git a/pyproject.toml b/pyproject.toml index 5689c2e..662c8bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,8 +47,3 @@ check_untyped_defs = true warn_redundant_casts = true warn_unused_ignores = true strict_optional = false - - -[[tool.mypy.overrides]] -module = "tests.*" -ignore_errors = true