Skip to content

Shorten text repr for DataTree #10139

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
44 changes: 40 additions & 4 deletions xarray/core/datatree_render.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from collections import namedtuple
from collections.abc import Iterable, Iterator
from math import ceil
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(
style=None,
childiter: type = list,
maxlevel: int | None = None,
maxchildren: int | None = None,
):
"""
Render tree starting at `node`.
Expand All @@ -88,6 +90,7 @@ def __init__(
Iterables that change the order of children cannot be used
(e.g., `reversed`).
maxlevel: Limit rendering to this depth.
maxchildren: Limit number of children at each node.
:any:`RenderDataTree` is an iterator, returning a tuple with 3 items:
`pre`
tree prefix.
Expand Down Expand Up @@ -160,6 +163,16 @@ def __init__(
root
├── sub0
└── sub1

# `maxchildren` limits the number of children per node

>>> print(RenderDataTree(root, maxchildren=1).by_attr("name"))
root
├── sub0
│ ├── sub0B
│ ...
...

"""
if style is None:
style = ContStyle()
Expand All @@ -169,24 +182,44 @@ def __init__(
self.style = style
self.childiter = childiter
self.maxlevel = maxlevel
self.maxchildren = maxchildren

def __iter__(self) -> Iterator[Row]:
return self.__next(self.node, tuple())

def __next(
self, node: DataTree, continues: tuple[bool, ...], level: int = 0
self,
node: DataTree,
continues: tuple[bool, ...],
level: int = 0,
) -> Iterator[Row]:
yield RenderDataTree.__item(node, continues, self.style)
children = node.children.values()
level += 1
if children and (self.maxlevel is None or level < self.maxlevel):
nchildren = len(children)
children = self.childiter(children)
for child, is_last in _is_last(children):
yield from self.__next(child, continues + (not is_last,), level=level)
for i, (child, is_last) in enumerate(_is_last(children)):
if (
self.maxchildren is None
or i < ceil(self.maxchildren / 2)
or i >= ceil(nchildren - self.maxchildren / 2)
):
yield from self.__next(
child,
continues + (not is_last,),
level=level,
)
if (
self.maxchildren is not None
and nchildren > self.maxchildren
and i == ceil(self.maxchildren / 2)
):
yield RenderDataTree.__item("...", continues, self.style)

@staticmethod
def __item(
node: DataTree, continues: tuple[bool, ...], style: AbstractStyle
node: DataTree | str, continues: tuple[bool, ...], style: AbstractStyle
) -> Row:
if not continues:
return Row("", "", node)
Expand Down Expand Up @@ -244,6 +277,9 @@ def by_attr(self, attrname: str = "name") -> str:

def get() -> Iterator[str]:
for pre, fill, node in self:
if isinstance(node, str):
yield f"{fill}{node}"
continue
attr = (
attrname(node)
if callable(attrname)
Expand Down
9 changes: 8 additions & 1 deletion xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,14 +1137,21 @@ def _datatree_node_repr(node: DataTree, show_inherited: bool) -> str:

def datatree_repr(dt: DataTree) -> str:
"""A printable representation of the structure of this entire tree."""
renderer = RenderDataTree(dt)
max_children = OPTIONS["display_max_children"]

renderer = RenderDataTree(dt, maxchildren=max_children)

name_info = "" if dt.name is None else f" {dt.name!r}"
header = f"<xarray.DataTree{name_info}>"

lines = [header]
show_inherited = True

for pre, fill, node in renderer:
if isinstance(node, str):
lines.append(f"{fill}{node}")
continue

node_repr = _datatree_node_repr(node, show_inherited=show_inherited)
show_inherited = False # only show inherited coords on the root

Expand Down
50 changes: 31 additions & 19 deletions xarray/core/formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@
from functools import lru_cache, partial
from html import escape
from importlib.resources import files
from typing import TYPE_CHECKING
from math import ceil
from typing import TYPE_CHECKING, Literal

from xarray.core.formatting import (
inherited_vars,
inline_index_repr,
inline_variable_array_repr,
short_data_repr,
)
from xarray.core.options import _get_boolean_with_default
from xarray.core.options import OPTIONS, _get_boolean_with_default

STATIC_FILES = (
("xarray.static.html", "icons-svg-inline.html"),
Expand Down Expand Up @@ -192,16 +193,29 @@ def collapsible_section(


def _mapping_section(
mapping, name, details_func, max_items_collapse, expand_option_name, enabled=True
mapping,
name,
details_func,
max_items_collapse,
expand_option_name,
enabled=True,
max_option_name: Literal["display_max_children"] | None = None,
) -> str:
n_items = len(mapping)
expanded = _get_boolean_with_default(
expand_option_name, n_items < max_items_collapse
)
collapsed = not expanded

inline_details = ""
if max_option_name and max_option_name in OPTIONS:
max_items = int(OPTIONS[max_option_name])
if n_items > max_items:
inline_details = f"({max_items}/{n_items})"

return collapsible_section(
name,
inline_details=inline_details,
details=details_func(mapping),
n_items=n_items,
enabled=enabled,
Expand Down Expand Up @@ -348,26 +362,23 @@ def dataset_repr(ds) -> str:


def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
N_CHILDREN = len(children) - 1

# Get result from datatree_node_repr and wrap it
lines_callback = lambda n, c, end: _wrap_datatree_repr(
datatree_node_repr(n, c), end=end
)

children_html = "".join(
(
lines_callback(n, c, end=False) # Long lines
if i < N_CHILDREN
else lines_callback(n, c, end=True)
) # Short lines
for i, (n, c) in enumerate(children.items())
)
MAX_CHILDREN = OPTIONS["display_max_children"]
n_children = len(children)

children_html = []
for i, (n, c) in enumerate(children.items()):
if i < ceil(MAX_CHILDREN / 2) or i >= ceil(n_children - MAX_CHILDREN / 2):
is_last = i == (n_children - 1)
children_html.append(
_wrap_datatree_repr(datatree_node_repr(n, c), end=is_last)
)
elif n_children > MAX_CHILDREN and i == ceil(MAX_CHILDREN / 2):
children_html.append("<div>...</div>")

return "".join(
[
"<div style='display: inline-grid; grid-template-columns: 100%; grid-column: 1 / -1'>",
children_html,
"".join(children_html),
"</div>",
]
)
Expand All @@ -378,6 +389,7 @@ def summarize_datatree_children(children: Mapping[str, DataTree]) -> str:
name="Groups",
details_func=summarize_datatree_children,
max_items_collapse=1,
max_option_name="display_max_children",
expand_option_name="display_expand_groups",
)

Expand Down
6 changes: 6 additions & 0 deletions xarray/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"chunk_manager",
"cmap_divergent",
"cmap_sequential",
"display_max_children",
"display_max_rows",
"display_values_threshold",
"display_style",
Expand Down Expand Up @@ -40,6 +41,7 @@ class T_Options(TypedDict):
chunk_manager: str
cmap_divergent: str | Colormap
cmap_sequential: str | Colormap
display_max_children: int
display_max_rows: int
display_values_threshold: int
display_style: Literal["text", "html"]
Expand Down Expand Up @@ -67,6 +69,7 @@ class T_Options(TypedDict):
"chunk_manager": "dask",
"cmap_divergent": "RdBu_r",
"cmap_sequential": "viridis",
"display_max_children": 6,
"display_max_rows": 12,
"display_values_threshold": 200,
"display_style": "html",
Expand Down Expand Up @@ -99,6 +102,7 @@ def _positive_integer(value: Any) -> bool:
_VALIDATORS = {
"arithmetic_broadcast": lambda value: isinstance(value, bool),
"arithmetic_join": _JOIN_OPTIONS.__contains__,
"display_max_children": _positive_integer,
"display_max_rows": _positive_integer,
"display_values_threshold": _positive_integer,
"display_style": _DISPLAY_OPTIONS.__contains__,
Expand Down Expand Up @@ -222,6 +226,8 @@ class set_options:
* ``True`` : to always expand indexes
* ``False`` : to always collapse indexes
* ``default`` : to expand unless over a pre-defined limit (always collapse for html style)
display_max_children : int, default: 6
Maximum number of children to display for each node in a DataTree.
display_max_rows : int, default: 12
Maximum display rows.
display_values_threshold : int, default: 200
Expand Down
70 changes: 70 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,76 @@ def test_repr_two_children(self) -> None:
).strip()
assert result == expected

def test_repr_truncates_nodes(self) -> None:
# construct a datatree with 50 nodes
number_of_files = 10
number_of_groups = 5
tree_dict = {}
for f in range(number_of_files):
for g in range(number_of_groups):
tree_dict[f"file_{f}/group_{g}"] = Dataset({"g": f * g})

tree = DataTree.from_dict(tree_dict)
with xr.set_options(display_max_children=3):
result = repr(tree)

expected = dedent(
"""
<xarray.DataTree>
Group: /
├── Group: /file_0
│ ├── Group: /file_0/group_0
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ├── Group: /file_0/group_1
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ...
│ └── Group: /file_0/group_4
│ Dimensions: ()
│ Data variables:
│ g int64 8B 0
├── Group: /file_1
│ ├── Group: /file_1/group_0
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 0
│ ├── Group: /file_1/group_1
│ │ Dimensions: ()
│ │ Data variables:
│ │ g int64 8B 1
│ ...
│ └── Group: /file_1/group_4
│ Dimensions: ()
│ Data variables:
│ g int64 8B 4
...
└── Group: /file_9
├── Group: /file_9/group_0
│ Dimensions: ()
│ Data variables:
│ g int64 8B 0
├── Group: /file_9/group_1
│ Dimensions: ()
│ Data variables:
│ g int64 8B 9
...
└── Group: /file_9/group_4
Dimensions: ()
Data variables:
g int64 8B 36
"""
).strip()
assert expected == result

with xr.set_options(display_max_children=10):
result = repr(tree)

for key in tree_dict:
assert key in result

def test_repr_inherited_dims(self) -> None:
tree = DataTree.from_dict(
{
Expand Down
46 changes: 46 additions & 0 deletions xarray/tests/test_formatting_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,52 @@ def test_two_children(
)


class TestDataTreeTruncatesNodes:
def test_many_nodes(self) -> None:
# construct a datatree with 500 nodes
number_of_files = 20
number_of_groups = 25
tree_dict = {}
for f in range(number_of_files):
for g in range(number_of_groups):
tree_dict[f"file_{f}/group_{g}"] = xr.Dataset({"g": f * g})

tree = xr.DataTree.from_dict(tree_dict)
with xr.set_options(display_style="html"):
result = tree._repr_html_()

assert "6/20" in result
for i in range(number_of_files):
if i < 3 or i >= (number_of_files - 3):
assert f"file_{i}</div>" in result
else:
assert f"file_{i}</div>" not in result

assert "6/25" in result
for i in range(number_of_groups):
if i < 3 or i >= (number_of_groups - 3):
assert f"group_{i}</div>" in result
else:
assert f"group_{i}</div>" not in result

with xr.set_options(display_style="html", display_max_children=3):
result = tree._repr_html_()

assert "3/20" in result
for i in range(number_of_files):
if i < 2 or i >= (number_of_files - 1):
assert f"file_{i}</div>" in result
else:
assert f"file_{i}</div>" not in result

assert "3/25" in result
for i in range(number_of_groups):
if i < 2 or i >= (number_of_groups - 1):
assert f"group_{i}</div>" in result
else:
assert f"group_{i}</div>" not in result


class TestDataTreeInheritance:
def test_inherited_section_present(self) -> None:
dt = xr.DataTree.from_dict(
Expand Down
Loading