Skip to content

Commit

Permalink
Add e2e tests for historical feature retrieval
Browse files Browse the repository at this point in the history
Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng committed Oct 22, 2020
1 parent 4141cd0 commit 342c49d
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 51 deletions.
55 changes: 55 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import os
from pathlib import Path

import pyspark
import pytest

from feast import Client


def pytest_addoption(parser):
parser.addoption("--core_url", action="store", default="localhost:6565")
Expand Down Expand Up @@ -34,3 +40,52 @@ def pytest_runtest_setup(item):
previousfailed = getattr(item.parent, "_previousfailed", None)
if previousfailed is not None:
pytest.xfail("previous test failed (%s)" % previousfailed.name)


@pytest.fixture(scope="session")
def feast_version():
return "0.8-SNAPSHOT"


@pytest.fixture(scope="session")
def ingestion_job_jar(pytestconfig, feast_version):
default_path = (
Path(__file__).parent.parent.parent
/ "spark"
/ "ingestion"
/ "target"
/ f"feast-ingestion-spark-{feast_version}.jar"
)

return pytestconfig.getoption("ingestion_jar") or f"file://{default_path}"


@pytest.fixture(scope="session")
def feast_client(pytestconfig, ingestion_job_jar):
redis_host, redis_port = pytestconfig.getoption("redis_url").split(":")

if pytestconfig.getoption("env") == "local":
return Client(
core_url=pytestconfig.getoption("core_url"),
serving_url=pytestconfig.getoption("serving_url"),
spark_launcher="standalone",
spark_standalone_master="local",
spark_home=os.getenv("SPARK_HOME") or os.path.dirname(pyspark.__file__),
spark_ingestion_jar=ingestion_job_jar,
redis_host=redis_host,
redis_port=redis_port,
)

if pytestconfig.getoption("env") == "gcloud":
return Client(
core_url=pytestconfig.getoption("core_url"),
serving_url=pytestconfig.getoption("serving_url"),
spark_launcher="dataproc",
dataproc_cluster_name=pytestconfig.getoption("dataproc_cluster_name"),
dataproc_project=pytestconfig.getoption("dataproc_project"),
dataproc_region=pytestconfig.getoption("dataproc_region"),
dataproc_staging_location=os.path.join(
pytestconfig.getoption("staging_path"), "dataproc"
),
spark_ingestion_jar=ingestion_job_jar,
)
1 change: 1 addition & 0 deletions tests/e2e/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ mock==2.0.0
numpy==1.16.4
pandas~=1.0.0
pandavro==1.5.*
pyspark==2.4.2
pytest==6.0.0
pytest-benchmark==3.2.2
pytest-mock==1.10.4
Expand Down
125 changes: 125 additions & 0 deletions tests/e2e/test_historical_features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import tempfile
import uuid
from datetime import datetime, timedelta
from urllib.parse import urlparse

import numpy as np
import pandas as pd
import pytest
from google.protobuf.duration_pb2 import Duration
from pandas._testing import assert_frame_equal

from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType
from feast.data_format import ParquetFormat
from feast.staging.storage_client import get_staging_client


@pytest.fixture(scope="function")
def staging_path(pytestconfig, tmp_path):
if pytestconfig.getoption("env") == "local":
return f"file://{tmp_path}"

staging_path = pytestconfig.getoption("staging_path")
return os.path.join(staging_path, str(uuid.uuid4()))


def test_historical_features(feast_client: Client, staging_path: str):
customer_entity = Entity(
name="customer_id", description="Customer", value_type=ValueType.INT64
)
feast_client.apply_entity(customer_entity)

max_age = Duration()
max_age.FromSeconds(2 * 86400)

transactions_feature_table = FeatureTable(
name="transactions",
entities=["customer_id"],
features=[
Feature("daily_transactions", ValueType.DOUBLE),
Feature("total_transactions", ValueType.DOUBLE),
],
batch_source=FileSource(
"event_timestamp",
"created_timestamp",
ParquetFormat(),
os.path.join(staging_path, "transactions"),
),
max_age=max_age,
)

feast_client.apply_feature_table(transactions_feature_table)

retrieval_date = (
datetime.utcnow()
.replace(hour=0, minute=0, second=0, microsecond=0)
.replace(tzinfo=None)
)
retrieval_outside_max_age_date = retrieval_date + timedelta(1)
event_date = retrieval_date - timedelta(2)
creation_date = retrieval_date - timedelta(1)

