Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Shorten text repr for DataTree #10139

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Conversation

jsignell
Copy link
Contributor

@jsignell jsignell commented Mar 17, 2025

Demo

import numpy as np
import xarray as xr

number_of_files = 700
number_of_groups = 5
number_of_variables= 10

datasets = {}
for f in range(number_of_files):
    for g in range(number_of_groups):
        # Create random data
        time = np.linspace(0, 50 + f, 1 + 1000 * g)
        y = f * time + g

        # Create dataset:
        ds = xr.Dataset(
            data_vars={
                f"temperature_{g}{i}": ("time", y)
                for i in range(number_of_variables // number_of_groups)
            },
            coords={"time": ("time", time)},
        ).chunk()

        # Prepare for xr.DataTree:
        name = f"file_{f}/group_{g}"
        datasets[name] = ds


dt = xr.DataTree.from_dict(datasets)

%timeit dt._repr_html_()
# 62.6 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit dt.__repr__()
# 3.23 ms ± 66.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Here is what the repr actually looks like:

In [2]: dt
Out[2]: 
<xarray.DataTree>
Group: /
├── Group: /file_0
│   ├── Group: /file_0/group_0
│   │       Dimensions:         (time: 1)
│   │       Coordinates:
│   │         * time            (time) float64 8B 0.0
│   │       Data variables:
│   │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   │           temperature_01  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   ├── Group: /file_0/group_1
│   │       Dimensions:         (time: 1001)
│   │       Coordinates:
│   │         * time            (time) float64 8kB 0.0 0.05 0.1 0.15 ... 49.9 49.95 50.0
│   │       Data variables:
│   │           temperature_10  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   │           temperature_11  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   ├── Group: /file_0/group_2
│   │       Dimensions:         (time: 2001)
│   │       Coordinates:
│   │         * time            (time) float64 16kB 0.0 0.025 0.05 ... 49.95 49.98 50.0
│   │       Data variables:
│   │           temperature_20  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   │           temperature_21  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   ├── Group: /file_0/group_3
│   │       Dimensions:         (time: 3001)
│   │       Coordinates:
│   │         * time            (time) float64 24kB 0.0 0.01667 0.03333 ... 49.97 49.98 50.0
│   │       Data variables:
│   │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   └── Group: /file_0/group_4Dimensions:         (time: 4001)
│           Coordinates:
│             * time            (time) float64 32kB 0.0 0.0125 0.025 ... 49.98 49.99 50.0Data variables:
│               temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
├── Group: /file_1
│   ├── Group: /file_1/group_0
│   │       Dimensions:         (time: 1)
│   │       Coordinates:
│   │         * time            (time) float64 8B 0.0
│   │       Data variables:
│   │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   │           temperature_01  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   ├── Group: /file_1/group_1
│   │       Dimensions:         (time: 1001)
│   │       Coordinates:
│   │         * time            (time) float64 8kB 0.0 0.051 0.102 0.153 ... 50.9 50.95 51.0
│   │       Data variables:
│   │           temperature_10  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   │           temperature_11  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   ├── Group: /file_1/group_2
│   │       Dimensions:         (time: 2001)
│   │       Coordinates:
│   │         * time            (time) float64 16kB 0.0 0.0255 0.051 ... 50.95 50.97 51.0
│   │       Data variables:
│   │           temperature_20  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   │           temperature_21  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   ├── Group: /file_1/group_3
│   │       Dimensions:         (time: 3001)
│   │       Coordinates:
│   │         * time            (time) float64 24kB 0.0 0.017 0.034 ... 50.97 50.98 51.0
│   │       Data variables:
│   │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
...
    └── Group: /file_699/group_4
            Dimensions:         (time: 4001)
            Coordinates:
              * time            (time) float64 32kB 0.0 0.1872 0.3745 ... 748.6 748.8 749.0
            Data variables:
                temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
                temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>

Here is the html repr:

image

Notes

This logic is configurable via xr.set_options(display_max_rows=) I decide to reuse that option rather than making a new one since it felt functionally kind of similar to limiting the number of variables or coords shown on a dataset (which is what display_max_rows is otherwise used for). I can imagine making a new display_max_children setting though and I'd likely toggle the default 12 down to 6 or something.

I wasn't quite sure what the right approach was so I tried two different ideas. It might make sense to consolidate on one approach though

In the text repr:

  1. The max is applied to the total number of nodes, so the overall size of the output repr is fixed at n nodes.
  2. The un-repred nodes are replaced with a "..."
  3. The final node is always displayed

In the html repr:

  1. The max is applied to the number of children, so any given node can have up to n children
  2. There is a note about how not all nodes are displayed

Checklist

  • Closes #10052
  • Tests added
  • User visible changes (including notable bug fixes) are documented in whats-new.rst

* only show max 12 children in text ``DataTree`` repr
* use ``set_options(display_max_rows=12)`` to configure this setting
* always include last item in ``DataTree``
* insert "..." before last item in ``DataTree``
@jsignell jsignell marked this pull request as ready for review March 18, 2025 14:35
@TomNicholas TomNicholas added the topic-DataTree Related to the implementation of a DataTree class label Mar 18, 2025
@TomNicholas TomNicholas requested a review from benbovy March 18, 2025 19:32
@TomNicholas
Copy link
Member

TomNicholas commented Mar 18, 2025

Thank you for working on this @jsignell !

%timeit dt._repr_html_()
# 62.6 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit dt.__repr__()
# 3.23 ms ± 66.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)

This seems good! But how much better is it than before?

There is a note about how not all nodes are displayed

This seems a little ugly from a UI perspective.

@benbovy
Copy link
Member

benbovy commented Mar 19, 2025

Looks nice @jsignell! Some thoughts:

I decide to reuse that option rather than making a new one since it felt functionally kind of similar to limiting the number of variables or coords shown on a dataset (which is what display_max_rows is otherwise used for). I can imagine making a new display_max_children setting though and I'd likely toggle the default 12 down to 6 or something.

Same feelings and no strong opinion on this. I'm slightly leaning towards a separate option, though, in order to avoid constantly toggling the value of display_max_rows when working with both dataset and datatree objects.

The final node is always displayed

Could we instead compute the location of the "..." placeholder so that it is more in the "middle", likely showing more than one trailing node? It is probably better from a UX perspective (and more consistent with other reprs such as Dataset and pandas.DataFrame)

It might make sense to consolidate on one approach though

I agree. Maybe adding "..." placeholder rows in the html repr too and reuse the same logic for computing their location?

There is a note about how not all nodes are displayed

Would it be easy to customize the number of items in the section title rather than adding a full note? E.g., something like:

 ▶ Groups:  (700 / 12 displayed)

or just

 ▶ Groups:  (12/700)

@benbovy
Copy link
Member

benbovy commented Mar 19, 2025

I find that in the text repr the "..." placeholder could be a little more visible. E.g, something like:

│   ├── Group: /file_1/group_3
│   │       Dimensions:         (time: 3001)
│   │       Coordinates:
│   │         * time            (time) float64 24kB 0.0 0.017 0.034 ... 50.97 50.98 51.0
│   │       Data variables:
│   │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │ 
... ...
    └── Group: /file_699/group_4
            Dimensions:         (time: 4001)
            Coordinates:
              * time            (time) float64 32kB 0.0 0.1872 0.3745 ... 748.6 748.8 749.0
            Data variables:
                temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
                temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>

What I also find disturbing in this example is that /file_699/group_4 looks like it is in the same parent group than
/file_1/group_3.

I don't know how easy / hard it can be implemented, but it would be nice if we keep showing all hierarchical levels in the truncated repr, e.g.,

│   ├── Group: /file_1/group_3
│   │       Dimensions:         (time: 3001)
│   │       Coordinates:
│   │         * time            (time) float64 24kB 0.0 0.017 0.034 ... 50.97 50.98 51.0
│   │       Data variables:
│   │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │ 
... ...
└── Group: /file_699
    ...
    └── Group: /file_699/group_4
            Dimensions:         (time: 4001)
            Coordinates:
              * time            (time) float64 32kB 0.0 0.1872 0.3745 ... 748.6 748.8 749.0
            Data variables:
                temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
                temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>

Or if possible truncate the tree at smart locations such that it doesn't "break" the hierarchy:

<xarray.DataTree>
Group: /
├── Group: /file_0
│   ├── Group: /file_0/group_0
│   │       Dimensions:         (time: 1)
│   │       Coordinates:
│   │         * time            (time) float64 8B 0.0
│   │       Data variables:
│   │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   │           temperature_01  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   ├── Group: /file_0/group_1
│   │       Dimensions:         (time: 1001)
│   │       Coordinates:
│   │         * time            (time) float64 8kB 0.0 0.05 0.1 0.15 ... 49.9 49.95 50.0
│   │       Data variables:
│   │           temperature_10  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   │           temperature_11  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
│   ├── Group: /file_0/group_2
│   │       Dimensions:         (time: 2001)
│   │       Coordinates:
│   │         * time            (time) float64 16kB 0.0 0.025 0.05 ... 49.95 49.98 50.0
│   │       Data variables:
│   │           temperature_20  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   │           temperature_21  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
│   ├── Group: /file_0/group_3
│   │       Dimensions:         (time: 3001)
│   │       Coordinates:
│   │         * time            (time) float64 24kB 0.0 0.01667 0.03333 ... 49.97 49.98 50.0
│   │       Data variables:
│   │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
│   └── Group: /file_0/group_4
│           Dimensions:         (time: 4001)
│           Coordinates:
│             * time            (time) float64 32kB 0.0 0.0125 0.025 ... 49.98 49.99 50.0
│           Data variables:
│               temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
│               temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
│
...
│
└── Group: /file_699
    ├── Group: /file_699/group_0
    │       Dimensions:         (time: 1)
    │       Coordinates:
    │         * time            (time) float64 8B 0.0
    │       Data variables:
    │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
    │           temperature_01  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
    ├── Group: /file_699/group_1
    │       Dimensions:         (time: 1001)
    │       Coordinates:
    │         * time            (time) float64 8kB 0.0 0.051 0.102 0.153 ... 50.9 50.95 51.0
    │       Data variables:
    │           temperature_10  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
    │           temperature_11  (time) float64 8kB dask.array<chunksize=(1001,), meta=np.ndarray>
    ├── Group: /file_699/group_2
    │       Dimensions:         (time: 2001)
    │       Coordinates:
    │         * time            (time) float64 16kB 0.0 0.0255 0.051 ... 50.95 50.97 51.0
    │       Data variables:
    │           temperature_20  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
    │           temperature_21  (time) float64 16kB dask.array<chunksize=(2001,), meta=np.ndarray>
    ├── Group: /file_699/group_3
    │       Dimensions:         (time: 3001)
    │       Coordinates:
    │         * time            (time) float64 24kB 0.0 0.017 0.034 ... 50.97 50.98 51.0
    │       Data variables:
    │           temperature_30  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
    │           temperature_31  (time) float64 24kB dask.array<chunksize=(3001,), meta=np.ndarray>
    └── Group: /file_699/group_4
            Dimensions:         (time: 4001)
            Coordinates:
              * time            (time) float64 32kB 0.0 0.1872 0.3745 ... 748.6 748.8 749.0
            Data variables:
                temperature_40  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>
                temperature_41  (time) float64 32kB dask.array<chunksize=(4001,), meta=np.ndarray>

@Illviljan
Copy link
Contributor

Illviljan commented Mar 19, 2025

import numpy as np
import xarray as xr

number_of_files = 700
number_of_groups = 5
number_of_variables= 10

datasets = {}
for f in range(number_of_files):
    for g in range(number_of_groups):
        # Create random data
        time = np.linspace(0, 50 + f, 1 + 1000 * g)
        y = f * time + g

        # Create dataset:
        ds = xr.Dataset(
            data_vars={
                f"temperature_{g}{i}": ("time", y)
                for i in range(number_of_variables // number_of_groups)
            },
            coords={"time": ("time", time)},
        ).chunk()

        # Prepare for xr.DataTree:
        name = f"file_{f}/group_{g}"
        datasets[name] = ds


dt = xr.DataTree.from_dict(datasets)
%timeit dt._repr_html_()
# 32.9 s ± 797 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (main)
# 540 ms ± 3.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (PR)

%timeit dt.__repr__()
# 2.47 s ± 24 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (main)
# 7.55 ms ± 29.7 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) (PR)

Quite the improvement for me, thank you @jsignell !

@jsignell
Copy link
Contributor Author

Thank you all for looking this over! Yeah I should have made it more clear what the improvement is over main. Like @Illviljan mentioned it's ~300x faster for the text repr and ~60x faster for the html repr for this case. The difference in speed up might be partially due to the different logic. In the text repr I am limiting the display to 12 nodes and in the html I am limiting each node to 12 children . The computation time will depend on which of those choices we converge on. The node limit should be more of a fixed time and the children limit will depend more on the shape of the tree.

slightly leaning towards a separate option, though, in order to avoid constantly toggling the value of display_max_rows when working with both dataset and datatree objects.

Yeah I agree that feels like the right call.

The final node is always displayed

Could we instead compute the location of the "..." placeholder so that it is more in the "middle", likely showing more than one trailing node? It is probably better from a UX perspective (and more consistent with other reprs such as Dataset and pandas.DataFrame)

Yeah I like that idea.

It might make sense to consolidate on one approach though

I agree. Maybe adding "..." placeholder rows in the html repr too and reuse the same logic for computing their location?

👍

There is a note about how not all nodes are displayed

Would it be easy to customize the number of items in the section title rather than adding a full note?

Yep I can change that.

Or if possible truncate the tree at smart locations such that it doesn't "break" the hierarchy

I like that idea a lot. I'll try to get that idea working.

jsignell added 5 commits April 7, 2025 17:03
* create new option to control max number of children
* consolidate logic to counting number of children per node
* show first n/2 children then ... then last n/2 children
@jsignell
Copy link
Contributor Author

jsignell commented Apr 7, 2025

Ok! I converged on limiting the number of children per-node rather than limiting the total number of nodes. It felt like that would make it simpler to always show the starting and ending children in a given node. I also made a new option with a default of 6 xr.set_option(display_max_children=6)

I also added a "..." to the html repr:
Screenshot from 2025-04-07 16-59-01

So this now corresponds much more closely with the text repr:

<xarray.DataTree>
Group: /
├── Group: /file_0
│   ├── Group: /file_0/group_0
│   │       Dimensions:         (time: 1)
│   │       Coordinates:
│   │         * time            (time) float64 8B 0.0
│   │       Data variables:
│   │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
│   ...
│   └── Group: /file_0/group_6
│           Dimensions:         (time: 6001)
│           Coordinates:
│             * time            (time) float64 48kB 0.0 0.008333 0.01667 ... 49.99 50.0
│           Data variables:
│               temperature_60  (time) float64 48kB dask.array<chunksize=(6001,), meta=np.ndarray>
...
└── Group: /file_699
    ├── Group: /file_699/group_0
    │       Dimensions:         (time: 1)
    │       Coordinates:
    │         * time            (time) float64 8B 0.0
    │       Data variables:
    │           temperature_00  (time) float64 8B dask.array<chunksize=(1,), meta=np.ndarray>
    ...
    └── Group: /file_699/group_6
            Dimensions:         (time: 6001)
            Coordinates:
              * time            (time) float64 48kB 0.0 0.1248 0.2497 ... 748.8 748.9 749.0
            Data variables:
                temperature_60  (time) float64 48kB dask.array<chunksize=(6001,), meta=np.ndarray>

With this new version both versions are ~100x faster with this dummy datatree:

%timeit dt._repr_html_()
# 3.65 s ± 20.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (main)
# 30.8 ms ± 493 μs per loop (mean ± std. dev. of 7 runs, 10 loops each) (PR)

%timeit dt.__repr__()
# 1.05 s ± 30.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) (main)
# 9.34 ms ± 79 μs per loop (mean ± std. dev. of 7 runs, 100 loops each) (PR)

@Illviljan Illviljan mentioned this pull request Apr 9, 2025
@Illviljan Illviljan added the run-benchmark Run the ASV benchmark workflow label Apr 9, 2025
@@ -222,6 +226,8 @@ class set_options:
* ``True`` : to always expand indexes
* ``False`` : to always collapse indexes
* ``default`` : to expand unless over a pre-defined limit (always collapse for html style)
display_max_children : int, default: 6
Maximum number of children to display for each node in a DataTree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Maximum number of children to display for each node in a DataTree
Maximum number of children to display for each node in a DataTree.

children_html = []
for i, (n, c) in enumerate(children.items()):
if (
MAX_CHILDREN is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't MAX_CHILDREN always an int? It's defined like that in OPTIONS.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was thinking originally that it should be nullable but I didn't see any other nullable options so then I was worried that that might be an antipattern. The case for null is to allow arbitrarily large trees, which obviously has performance implications that triggered this work in the first place. All that is to say: yes. It is always an int. I'll change the code to make that explicit.

@@ -320,6 +320,52 @@ def test_two_children(
)


class TestDataTreeTruncatesNodes:
def test_many_nodes(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def test_many_nodes(self):
def test_many_nodes(self) -> None:

# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
run-benchmark Run the ASV benchmark workflow topic-DataTree Related to the implementation of a DataTree class topic-html-repr
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants