From 3c082122f61f4595df14dc6921f1f65ae310fd19 Mon Sep 17 00:00:00 2001 From: Mwad22 <51929507+Mwad22@users.noreply.github.com> Date: Fri, 11 Jun 2021 21:15:55 -0400 Subject: [PATCH] Made simple feature names default on data retrieval, provides option for names prefixed with featureviews Signed-off-by: Mwad22 <51929507+Mwad22@users.noreply.github.com> --- README.md | 1 + docs/quickstart.md | 3 +- sdk/python/feast/errors.py | 7 + sdk/python/feast/feature_store.py | 56 ++++- sdk/python/feast/infra/gcp.py | 2 + sdk/python/feast/infra/local.py | 2 + .../feast/infra/offline_stores/bigquery.py | 13 +- sdk/python/feast/infra/offline_stores/file.py | 14 +- .../infra/offline_stores/offline_store.py | 1 + sdk/python/feast/infra/provider.py | 21 +- sdk/python/tests/foo_provider.py | 1 + sdk/python/tests/test_e2e_local.py | 1 + sdk/python/tests/test_historical_retrieval.py | 191 +++++++++++++++++- .../test_offline_online_store_consistency.py | 3 +- sdk/python/tests/test_online_retrieval.py | 34 +++- 15 files changed, 316 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index f64f91c27dd..00604e2f788 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ training_df = store.get_historical_features( 'driver_hourly_stats:acc_rate', 'driver_hourly_stats:avg_daily_trips' ], + full_feature_names=True ).to_df() print(training_df.head()) diff --git a/docs/quickstart.md b/docs/quickstart.md index 6b94b04a812..b98bfe8acac 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -234,7 +234,8 @@ feature_vector = store.get_online_features( 'driver_hourly_stats:acc_rate', 'driver_hourly_stats:avg_daily_trips' ], - entity_rows=[{"driver_id": 1001}] + entity_rows=[{"driver_id": 1001}], + full_feature_names=True ).to_dict() pprint(feature_vector) diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index a0b3e2bf49a..0c026af31a0 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -71,6 +71,13 @@ def __init__(self, offline_store_name: str, data_source_name: str): ) +class FeatureNameCollisionError(Exception): + def __init__(self, feature_name_collisions: str): + super().__init__( + f"The following feature name(s) have collisions: {feature_name_collisions}. Set 'feature_names_only' argument in the data retrieval function to False to use the full feature name which is prefixed by the feature view name." + ) + + class FeastOnlineStoreUnsupportedDataSource(Exception): def __init__(self, online_store_name: str, data_source_name: str): super().__init__( diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 8503632d5c2..b30ce5f4399 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -24,7 +24,11 @@ from feast import utils from feast.entity import Entity -from feast.errors import FeastProviderLoginError, FeatureViewNotFoundException +from feast.errors import ( + FeastProviderLoginError, + FeatureNameCollisionError, + FeatureViewNotFoundException, +) from feast.feature_view import FeatureView from feast.inference import infer_entity_value_type_from_feature_views from feast.infra.provider import Provider, RetrievalJob, get_provider @@ -244,7 +248,10 @@ def apply( @log_exceptions_and_usage def get_historical_features( - self, entity_df: Union[pd.DataFrame, str], feature_refs: List[str], + self, + entity_df: Union[pd.DataFrame, str], + feature_refs: List[str], + full_feature_names: bool = False, ) -> RetrievalJob: """Enrich an entity dataframe with historical feature values for either training or batch scoring. @@ -266,6 +273,10 @@ def get_historical_features( SQL query. The query must be of a format supported by the configured offline store (e.g., BigQuery) feature_refs: A list of features that should be retrieved from the offline store. Feature references are of the format "feature_view:feature", e.g., "customer_fv:daily_transactions". + full_feature_names: By default, this value is set to False. This strips the feature view prefixes from the data + and returns only the feature name, changing them from the format "feature_view__feature" to "feature" + (e.g., "customer_fv__daily_transactions" changes to "daily_transactions"). Set the value to True for + the feature names to be prefixed by the feature view name in the format "feature_view__feature". Returns: RetrievalJob which can be used to materialize the results. @@ -278,12 +289,12 @@ def get_historical_features( >>> fs = FeatureStore(config=RepoConfig(provider="gcp")) >>> retrieval_job = fs.get_historical_features( >>> entity_df="SELECT event_timestamp, order_id, customer_id from gcp_project.my_ds.customer_orders", - >>> feature_refs=["customer:age", "customer:avg_orders_1d", "customer:avg_orders_7d"] - >>> ) + >>> feature_refs=["customer:age", "customer:avg_orders_1d", "customer:avg_orders_7d"], + >>> full_feature_names=False + >>> ) >>> feature_data = retrieval_job.to_df() >>> model.fit(feature_data) # insert your modeling framework here. """ - all_feature_views = self._registry.list_feature_views(project=self.project) try: feature_views = _get_requested_feature_views( @@ -301,6 +312,7 @@ def get_historical_features( entity_df, self._registry, self.project, + full_feature_names, ) except FeastProviderLoginError as e: sys.exit(e) @@ -467,7 +479,10 @@ def tqdm_builder(length): @log_exceptions_and_usage def get_online_features( - self, feature_refs: List[str], entity_rows: List[Dict[str, Any]], + self, + feature_refs: List[str], + entity_rows: List[Dict[str, Any]], + full_feature_names: bool = False, ) -> OnlineResponse: """ Retrieves the latest online feature data. @@ -535,7 +550,7 @@ def get_online_features( project=self.project, allow_cache=True ) - grouped_refs = _group_refs(feature_refs, all_feature_views) + grouped_refs = _group_refs(feature_refs, all_feature_views, full_feature_names) for table, requested_features in grouped_refs: entity_keys = _get_table_entity_keys( table, union_of_entity_keys, entity_name_to_join_key_map @@ -552,13 +567,21 @@ def get_online_features( if feature_data is None: for feature_name in requested_features: - feature_ref = f"{table.name}__{feature_name}" + feature_ref = ( + f"{table.name}__{feature_name}" + if full_feature_names + else feature_name + ) result_row.statuses[ feature_ref ] = GetOnlineFeaturesResponse.FieldStatus.NOT_FOUND else: for feature_name in feature_data: - feature_ref = f"{table.name}__{feature_name}" + feature_ref = ( + f"{table.name}__{feature_name}" + if full_feature_names + else feature_name + ) if feature_name in requested_features: result_row.fields[feature_ref].CopyFrom( feature_data[feature_name] @@ -587,7 +610,9 @@ def _entity_row_to_field_values( def _group_refs( - feature_refs: List[str], all_feature_views: List[FeatureView] + feature_refs: List[str], + all_feature_views: List[FeatureView], + full_feature_names: bool = False, ) -> List[Tuple[FeatureView, List[str]]]: """ Get list of feature views and corresponding feature names based on feature references""" @@ -597,12 +622,23 @@ def _group_refs( # view name to feature names views_features = defaultdict(list) + feature_set = set() + feature_collision_set = set() + for ref in feature_refs: view_name, feat_name = ref.split(":") + if feat_name in feature_set: + feature_collision_set.add(feat_name) + else: + feature_set.add(feat_name) if view_name not in view_index: raise FeatureViewNotFoundException(view_name) views_features[view_name].append(feat_name) + if not full_feature_names and len(feature_collision_set) > 0: + err = ", ".join(x for x in feature_collision_set) + raise FeatureNameCollisionError(err) + result = [] for view_name, feature_names in views_features.items(): result.append((view_index[view_name], feature_names)) diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index f33b501d62e..9e307d761b4 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -128,6 +128,7 @@ def get_historical_features( entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: job = self.offline_store.get_historical_features( config=config, @@ -136,5 +137,6 @@ def get_historical_features( entity_df=entity_df, registry=registry, project=project, + full_feature_names=full_feature_names, ) return job diff --git a/sdk/python/feast/infra/local.py b/sdk/python/feast/infra/local.py index a76f49b2c4d..23c813e6083 100644 --- a/sdk/python/feast/infra/local.py +++ b/sdk/python/feast/infra/local.py @@ -127,6 +127,7 @@ def get_historical_features( entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: return self.offline_store.get_historical_features( config=config, @@ -135,6 +136,7 @@ def get_historical_features( entity_df=entity_df, registry=registry, project=project, + full_feature_names=full_feature_names, ) diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 2c0b115e0cd..e07f4c8fc4c 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -84,6 +84,7 @@ def get_historical_features( entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: # TODO: Add entity_df validation in order to fail before interacting with BigQuery @@ -105,7 +106,7 @@ def get_historical_features( # Build a query context containing all information required to template the BigQuery SQL query query_context = get_feature_view_query_context( - feature_refs, feature_views, registry, project + feature_refs, feature_views, registry, project, full_feature_names ) # TODO: Infer min_timestamp and max_timestamp from entity_df @@ -116,6 +117,7 @@ def get_historical_features( max_timestamp=datetime.now() + timedelta(days=1), left_table_query_string=str(table.reference), entity_df_event_timestamp_col=entity_df_event_timestamp_col, + full_feature_names=full_feature_names, ) job = BigQueryRetrievalJob(query=query, client=client) @@ -292,11 +294,12 @@ def get_feature_view_query_context( feature_views: List[FeatureView], registry: Registry, project: str, + full_feature_names: bool = False, ) -> List[FeatureViewQueryContext]: """Build a query context containing all information required to template a BigQuery point-in-time SQL query""" feature_views_to_feature_map = _get_requested_feature_views_to_features_dict( - feature_refs, feature_views + feature_refs, feature_views, full_feature_names ) query_context = [] @@ -351,6 +354,7 @@ def build_point_in_time_query( max_timestamp: datetime, left_table_query_string: str, entity_df_event_timestamp_col: str, + full_feature_names: bool = False, ): """Build point-in-time query between each feature view table and the entity dataframe""" template = Environment(loader=BaseLoader()).from_string( @@ -367,6 +371,7 @@ def build_point_in_time_query( [entity for fv in feature_view_query_contexts for entity in fv.entities] ), "featureviews": [asdict(context) for context in feature_view_query_contexts], + "full_feature_names": full_feature_names, } query = template.render(template_context) @@ -440,7 +445,7 @@ def _get_bigquery_client(): {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} {{ featureview.entity_selections | join(', ')}}, {% for feature in featureview.features %} - {{ feature }} as {{ featureview.name }}__{{ feature }}{% if loop.last %}{% else %}, {% endif %} + {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} ), @@ -533,7 +538,7 @@ def _get_bigquery_client(): SELECT entity_row_unique_id, {% for feature in featureview.features %} - {{ featureview.name }}__{{ feature }}, + {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}, {% endfor %} FROM {{ featureview.name }}__cleaned ) USING (entity_row_unique_id) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index acd12ff9003..0513dc884a4 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -40,6 +40,7 @@ def get_historical_features( entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> FileRetrievalJob: if not isinstance(entity_df, pd.DataFrame): raise ValueError( @@ -59,9 +60,8 @@ def get_historical_features( raise ValueError( f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." ) - feature_views_to_features = _get_requested_feature_views_to_features_dict( - feature_refs, feature_views + feature_refs, feature_views, full_feature_names ) # Create lazy function that is only called from the RetrievalJob object @@ -125,14 +125,16 @@ def evaluate_historical_retrieval(): # Modify the separator for feature refs in column names to double underscore. We are using # double underscore as separator for consistency with other databases like BigQuery, # where there are very few characters available for use as separators - prefixed_feature_name = f"{feature_view.name}__{feature}" - + if full_feature_names: + formatted_feature_name = f"{feature_view.name}__{feature}" + else: + formatted_feature_name = feature # Add the feature name to the list of columns - feature_names.append(prefixed_feature_name) + feature_names.append(formatted_feature_name) # Ensure that the source dataframe feature column includes the feature view name as a prefix df_to_join.rename( - columns={feature: prefixed_feature_name}, inplace=True, + columns={feature: formatted_feature_name}, inplace=True, ) # Build a list of entity columns to join on (from the right table) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index d31d11aae2a..c1c2279dc61 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -66,5 +66,6 @@ def get_historical_features( entity_df: Union[pd.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: pass diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 5d4f8d6cf0c..353be43b766 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -2,7 +2,7 @@ import importlib from datetime import datetime from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import pandas import pyarrow @@ -116,6 +116,7 @@ def get_historical_features( entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: pass @@ -179,15 +180,24 @@ def get_provider(config: RepoConfig, repo_path: Path) -> Provider: def _get_requested_feature_views_to_features_dict( - feature_refs: List[str], feature_views: List[FeatureView] + feature_refs: List[str], feature_views: List[FeatureView], full_feature_names: bool ) -> Dict[FeatureView, List[str]]: - """Create a dict of FeatureView -> List[Feature] for all requested features""" + """Create a dict of FeatureView -> List[Feature] for all requested features. + Features are prefixed by the feature view name, set value to True to obtain only the feature names.""" feature_views_to_feature_map = {} # type: Dict[FeatureView, List[str]] + feature_set = set() # type: Set[str] + feature_collision_set = set() # type: Set[str] + for ref in feature_refs: ref_parts = ref.split(":") feature_view_from_ref = ref_parts[0] feature_from_ref = ref_parts[1] + if feature_from_ref in feature_set: + feature_collision_set.add(feature_from_ref) + else: + feature_set.add(feature_from_ref) + found = False for feature_view_from_registry in feature_views: if feature_view_from_registry.name == feature_view_from_ref: @@ -203,6 +213,11 @@ def _get_requested_feature_views_to_features_dict( if not found: raise ValueError(f"Could not find feature view from reference {ref}") + + if not full_feature_names and len(feature_collision_set) > 0: + err = ", ".join(x for x in feature_collision_set) + raise errors.FeatureNameCollisionError(err) + return feature_views_to_feature_map diff --git a/sdk/python/tests/foo_provider.py b/sdk/python/tests/foo_provider.py index 8b7e5f4d368..ac902376f52 100644 --- a/sdk/python/tests/foo_provider.py +++ b/sdk/python/tests/foo_provider.py @@ -62,6 +62,7 @@ def get_historical_features( entity_df: Union[pandas.DataFrame, str], registry: Registry, project: str, + full_feature_names: bool = False, ) -> RetrievalJob: pass diff --git a/sdk/python/tests/test_e2e_local.py b/sdk/python/tests/test_e2e_local.py index d61d8caa7b1..6057226b738 100644 --- a/sdk/python/tests/test_e2e_local.py +++ b/sdk/python/tests/test_e2e_local.py @@ -32,6 +32,7 @@ def _assert_online_features( "driver_hourly_stats:avg_daily_trips", ], entity_rows=[{"driver_id": 1001}], + full_feature_names=True, ) assert "driver_hourly_stats__avg_daily_trips" in result.to_dict() diff --git a/sdk/python/tests/test_historical_retrieval.py b/sdk/python/tests/test_historical_retrieval.py index 83f48ccd961..ff2093307c8 100644 --- a/sdk/python/tests/test_historical_retrieval.py +++ b/sdk/python/tests/test_historical_retrieval.py @@ -17,6 +17,7 @@ from feast import errors, utils from feast.data_source import BigQuerySource, FileSource from feast.entity import Entity +from feast.errors import FeatureNameCollisionError from feast.feature import Feature from feast.feature_store import FeatureStore from feast.feature_view import FeatureView @@ -111,6 +112,7 @@ def create_customer_daily_profile_feature_view(source): Feature(name="current_balance", dtype=ValueType.FLOAT), Feature(name="avg_passenger_count", dtype=ValueType.FLOAT), Feature(name="lifetime_trip_count", dtype=ValueType.INT32), + Feature(name="avg_daily_trips", dtype=ValueType.INT32), ], input=source, ttl=timedelta(days=2), @@ -142,6 +144,7 @@ def get_expected_training_df( driver_fv: FeatureView, orders_df: pd.DataFrame, event_timestamp: str, + full_feature_names: bool = False, ): # Convert all pandas dataframes into records with UTC timestamps order_records = convert_timestamp_records_to_utc( @@ -177,6 +180,10 @@ def get_expected_training_df( f"driver_stats__{k}": driver_record.get(k, None) for k in ("conv_rate", "avg_daily_trips") } + if full_feature_names + else { + k: driver_record.get(k, None) for k in ("conv_rate", "avg_daily_trips") + } ) order_record.update( { @@ -187,6 +194,15 @@ def get_expected_training_df( "lifetime_trip_count", ) } + if full_feature_names + else { + k: customer_record.get(k, None) + for k in ( + "current_balance", + "avg_passenger_count", + "lifetime_trip_count", + ) + } ) # Convert records back to pandas dataframe @@ -199,12 +215,21 @@ def get_expected_training_df( # Cast some columns to expected types, since we lose information when converting pandas DFs into Python objects. expected_df["order_is_success"] = expected_df["order_is_success"].astype("int32") - expected_df["customer_profile__current_balance"] = expected_df[ - "customer_profile__current_balance" - ].astype("float32") - expected_df["customer_profile__avg_passenger_count"] = expected_df[ - "customer_profile__avg_passenger_count" - ].astype("float32") + + if full_feature_names: + expected_df["customer_profile__current_balance"] = expected_df[ + "customer_profile__current_balance" + ].astype("float32") + expected_df["customer_profile__avg_passenger_count"] = expected_df[ + "customer_profile__avg_passenger_count" + ].astype("float32") + else: + expected_df["current_balance"] = expected_df["current_balance"].astype( + "float32" + ) + expected_df["avg_passenger_count"] = expected_df["avg_passenger_count"].astype( + "float32" + ) return expected_df @@ -294,6 +319,7 @@ def test_historical_features_from_parquet_sources(infer_event_timestamp_col): "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], + full_feature_names=True, ) actual_df = job.to_df() @@ -303,7 +329,13 @@ def test_historical_features_from_parquet_sources(infer_event_timestamp_col): else "e_ts" ) expected_df = get_expected_training_df( - customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names=True, ) assert_frame_equal( expected_df.sort_values( @@ -314,6 +346,59 @@ def test_historical_features_from_parquet_sources(infer_event_timestamp_col): ).reset_index(drop=True), ) + event_timestamp = ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns + else "e_ts" + ) + expected_df_fno = get_expected_training_df( + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names=False, + ) + + # Test parquet sources when using feature names only (strip prefixed feature views) + job = store.get_historical_features( + entity_df=orders_df, + feature_refs=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=False, + ) + + actual_df_fno = job.to_df() + assert_frame_equal( + expected_df_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + actual_df_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + ) + + # Test for colliding feature names when featureview prefixes are stripped + with pytest.raises(FeatureNameCollisionError): + store.get_historical_features( + entity_df=orders_df, + feature_refs=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + "customer_profile:avg_daily_trips", + ], + full_feature_names=False, + ) + @pytest.mark.integration @pytest.mark.parametrize( @@ -427,7 +512,13 @@ def test_historical_features_from_bigquery_sources( else "e_ts" ) expected_df = get_expected_training_df( - customer_df, customer_fv, driver_df, driver_fv, orders_df, event_timestamp, + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names=True, ) job_from_sql = store.get_historical_features( @@ -439,6 +530,7 @@ def test_historical_features_from_bigquery_sources( "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], + full_feature_names=True, ) start_time = datetime.utcnow() @@ -497,6 +589,7 @@ def test_historical_features_from_bigquery_sources( "customer_profile:avg_passenger_count", "customer_profile:lifetime_trip_count", ], + full_feature_names=True, ) # Rename the join key; this should now raise an error. @@ -544,3 +637,85 @@ def test_historical_features_from_bigquery_sources( .reset_index(drop=True), check_dtype=False, ) + + # Test BigQuery sources when using feature names only (strip prefixed feature views) + + expected_df_fno = get_expected_training_df( + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names=True, + ) + + job_from_sql = store.get_historical_features( + entity_df=entity_df_query, + feature_refs=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=False, + ) + + actual_df_from_sql_entities_fno = job_from_sql.to_df() + + assert_frame_equal( + expected_df_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + actual_df_from_sql_entities_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + check_dtype=False, + ) + + job_from_df = store.get_historical_features( + entity_df=orders_df, + feature_refs=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=False, + ) + + if provider_type == "gcp_custom_offline_config": + # Make sure that custom dataset name is being used from the offline_store config + assertpy.assert_that(job_from_df.query).contains("foo.entity_df") + else: + # If the custom dataset name isn't provided in the config, use default `feast` name + assertpy.assert_that(job_from_df.query).contains("feast.entity_df") + + actual_df_from_df_entities_fno = job_from_df.to_df() + + assert_frame_equal( + expected_df_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + actual_df_from_df_entities_fno.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + check_dtype=False, + ) + + # Test for colliding feature names when featureview prefixes are stripped + with pytest.raises(FeatureNameCollisionError): + store.get_historical_features( + entity_df=orders_df, + feature_refs=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + "customer_profile:avg_daily_trips", + ], + full_feature_names=False, + ) diff --git a/sdk/python/tests/test_offline_online_store_consistency.py b/sdk/python/tests/test_offline_online_store_consistency.py index 02943fd2eb8..d7a58692817 100644 --- a/sdk/python/tests/test_offline_online_store_consistency.py +++ b/sdk/python/tests/test_offline_online_store_consistency.py @@ -196,7 +196,7 @@ def check_offline_and_online_features( ) -> None: # Check online store response_dict = fs.get_online_features( - [f"{fv.name}:value"], [{"driver": driver_id}] + [f"{fv.name}:value"], [{"driver": driver_id}], full_feature_names=True ).to_dict() if expected_value: @@ -210,6 +210,7 @@ def check_offline_and_online_features( {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} ), feature_refs=[f"{fv.name}:value"], + full_feature_names=True, ).to_df() if expected_value: diff --git a/sdk/python/tests/test_online_retrieval.py b/sdk/python/tests/test_online_retrieval.py index b29b0c946fa..a443efd2439 100644 --- a/sdk/python/tests/test_online_retrieval.py +++ b/sdk/python/tests/test_online_retrieval.py @@ -98,6 +98,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}, {"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() assert "driver_locations__lon" in result @@ -110,10 +111,34 @@ def test_online() -> None: assert result["customer_profile__name"] == ["John", "John"] assert result["customer_driver_combined__trips"] == [7, 7] + # Ensure setting full_feature_names to False strips featureview prefixes + # from feature names + result = store.get_online_features( + feature_refs=[ + "driver_locations:lon", + "customer_profile:avg_orders_day", + "customer_profile:name", + "customer_driver_combined:trips", + ], + entity_rows=[{"driver": 1, "customer": 5}, {"driver": 1, "customer": 5}], + full_feature_names=False, + ).to_dict() + + assert "lon" in result + assert "avg_orders_day" in result + assert "name" in result + assert result["driver"] == [1, 1] + assert result["customer"] == [5, 5] + assert result["lon"] == ["1.0", "1.0"] + assert result["avg_orders_day"] == [1.0, 1.0] + assert result["name"] == ["John", "John"] + assert result["trips"] == [7, 7] + # Ensure features are still in result when keys not found result = store.get_online_features( feature_refs=["customer_driver_combined:trips"], entity_rows=[{"driver": 0, "customer": 0}], + full_feature_names=True, ).to_dict() assert "customer_driver_combined__trips" in result @@ -121,7 +146,9 @@ def test_online() -> None: # invalid table reference with pytest.raises(FeatureViewNotFoundException): store.get_online_features( - feature_refs=["driver_locations_bad:lon"], entity_rows=[{"driver": 1}], + feature_refs=["driver_locations_bad:lon"], + entity_rows=[{"driver": 1}], + full_feature_names=True, ) # Create new FeatureStore object with fast cache invalidation @@ -146,6 +173,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] @@ -166,6 +194,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() # Restore registry.db so that we can see if it actually reloads registry @@ -180,6 +209,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] @@ -205,6 +235,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7] @@ -224,6 +255,7 @@ def test_online() -> None: "customer_driver_combined:trips", ], entity_rows=[{"driver": 1, "customer": 5}], + full_feature_names=True, ).to_dict() assert result["driver_locations__lon"] == ["1.0"] assert result["customer_driver_combined__trips"] == [7]