diff --git a/.gitignore b/.gitignore index a8b996c0..3aa9dd8a 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,4 @@ dmypy.json tests/work /tests/**/*.png /tests/**/*txt -.vscode +.vscode/ diff --git a/aiida_workgraph/task.py b/aiida_workgraph/task.py index cc4e421b..326f2b50 100644 --- a/aiida_workgraph/task.py +++ b/aiida_workgraph/task.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from node_graph.node import Node as GraphNode from aiida_workgraph import USE_WIDGET from aiida_workgraph.properties import property_pool @@ -56,6 +58,7 @@ def __init__( self._widget = None self.state = "PLANNED" self.action = "" + self.show_socket_depth = 0 def to_dict(self, short: bool = False) -> Dict[str, Any]: from aiida.orm.utils.serialize import serialize @@ -174,18 +177,22 @@ def _repr_mimebundle_(self, *args: Any, **kwargs: Any) -> any: print(WIDGET_INSTALLATION_MESSAGE) return # if ipywdigets > 8.0.0, use _repr_mimebundle_ instead of _ipython_display_ - self._widget.from_node(self) + self._widget.from_node(self, show_socket_depth=self.show_socket_depth) if hasattr(self._widget, "_repr_mimebundle_"): return self._widget._repr_mimebundle_(*args, **kwargs) else: return self._widget._ipython_display_(*args, **kwargs) - def to_html(self, output: str = None, **kwargs): + def to_html( + self, output: str = None, show_socket_depth: Optional[int] = None, **kwargs + ): """Write a standalone html file to visualize the task.""" + if show_socket_depth is None: + show_socket_depth = self.show_socket_depth if self._widget is None: print(WIDGET_INSTALLATION_MESSAGE) return - self._widget.from_node(self) + self._widget.from_node(node=self, show_socket_depth=show_socket_depth) return self._widget.to_html(output=output, **kwargs) diff --git a/aiida_workgraph/utils/__init__.py b/aiida_workgraph/utils/__init__.py index f57d4c2f..bee8cf6a 100644 --- a/aiida_workgraph/utils/__init__.py +++ b/aiida_workgraph/utils/__init__.py @@ -651,3 +651,25 @@ def validate_task_inout(inout_list: list[str | dict], list_type: str) -> list[di processed_inout_list.append(item) return processed_inout_list + + +def filter_keys_namespace_depth( + dict_: dict[Any, Any], max_depth: int = 0 +) -> dict[Any, Any]: + """ + Filter top-level keys of a dictionary based on the namespace nesting level (number of periods) in the key. + + :param dict dict_: The dictionary to filter. + :param int max_depth: Maximum depth of namespaces to retain (number of periods). + :return: The filtered dictionary with only keys satisfying the depth condition. + :rtype: dict + """ + result: dict[Any, Any] = {} + + for key, value in dict_.items(): + depth = key.count(".") + + if depth <= max_depth: + result[key] = value + + return result diff --git a/aiida_workgraph/widget/src/widget/__init__.py b/aiida_workgraph/widget/src/widget/__init__.py index 65a1405b..a5908534 100644 --- a/aiida_workgraph/widget/src/widget/__init__.py +++ b/aiida_workgraph/widget/src/widget/__init__.py @@ -5,6 +5,7 @@ import anywidget import traitlets from .utils import wait_to_link +from aiida_workgraph.utils import filter_keys_namespace_depth try: __version__ = importlib.metadata.version("widget") @@ -37,16 +38,22 @@ def from_workgraph(self, workgraph: Any) -> None: wgdata = workgraph_to_short_json(wgdata) self.value = wgdata - def from_node(self, node: Any) -> None: + def from_node(self, node: Any, show_socket_depth: int = 0) -> None: + tdata = node.to_dict() - tdata.pop("properties", None) - tdata.pop("executor", None) - tdata.pop("node_class", None) - tdata.pop("process", None) - tdata["label"] = tdata["identifier"] + + # Remove certain elements of the dict-representation of the Node that we don't want to show + for key in ("properties", "executor", "node_class", "process"): + tdata.pop(key, None) for input in tdata["inputs"].values(): input.pop("property") - tdata["inputs"] = list(tdata["inputs"].values()) + + tdata["label"] = tdata["identifier"] + + filtered_inputs = filter_keys_namespace_depth( + dict_=tdata["inputs"], max_depth=show_socket_depth + ) + tdata["inputs"] = list(filtered_inputs.values()) tdata["outputs"] = list(tdata["outputs"].values()) wgdata = {"name": node.name, "nodes": {node.name: tdata}, "links": []} self.value = wgdata