Skip to content

DRAFT: Optimize Grid.isel() #1175

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions test/test_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,30 +149,22 @@ def test_inverse_indices():

# Test nearest neighbor subsetting
coord = [0, 0]
subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True)
subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers")

assert subset.inverse_indices is not None
assert subset._ds['source_face_indices'] is not None

# Test bounding box subsetting
box = [(-10, 10), (-10, 10)]
subset = grid.subset.bounding_box(box[0], box[1], inverse_indices=True)
subset = grid.subset.bounding_box(box[0], box[1])

assert subset.inverse_indices is not None
assert subset._ds['source_face_indices'] is not None

# Test bounding circle subsetting
center_coord = [0, 0]
subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers", inverse_indices=True)
subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers")

assert subset.inverse_indices is not None
assert subset._ds['source_face_indices'] is not None

# Ensure code raises exceptions when the element is edges or nodes or inverse_indices is incorrect
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="edge centers", inverse_indices=True)
assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="nodes", inverse_indices=True)
assert pytest.raises(ValueError, grid.subset.bounding_circle, center_coord, r=10, element="face center", inverse_indices=(['not right'], True))

# Test isel directly
subset = grid.isel(n_face=[1], inverse_indices=True)
assert subset.inverse_indices.face.values == 1


def test_da_subset():
Expand Down
21 changes: 8 additions & 13 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,7 @@ def _edge_centered(self) -> bool:
"n_edge" dimension)"""
return "n_edge" in self.dims

def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs):
def isel(self, ignore_grid=False, *args, **kwargs):
"""Grid-informed implementation of xarray's ``isel`` method, which
enables indexing across grid dimensions.

Expand Down Expand Up @@ -1210,17 +1210,11 @@ def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs):
raise ValueError("Only one grid dimension can be sliced at a time")

if "n_node" in kwargs:
sliced_grid = self.uxgrid.isel(
n_node=kwargs["n_node"], inverse_indices=inverse_indices
)
sliced_grid = self.uxgrid.isel(n_node=kwargs["n_node"])
elif "n_edge" in kwargs:
sliced_grid = self.uxgrid.isel(
n_edge=kwargs["n_edge"], inverse_indices=inverse_indices
)
sliced_grid = self.uxgrid.isel(n_edge=kwargs["n_edge"])
else:
sliced_grid = self.uxgrid.isel(
n_face=kwargs["n_face"], inverse_indices=inverse_indices
)
sliced_grid = self.uxgrid.isel(n_face=kwargs["n_face"])

return self._slice_from_grid(sliced_grid)

Expand Down Expand Up @@ -1255,23 +1249,24 @@ def from_xarray(cls, da: xr.DataArray, uxgrid: Grid, ugrid_dims: dict = None):

return cls(ds, uxgrid=uxgrid)

# TODO:
def _slice_from_grid(self, sliced_grid):
"""Slices a ``UxDataArray`` from a sliced ``Grid``, using cached
indices to correctly slice the data variable."""

if self._face_centered():
da_sliced = self.isel(
n_face=sliced_grid._ds["subgrid_face_indices"], ignore_grid=True
n_face=sliced_grid._ds["source_face_indices"], ignore_grid=True
)

elif self._edge_centered():
da_sliced = self.isel(
n_edge=sliced_grid._ds["subgrid_edge_indices"], ignore_grid=True
n_edge=sliced_grid._ds["source_edge_indices"], ignore_grid=True
)

elif self._node_centered():
da_sliced = self.isel(
n_node=sliced_grid._ds["subgrid_node_indices"], ignore_grid=True
n_node=sliced_grid._ds["source_node_indices"], ignore_grid=True
)

else:
Expand Down
40 changes: 7 additions & 33 deletions uxarray/cross_sections/dataarray_accessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import TYPE_CHECKING, Union, List, Set, Tuple
from typing import TYPE_CHECKING, Tuple

if TYPE_CHECKING:
pass
Expand All @@ -24,9 +24,7 @@ def __repr__(self):

return prefix + methods_heading

