Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add YAML export to Python SDK #782

Merged
merged 4 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/go/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (fc *GrpcClient) GetOnlineFeatures(ctx context.Context, req *OnlineFeatures
// collect unqiue entity refs from entity rows
entityRefs := make(map[string]struct{})
for _, entityRows := range req.Entities {
for ref, _ := range entityRows {
for ref := range entityRows {
entityRefs[ref] = struct{}{}
}
}
Expand Down
41 changes: 38 additions & 3 deletions sdk/python/feast/feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@

import pandas as pd
import pyarrow as pa
import yaml
from google.protobuf import json_format
from google.protobuf.duration_pb2 import Duration
from google.protobuf.json_format import MessageToJson
from google.protobuf.json_format import MessageToDict, MessageToJson
from google.protobuf.message import Message
from pandas.api.types import is_datetime64_ns_dtype
from pyarrow.lib import TimestampType
Expand Down Expand Up @@ -79,12 +80,18 @@ def __eq__(self, other):
if key not in other.fields.keys() or self.fields[key] != other.fields[key]:
return False

if self.fields[key] != other.fields[key]:
return False

if (
self.name != other.name
or self.project != other.project
or self.max_age != other.max_age
):
return False

if self.source != other.source:
return False
return True

def __str__(self):
Expand Down Expand Up @@ -783,13 +790,18 @@ def from_proto(cls, feature_set_proto: FeatureSetProto):
entities=[
Entity.from_proto(entity) for entity in feature_set_proto.spec.entities
],
max_age=feature_set_proto.spec.max_age,
max_age=(
None
if feature_set_proto.spec.max_age.seconds == 0
and feature_set_proto.spec.max_age.nanos == 0
else feature_set_proto.spec.max_age
),
source=(
None
if feature_set_proto.spec.source.type == 0
else Source.from_proto(feature_set_proto.spec.source)
),
project=feature_set_proto.spec.project
project=None
if len(feature_set_proto.spec.project) == 0
else feature_set_proto.spec.project,
)
Expand Down Expand Up @@ -828,6 +840,29 @@ def to_proto(self) -> FeatureSetProto:

return FeatureSetProto(spec=spec, meta=meta)

def to_dict(self) -> Dict:
"""
Converts feature set to dict

:return: Dictionary object representation of feature set
"""
feature_set_dict = MessageToDict(self.to_proto())

# Remove meta when empty for more readable exports
if feature_set_dict["meta"] == {}:
del feature_set_dict["meta"]

return feature_set_dict

def to_yaml(self):
"""
Converts a feature set to a YAML string.

:return: Feature set string returned in YAML format
"""
feature_set_dict = self.to_dict()
return yaml.dump(feature_set_dict, allow_unicode=True, sort_keys=False)


class FeatureSetRef:
"""
Expand Down
22 changes: 22 additions & 0 deletions sdk/python/tests/test_feature_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pathlib
from concurrent import futures
from datetime import datetime
Expand Down Expand Up @@ -244,6 +245,27 @@ def test_export_tfx_schema(self):
for actual, expected in zip(actual_schema.feature, expected_schema.feature):
assert actual.SerializeToString() == expected.SerializeToString()

def test_feature_set_import_export_yaml(self):

test_feature_set = FeatureSet(
name="bikeshare",
entities=[Entity(name="station_id", dtype=ValueType.INT64)],
features=[
Feature(name="name", dtype=ValueType.STRING),
Feature(name="longitude", dtype=ValueType.FLOAT),
Feature(name="location", dtype=ValueType.STRING),
],
)

# Create a string YAML representation of the feature set
string_yaml = test_feature_set.to_yaml()

# Create a new feature set object from the YAML string
actual_feature_set_from_string = FeatureSet.from_yaml(string_yaml)

# Ensure equality is upheld to original feature set
assert test_feature_set == actual_feature_set_from_string


def make_tfx_schema_domain_info_inline(schema):
# Copy top-level domain info defined in the schema to inline definition.
Expand Down