Skip to content

ENG-1235 Utility functions has now special treatment as they need dyn… #344

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 5 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aixplain/factories/model_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def create_model_from_response(response: Dict) -> Model:
for param in response["params"]:
if "language" in param["name"]:
parameters[param["name"]] = [w["value"] for w in param["values"]]
else:
values = [w["value"] for w in param["defaultValues"]]
if len(values) > 0:
parameters[param["name"]] = values

function_id = response["function"]["id"]
function = Function(function_id)
Expand Down
82 changes: 59 additions & 23 deletions aixplain/modules/pipeline/designer/nodes.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
from typing import List, Union, Type, TYPE_CHECKING, Optional

from aixplain.modules import Model
from aixplain.enums import DataType

from .enums import (
NodeType,
FunctionType,
RouteType,
Operation,
AssetType,
)
from aixplain.enums import DataType, Function

from .enums import NodeType, FunctionType, RouteType, Operation, AssetType
from .base import (
Node,
Link,
Expand Down Expand Up @@ -85,7 +79,15 @@ def populate_asset(self):

if self.function:
if self.asset.function.value != self.function:
raise ValueError(f"Function {self.function} is not supported by asset {self.asset_id}") # noqa
raise ValueError(
f"Function {self.function} is not supported by asset {self.asset_id}"
)

# Despite function field has been set, we should still dynamically
# populate parameters for Utility functions
if self.function == Function.UTILITIES:
self._auto_populate_params()

else:
self.function = self.asset.function.value
self._auto_populate_params()
Expand All @@ -95,13 +97,24 @@ def populate_asset(self):
def _auto_populate_params(self):
from aixplain.enums.function import FunctionInputOutput

spec = FunctionInputOutput[self.asset.function.value]["spec"]
for item in spec["params"]:
self.inputs.create_param(
code=item["code"],
data_type=item["dataType"],
is_required=item["required"],
)
spec = FunctionInputOutput[self.function]["spec"]

# When the node is a utility, we need to create it's input parameters
# dynamically by referring the node data.
if self.function == Function.UTILITIES:
for param in self.asset.input_params.values():
self.inputs.create_param(
code=param["name"],
data_type=param["dataType"],
is_required=param["required"],
)
else:
for item in spec["params"]:
self.inputs.create_param(
code=item["code"],
data_type=item["dataType"],
is_required=item["required"],
)

for item in spec["output"]:
self.outputs.create_param(
Expand All @@ -111,6 +124,9 @@ def _auto_populate_params(self):

def _auto_set_params(self):
for k, v in self.asset.additional_info["parameters"].items():
if k not in self.inputs:
continue

if isinstance(v, list):
self.inputs[k] = v[0]
else:
Expand Down Expand Up @@ -140,6 +156,11 @@ class BareAsset(AssetNode[BareAssetInputs, BareAssetOutputs]):
pass


class Utility(AssetNode[BareAssetInputs, BareAssetOutputs]):

function = "utilities"


class InputInputs(Inputs):
pass

Expand Down Expand Up @@ -217,7 +238,12 @@ class Output(Node[OutputInputs, OutputOutputs]):
inputs_class: Type[TI] = OutputInputs
outputs_class: Type[TO] = OutputOutputs

def __init__(self, data_types: Optional[List[DataType]] = None, pipeline: "DesignerPipeline" = None, **kwargs):
def __init__(
self,
data_types: Optional[List[DataType]] = None,
pipeline: "DesignerPipeline" = None,
**kwargs
):
super().__init__(pipeline=pipeline, **kwargs)
self.data_types = data_types or []

Expand Down Expand Up @@ -278,7 +304,14 @@ class Route(Serializable):
operation: Operation
type: RouteType

def __init__(self, value: DataType, path: List[Union[Node, int]], operation: Operation, type: RouteType, **kwargs):
def __init__(
self,
value: DataType,
path: List[Union[Node, int]],
operation: Operation,
type: RouteType,
**kwargs
):
"""
Post init method to convert the nodes to node numbers if they are
nodes.
Expand All @@ -294,8 +327,7 @@ def __init__(self, value: DataType, path: List[Union[Node, int]], operation: Ope

# convert nodes to node numbers if they are nodes
self.path = [
node.number if isinstance(node, Node) else node
for node in self.path
node.number if isinstance(node, Node) else node for node in self.path
]

def serialize(self) -> dict:
Expand Down Expand Up @@ -334,7 +366,9 @@ class Router(Node[RouterInputs, RouterOutputs], LinkableMixin):
inputs_class: Type[TI] = RouterInputs
outputs_class: Type[TO] = RouterOutputs

def __init__(self, routes: List[Route], pipeline: "DesignerPipeline" = None, **kwargs):
def __init__(
self, routes: List[Route], pipeline: "DesignerPipeline" = None, **kwargs
):
super().__init__(pipeline=pipeline, **kwargs)
self.routes = routes

Expand Down Expand Up @@ -373,7 +407,9 @@ class Decision(Node[DecisionInputs, DecisionOutputs], LinkableMixin):
inputs_class: Type[TI] = DecisionInputs
outputs_class: Type[TO] = DecisionOutputs

def __init__(self, routes: List[Route], pipeline: "DesignerPipeline" = None, **kwargs):
def __init__(
self, routes: List[Route], pipeline: "DesignerPipeline" = None, **kwargs
):
super().__init__(pipeline=pipeline, **kwargs)
self.routes = routes

Expand Down
48 changes: 43 additions & 5 deletions aixplain/modules/pipeline/designer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,19 @@
from aixplain.enums import DataType

from .base import Serializable, Node, Link
from .nodes import AssetNode, Decision, Script, Input, Output, Router, Route, BareReconstructor, BareSegmentor, BareMetric
from .nodes import (
AssetNode,
Utility,
Decision,
Script,
Input,
Output,
Router,
Route,
BareReconstructor,
BareSegmentor,
BareMetric,
)
from .enums import NodeType, RouteType, Operation
from .mixins import OutputableMixin
from .utils import find_prompt_params
Expand Down Expand Up @@ -141,7 +153,9 @@ def special_prompt_validation(self, node: Node):
node.inputs.text.is_required = False
for match in matches:
if match not in node.inputs:
raise ValueError(f"Param {match} of node {node.label} should be defined and set")
raise ValueError(
f"Param {match} of node {node.label} should be defined and set"
)

def validate_params(self):
"""
Expand All @@ -153,7 +167,9 @@ def validate_params(self):
self.special_prompt_validation(node)
for param in node.inputs:
if param.is_required and not self.is_param_set(node, param):
raise ValueError(f"Param {param.code} of node {node.label} is required")
raise ValueError(
f"Param {param.code} of node {node.label} is required"
)

def validate(self):
"""
Expand All @@ -179,7 +195,11 @@ def get_link(self, from_node: int, to_node: int) -> Link:
:return: the link
"""
return next(
(link for link in self.links if link.from_node == from_node and link.to_node == to_node),
(
link
for link in self.links
if link.from_node == from_node and link.to_node == to_node
),
None,
)

Expand Down Expand Up @@ -225,7 +245,9 @@ def infer_data_type(node):
infer_data_type(self)
infer_data_type(to_node)

def asset(self, asset_id: str, *args, asset_class: Type[T] = AssetNode, **kwargs) -> T:
def asset(
self, asset_id: str, *args, asset_class: Type[T] = AssetNode, **kwargs
) -> T:
"""
Shortcut to create an asset node for the current pipeline.
All params will be passed as keyword arguments to the node
Expand All @@ -236,6 +258,22 @@ def asset(self, asset_id: str, *args, asset_class: Type[T] = AssetNode, **kwargs
"""
return asset_class(asset_id, *args, pipeline=self, **kwargs)

def utility(
self, asset_id: str, *args, asset_class: Type[T] = Utility, **kwargs
) -> T:
"""
Shortcut to create an utility nodes for the current pipeline.
All params will be passed as keyword arguments to the node
constructor.

:param kwargs: keyword arguments
:return: the node
"""
if not issubclass(asset_class, Utility):
raise ValueError("`asset_class` should be a subclass of `Utility` class")

return asset_class(asset_id, *args, pipeline=self, **kwargs)

def decision(self, *args, **kwargs) -> Decision:
"""
Shortcut to create an decision node for the current pipeline.
Expand Down
14 changes: 12 additions & 2 deletions aixplain/modules/pipeline/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from jinja2 import Environment, BaseLoader

from aixplain.utils import config
from aixplain.enums import Function

SEGMENTOR_FUNCTIONS = [
"split-on-linebreak",
Expand Down Expand Up @@ -143,17 +144,26 @@ def populate_specs(functions: list):
"""
function_class_specs = []
for function in functions:
# Utility functions has dynamic input parameters so they are not
# subject to static class generation
if function["id"] == Function.UTILITIES:
continue

# slugify function name by trimming some special chars and
# transforming it to snake case
function_name = function["id"].replace("-", "_").replace("(", "_").replace(")", "_")
function_name = (
function["id"].replace("-", "_").replace("(", "_").replace(")", "_")
)
base_class = "AssetNode"
is_segmentor = function["id"] in SEGMENTOR_FUNCTIONS
is_reconstructor = function["id"] in RECONSTRUCTOR_FUNCTIONS
if is_segmentor:
base_class = "BaseSegmentor"
elif is_reconstructor:
base_class = "BaseReconstructor"
elif "metric" in function_name.split("_"): # noqa: Advise a better distinguisher please
elif "metric" in function_name.split(
"_"
): # noqa: Advise a better distinguisher please
base_class = "BaseMetric"

spec = {
Expand Down
Loading