def constant_latitude(
self, lat: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
def constant_latitude(self, lat: float):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant latitude.

Expand All @@ -35,11 +33,6 @@ def constant_latitude(
lat : float
The latitude at which to extract the cross-section, in degrees.
Must be between -90.0 and 90.0
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -69,11 +62,9 @@ def constant_latitude(

faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat)

return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)
return self.uxda.isel(n_face=faces)

def constant_longitude(
self, lon: float, inverse_indices: Union[List[str], Set[str], bool] = False
):
def constant_longitude(self, lon: float):
"""Extracts a cross-section of the data array by selecting all faces that
intersect with a specified line of constant longitude.

Expand All @@ -82,11 +73,6 @@ def constant_longitude(
lon : float
The latitude at which to extract the cross-section, in degrees.
Must be between -180.0 and 180.0
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -118,12 +104,11 @@ def constant_longitude(
lon,
)

return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)
return self.uxda.isel(n_face=faces)

def constant_latitude_interval(
self,
lats: Tuple[float, float],
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of data by selecting all faces that
are within a specified latitude interval.
Expand All @@ -133,11 +118,6 @@ def constant_latitude_interval(
lats : Tuple[float, float]
The latitude interval (min_lat, max_lat) at which to extract the cross-section,
in degrees. Values must be between -90.0 and 90.0
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand All @@ -164,12 +144,11 @@ def constant_latitude_interval(
"""
faces = self.uxda.uxgrid.get_faces_between_latitudes(lats)

return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)
return self.uxda.isel(n_face=faces)

def constant_longitude_interval(
self,
lons: Tuple[float, float],
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of data by selecting all faces are within a specifed longitude interval.

Expand All @@ -178,11 +157,6 @@ def constant_longitude_interval(
lons : Tuple[float, float]
The longitude interval (min_lon, max_lon) at which to extract the cross-section,
in degrees. Values must be between -180.0 and 180.0
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand All @@ -209,4 +183,4 @@ def constant_longitude_interval(
"""
faces = self.uxda.uxgrid.get_faces_between_longitudes(lons)

return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices)
return self.uxda.isel(n_face=faces)
42 changes: 5 additions & 37 deletions uxarray/cross_sections/grid_accessor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Union, List, Set, Tuple
from typing import TYPE_CHECKING, Tuple

if TYPE_CHECKING:
from uxarray.grid import Grid
Expand Down Expand Up @@ -30,7 +30,6 @@ def constant_latitude(
self,
lat: float,
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant latitude.
Expand All @@ -43,11 +42,6 @@ def constant_latitude(
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the
line of constant latitude.
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -84,9 +78,7 @@ def constant_latitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lat={lat}.")

grid_at_constant_lat = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)
grid_at_constant_lat = self.uxgrid.isel(n_face=faces)

if return_face_indices:
return grid_at_constant_lat, faces
Expand All @@ -97,7 +89,6 @@ def constant_longitude(
self,
lon: float,
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
intersect with a specified line of constant longitude.
Expand All @@ -110,11 +101,6 @@ def constant_longitude(
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the
line of constant longitude.
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -150,9 +136,7 @@ def constant_longitude(
if len(faces) == 0:
raise ValueError(f"No intersections found at lon={lon}")

grid_at_constant_lon = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)
grid_at_constant_lon = self.uxgrid.isel(n_face=faces)

if return_face_indices:
return grid_at_constant_lon, faces
Expand All @@ -163,7 +147,6 @@ def constant_latitude_interval(
self,
lats: Tuple[float, float],
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces that
are within a specified latitude interval.
Expand All @@ -176,11 +159,6 @@ def constant_latitude_interval(
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect with the
latitude interval.
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -212,9 +190,7 @@ def constant_latitude_interval(
"""
faces = self.uxgrid.get_faces_between_latitudes(lats)

grid_between_lats = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)
grid_between_lats = self.uxgrid.isel(n_face=faces)

if return_face_indices:
return grid_between_lats, faces
Expand All @@ -225,7 +201,6 @@ def constant_longitude_interval(
self,
lons: Tuple[float, float],
return_face_indices: bool = False,
inverse_indices: Union[List[str], Set[str], bool] = False,
):
"""Extracts a cross-section of the grid by selecting all faces are within a specifed longitude interval.

Expand All @@ -236,11 +211,6 @@ def constant_longitude_interval(
in degrees. Values must be between -180.0 and 180.0
return_face_indices : bool, optional
If True, also returns the indices of the faces that intersect are within a specifed longitude interval.
inverse_indices : Union[List[str], Set[str], bool], optional
Controls storage of original grid indices. Options:
- True: Stores original face indices
- List/Set of strings: Stores specified index types (valid values: "face", "edge", "node")
- False: No index storage (default)

Returns
-------
Expand Down Expand Up @@ -273,9 +243,7 @@ def constant_longitude_interval(
"""
faces = self.uxgrid.get_faces_between_longitudes(lons)

grid_between_lons = self.uxgrid.isel(
n_face=faces, inverse_indices=inverse_indices
)
grid_between_lons = self.uxgrid.isel(n_face=faces)

if return_face_indices:
return grid_between_lons, faces
Expand Down
Loading
Loading