diff --git a/README.md b/README.md index 0dbd2bc9c..aa5db8fe3 100644 --- a/README.md +++ b/README.md @@ -248,10 +248,12 @@ edit `.docker/settings_docker.py` and setup application config. Then `docker-com ## STAC API Extensions -The STAC endpoint implements the [query](https://github.com/stac-api-extensions/query), [filter](https://github.com/stac-api-extensions/filter), [fields](https://github.com/stac-api-extensions/fields), and [sort](https://github.com/stac-api-extensions/sort) extensions, all of which are bound to the `search` endpoint as used with POST requests, with fields and sort additionally bound to the features endpoint. +The STAC endpoint implements the [filter](https://github.com/stac-api-extensions/filter), [fields](https://github.com/stac-api-extensions/fields), and [sort](https://github.com/stac-api-extensions/sort) extensions, all of which are bound to the STAC API - Item Search (`/search`) endpoint. All support both GET and POST request syntax. Fields contained in the item properties must be prefixed with `properties.`, ex `properties.dea:dataset_maturity`. -The implementation of `fields` differs somewhat from the suggested include/exclude semantics in that it does not permit for invalid STAC entities, so the `id`, `type`, `geometry`, `bbox`, `links`, `assets`, `properties.datetime`, and `stac_version` fields will always be included, regardless of user input. +The implementation of `fields` differs somewhat from the suggested include/exclude semantics in that it does not permit for invalid STAC entities, so the `id`, `type`, `geometry`, `bbox`, `links`, `assets`, `properties.datetime`, `collection`, and `stac_version` fields will always be included, regardless of user input. -The implementation of `filter` is limited, and currently only supports CQL2 JSON syntax with the following basic CQL2 operators: `AND`, `OR`, `=`, `>`, `>=`, `<`, `<=`, `<>`, `IS NULL`. +The `sort` and `filter` implementations will recognise any syntactically valid version of a property name, which is the say, the STAC, eo3, and search field (as defined by the metadata type) variants of the name, with or without the `item.` or `properties.` prefixes. If a property does not exist for an item, `sort` will ignore it while `filter` will treat it as `NULL`. + +The `filter` extension supports both `cql2-text` and `cql2-json` for both GET and POST requesets, and uses [pygeofilter](https://github.com/geopython/pygeofilter) to parse the cql and convert it to a sqlalchemy filter expression. `filter-crs` only accepts http://www.opengis.net/def/crs/OGC/1.3/CRS84 as a valid value. diff --git a/cubedash/_stac.py b/cubedash/_stac.py index 0c94c84bf..1ba28c2b3 100644 --- a/cubedash/_stac.py +++ b/cubedash/_stac.py @@ -47,10 +47,28 @@ STAC_VERSION = "1.0.0" +ItemLike = Union[pystac.Item, dict] + ############################ # Helpers ############################ + +def dissoc_in(d: dict, key: str): + # like dicttoolz.dissoc but with support for nested keys + split = key.split(".") + + if len(split) > 1: # if nested + if dicttoolz.get_in(split, d) is not None: + outer = dicttoolz.get_in(split[:-1], d) + return dicttoolz.update_in( + d=d, + keys=split[:-1], + func=lambda _: dicttoolz.dissoc(outer, split[-1]), # noqa: B023 + ) + return dicttoolz.dissoc(d, key) + + # Time-related @@ -340,6 +358,13 @@ def _build_properties(d: DocReader): # Search arguments +def _remove_prefixes(arg: str): + # remove potential 'item.', 'properties.', or 'item.properties.' prefixes for ease of handling + arg = arg.replace("item.", "") + arg = arg.replace("properties.", "") + return arg + + def _array_arg( arg: Union[str, List[Union[str, float]]], expect_type=str, expect_size=None ) -> List: @@ -390,7 +415,7 @@ def _geojson_arg(arg: dict) -> BaseGeometry: raise BadRequest("The 'intersects' argument must be valid GeoJSON geometry.") -def _bool_argument(s: str): +def _bool_argument(s: Union[str, bool]): """ Parse an argument that should be a bool """ @@ -401,7 +426,7 @@ def _bool_argument(s: str): return s.strip().lower() in ("1", "true", "on", "yes") -def _dict_arg(arg: dict): +def _dict_arg(arg: Union[str, dict]): """ Parse stac extension arguments as dicts """ @@ -410,15 +435,73 @@ def _dict_arg(arg: dict): return arg -def _list_arg(arg: list): +def _field_arg(arg: Union[str, list, dict]) -> dict[str, list[str]]: """ - Parse sortby argument as a list of dicts + Parse field argument into a dict """ + if isinstance(arg, dict): + return _dict_arg(arg) if isinstance(arg, str): - arg = list(arg) - return list( - map(lambda a: json.loads(a.replace("'", '"')) if isinstance(a, str) else a, arg) - ) + if arg.startswith("{"): + return _dict_arg(arg) + arg = arg.split(",") + if isinstance(arg, list): + include = [] + exclude = [] + for a in arg: + if a.startswith("-"): + exclude.append(a[1:]) + else: + # account for '+' showing up as a space if not encoded + include.append(a[1:] if a.startswith("+") else a.strip()) + return {"include": include, "exclude": exclude} + + +def _sort_arg(arg: Union[str, list]) -> list[dict]: + """ + Parse sortby argument into a list of dicts + """ + + def _format(val: str) -> dict[str, str]: + val = _remove_prefixes(val) + if val.startswith("-"): + return {"field": val[1:], "direction": "desc"} + if val.startswith("+"): + return {"field": val[1:], "direction": "asc"} + # default is ascending + return {"field": val.strip(), "direction": "asc"} + + if isinstance(arg, str): + arg = arg.split(",") + if len(arg): + if isinstance(arg[0], str): + return [_format(a) for a in arg] + if isinstance(arg[0], dict): + for a in arg: + a["field"] = _remove_prefixes(a["field"]) + + return arg + + +def _filter_arg(arg: Union[str, dict]) -> str: + # convert dict to arg to more easily remove prefixes + if isinstance(arg, dict): + arg = json.dumps(arg) + return _remove_prefixes(arg) + + +def _validate_filter(filter_lang: str, cql: str): + # check filter-lang and actual cql format are aligned + is_json = True + try: + json.loads(cql) + except json.decoder.JSONDecodeError: + is_json = False + + if filter_lang == "cql2-text" and is_json: + abort(400, "Expected filter to be cql2-text, but received cql2-json") + if filter_lang == "cql2-json" and not is_json: + abort(400, "Expected filter to be cql2-json, but received cql2-text") # Search @@ -452,13 +535,46 @@ def _handle_search_request( intersects = request_args.get("intersects", default=None, type=_geojson_arg) - query = request_args.get("query", default=None, type=_dict_arg) - - fields = request_args.get("fields", default=None, type=_dict_arg) - - sortby = request_args.get("sortby", default=None, type=_list_arg) + fields = request_args.get("fields", default=None, type=_field_arg) + + sortby = request_args.get("sortby", default=None, type=_sort_arg) + # not sure if there's a neater way to check sortable attribute type in _stores + # but the handling logic (i.e. 400 status code) would still need to live in here + if sortby: + for s in sortby: + field = s.get("field") + if field in [ + "type", + "stac_version", + "properties", + "geometry", + "links", + "assets", + "bbox", + "stac_extensions", + ]: + abort( + 400, + f"Cannot sort by {field}. " + "Only 'id', 'collection', and Item properties can be used to sort results.", + ) - filter_cql = request_args.get("filter", default=None, type=_dict_arg) + filter_lang = request_args.get("filter-lang", default=None, type=str) + filter_cql = request_args.get("filter", default=None, type=_filter_arg) + filter_crs = request_args.get("filter-crs", default=None) + if filter_crs and filter_crs != "https://www.opengis.net/def/crs/OGC/1.3/CRS84": + abort( + 400, + "filter-crs only accepts 'https://www.opengis.net/def/crs/OGC/1.3/CRS84' as a valid value.", + ) + if filter_lang is None and filter_cql is not None: + # If undefined, defaults to cql2-text for a GET request and cql2-json for a POST request. + if method == "GET": + filter_lang = "cql2-text" + else: + filter_lang = "cql2-json" + if filter_cql: + _validate_filter(filter_lang, filter_cql) if limit > PAGE_SIZE_LIMIT: abort( @@ -483,9 +599,11 @@ def next_page_url(next_offset): limit=limit, _o=next_offset, _full=full_information, - query=query, + intersects=intersects, fields=fields, sortby=sortby, + # so that it doesn't get named 'filter_lang' + **{"filter-lang": filter_lang}, filter=filter_cql, ) @@ -502,9 +620,9 @@ def next_page_url(next_offset): get_next_url=next_page_url, full_information=full_information, include_total_count=include_total_count, - query=query, fields=fields, sortby=sortby, + filter_lang=filter_lang, filter_cql=filter_cql, ) @@ -532,180 +650,88 @@ def next_page_url(next_offset): # Item search extensions -def _get_property(prop: str, item: pystac.Item, no_default=False): +def _get_property(prop: str, item: ItemLike, no_default=False): """So that we don't have to keep using this bulky expression""" - return dicttoolz.get_in(prop.split("."), item.to_dict(), no_default=no_default) - - -def _predicate_helper(items: List[pystac.Item], prop: str, op: str, val) -> filter: - """Common comparison predicates used in both query and filter""" - if op == "eq" or op == "=": - return filter(lambda item: _get_property(prop, item) == val, items) - if op == "gte" or op == ">=": - return filter(lambda item: _get_property(prop, item) >= val, items) - if op == "lte" or op == "<=": - return filter(lambda item: _get_property(prop, item) <= val, items) - elif op == "gt" or op == ">": - return filter(lambda item: _get_property(prop, item) > val, items) - elif op == "lt" or op == "<": - return filter(lambda item: _get_property(prop, item) < val, items) - elif op == "neq" or op == "<>": - return filter(lambda item: _get_property(prop, item) != val, items) - - -def _handle_query_extension(items: List[pystac.Item], query: dict) -> List[pystac.Item]: - """ - Implementation of item search query extension (https://github.com/stac-api-extensions/query/blob/main/README.md) - The documentation doesn't specify whether multiple properties should be treated as logical AND or OR; this - implementation has assumed AND. + if isinstance(item, pystac.Item): + item = item.to_dict() + return dicttoolz.get_in(prop.split("."), item, no_default=no_default) - query = {'property': {'op': 'value'}, 'property': {'op': 'value', 'op': 'value'}} - """ - filtered = items - # split on '.' to use dicttoolz for nested items - for prop in query.keys(): - # Retrieve nested dict values - for op, val in query[prop].items(): - if op == "startsWith": - matched = filter( - lambda item: _get_property(prop, item).startswith(val), items - ) - elif op == "endsWith": - matched = filter( - lambda item: _get_property(prop, item).endswith(val), items - ) - elif op == "contains": - matched = filter(lambda item: val in _get_property(prop, item), items) - elif op == "in": - matched = filter(lambda item: _get_property(prop, item) in val, items) - else: - matched = _predicate_helper(items, prop, op, val) - - # achieve logical and between queries with set intersection - filtered = list(set(filtered).intersection(set(matched))) - - return filtered - -def _handle_fields_extension( - items: List[pystac.Item], fields: dict -) -> List[pystac.Item]: +def _handle_fields_extension(items: List[ItemLike], fields: dict) -> List[ItemLike]: """ Implementation of fields extension (https://github.com/stac-api-extensions/fields/blob/main/README.md) - This implementation differs slightly from the documented semantics in that if only `exclude` is specified, those - attributes will be subtracted from the complete set of the item's attributes, not just the default. `exclude` will - also not remove any of the default attributes so as to prevent errors due to invalid stac items. + This implementation differs slightly from the documented semantics in that the default fields will always + be included regardless of `include` or `exclude` values so as to ensure valid stac items. fields = {'include': [...], 'exclude': [...]} """ res = [] - # minimum fields needed for a valid stac item - default_fields = [ - "id", - "type", - "geometry", - "bbox", - "links", - "assets", - "properties.datetime", - "stac_version", - ] for item in items: - include = fields.get("include") or [] - # if 'include' is provided we build up from an empty slate; - # but if only 'exclude' is provided we remove from all existing fields - filtered_item = {} if fields.get("include") else item.to_dict() - # union of 'include' and default fields to ensure a valid stac item - include = list(set(include + default_fields)) - - for inc in include: - filtered_item = dicttoolz.update_in( - d=filtered_item, - keys=inc.split("."), - # get corresponding field from item - # disallow default to avoid None values being inserted - func=lambda _: _get_property(inc, item, no_default=True), - ) - - for exc in fields.get("exclude") or []: - # don't remove a field if it will make for an invalid stac item - if exc not in default_fields: - # what about a field that isn't there? - split = exc.split(".") - # have to manually take care of nested case because dicttoolz doesn't have a dissoc_in - if len(split): - filtered_item[split[0]] = dicttoolz.dissoc( - filtered_item[split[0]], split[1] - ) - else: - filtered_item = dicttoolz.dissoc(filtered_item, exc) - - res.append(pystac.Item.from_dict(filtered_item)) - - return res - - -def _handle_sortby_extension( - items: List[pystac.Item], sortby: List[dict] -) -> List[pystac.Item]: - """ - Implementation of sort extension (https://github.com/stac-api-extensions/sort/blob/main/README.md) + # minimum fields needed for a valid stac item + default_fields = [ + "id", + "type", + "geometry", + "bbox", + "links", + "assets", + "stac_version", + # while not necessary for a valid stac item, we still want them included + "stac_extensions", + "collection", + ] - sortby = [ {'field': 'field_name', 'direction': <'asc' or 'desc'>} ] - """ - sorted_items = items - - for s in sortby: - field = s.get("field") - reverse = s.get("direction") == "desc" - # should we enforce correct names and raise error if not? - sorted_items = sorted( - sorted_items, key=lambda i: _get_property(field, i), reverse=reverse - ) + # datetime is one of the default fields, but might be included as start_datetime/end_datetime instead + if _get_property("properties.start_datetime", item) is None: + dt_field = ["properties.start_datetime", "properties.end_datetime"] + else: + dt_field = ["properties.datetime"] + + try: + # if 'include' is present at all, start with default fields to add to or extract from + include = fields["include"] + if include is None: + include = [] + + filtered_item = {k: _get_property(k, item) for k in default_fields} + # handle datetime separately due to nested keys + for f in dt_field: + filtered_item = dicttoolz.assoc_in( + filtered_item, f.split("."), _get_property(f, item) + ) + except KeyError: + # if 'include' wasn't provided, remove 'exclude' fields from set of all available fields + filtered_item = item.to_dict() + include = [] - return list(sorted_items) + # add datetime field names to list of defaults for easy access + default_fields.extend(dt_field) + include = list(set(include + default_fields)) + for exc in fields.get("exclude", []): + if exc not in default_fields: + filtered_item = dissoc_in(filtered_item, exc) -def _handle_filter_extension( - items: List[pystac.Item], filter_cql: dict -) -> List[pystac.Item]: - """ - Implementation of filter extension (https://github.com/stac-api-extensions/filter/blob/main/README.md) - Currently only supporting logical expression (and/or), null and binary comparisons, provided in cql-json - Assumes comparisons to be done between a property value and a literal + # include takes precedence over exclude, plus account for a nested field of an excluded field + for inc in include: + # we don't want to insert None values if a field doesn't exist, but we also don't want to error + try: + filtered_item = dicttoolz.update_in( + d=filtered_item, + keys=inc.split("."), + func=lambda _: _get_property( + inc, + item, + no_default=True, # noqa: B023 + ), + ) + except KeyError: + continue - filter = {'op': 'and','args': - [{'op': '=', 'args': [{'property': 'prop_name'}, val]}, {'op': 'isNull', 'args': {'property': 'prop_name'}}] - } - """ - results = [] - op = filter_cql.get("op") - args = filter_cql.get("args") - # if there is a nested operation in the args, recur to resolve those, creating - # a list of lists that we can then apply the top level operator to - for arg in [a for a in args if isinstance(a, dict) and a.get("op")]: - results.append(_handle_filter_extension(items, arg)) - - if op == "and": - # set intersection between each result - # need to pass results as a list of sets to intersection - results = list(set.intersection(*map(set, results))) - elif op == "or": - # set union between each result - results = list(set.union(*map(set, results))) - elif op == "isNull": - # args is a single property rather than a list - prop = args.get("property") - results = filter( - lambda item: _get_property(prop, item) in [None, "None"], items - ) - else: - prop = args[0].get("property") - val = args[1] - results = _predicate_helper(items, prop, op, val) + res.append(filtered_item) - return list(results) + return res def search_stac_items( @@ -721,10 +747,10 @@ def search_stac_items( order: ItemSort = ItemSort.DEFAULT_SORT, include_total_count: bool = False, use_post_request: bool = False, - query: Optional[dict] = None, fields: Optional[dict] = None, sortby: Optional[List[dict]] = None, - filter_cql: Optional[dict] = None, + filter_lang: Optional[str] = None, + filter_cql: Optional[str | dict] = None, ) -> ItemCollection: """ Perform a search, returning a FeatureCollection of stac Item results. @@ -732,6 +758,8 @@ def search_stac_items( :param get_next_url: A function that calculates a page url for the given offset. """ offset = offset or 0 + if sortby is not None: + order = sortby items = list( _model.STORE.search_items( product_names=product_names, @@ -742,6 +770,8 @@ def search_stac_items( intersects=intersects, offset=offset, full_dataset=full_information, + filter_lang=filter_lang, + filter_cql=filter_cql, order=order, ) ) @@ -765,15 +795,18 @@ def search_stac_items( ) if include_total_count: count_matching = _model.STORE.get_count( - product_names=product_names, time=time, bbox=bbox, dataset_ids=dataset_ids + product_names=product_names, + time=time, + bbox=bbox, + intersects=intersects, + dataset_ids=dataset_ids, + filter_lang=filter_lang, + filter_cql=filter_cql, ) extra_properties["numberMatched"] = count_matching extra_properties["context"]["matched"] = count_matching items = [as_stac_item(f) for f in returned] - items = _handle_query_extension(items, query) if query else items - items = _handle_filter_extension(items, filter_cql) if filter_cql else items - items = _handle_sortby_extension(items, sortby) if sortby else items items = _handle_fields_extension(items, fields) if fields else items result = ItemCollection(items, extra_fields=extra_properties) @@ -988,14 +1021,14 @@ def root(): "https://api.stacspec.org/v1.0.0-rc.1/core", "https://api.stacspec.org/v1.0.0-rc.1/item-search", "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features", - "https://api.stacspec.org/v1.0.0-rc.1/item-search#query", "https://api.stacspec.org/v1.0.0-rc.1/item-search#fields", - "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#fields", "https://api.stacspec.org/v1.0.0-rc.1/item-search#sort", - "https://api.stacspec.org/v1.0.0-rc.1/ogcapi-features#sort", "https://api.stacspec.org/v1.0.0-rc.1/item-search#filter", + "http://www.opengis.net/spec/cql2/1.0/conf/cql2-text", "http://www.opengis.net/spec/cql2/1.0/conf/cql2-json", "http://www.opengis.net/spec/cql2/1.0/conf/basic-cql2", + "http://www.opengis.net/spec/cql2/1.0/conf/advanced-comparison-operators", + "http://www.opengis.net/spec/cql2/1.0/conf/spatial-operators", "http://www.opengis.net/spec/ogcapi-features-3/1.0/conf/filter", "https://api.stacspec.org/v1.0.0-rc.1/collections", ] diff --git a/cubedash/_utils.py b/cubedash/_utils.py index 5efe333dd..1643edf98 100644 --- a/cubedash/_utils.py +++ b/cubedash/_utils.py @@ -897,6 +897,7 @@ def alchemy_engine(index: Index) -> Engine: return index.datasets._db._engine +# somewhat misleading name def make_dataset_from_select_fields(index, row): # pylint: disable=protected-access return index.datasets._make(row, full_info=True) diff --git a/cubedash/summary/_stores.py b/cubedash/summary/_stores.py index 0315191c6..eb8eb42ca 100644 --- a/cubedash/summary/_stores.py +++ b/cubedash/summary/_stores.py @@ -7,6 +7,7 @@ from enum import Enum, auto from itertools import groupby from typing import ( + Any, Dict, Generator, Iterable, @@ -25,11 +26,28 @@ import structlog from cachetools.func import lru_cache, ttl_cache from dateutil import tz +from eodatasets3.stac import MAPPING_EO3_TO_STAC from geoalchemy2 import WKBElement from geoalchemy2 import shape as geo_shape from geoalchemy2.shape import from_shape, to_shape +from pygeofilter import ast +from pygeofilter.backends.evaluator import handle +from pygeofilter.backends.sqlalchemy.evaluate import SQLAlchemyFilterEvaluator +from pygeofilter.parsers.cql2_json import parse as parse_cql2_json +from pygeofilter.parsers.cql2_text import parse as parse_cql2_text from shapely.geometry.base import BaseGeometry -from sqlalchemy import DDL, String, and_, exists, func, literal, or_, select, union_all +from sqlalchemy import ( + DDL, + String, + and_, + exists, + func, + literal, + null, + or_, + select, + union_all, +) from sqlalchemy.dialects import postgresql as postgres from sqlalchemy.dialects.postgresql import TSTZRANGE from sqlalchemy.engine import Engine @@ -1192,6 +1210,89 @@ def _add_fields_to_query( return query + def _get_field_exprs( + self, + product_names: Optional[List[str]] = None, + ) -> dict[str, Any]: + """ + Map properties to their sqlalchemy expressions. + Allow for properties to be provided as their STAC property name (ex: created), + their eo3 property name (ex: odc:processing_datetime), + or their searchable field name as defined by the metadata type (ex: creation_time). + """ + if product_names: + products = {self.index.products.get_by_name(name) for name in product_names} + else: + products = set(self.index.products.get_all()) + field_exprs = {} + for product in products: + for value in _utils.get_mutable_dataset_search_fields( + self.index, product.metadata_type + ).values(): + expr = value.alchemy_expression + if hasattr(value, "offset"): + field_exprs[value.offset[-1]] = expr + field_exprs[value.name] = expr + + # add stac property names as well + for k, v in MAPPING_EO3_TO_STAC.items(): + try: + # map to same alchemy expression as the eo3 counterparts + field_exprs[v] = field_exprs[k] + except KeyError: + continue + # manually add fields that aren't included in the metadata search fields + field_exprs["collection"] = ( + select([ODC_DATASET_TYPE.c.name]) + .where(ODC_DATASET_TYPE.c.id == DATASET_SPATIAL.c.dataset_type_ref) + .scalar_subquery() + ) + field_exprs["datetime"] = DATASET_SPATIAL.c.center_time + geom = func.ST_Transform(DATASET_SPATIAL.c.footprint, 4326) + field_exprs["geometry"] = geom + field_exprs["bbox"] = func.Box2D(geom).cast(String) + + return field_exprs + + def _add_filter_to_query( + self, + query: Select, + field_exprs: dict[str, Any], + filter_lang: str, + filter_cql: dict, + ) -> Select: + # use pygeofilter's SQLAlchemy integration to construct the filter query + filter_cql = ( + parse_cql2_text(filter_cql) + if filter_lang == "cql2-text" + else parse_cql2_json(filter_cql) + ) + query = query.filter(FilterEvaluator(field_exprs).evaluate(filter_cql)) + + return query + + def _add_order_to_query( + self, + query: Select, + field_exprs: dict[str, Any], + sortby: list[dict[str, str]], + ) -> Select: + order_clauses = [] + for s in sortby: + field = field_exprs.get(s.get("field")) + # is there any way to check if sortable? + if field is not None: + asc = s.get("direction") == "asc" + if asc: + order_clauses.append(field.asc()) + else: + order_clauses.append(field.desc()) + # there is no field by that name, ignore + # the spec does not specify a handling directive for unspecified fields, + # so we've chosen to ignore them to be in line with the other extensions + query = query.order_by(*order_clauses) + return query + @ttl_cache(ttl=DEFAULT_TTL) def get_arrivals( self, period_length: timedelta @@ -1243,21 +1344,40 @@ def get_count( product_names: Optional[List[str]] = None, time: Optional[Tuple[datetime, datetime]] = None, bbox: Tuple[float, float, float, float] = None, + intersects: BaseGeometry = None, dataset_ids: Sequence[UUID] = None, + filter_lang: str | None = None, + filter_cql: str | dict | None = None, ) -> int: """ - Do the most simple select query to get the count of matching datasets. + Do the base select query to get the count of matching datasets. """ - query: Select = select([func.count()]).select_from(DATASET_SPATIAL) + if filter_cql: # to account the possibiity of 'collection' in the filter + query: Select = select([func.count()]).select_from( + DATASET_SPATIAL.join( + ODC_DATASET, onclause=ODC_DATASET.c.id == DATASET_SPATIAL.c.id + ) + ) + else: + query: Select = select([func.count()]).select_from(DATASET_SPATIAL) query = self._add_fields_to_query( query, product_names=product_names, time=time, bbox=bbox, + intersects=intersects, dataset_ids=dataset_ids, ) + if filter_cql: + query = self._add_filter_to_query( + query, + self._get_field_exprs(product_names), + filter_lang, + filter_cql, + ) + result = self._engine.execute(query).fetchall() if len(result) != 0: @@ -1276,7 +1396,9 @@ def search_items( offset: int = 0, full_dataset: bool = False, dataset_ids: Sequence[UUID] = None, - order: ItemSort = ItemSort.DEFAULT_SORT, + filter_lang: str | None = None, + filter_cql: str | dict | None = None, + order: ItemSort | list[dict[str, str]] = ItemSort.DEFAULT_SORT, ) -> Generator[DatasetItem, None, None]: """ Search datasets using Explorer's spatial table @@ -1322,6 +1444,13 @@ def search_items( dataset_ids=dataset_ids, ) + field_exprs = self._get_field_exprs(product_names) + + if filter_cql: + query = self._add_filter_to_query( + query, field_exprs, filter_lang, filter_cql + ) + # Maybe sort if order == ItemSort.DEFAULT_SORT: query = query.order_by(DATASET_SPATIAL.c.center_time, DATASET_SPATIAL.c.id) @@ -1333,10 +1462,8 @@ def search_items( "Only full-dataset searches can be sorted by recently added" ) query = query.order_by(ODC_DATASET.c.added.desc()) - else: - raise RuntimeError( - f"Unknown item sort order {order!r} (perhaps this is a bug?)" - ) + elif order: # order was provided as a sortby query + query = self._add_order_to_query(query, field_exprs, order) query = query.limit(limit).offset( # TODO: Offset/limit isn't particularly efficient for paging... @@ -1754,6 +1881,19 @@ def get_dataset_footprint_region(self, dataset_id): ) +class FilterEvaluator(SQLAlchemyFilterEvaluator): + """ + Since pygeofilter's SQLAlchemyFilterEvaluator doesn't support treating + invalid/undefined attributes as NULL as per the STAC API Filter spec, + this class overwrites the Evaluator's handling of attributes to return NULL + as the default value if a field is not present in the mapping of sqlalchemy expressions. + """ + + @handle(ast.Attribute) + def attribute(self, node: ast.Attribute): + return self.field_mapping.get(node.name, null()) + + def _refresh_data(please_refresh: Set[PleaseRefresh], store: SummaryStore): """ Refresh product information after a schema update, plus the given kind of data. diff --git a/docs/rtd-requirements.txt b/docs/rtd-requirements.txt index 3de1a536a..546a5f5f0 100644 --- a/docs/rtd-requirements.txt +++ b/docs/rtd-requirements.txt @@ -83,10 +83,12 @@ datacube==1.8.18 # via # datacube-explorer (setup.py) # eodatasets3 +dateparser==1.2.0 + # via pygeofilter defusedxml==0.7.1 - # via eodatasets3 -deprecat==2.1.1 # via datacube +deprecat==2.1.1 + # via eodatasets3 distributed==2023.1.1 # via datacube eodatasets3==0.30.1 @@ -142,7 +144,9 @@ jsonschema==4.20.0 jsonschema-specifications==2023.12.1 # via jsonschema lark==0.12.0 - # via datacube + # via + # datacube + # pygeofilter locket==1.0.0 # via # distributed @@ -192,6 +196,10 @@ psutil==5.9.4 # via distributed psycopg2==2.9.5 # via datacube +pygeofilter==0.2.1 + # via datacube-explorer (setup.py) +pygeoif==1.4.0 + # via pygeofilter pyorbital==1.7.3 # via datacube-explorer (setup.py) pyparsing==3.0.9 @@ -210,6 +218,7 @@ python-dateutil==2.8.2 # botocore # datacube # datacube-explorer (setup.py) + # dateparser # pandas # pystac python-rapidjson==1.9 @@ -217,6 +226,7 @@ python-rapidjson==1.9 pytz==2022.7.1 # via # datacube-explorer (setup.py) + # dateparser # pandas pyyaml==6.0 # via @@ -231,6 +241,8 @@ referencing==0.32.0 # via # jsonschema # jsonschema-specifications +regex==2023.12.25 + # via dateparser requests==2.28.2 # via pyorbital rpds-py==0.16.2 @@ -284,6 +296,10 @@ toolz==0.12.0 # partd tornado==6.2 # via distributed +typing-extensions==4.11.0 + # via pygeoif +tzlocal==5.2 + # via dateparser urllib3==1.26.14 # via # botocore diff --git a/integration_tests/test_stac.py b/integration_tests/test_stac.py index 2a96fdc54..daaa36e0f 100644 --- a/integration_tests/test_stac.py +++ b/integration_tests/test_stac.py @@ -1274,27 +1274,6 @@ def test_stac_search_by_post(stac_client: FlaskClient): validate_item(feature) -def test_stac_query_extension(stac_client: FlaskClient): - query = {"properties.dea:dataset_maturity": dict(eq="nrt")} - rv: Response = stac_client.post( - "/stac/search", - data=json.dumps( - { - "product": "ga_ls8c_ard_3", - "time": "2022-01-01T00:00:00/2022-12-31T00:00:00", - "limit": OUR_DATASET_LIMIT, - "_full": True, - "query": query, - } - ), - headers={"Content-Type": "application/json", "Accept": "application/json"}, - ) - assert rv.status_code == 200 - doc = rv.json - assert len(doc.get("features")) == 1 - assert doc["features"][0]["properties"]["dea:dataset_maturity"] == "nrt" - - def test_stac_fields_extension(stac_client: FlaskClient): fields = {"include": ["properties.dea:dataset_maturity"]} rv: Response = stac_client.post( @@ -1324,11 +1303,13 @@ def test_stac_fields_extension(stac_client: FlaskClient): "properties", "stac_version", "stac_extensions", + "collection", } == keys properties = doc["features"][0]["properties"] assert {"datetime", "dea:dataset_maturity"} == set(properties.keys()) - fields = {"exclude": ["assets.thumbnail:nbart"]} + # exclude without include should remove from full set of properties + fields = {"exclude": ["properties.title"]} rv: Response = stac_client.post( "/stac/search", data=json.dumps( @@ -1346,10 +1327,49 @@ def test_stac_fields_extension(stac_client: FlaskClient): doc = rv.json keys = set(doc["features"][0].keys()) assert "collection" in keys - properties = doc["features"][0]["assets"] - assert "thumbnail:nbart" not in set(properties.keys()) + properties = doc["features"][0]["properties"] + assert "title" not in set(properties.keys()) + assert "dea:dataset_maturity" in set(properties.keys()) + + # with get + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&fields=+properties.title" + ) + assert rv.status_code == 200 + doc = rv.json + assert doc.get("features") + properties = doc["features"][0]["properties"] + assert {"datetime", "title"} == set(properties.keys()) - # should we do an invalid field as well? + # invalid field + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&fields=properties.foo" + ) + assert rv.status_code == 200 + doc = rv.json + assert doc.get("features") + properties = doc["features"][0]["properties"] + assert {"datetime"} == set(properties.keys()) + + # exclude properties, but nested field properties.datetime is included by default + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&fields=-properties" + ) + assert rv.status_code == 200 + doc = rv.json + assert doc.get("features") + properties = doc["features"][0]["properties"] + assert {"datetime"} == set(properties.keys()) + + # empty include and exclude should return just default fields + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&fields=" + ) + assert rv.status_code == 200 + doc = rv.json + assert doc.get("features") + properties = doc["features"][0]["properties"] + assert {"datetime"} == set(properties.keys()) def test_stac_sortby_extension(stac_client: FlaskClient): @@ -1398,16 +1418,51 @@ def test_stac_sortby_extension(stac_client: FlaskClient): > doc["features"][i]["properties"]["datetime"] ) + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&sortby=assets" + ) + assert rv.status_code == 400 + + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&sortby=id,-properties.datetime" + ) + doc = rv.json + for i in range(1, len(doc["features"])): + assert doc["features"][i - 1]["id"] < doc["features"][i]["id"] + + # use of property prefixes shouldn't impact result + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&sortby=item.id,-datetime" + ) + assert rv.json == doc + + # ignore undefined field + rv: Response = stac_client.get( + "/stac/search?collection=ga_ls8c_ard_3&limit=5&sortby=id,-datetime,foo" + ) + assert rv.json["features"] == doc["features"] + + # sorting across pages + next_link = _get_next_href(doc) + next_link = next_link.replace("http://localhost", "") + rv: Response = stac_client.get(next_link) + last_item = doc["features"][-1] + next_item = rv.json["features"][0] + assert last_item["id"] < next_item["id"] + def test_stac_filter_extension(stac_client: FlaskClient): - filter_cql = { + filter_json = { "op": "and", "args": [ { "op": "<>", - "args": [{"property": "properties.dea:dataset_maturity"}, "final"], + "args": [{"property": "dea:dataset_maturity"}, "final"], + }, + { + "op": ">=", + "args": [{"property": "eo:cloud_cover"}, float(2)], }, - {"op": ">=", "args": [{"property": "properties.eo:cloud_cover"}, float(2)]}, ], } rv: Response = stac_client.post( @@ -1418,16 +1473,70 @@ def test_stac_filter_extension(stac_client: FlaskClient): "time": "2022-01-01T00:00:00/2022-12-31T00:00:00", "limit": OUR_DATASET_LIMIT, "_full": True, - "filter": filter_cql, + "filter": filter_json, } ), headers={"Content-Type": "application/json", "Accept": "application/json"}, ) assert rv.status_code == 200 - doc = rv.json - features = doc.get("features") - assert len(features) == 2 + features = rv.json.get("features") + assert len(features) == rv.json.get("numberMatched") == 2 ids = [f["id"] for f in features] - assert "fc792b3b-a685-4c0f-9cf6-f5257f042c64", ( - "192276c6-8fa4-46a9-8bc6-e04e157974b9" in ids + assert "fc792b3b-a685-4c0f-9cf6-f5257f042c64" in ids + assert "192276c6-8fa4-46a9-8bc6-e04e157974b9" in ids + + # test cql2-text + filter_text = "collection='ga_ls8c_ard_3' AND dataset_maturity <> 'final' AND cloud_cover >= 2" + rv: Response = stac_client.get(f"/stac/search?filter={filter_text}") + assert rv.json.get("numberMatched") == 2 + + filter_text = "view:sun_azimuth < 40 AND dataset_maturity = 'final'" + rv: Response = stac_client.get( + f"/stac/search?collections=ga_ls8c_ard_3&filter={filter_text}" + ) + assert rv.json.get("numberMatched") == 4 + + # test invalid property name treated as null + rv: Response = stac_client.get( + "/stac/search?filter=item.collection='ga_ls8c_ard_3' AND properties.foo > 2" ) + assert rv.json.get("numberMatched") == 0 + + rv: Response = stac_client.get( + "/stac/search?filter=collection='ga_ls8c_ard_3' AND foo IS NULL" + ) + assert rv.json.get("numberMatched") == 21 + + # test lang mismatch + rv: Response = stac_client.post( + "/stac/search", + data=json.dumps( + { + "product": "ga_ls8c_ard_3", + "time": "2022-01-01T00:00:00/2022-12-31T00:00:00", + "limit": OUR_DATASET_LIMIT, + "_full": True, + "filter-lang": "cql2-text", + "filter": filter_json, + } + ), + headers={"Content-Type": "application/json", "Accept": "application/json"}, + ) + assert rv.status_code == 400 + + # filter-crs invalid value + rv: Response = stac_client.post( + "/stac/search", + data=json.dumps( + { + "product": "ga_ls8c_ard_3", + "time": "2022-01-01T00:00:00/2022-12-31T00:00:00", + "limit": OUR_DATASET_LIMIT, + "_full": True, + "filter-crs": "http://www.opengis.net/def/crs/OGC/1.3/CRS83", + "filter": filter_json, + } + ), + headers={"Content-Type": "application/json", "Accept": "application/json"}, + ) + assert rv.status_code == 400 diff --git a/setup.py b/setup.py index 308b79898..da83e3355 100755 --- a/setup.py +++ b/setup.py @@ -96,6 +96,7 @@ "sqlalchemy>=1.4", "structlog>=20.2.0", "pytz", + "pygeofilter", ], tests_require=tests_require, extras_require=extras_require,