diff --git a/test/test_subset.py b/test/test_subset.py index 26bc79b55..bfef378da 100644 --- a/test/test_subset.py +++ b/test/test_subset.py @@ -149,30 +149,22 @@ def test_inverse_indices(): # Test nearest neighbor subsetting coord = [0, 0] - subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers", inverse_indices=True) + subset = grid.subset.nearest_neighbor(coord, k=1, element="face centers") - assert subset.inverse_indices is not None + assert subset._ds['source_face_indices'] is not None # Test bounding box subsetting box = [(-10, 10), (-10, 10)] - subset = grid.subset.bounding_box(box[0], box[1], inverse_indices=True) + subset = grid.subset.bounding_box(box[0], box[1]) - assert subset.inverse_indices is not None + assert subset._ds['source_face_indices'] is not None # Test bounding circle subsetting center_coord = [0, 0] - subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers", inverse_indices=True) + subset = grid.subset.bounding_circle(center_coord, r=10, element="face centers") - assert subset.inverse_indices is not None + assert subset._ds['source_face_indices'] is not None - # Ensure code raises exceptions when the element is edges or nodes or inverse_indices is incorrect - assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="edge centers", inverse_indices=True) - assert pytest.raises(Exception, grid.subset.bounding_circle, center_coord, r=10, element="nodes", inverse_indices=True) - assert pytest.raises(ValueError, grid.subset.bounding_circle, center_coord, r=10, element="face center", inverse_indices=(['not right'], True)) - - # Test isel directly - subset = grid.isel(n_face=[1], inverse_indices=True) - assert subset.inverse_indices.face.values == 1 def test_da_subset(): diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index 9ec371075..00c372528 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -1176,7 +1176,7 @@ def _edge_centered(self) -> bool: "n_edge" dimension)""" return "n_edge" in self.dims - def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs): + def isel(self, ignore_grid=False, *args, **kwargs): """Grid-informed implementation of xarray's ``isel`` method, which enables indexing across grid dimensions. @@ -1210,17 +1210,11 @@ def isel(self, ignore_grid=False, inverse_indices=False, *args, **kwargs): raise ValueError("Only one grid dimension can be sliced at a time") if "n_node" in kwargs: - sliced_grid = self.uxgrid.isel( - n_node=kwargs["n_node"], inverse_indices=inverse_indices - ) + sliced_grid = self.uxgrid.isel(n_node=kwargs["n_node"]) elif "n_edge" in kwargs: - sliced_grid = self.uxgrid.isel( - n_edge=kwargs["n_edge"], inverse_indices=inverse_indices - ) + sliced_grid = self.uxgrid.isel(n_edge=kwargs["n_edge"]) else: - sliced_grid = self.uxgrid.isel( - n_face=kwargs["n_face"], inverse_indices=inverse_indices - ) + sliced_grid = self.uxgrid.isel(n_face=kwargs["n_face"]) return self._slice_from_grid(sliced_grid) @@ -1255,23 +1249,24 @@ def from_xarray(cls, da: xr.DataArray, uxgrid: Grid, ugrid_dims: dict = None): return cls(ds, uxgrid=uxgrid) + # TODO: def _slice_from_grid(self, sliced_grid): """Slices a ``UxDataArray`` from a sliced ``Grid``, using cached indices to correctly slice the data variable.""" if self._face_centered(): da_sliced = self.isel( - n_face=sliced_grid._ds["subgrid_face_indices"], ignore_grid=True + n_face=sliced_grid._ds["source_face_indices"], ignore_grid=True ) elif self._edge_centered(): da_sliced = self.isel( - n_edge=sliced_grid._ds["subgrid_edge_indices"], ignore_grid=True + n_edge=sliced_grid._ds["source_edge_indices"], ignore_grid=True ) elif self._node_centered(): da_sliced = self.isel( - n_node=sliced_grid._ds["subgrid_node_indices"], ignore_grid=True + n_node=sliced_grid._ds["source_node_indices"], ignore_grid=True ) else: diff --git a/uxarray/cross_sections/dataarray_accessor.py b/uxarray/cross_sections/dataarray_accessor.py index 7d7610db7..3fd682b30 100644 --- a/uxarray/cross_sections/dataarray_accessor.py +++ b/uxarray/cross_sections/dataarray_accessor.py @@ -1,7 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, List, Set, Tuple +from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: pass @@ -24,9 +24,7 @@ def __repr__(self): return prefix + methods_heading - def constant_latitude( - self, lat: float, inverse_indices: Union[List[str], Set[str], bool] = False - ): + def constant_latitude(self, lat: float): """Extracts a cross-section of the data array by selecting all faces that intersect with a specified line of constant latitude. @@ -35,11 +33,6 @@ def constant_latitude( lat : float The latitude at which to extract the cross-section, in degrees. Must be between -90.0 and 90.0 - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -69,11 +62,9 @@ def constant_latitude( faces = self.uxda.uxgrid.get_faces_at_constant_latitude(lat) - return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) + return self.uxda.isel(n_face=faces) - def constant_longitude( - self, lon: float, inverse_indices: Union[List[str], Set[str], bool] = False - ): + def constant_longitude(self, lon: float): """Extracts a cross-section of the data array by selecting all faces that intersect with a specified line of constant longitude. @@ -82,11 +73,6 @@ def constant_longitude( lon : float The latitude at which to extract the cross-section, in degrees. Must be between -180.0 and 180.0 - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -118,12 +104,11 @@ def constant_longitude( lon, ) - return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) + return self.uxda.isel(n_face=faces) def constant_latitude_interval( self, lats: Tuple[float, float], - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of data by selecting all faces that are within a specified latitude interval. @@ -133,11 +118,6 @@ def constant_latitude_interval( lats : Tuple[float, float] The latitude interval (min_lat, max_lat) at which to extract the cross-section, in degrees. Values must be between -90.0 and 90.0 - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -164,12 +144,11 @@ def constant_latitude_interval( """ faces = self.uxda.uxgrid.get_faces_between_latitudes(lats) - return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) + return self.uxda.isel(n_face=faces) def constant_longitude_interval( self, lons: Tuple[float, float], - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of data by selecting all faces are within a specifed longitude interval. @@ -178,11 +157,6 @@ def constant_longitude_interval( lons : Tuple[float, float] The longitude interval (min_lon, max_lon) at which to extract the cross-section, in degrees. Values must be between -180.0 and 180.0 - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -209,4 +183,4 @@ def constant_longitude_interval( """ faces = self.uxda.uxgrid.get_faces_between_longitudes(lons) - return self.uxda.isel(n_face=faces, inverse_indices=inverse_indices) + return self.uxda.isel(n_face=faces) diff --git a/uxarray/cross_sections/grid_accessor.py b/uxarray/cross_sections/grid_accessor.py index b4fed76db..1e5c776d0 100644 --- a/uxarray/cross_sections/grid_accessor.py +++ b/uxarray/cross_sections/grid_accessor.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Union, List, Set, Tuple +from typing import TYPE_CHECKING, Tuple if TYPE_CHECKING: from uxarray.grid import Grid @@ -30,7 +30,6 @@ def constant_latitude( self, lat: float, return_face_indices: bool = False, - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces that intersect with a specified line of constant latitude. @@ -43,11 +42,6 @@ def constant_latitude( return_face_indices : bool, optional If True, also returns the indices of the faces that intersect with the line of constant latitude. - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -84,9 +78,7 @@ def constant_latitude( if len(faces) == 0: raise ValueError(f"No intersections found at lat={lat}.") - grid_at_constant_lat = self.uxgrid.isel( - n_face=faces, inverse_indices=inverse_indices - ) + grid_at_constant_lat = self.uxgrid.isel(n_face=faces) if return_face_indices: return grid_at_constant_lat, faces @@ -97,7 +89,6 @@ def constant_longitude( self, lon: float, return_face_indices: bool = False, - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces that intersect with a specified line of constant longitude. @@ -110,11 +101,6 @@ def constant_longitude( return_face_indices : bool, optional If True, also returns the indices of the faces that intersect with the line of constant longitude. - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -150,9 +136,7 @@ def constant_longitude( if len(faces) == 0: raise ValueError(f"No intersections found at lon={lon}") - grid_at_constant_lon = self.uxgrid.isel( - n_face=faces, inverse_indices=inverse_indices - ) + grid_at_constant_lon = self.uxgrid.isel(n_face=faces) if return_face_indices: return grid_at_constant_lon, faces @@ -163,7 +147,6 @@ def constant_latitude_interval( self, lats: Tuple[float, float], return_face_indices: bool = False, - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces that are within a specified latitude interval. @@ -176,11 +159,6 @@ def constant_latitude_interval( return_face_indices : bool, optional If True, also returns the indices of the faces that intersect with the latitude interval. - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -212,9 +190,7 @@ def constant_latitude_interval( """ faces = self.uxgrid.get_faces_between_latitudes(lats) - grid_between_lats = self.uxgrid.isel( - n_face=faces, inverse_indices=inverse_indices - ) + grid_between_lats = self.uxgrid.isel(n_face=faces) if return_face_indices: return grid_between_lats, faces @@ -225,7 +201,6 @@ def constant_longitude_interval( self, lons: Tuple[float, float], return_face_indices: bool = False, - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Extracts a cross-section of the grid by selecting all faces are within a specifed longitude interval. @@ -236,11 +211,6 @@ def constant_longitude_interval( in degrees. Values must be between -180.0 and 180.0 return_face_indices : bool, optional If True, also returns the indices of the faces that intersect are within a specifed longitude interval. - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) Returns ------- @@ -273,9 +243,7 @@ def constant_longitude_interval( """ faces = self.uxgrid.get_faces_between_longitudes(lons) - grid_between_lons = self.uxgrid.isel( - n_face=faces, inverse_indices=inverse_indices - ) + grid_between_lons = self.uxgrid.isel(n_face=faces) if return_face_indices: return grid_between_lons, faces diff --git a/uxarray/grid/grid.py b/uxarray/grid/grid.py index 9cae36bac..4c4035454 100644 --- a/uxarray/grid/grid.py +++ b/uxarray/grid/grid.py @@ -9,8 +9,6 @@ from typing import ( Optional, Union, - List, - Set, Tuple, ) @@ -154,8 +152,6 @@ class Grid: is_subset : bool, default=False Flag to mark if the grid is a subset or not - inverse_indices: xr.Dataset, default=None - A dataset of indices that correspond to the original grid, if the grid being constructed is a subset Examples ---------- @@ -181,7 +177,6 @@ def __init__( source_grid_spec: Optional[str] = None, source_dims_dict: Optional[dict] = {}, is_subset: bool = False, - inverse_indices: Optional[xr.Dataset] = None, ): # check if inputted dataset is a minimum representable 2D UGRID unstructured grid # TODO: @@ -221,8 +216,6 @@ def __init__( self._ds.assign_attrs({"source_grid_spec": self.source_grid_spec}) self._is_subset = is_subset - self._inverse_indices = inverse_indices - # cached parameters for GeoDataFrame conversions self._gdf_cached_parameters = { "gdf": None, @@ -338,7 +331,6 @@ def from_dataset(cls, dataset, use_dual: Optional[bool] = False, **kwargs): source_grid_spec, source_dims_dict, is_subset=kwargs.get("is_subset", False), - inverse_indices=kwargs.get("inverse_indices"), ) @classmethod @@ -1575,21 +1567,29 @@ def global_sphere_coverage(self): (i.e. contains no holes)""" return not self.partial_sphere_coverage - @property - def inverse_indices(self) -> xr.Dataset: - """Indices for a subset that map each face in the subset back to the original grid""" - if self.is_subset: - return self._inverse_indices - else: - raise Exception( - "Grid is not a subset, therefore no inverse face indices exist" - ) - @property def is_subset(self): """Returns `True` if the Grid is a subset, 'False' otherwise.""" return self._is_subset + @property + def source_face_indices(self): + """TODO:""" + if "source_face_indices" in self._ds: + return self._ds["source_face_indices"] + + @property + def source_edge_indices(self): + """TODO:""" + if "source_edge_indices" in self._ds: + return self._ds["source_edge_indices"] + + @property + def source_node_indices(self): + """TODO:""" + if "source_node_indices" in self._ds: + return self._ds["source_node_indices"] + @property def max_face_radius(self): """Maximum face radius of the grid in degrees""" @@ -2347,9 +2347,7 @@ def get_dual(self): return dual - def isel( - self, inverse_indices: Union[List[str], Set[str], bool] = False, **dim_kwargs - ): + def isel(self, **dim_kwargs): """Indexes an unstructured grid along a given dimension (``n_node``, ``n_edge``, or ``n_face``) and returns a new grid. @@ -2359,9 +2357,6 @@ def isel( exclusive and clipped indexing is in the works. Parameters - inverse_indices : Union[List[str], Set[str], bool], default=False - Indicates whether to store the original grids indices. Passing `True` stores the original face indices, - other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) **dims_kwargs: kwargs Dimension to index, one of ['n_node', 'n_edge', 'n_face'] @@ -2377,22 +2372,15 @@ def isel( raise ValueError("Indexing must be along a single dimension.") if "n_node" in dim_kwargs: - if inverse_indices: - raise Exception( - "Inverse indices are not yet supported for node selection, please use face centers" - ) return _slice_node_indices(self, dim_kwargs["n_node"]) elif "n_edge" in dim_kwargs: - if inverse_indices: - raise Exception( - "Inverse indices are not yet supported for edge selection, please use face centers" - ) return _slice_edge_indices(self, dim_kwargs["n_edge"]) elif "n_face" in dim_kwargs: return _slice_face_indices( - self, dim_kwargs["n_face"], inverse_indices=inverse_indices + self, + dim_kwargs["n_face"], ) else: @@ -2624,7 +2612,6 @@ def get_faces_containing_point( r=max_face_radius, center_coord=point_lonlat, element="face centers", - inverse_indices=True, ) # If no subset is found, warn the user except ValueError: @@ -2648,7 +2635,7 @@ def get_faces_containing_point( ) # Get the original face indices from the subset - inverse_indices = subset.inverse_indices.face.values + inverse_indices = subset["source_face_indices"].values # Check to see if the point is on the nodes of any face lies_on_node = np.isclose( diff --git a/uxarray/grid/slice.py b/uxarray/grid/slice.py index 8cce19d15..d36473b7e 100644 --- a/uxarray/grid/slice.py +++ b/uxarray/grid/slice.py @@ -2,14 +2,48 @@ import numpy as np import xarray as xr -from uxarray.constants import INT_FILL_VALUE, INT_DTYPE +from uxarray.constants import INT_FILL_VALUE -from typing import TYPE_CHECKING, Union, List, Set +from uxarray.grid import Grid +import polars as pl +from numba import njit, types, prange +from numba.typed import Dict + +from typing import TYPE_CHECKING if TYPE_CHECKING: pass +@njit( + "int64[:,:](int64[:,:], DictType(int64, int64), int64)", cache=True, parallel=True +) +def update_connectivity(conn, indices_dict, fill_value): + dim_a, dim_b = conn.shape + conn_flat = conn.flatten() + result = np.empty_like(conn_flat) + + for i in prange(len(conn_flat)): + if conn_flat[i] == fill_value: + result[i] = fill_value + else: + # Use the dictionary to find the new index + result[i] = indices_dict.get(conn_flat[i], fill_value) + + return result.reshape(dim_a, dim_b) + + +@njit( + "DictType(int64, int64)(int64[:])", + cache=True, +) +def create_indices_dict(indices): + indices_dict = Dict.empty(key_type=types.int64, value_type=types.int64) + for new_idx, old_idx in enumerate(indices): + indices_dict[old_idx] = new_idx + return indices_dict + + def _slice_node_indices( grid, indices, @@ -33,9 +67,16 @@ def _slice_node_indices( if inclusive is False: raise ValueError("Exclusive slicing is not yet supported.") - # faces that saddle nodes given in 'indices' - face_indices = np.unique(grid.node_face_connectivity.values[indices].ravel()) - face_indices = face_indices[face_indices != INT_FILL_VALUE] + node_indices = _validate_indices(indices) + + # Identify face indices from edge_face_connectivity + node_face_connectivity = grid.node_face_connectivity.isel( + n_node=node_indices + ).values + face_indices = ( + pl.DataFrame(node_face_connectivity.flatten()).unique().to_numpy().squeeze() + ) + face_indices = face_indices[face_indices >= 0] return _slice_face_indices(grid, face_indices) @@ -63,123 +104,102 @@ def _slice_edge_indices( if inclusive is False: raise ValueError("Exclusive slicing is not yet supported.") - # faces that saddle nodes given in 'indices' - face_indices = np.unique(grid.edge_face_connectivity.values[indices].ravel()) - face_indices = face_indices[face_indices != INT_FILL_VALUE] - - return _slice_face_indices(grid, face_indices) + edge_indices = _validate_indices(indices) + # Identify face indices from edge_face_connectivity + edge_face_connectivity = grid.edge_face_connectivity.isel( + n_edge=edge_indices + ).values + face_indices = ( + pl.DataFrame(edge_face_connectivity.flatten()).unique().to_numpy().squeeze() + ) + face_indices = face_indices[face_indices >= 0] -def _slice_face_indices( - grid, - indices, - inclusive=True, - inverse_indices: Union[List[str], Set[str], bool] = False, -): - """Slices (indexes) an unstructured grid given a list/array of face - indices, returning a new Grid composed of elements that contain the faces - specified in the indices. - - Parameters - ---------- - grid : ux.Grid - Source unstructured grid - indices: array-like - A list or 1-D array of face indices - inclusive: bool - Whether to perform inclusive (i.e. elements must contain at least one desired feature from a slice) as opposed - to exclusive (i.e elements be made up all desired features from a slice) - inverse_indices : Union[List[str], Set[str], bool], optional - Indicates whether to store the original grids indices. Passing `True` stores the original face centers, - other reverse indices can be stored by passing any or all of the following: (["face", "edge", "node"], True) - """ - if inclusive is False: - raise ValueError("Exclusive slicing is not yet supported.") + return _slice_face_indices(grid, face_indices) - from uxarray.grid import Grid +def _slice_face_indices(grid, indices): ds = grid._ds - indices = np.asarray(indices, dtype=INT_DTYPE) + face_indices = _validate_indices(indices) - if indices.ndim == 0: - indices = np.expand_dims(indices, axis=0) + # Identify node indices from face_node_connectivity + face_node_connectivity = grid.face_node_connectivity.isel( + n_face=face_indices + ).values + node_indices = ( + pl.DataFrame(face_node_connectivity.flatten()).unique().to_numpy().squeeze() + ) + node_indices = node_indices[node_indices >= 0] - face_indices = indices + # Prepare indexers and source indices + indexers = {"n_node": node_indices, "n_face": face_indices} - # nodes of each face (inclusive) - node_indices = np.unique(grid.face_node_connectivity.values[face_indices].ravel()) - node_indices = node_indices[node_indices != INT_FILL_VALUE] - - # index original dataset to obtain a 'subgrid' - ds = ds.isel(n_node=node_indices) - ds = ds.isel(n_face=face_indices) + source_indices = { + "source_node_indices": xr.DataArray(node_indices, dims=["n_node"]), + "source_face_indices": xr.DataArray(face_indices, dims=["n_face"]), + } - # Only slice edge dimension if we already have the connectivity - if "face_edge_connectivity" in grid._ds: - edge_indices = np.unique( - grid.face_edge_connectivity.values[face_indices].ravel() + # Identify edge indices from face_edge_connectivity and prepare indexers and source indices + if "n_edge" in ds.dims: + face_edge_connectivity = grid.face_edge_connectivity.isel( + n_face=face_indices + ).values + edge_indices = ( + pl.DataFrame(face_edge_connectivity.flatten()).unique().to_numpy() + ) + edge_indices = edge_indices[edge_indices >= 0] + indexers["n_edge"] = edge_indices + source_indices["source_edge_indices"] = xr.DataArray( + edge_indices, dims=["n_edge"] ) - edge_indices = edge_indices[edge_indices != INT_FILL_VALUE] - ds = ds.isel(n_edge=edge_indices) - ds["subgrid_edge_indices"] = xr.DataArray(edge_indices, dims=["n_edge"]) - else: - edge_indices = None - - ds["subgrid_node_indices"] = xr.DataArray(node_indices, dims=["n_node"]) - ds["subgrid_face_indices"] = xr.DataArray(face_indices, dims=["n_face"]) - # mapping to update existing connectivity - node_indices_dict = { - key: val for key, val in zip(node_indices, np.arange(0, len(node_indices))) - } - node_indices_dict[INT_FILL_VALUE] = INT_FILL_VALUE - - for conn_name in grid._ds.data_vars: - # update or drop connectivity variables to correctly point to the new index of each element - - if "_node_connectivity" in conn_name: - # update connectivity vars that index into nodes - ds[conn_name] = xr.DataArray( - np.vectorize(node_indices_dict.__getitem__, otypes=[INT_DTYPE])( - ds[conn_name].values - ), - dims=ds[conn_name].dims, - attrs=ds[conn_name].attrs, + ds = ds.isel(indexers) + ds = ds.assign(source_indices) + + # Update existing connectivity to match valid indices + node_conn_names = [ + conn_name for conn_name in ds.data_vars if "_node_connectivity" in conn_name + ] + edge_conn_names = [ + conn_name for conn_name in ds.data_vars if "_edge_connectivity" in conn_name + ] + face_conn_names = [ + conn_name for conn_name in ds.data_vars if "_face_connectivity" in conn_name + ] + + # Update Node Connectivity Variables + if node_conn_names: + node_indices_dict = create_indices_dict(node_indices) + + for conn_name in node_conn_names: + ds[conn_name].data = update_connectivity( + ds[conn_name].values, node_indices_dict, INT_FILL_VALUE ) - elif "_connectivity" in conn_name: - # drop any conn that would require re-computation - ds = ds.drop_vars(conn_name) + # Update Edge Connectivity Variables + if edge_conn_names: + edge_indices_dict = create_indices_dict(edge_indices) - if inverse_indices: - inverse_indices_ds = xr.Dataset() + for conn_name in edge_conn_names: + ds[conn_name].data = update_connectivity( + ds[conn_name].values, edge_indices_dict, INT_FILL_VALUE + ) - index_types = { - "face": face_indices, - "node": node_indices, - } + # Update Face Connectivity Variables (TODO) + if face_conn_names: + ds = ds.drop_vars(face_conn_names) - if edge_indices is not None: - index_types["edge"] = edge_indices + return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec, is_subset=True) - if isinstance(inverse_indices, bool): - inverse_indices_ds["face"] = face_indices - else: - for index_type in inverse_indices[0]: - if index_type in index_types: - inverse_indices_ds[index_type] = index_types[index_type] - else: - raise ValueError( - "Incorrect type of index for `inverse_indices`. Try passing one of the following " - "instead: 'face', 'edge', 'node'" - ) - - return Grid.from_dataset( - ds, - source_grid_spec=grid.source_grid_spec, - is_subset=True, - inverse_indices=inverse_indices_ds, - ) - return Grid.from_dataset(ds, source_grid_spec=grid.source_grid_spec, is_subset=True) +def _validate_indices(indices): + if hasattr(indices, "ndim") and indices.ndim == 0: + # Handle scalar numpy array case + return [indices.item()] + elif np.isscalar(indices): + # Handle Python scalar case + return [indices] + else: + # Already array-like + return indices diff --git a/uxarray/subset/dataarray_accessor.py b/uxarray/subset/dataarray_accessor.py index 0c2279e48..fde4367f3 100644 --- a/uxarray/subset/dataarray_accessor.py +++ b/uxarray/subset/dataarray_accessor.py @@ -2,7 +2,7 @@ import numpy as np -from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set +from typing import TYPE_CHECKING, Union, Tuple, List, Optional if TYPE_CHECKING: pass @@ -33,7 +33,6 @@ def bounding_box( self, lon_bounds: Union[Tuple, List, np.ndarray], lat_bounds: Union[Tuple, List, np.ndarray], - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Subsets an unstructured grid between two latitude and longitude points which form a bounding box. @@ -53,15 +52,8 @@ def bounding_box( face centers, or edge centers lie within the bounds. element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ - grid = self.uxda.uxgrid.subset.bounding_box( - lon_bounds, lat_bounds, inverse_indices=inverse_indices - ) + grid = self.uxda.uxgrid.subset.bounding_box(lon_bounds, lat_bounds) return self.uxda._slice_from_grid(grid) @@ -70,7 +62,6 @@ def bounding_circle( center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], element: Optional[str] = "face centers", - inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning all elements within some @@ -84,14 +75,9 @@ def bounding_circle( Radius of bounding circle (in degrees) element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ grid = self.uxda.uxgrid.subset.bounding_circle( - center_coord, r, element, inverse_indices=inverse_indices, **kwargs + center_coord, r, element, **kwargs ) return self.uxda._slice_from_grid(grid) @@ -100,7 +86,6 @@ def nearest_neighbor( center_coord: Union[Tuple, List, np.ndarray], k: int, element: Optional[str] = "face centers", - inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning the ``k`` closest @@ -114,15 +99,10 @@ def nearest_neighbor( Number of neighbors to query element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ grid = self.uxda.uxgrid.subset.nearest_neighbor( - center_coord, k, element, inverse_indices=inverse_indices, **kwargs + center_coord, k, element, **kwargs ) return self.uxda._slice_from_grid(grid) diff --git a/uxarray/subset/grid_accessor.py b/uxarray/subset/grid_accessor.py index 971ab477f..0e0439323 100644 --- a/uxarray/subset/grid_accessor.py +++ b/uxarray/subset/grid_accessor.py @@ -2,7 +2,7 @@ import numpy as np -from typing import TYPE_CHECKING, Union, Tuple, List, Optional, Set +from typing import TYPE_CHECKING, Union, Tuple, List, Optional if TYPE_CHECKING: from uxarray.grid import Grid @@ -31,7 +31,6 @@ def bounding_box( self, lon_bounds: Tuple[float, float], lat_bounds: Tuple[float, float], - inverse_indices: Union[List[str], Set[str], bool] = False, ): """Subsets an unstructured grid between two latitude and longitude points which form a bounding box. @@ -51,11 +50,6 @@ def bounding_box( face centers, or edge centers lie within the bounds. element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ faces_between_lons = self.uxgrid.get_faces_between_longitudes(lon_bounds) @@ -63,14 +57,13 @@ def bounding_box( faces = np.intersect1d(faces_between_lons, face_between_lats) - return self.uxgrid.isel(n_face=faces, inverse_indices=inverse_indices) + return self.uxgrid.isel(n_face=faces) def bounding_circle( self, center_coord: Union[Tuple, List, np.ndarray], r: Union[float, int], element: Optional[str] = "face centers", - inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning all elements within some @@ -84,11 +77,6 @@ def bounding_circle( Radius of bounding circle (in degrees) element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ coords = np.asarray(center_coord) @@ -102,14 +90,13 @@ def bounding_circle( f"No elements founding within the bounding circle with radius {r} when querying {element}" ) - return self._index_grid(ind, element, inverse_indices) + return self._index_grid(ind, element) def nearest_neighbor( self, center_coord: Union[Tuple, List, np.ndarray], k: int, element: Optional[str] = "face centers", - inverse_indices: Union[List[str], Set[str], bool] = False, **kwargs, ): """Subsets an unstructured grid by returning the ``k`` closest @@ -123,11 +110,6 @@ def nearest_neighbor( Number of neighbors to query element: str Element for use with `coords` comparison, one of `nodes`, `face centers`, or `edge centers` - inverse_indices : Union[List[str], Set[str], bool], optional - Controls storage of original grid indices. Options: - - True: Stores original face indices - - List/Set of strings: Stores specified index types (valid values: "face", "edge", "node") - - False: No index storage (default) """ coords = np.asarray(center_coord) @@ -136,7 +118,7 @@ def nearest_neighbor( _, ind = tree.query(coords, k) - return self._index_grid(ind, element, inverse_indices=inverse_indices) + return self._index_grid(ind, element) def _get_tree(self, coords, tree_type): """Internal helper for obtaining the desired KDTree or BallTree.""" @@ -154,12 +136,12 @@ def _get_tree(self, coords, tree_type): return tree - def _index_grid(self, ind, tree_type, inverse_indices=False): + def _index_grid(self, ind, tree_type): """Internal helper for indexing a grid with indices based off the provided tree type.""" if tree_type == "nodes": - return self.uxgrid.isel(inverse_indices, n_node=ind) + return self.uxgrid.isel(n_node=ind) elif tree_type == "edge centers": - return self.uxgrid.isel(inverse_indices, n_edge=ind) + return self.uxgrid.isel(n_edge=ind) else: - return self.uxgrid.isel(inverse_indices, n_face=ind) + return self.uxgrid.isel(n_face=ind)