diff --git a/sdk/go/client.go b/sdk/go/client.go index 1f0cbaa3b96..30b3dc4bf09 100644 --- a/sdk/go/client.go +++ b/sdk/go/client.go @@ -54,7 +54,7 @@ func (fc *GrpcClient) GetOnlineFeatures(ctx context.Context, req *OnlineFeatures // collect unqiue entity refs from entity rows entityRefs := make(map[string]struct{}) for _, entityRows := range req.Entities { - for ref, _ := range entityRows { + for ref := range entityRows { entityRefs[ref] = struct{}{} } } diff --git a/sdk/python/feast/feature_set.py b/sdk/python/feast/feature_set.py index aebee52ca42..a181c9378bd 100644 --- a/sdk/python/feast/feature_set.py +++ b/sdk/python/feast/feature_set.py @@ -17,9 +17,10 @@ import pandas as pd import pyarrow as pa +import yaml from google.protobuf import json_format from google.protobuf.duration_pb2 import Duration -from google.protobuf.json_format import MessageToJson +from google.protobuf.json_format import MessageToDict, MessageToJson from google.protobuf.message import Message from pandas.api.types import is_datetime64_ns_dtype from pyarrow.lib import TimestampType @@ -79,12 +80,18 @@ def __eq__(self, other): if key not in other.fields.keys() or self.fields[key] != other.fields[key]: return False + if self.fields[key] != other.fields[key]: + return False + if ( self.name != other.name or self.project != other.project or self.max_age != other.max_age ): return False + + if self.source != other.source: + return False return True def __str__(self): @@ -783,13 +790,18 @@ def from_proto(cls, feature_set_proto: FeatureSetProto): entities=[ Entity.from_proto(entity) for entity in feature_set_proto.spec.entities ], - max_age=feature_set_proto.spec.max_age, + max_age=( + None + if feature_set_proto.spec.max_age.seconds == 0 + and feature_set_proto.spec.max_age.nanos == 0 + else feature_set_proto.spec.max_age + ), source=( None if feature_set_proto.spec.source.type == 0 else Source.from_proto(feature_set_proto.spec.source) ), - project=feature_set_proto.spec.project + project=None if len(feature_set_proto.spec.project) == 0 else feature_set_proto.spec.project, ) @@ -828,6 +840,29 @@ def to_proto(self) -> FeatureSetProto: return FeatureSetProto(spec=spec, meta=meta) + def to_dict(self) -> Dict: + """ + Converts feature set to dict + + :return: Dictionary object representation of feature set + """ + feature_set_dict = MessageToDict(self.to_proto()) + + # Remove meta when empty for more readable exports + if feature_set_dict["meta"] == {}: + del feature_set_dict["meta"] + + return feature_set_dict + + def to_yaml(self): + """ + Converts a feature set to a YAML string. + + :return: Feature set string returned in YAML format + """ + feature_set_dict = self.to_dict() + return yaml.dump(feature_set_dict, allow_unicode=True, sort_keys=False) + class FeatureSetRef: """ diff --git a/sdk/python/tests/test_feature_set.py b/sdk/python/tests/test_feature_set.py index 04e75c9e76d..4052623e699 100644 --- a/sdk/python/tests/test_feature_set.py +++ b/sdk/python/tests/test_feature_set.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import pathlib from concurrent import futures from datetime import datetime @@ -244,6 +245,27 @@ def test_export_tfx_schema(self): for actual, expected in zip(actual_schema.feature, expected_schema.feature): assert actual.SerializeToString() == expected.SerializeToString() + def test_feature_set_import_export_yaml(self): + + test_feature_set = FeatureSet( + name="bikeshare", + entities=[Entity(name="station_id", dtype=ValueType.INT64)], + features=[ + Feature(name="name", dtype=ValueType.STRING), + Feature(name="longitude", dtype=ValueType.FLOAT), + Feature(name="location", dtype=ValueType.STRING), + ], + ) + + # Create a string YAML representation of the feature set + string_yaml = test_feature_set.to_yaml() + + # Create a new feature set object from the YAML string + actual_feature_set_from_string = FeatureSet.from_yaml(string_yaml) + + # Ensure equality is upheld to original feature set + assert test_feature_set == actual_feature_set_from_string + def make_tfx_schema_domain_info_inline(schema): # Copy top-level domain info defined in the schema to inline definition.