From a7430f4d1e91e8e1af85c9b9723dc6f88053102e Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 9 Jan 2025 00:18:53 +0100 Subject: [PATCH 1/4] revert to integer indexing --- src/spatialdata/_core/query/relational_query.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 5c377771..c395fbca 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -382,6 +382,7 @@ def _inner_join_spatialelement_table( regions, region_column_name, instance_key = get_table_keys(table) groups_df = table.obs.groupby(by=region_column_name, observed=False) joined_indices = None + element_indices_mapping = {} for element_type, name_element in element_dict.items(): for name, element in name_element.items(): if name in regions: @@ -399,6 +400,7 @@ def _inner_join_spatialelement_table( masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) element_dict[element_type][name] = masked_element + element_indices_mapping[name] = masked_element.index joined_indices = _get_joined_table_indices( joined_indices, element_indices, table_instance_key_column, match_rows @@ -413,7 +415,18 @@ def _inner_join_spatialelement_table( if joined_indices is not None: joined_indices = joined_indices.dropna() if any(joined_indices.isna()) else joined_indices - joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + try: + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None + # happens when having duplicate indices in obs. Need to revert to integer indexing. + # TODO: benchmark to check whether this by default is just as quick as obtaining joined_indices. + except pd.errors.InvalidIndexError: + indices = [] + obs = table.obs.reset_index() + _, region_col, index_col = get_table_keys(table) + for name_key, index_values in element_indices_mapping.items(): + indices.extend(obs[(obs[region_col] == name_key) & (obs[index_col].isin(index_values))].index) + joined_table = table[indices, :].copy() + _inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table) return element_dict, joined_table From c705649f6c660084aa5dd3d6a8792a96beef3fd0 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 9 Jan 2025 00:47:36 +0100 Subject: [PATCH 2/4] add test --- tests/core/query/test_relational_query.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index 8c5fa4a2..a83588be 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -376,6 +376,22 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): assert all(indices == reversed_instance_id) +def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation): + sdata = sdata_query_aggregation + sdata["table"].obs.index = ["a"] * sdata["table"].n_obs + sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] + sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + + element_dict, table = join_spatialelement_table( + sdata=sdata, + spatial_element_names=["values_circles", "values_polygons"], + table_name="table", + how="inner", + ) + + assert table.n_obs == 10 + + # TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata def test_match_rows_join(sdata_query_aggregation): sdata = sdata_query_aggregation From 64c5ffc30b4b86fae14a523aa096e6cf421412ac Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Thu, 9 Jan 2025 09:39:54 +0100 Subject: [PATCH 3/4] fully change to use integer indexing --- .../_core/query/relational_query.py | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index c395fbca..1041f4ba 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -380,9 +380,9 @@ def _inner_join_spatialelement_table( element_dict: dict[str, dict[str, Any]], table: AnnData, match_rows: Literal["left", "no", "right"] ) -> tuple[dict[str, Any], AnnData]: regions, region_column_name, instance_key = get_table_keys(table) - groups_df = table.obs.groupby(by=region_column_name, observed=False) + obs = table.obs.reset_index() + groups_df = obs.groupby(by=region_column_name, observed=False) joined_indices = None - element_indices_mapping = {} for element_type, name_element in element_dict.items(): for name, element in name_element.items(): if name in regions: @@ -400,10 +400,9 @@ def _inner_join_spatialelement_table( masked_element = _get_masked_element(element_indices, element, table_instance_key_column, match_rows) element_dict[element_type][name] = masked_element - element_indices_mapping[name] = masked_element.index joined_indices = _get_joined_table_indices( - joined_indices, element_indices, table_instance_key_column, match_rows + joined_indices, masked_element.index, table_instance_key_column, match_rows ) else: warnings.warn( @@ -415,17 +414,7 @@ def _inner_join_spatialelement_table( if joined_indices is not None: joined_indices = joined_indices.dropna() if any(joined_indices.isna()) else joined_indices - try: - joined_table = table[joined_indices, :].copy() if joined_indices is not None else None - # happens when having duplicate indices in obs. Need to revert to integer indexing. - # TODO: benchmark to check whether this by default is just as quick as obtaining joined_indices. - except pd.errors.InvalidIndexError: - indices = [] - obs = table.obs.reset_index() - _, region_col, index_col = get_table_keys(table) - for name_key, index_values in element_indices_mapping.items(): - indices.extend(obs[(obs[region_col] == name_key) & (obs[index_col].isin(index_values))].index) - joined_table = table[indices, :].copy() + joined_table = table[joined_indices, :].copy() if joined_indices is not None else None _inplace_fix_subset_categorical_obs(subset_adata=joined_table, original_adata=table) return element_dict, joined_table @@ -468,7 +457,8 @@ def _left_join_spatialelement_table( if match_rows == "right": warnings.warn("Matching rows 'right' is not supported for 'left' join.", UserWarning, stacklevel=2) regions, region_column_name, instance_key = get_table_keys(table) - groups_df = table.obs.groupby(by=region_column_name, observed=False) + obs = table.obs.reset_index() + groups_df = obs.groupby(by=region_column_name, observed=False) joined_indices = None for element_type, name_element in element_dict.items(): for name, element in name_element.items(): From 4ca78e5207b03e4dbcc6c10a6b9ab0693f0e47e4 Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Mon, 13 Jan 2025 14:03:51 +0100 Subject: [PATCH 4/4] added tests --- tests/core/query/test_relational_query.py | 42 ++++++++++++++++++----- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index a83588be..877349d4 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -3,7 +3,7 @@ import pytest from anndata import AnnData -from spatialdata import get_values, match_table_to_element +from spatialdata import SpatialData, get_values, match_table_to_element from spatialdata._core.query.relational_query import ( _locate_value, _ValueOrigin, @@ -11,6 +11,7 @@ join_spatialelement_table, ) from spatialdata.models.models import TableModel +from spatialdata.testing import assert_anndata_equal, assert_geodataframe_equal def test_match_table_to_element(sdata_query_aggregation): @@ -376,20 +377,45 @@ def test_match_rows_inner_join_non_matching_table(sdata_query_aggregation): assert all(indices == reversed_instance_id) -def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation): +# TODO: 'left_exclusive' is currently not working, reported in this issue: +@pytest.mark.parametrize("join_type", ["left", "right", "inner", "right_exclusive"]) +def test_inner_join_match_rows_duplicate_obs_indices(sdata_query_aggregation: SpatialData, join_type: str) -> None: sdata = sdata_query_aggregation sdata["table"].obs.index = ["a"] * sdata["table"].n_obs - sdata_query_aggregation["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] - sdata_query_aggregation["values_circles"] = sdata_query_aggregation["values_circles"][:5] + sdata["values_circles"] = sdata_query_aggregation["values_circles"][:4] + sdata["values_polygons"] = sdata_query_aggregation["values_polygons"][:5] element_dict, table = join_spatialelement_table( sdata=sdata, spatial_element_names=["values_circles", "values_polygons"], table_name="table", - how="inner", - ) - - assert table.n_obs == 10 + how=join_type, + ) + + if join_type in ["left", "inner"]: + # table check + assert table.n_obs == 9 + assert np.array_equal(table.obs["instance_id"][:4], sdata["values_circles"].index) + assert np.array_equal(table.obs["instance_id"][4:], sdata["values_polygons"].index) + # shapes check + assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"]) + assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"]) + elif join_type == "right": + # table check + assert_anndata_equal(table.obs, sdata["table"].obs) + # shapes check + assert_geodataframe_equal(element_dict["values_circles"], sdata["values_circles"]) + assert_geodataframe_equal(element_dict["values_polygons"], sdata["values_polygons"]) + elif join_type == "left_exclusive": + # TODO: currently not working, reported in this issue + pass + else: + assert join_type == "right_exclusive" + # table check + assert table.n_obs == sdata["table"].n_obs - len(sdata["values_circles"]) - len(sdata["values_polygons"]) + # shapes check + assert element_dict["values_circles"] is None + assert element_dict["values_polygons"] is None # TODO: there is a lot of dublicate code, simplify with a function that tests both the case sdata=None and sdata=sdata