Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Type annotations for main classes #705

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions supervised/algorithms/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid
from typing import Union

import numpy as np
import pandas as pd

from supervised.utils.common import construct_learner_name
from supervised.utils.importance import PermutationImportance
Expand All @@ -16,15 +18,15 @@ class BaseAlgorithm:
algorithm_name = "Unknown"
algorithm_short_name = "Unknown"

def __init__(self, params):
self.params = params
self.stop_training = False
self.library_version = None
self.model = None
self.uid = params.get("uid", str(uuid.uuid4()))
self.ml_task = params.get("ml_task")
self.model_file_path = None
self.name = "amazing_learner"
def __init__(self, params: dict):
self.params: dict = params
self.stop_training: bool = False
self.library_version: str = None
self.model: object = None
self.uid: str = params.get("uid", str(uuid.uuid4()))
self.ml_task: str = params.get("ml_task")
self.model_file_path: str = None
self.name: str = "amazing_learner"

def set_learner_name(self, fold, repeat, repeats):
self.name = construct_learner_name(fold, repeat, repeats)
Expand All @@ -39,8 +41,8 @@ def reload(self):

def fit(
self,
X,
y,
X: Union[np.ndarray, pd.DataFrame],
y: Union[np.ndarray, pd.Series],
sample_weight=None,
X_validation=None,
y_validation=None,
Expand Down
6 changes: 3 additions & 3 deletions supervised/algorithms/baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class BaselineClassifierAlgorithm(SklearnAlgorithm, ClassifierMixin):
algorithm_name = "Baseline Classifier"
algorithm_short_name = "Baseline"

def __init__(self, params):
def __init__(self, params: dict):
super(BaselineClassifierAlgorithm, self).__init__(params)
logger.debug("BaselineClassifierAlgorithm.__init__")

self.library_version = sklearn.__version__
self.max_iters = additional.get("max_steps", 1)
self.library_version: str = sklearn.__version__
self.max_iters: int = additional.get("max_steps", 1)
self.model = DummyClassifier(
strategy="prior", random_state=params.get("seed", 1)
)
Expand Down
5 changes: 3 additions & 2 deletions supervised/algorithms/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging

from supervised.algorithms.algorithm import BaseAlgorithm
from supervised.algorithms.registry import BINARY_CLASSIFICATION, AlgorithmsRegistry

logger = logging.getLogger(__name__)
Expand All @@ -9,7 +10,7 @@

class AlgorithmFactory(object):
@classmethod
def get_algorithm(cls, params):
def get_algorithm(cls, params: dict) -> BaseAlgorithm:
alg_type = params.get("model_type", "Xgboost")
ml_task = params.get("ml_task", BINARY_CLASSIFICATION)

Expand All @@ -20,7 +21,7 @@ def get_algorithm(cls, params):
raise AutoMLException(f"Cannot get algorithm class. {str(e)}")

@classmethod
def load(cls, json_desc, learner_path, lazy_load):
def load(cls, json_desc: dict, learner_path: str, lazy_load: bool) -> BaseAlgorithm:
learner = AlgorithmFactory.get_algorithm(json_desc.get("params"))
learner.set_params(json_desc, learner_path)
if not lazy_load:
Expand Down
30 changes: 17 additions & 13 deletions supervised/algorithms/registry.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# tasks that can be handled by the package

from typing import List, Type

BINARY_CLASSIFICATION = "binary_classification"
MULTICLASS_CLASSIFICATION = "multiclass_classification"
REGRESSION = "regression"

class AlgorithmsRegistry:
from supervised.algorithms.algorithm import BaseAlgorithm
registry = {
BINARY_CLASSIFICATION: {},
MULTICLASS_CLASSIFICATION: {},
Expand All @@ -12,13 +16,13 @@ class AlgorithmsRegistry:

@staticmethod
def add(
task_name,
model_class,
model_params,
required_preprocessing,
additional,
default_params,
):
task_name: str,
model_class: Type[BaseAlgorithm],
model_params: dict,
required_preprocessing: list,
additional: dict,
default_params: dict,
) -> None:
model_information = {
"class": model_class,
"params": model_params,
Expand All @@ -31,33 +35,33 @@ def add(
] = model_information

@staticmethod
def get_supported_ml_tasks():
def get_supported_ml_tasks() -> List[str]:
return AlgorithmsRegistry.registry.keys()

@staticmethod
def get_algorithm_class(ml_task, algorithm_name):
def get_algorithm_class(ml_task: str, algorithm_name: str) -> Type[BaseAlgorithm]:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["class"]

@staticmethod
def get_long_name(ml_task, algorithm_name):
def get_long_name(ml_task: str, algorithm_name: str) -> str:
return AlgorithmsRegistry.registry[ml_task][algorithm_name][
"class"
].algorithm_name

@staticmethod
def get_max_rows_limit(ml_task, algorithm_name):
def get_max_rows_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_rows_limit"
]

@staticmethod
def get_max_cols_limit(ml_task, algorithm_name):
def get_max_cols_limit(ml_task: str, algorithm_name: str) -> int:
return AlgorithmsRegistry.registry[ml_task][algorithm_name]["additional"][
"max_cols_limit"
]

@staticmethod
def get_eval_metric(algorithm_name, ml_task, automl_eval_metric):
def get_eval_metric(ml_task: str, algorithm_name: str, automl_eval_metric: str):
if algorithm_name == "Xgboost":
return xgboost_eval_metric(ml_task, automl_eval_metric)

Expand Down
28 changes: 17 additions & 11 deletions supervised/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from typing import List

from supervised.algorithms.algorithm import BaseAlgorithm


class Callback(object):
def __init__(self, params):
self.params = params
self.learners = []
self.learner = None # current learner
self.name = "callback"

def add_and_set_learner(self, learner):
def __init__(self, params: dict):
self.params: dict = params
self.learners: List[BaseAlgorithm] = []
self.learner: BaseAlgorithm = None # current learner
self.name: str = "callback"

def add_and_set_learner(self, learner: BaseAlgorithm):
self.learners += [learner]
self.learner = learner

def on_learner_train_start(self, logs):
def on_learner_train_start(self, logs: dict) -> None:
pass

def on_learner_train_end(self, logs):
def on_learner_train_end(self, logs: dict) -> None:
pass

def on_iteration_start(self, logs):
def on_iteration_start(self, logs: dict) -> None:
pass

def on_iteration_end(self, logs, predictions):
def on_iteration_end(self, logs: dict, predictions: dict) -> None:
pass

def on_framework_train_end(self, logs):
def on_framework_train_end(self, logs: dict) -> None:
pass
25 changes: 16 additions & 9 deletions supervised/callbacks/callback_list.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,39 @@
from typing import List

from supervised.algorithms.algorithm import BaseAlgorithm
from supervised.callbacks.callback import Callback


class CallbackList(object):
def __init__(self, callbacks=[]):
self.callbacks = callbacks

def add_and_set_learner(self, learner):
def __init__(self, callbacks: List[Callback] = []):
self.callbacks: List[Callback] = callbacks

def add_and_set_learner(self, learner: BaseAlgorithm) -> None:
for cb in self.callbacks:
cb.add_and_set_learner(learner)

def on_learner_train_start(self, logs=None):
def on_learner_train_start(self, logs: dict = None) -> None:
for cb in self.callbacks:
cb.on_learner_train_start(logs)

def on_learner_train_end(self, logs=None):
def on_learner_train_end(self, logs: dict = None) -> None:
for cb in self.callbacks:
cb.on_learner_train_end(logs)

def on_iteration_start(self, logs=None):
def on_iteration_start(self, logs: dict = None) -> None:
for cb in self.callbacks:
cb.on_iteration_start(logs)

def on_iteration_end(self, logs=None, predictions=None):
def on_iteration_end(self, logs: dict = None, predictions: dict = None) -> None:
for cb in self.callbacks:
cb.on_iteration_end(logs, predictions)

def on_framework_train_end(self, logs=None):
def on_framework_train_end(self, logs: dict = None) -> None:
for cb in self.callbacks:
cb.on_framework_train_end(logs)

def get(self, callback_name):
def get(self, callback_name: str) -> Callback:
for cb in self.callbacks:
if cb.name == callback_name:
return cb
Expand Down
6 changes: 3 additions & 3 deletions supervised/model_framework.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import copy
import gc
import json
import logging
import os
import time
import uuid

import gc
import numpy as np
import pandas as pd
import time

from supervised.algorithms.factory import AlgorithmFactory
from supervised.algorithms.registry import (
Expand Down Expand Up @@ -100,7 +100,7 @@ def predictions(
y_validation,
sample_weight_validation,
sensitive_features_validation,
):
) -> dict:
y_train_true = y_train
y_train_predicted = learner.predict(X_train)
y_validation_true = y_validation
Expand Down
9 changes: 6 additions & 3 deletions supervised/validation/validation_step.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import logging
from typing import Tuple

import pandas as pd

log = logging.getLogger(__name__)

Expand All @@ -24,14 +27,14 @@ def __init__(self, params):
f"The validation type ({self.validation_type}) is not implemented."
)

def get_split(self, k, repeat=0):
def get_split(self, k: int, repeat: int = 0) -> Tuple[pd.DataFrame, pd.DataFrame]:
return self.validator.get_split(k, repeat)

def split(self):
return self.validator.split()

def get_n_splits(self):
def get_n_splits(self) -> int:
return self.validator.get_n_splits()

def get_repeats(self):
def get_repeats(self) -> int:
return self.validator.get_repeats()
8 changes: 4 additions & 4 deletions supervised/validation/validator_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@


class BaseValidator(object):
def __init__(self, params):
self.params = params
def __init__(self, params: dict):
self.params: dict = params

def split(self):
pass

def get_n_splits(self):
def get_n_splits(self) -> int:
pass

def get_repeats(self):
def get_repeats(self) -> int:
return 1
Loading