customers = [1001, 1002, 1003, 1004, 1005]
daily_transactions = [np.random.rand(seed=0) * 10 for _ in customers]
total_transactions = [np.random.rand(seed=0) * 100 for _ in customers]

transactions_df = pd.DataFrame(
{
"event_timestamp": [event_date for _ in customers],
"created_timestamp": [creation_date for _ in customers],
"customer_id": customers,
"daily_transactions": daily_transactions,
"total_transactions": total_transactions,
}
)

feast_client.ingest(transactions_feature_table, transactions_df)

feature_refs = ["transactions:daily_transactions"]

customer_df = pd.DataFrame(
{
"event_timestamp": [retrieval_date for _ in customers]
+ [retrieval_outside_max_age_date for _ in customers],
"customer_id": customers + customers,
}
)

with tempfile.TemporaryDirectory() as tempdir:
df_export_path = os.path.join(tempdir, "customers.parquets")
customer_df.to_parquet(df_export_path)
scheme, _, remote_path, _, _, _ = urlparse(staging_path)
staging_client = get_staging_client(scheme)
staging_client.upload_file(df_export_path, None, remote_path)
customer_source = FileSource(
"event_timestamp",
"event_timestamp",
ParquetFormat(),
os.path.join(staging_path, os.path.basename(df_export_path)),
)

job = feast_client.get_historical_features(feature_refs, customer_source)
output_dir = job.get_output_file_uri()

_, _, joined_df_destination_path, _, _, _ = urlparse(output_dir)
joined_df = pd.read_parquet(joined_df_destination_path)

expected_joined_df = pd.DataFrame(
{
"event_timestamp": [retrieval_date for _ in customers]
+ [retrieval_outside_max_age_date for _ in customers],
"customer_id": customers + customers,
"transactions__daily_transactions": daily_transactions
+ [None] * len(customers),
}
)

assert_frame_equal(
joined_df.sort_values(by=["customer_id", "event_timestamp"]).reset_index(
drop=True
),
expected_joined_df.sort_values(
by=["customer_id", "event_timestamp"]
).reset_index(drop=True),
)
51 changes: 0 additions & 51 deletions tests/e2e/test_online_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
import time
import uuid
from datetime import datetime, timedelta
from pathlib import Path

import avro.schema
import numpy as np
import pandas as pd
import pyspark
import pytest
import pytz
from avro.io import BinaryEncoder, DatumWriter
Expand Down Expand Up @@ -41,55 +39,6 @@ def generate_data():
return df


@pytest.fixture(scope="session")
def feast_version():
return "0.8-SNAPSHOT"


@pytest.fixture(scope="session")
def ingestion_job_jar(pytestconfig, feast_version):
default_path = (
Path(__file__).parent.parent.parent
/ "spark"
/ "ingestion"
/ "target"
/ f"feast-ingestion-spark-{feast_version}.jar"
)

return pytestconfig.getoption("ingestion_jar") or f"file://{default_path}"


@pytest.fixture(scope="session")
def feast_client(pytestconfig, ingestion_job_jar):
redis_host, redis_port = pytestconfig.getoption("redis_url").split(":")

if pytestconfig.getoption("env") == "local":
return Client(
core_url=pytestconfig.getoption("core_url"),
serving_url=pytestconfig.getoption("serving_url"),
spark_launcher="standalone",
spark_standalone_master="local",
spark_home=os.getenv("SPARK_HOME") or os.path.dirname(pyspark.__file__),
spark_ingestion_jar=ingestion_job_jar,
redis_host=redis_host,
redis_port=redis_port,
)

if pytestconfig.getoption("env") == "gcloud":
return Client(
core_url=pytestconfig.getoption("core_url"),
serving_url=pytestconfig.getoption("serving_url"),
spark_launcher="dataproc",
dataproc_cluster_name=pytestconfig.getoption("dataproc_cluster_name"),
dataproc_project=pytestconfig.getoption("dataproc_project"),
dataproc_region=pytestconfig.getoption("dataproc_region"),
dataproc_staging_location=os.path.join(
pytestconfig.getoption("staging_path"), "dataproc"
),
spark_ingestion_jar=ingestion_job_jar,
)


@pytest.fixture(scope="function")
def staging_path(pytestconfig, tmp_path):
if pytestconfig.getoption("env") == "local":
Expand Down

0 comments on commit 342c49d

Please # to comment.