Skip to content

Commit

Permalink
Allow Artifacts in argument dictionaries
Browse files Browse the repository at this point in the history
Signed-off-by: Elliot Gunton <elliotgunton@gmail.com>
  • Loading branch information
elliotgunton committed Feb 28, 2025
1 parent 3e5df77 commit 3705be3
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 24 deletions.
54 changes: 30 additions & 24 deletions src/hera/workflows/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,34 +579,40 @@ def _build_arguments(self) -> Optional[ModelArguments]:
elif isinstance(self.arguments, ModelArguments):
return self.arguments

def add_argument(k: str, v: Any, result: ModelArguments):
if isinstance(v, Parameter):
value = v.with_name(k).as_argument()
result.parameters = [value] if result.parameters is None else result.parameters + [value]
elif isinstance(v, ModelParameter):
value = Parameter.from_model(v).as_argument()
value.name = k
result.parameters = [value] if result.parameters is None else result.parameters + [value]
elif isinstance(v, ModelArtifact):
copy_art = v.copy(deep=True)
copy_art.name = k
result.artifacts = [copy_art] if result.artifacts is None else result.artifacts + [copy_art]
elif isinstance(v, Artifact):
v = v.with_name(k)
result.artifacts = (
[v._build_artifact()] if result.artifacts is None else result.artifacts + [v._build_artifact()]
)
else:
# POD types are assumed to be parameters, which will be serialised upon creation
value = Parameter(name=k, value=v).as_argument()
result.parameters = [value] if result.parameters is None else result.parameters + [value]

result = ModelArguments()
for arg in self.arguments:
if isinstance(arg, dict):
for k, v in arg.items():
if isinstance(v, Parameter):
value = v.with_name(k).as_argument()
elif isinstance(v, ModelParameter):
value = Parameter.from_model(v).as_argument()
value.name = k
else:
value = Parameter(name=k, value=v).as_argument()

if result.parameters is None:
result.parameters = [value]
else:
result.parameters.append(value)
elif isinstance(arg, ModelArtifact):
result.artifacts = [arg] if result.artifacts is None else result.artifacts + [arg]
elif isinstance(arg, Artifact):
result.artifacts = (
[arg._build_artifact()] if result.artifacts is None else result.artifacts + [arg._build_artifact()]
)
elif isinstance(arg, Parameter):
result.parameters = (
[arg.as_argument()] if result.parameters is None else result.parameters + [arg.as_argument()]
)
elif isinstance(arg, ModelParameter):
result.parameters = [arg] if result.parameters is None else result.parameters + [arg]
add_argument(k, v, result)
elif isinstance(arg, (Parameter, ModelParameter, Artifact, ModelArtifact)):
# name can only be None for Parameters/Artifacts if they have not been
# "built" yet (see the `_check_name` function)
add_argument(arg.name or "", arg, result)
else:
raise ValueError(f"Invalid argument type {type(arg)}")

Check warning on line 614 in src/hera/workflows/_mixins.py

View check run for this annotation

Codecov / codecov/patch

src/hera/workflows/_mixins.py#L614

Added line #L614 was not covered by tests

# returning `None` for `Arguments` means the submission to the server will not even have the
# `arguments` field set, which saves some payload
if result.parameters is None and result.artifacts is None:
Expand Down
206 changes: 206 additions & 0 deletions tests/test_unit/test_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import pytest

from hera.workflows._mixins import ArgumentsMixin
from hera.workflows.artifact import Artifact
from hera.workflows.models import (
Arguments as ModelArguments,
Artifact as ModelArtifact,
Parameter as ModelParameter,
)
from hera.workflows.parameter import Parameter
Expand Down Expand Up @@ -98,3 +100,207 @@ def test_argument_parameter_build(arguments, expected_built_arguments):
)._build_arguments()
== expected_built_arguments
)


model_artifact = ModelArtifact(name="artifact-name", from_="somewhere")


@pytest.mark.parametrize(
"arguments,expected_built_arguments",
(
pytest.param(
ModelArtifact(
name="artifact-name",
from_="somewhere",
),
ModelArguments(
artifacts=[
ModelArtifact(
name="artifact-name",
from_="somewhere",
)
],
),
id="single-model-artifact",
),
pytest.param(
[
ModelArtifact(
name="artifact-name-1",
from_="somewhere",
),
ModelArtifact(
name="artifact-name-2",
from_="somewhere",
),
],
ModelArguments(
artifacts=[
ModelArtifact(
name="artifact-name-1",
from_="somewhere",
),
ModelArtifact(
name="artifact-name-2",
from_="somewhere",
),
],
),
id="list-model-artifacts",
),
pytest.param(
{"a-key": ModelArtifact(name="artifact-name", from_="somewhere")},
ModelArguments(
artifacts=[
ModelArtifact(
name="a-key",
from_="somewhere",
)
],
),
id="model-artifact-in-dict",
),
pytest.param(
[
{
"a-key": model_artifact,
},
model_artifact,
],
ModelArguments(
artifacts=[
ModelArtifact(
name="a-key",
from_="somewhere",
),
ModelArtifact(
name="artifact-name",
from_="somewhere",
),
],
),
id="do-not-rename-original-artifact-object",
),
pytest.param(
{"a-key": Artifact(name="artifact-name", from_="somewhere").with_name("ignore-me")},
ModelArguments(
artifacts=[
ModelArtifact(
name="a-key",
from_="somewhere",
)
],
),
id="hera-artifact-ignore-alt-name",
),
),
)
def test_argument_artifact_build(arguments, expected_built_arguments):
assert (
ArgumentsMixin(
arguments=arguments,
)._build_arguments()
== expected_built_arguments
)


@pytest.mark.parametrize(
"arguments,expected_built_arguments",
(
pytest.param(
None,
None,
id="no-arguments",
),
pytest.param(
ModelArguments(),
ModelArguments(),
id="model-arguments",
),
pytest.param(
[Parameter(name="param-name", value="a-value"), Artifact(name="artifact-name", from_="somewhere")],
ModelArguments(
parameters=[
ModelParameter(
name="param-name",
value="a-value",
)
],
artifacts=[
ModelArtifact(
name="artifact-name",
from_="somewhere",
)
],
),
id="mixed-list",
),
pytest.param(
{"param-name": "a-value", "a-key": ModelArtifact(name="artifact-name", from_="somewhere")},
ModelArguments(
parameters=[
ModelParameter(
name="param-name",
value="a-value",
)
],
artifacts=[
ModelArtifact(
name="a-key",
from_="somewhere",
)
],
),
id="mixed-dict",
),
pytest.param(
[
{"param-name": "a-value"},
ModelArtifact(name="artifact-name", from_="somewhere"),
],
ModelArguments(
parameters=[
ModelParameter(
name="param-name",
value="a-value",
)
],
artifacts=[
ModelArtifact(
name="artifact-name",
from_="somewhere",
)
],
),
id="param-dict-in-list",
),
pytest.param(
[
{"param-name": "a-value"},
{"a-key": Artifact(name="artifact-name", from_="somewhere").with_name("ignore-me")},
],
ModelArguments(
parameters=[
ModelParameter(
name="param-name",
value="a-value",
)
],
artifacts=[
ModelArtifact(
name="a-key",
from_="somewhere",
)
],
),
id="multiple-dicts-in-list",
),
),
)
def test_mixed_arguments_build(arguments, expected_built_arguments):
assert (
ArgumentsMixin(
arguments=arguments,
)._build_arguments()
== expected_built_arguments
)

0 comments on commit 3705be3

Please # to comment.