From e9a551eda207186c6693b36fc1ab442b67df3430 Mon Sep 17 00:00:00 2001 From: Baha Aiman Date: Wed, 11 Sep 2024 13:02:11 -0700 Subject: [PATCH] feat(firestore): Adding distance threshold and result field (#10802) * feat(firestore): Adding distance threshold and result field * refactor(firestore): Renaming method names * refactor(firestore): Move threshold and result field to options. Rename FindNearestOptions * refactor(firestore): Rename to FindNearestOptions * refactor(firestore): Refactoring code --- firestore/examples_test.go | 5 +- firestore/fieldpath.go | 5 +- firestore/integration_test.go | 97 ++++++++++++++++------- firestore/query.go | 37 ++++++++- firestore/query_test.go | 145 ++++++++++++++++++++++++++-------- 5 files changed, 223 insertions(+), 66 deletions(-) diff --git a/firestore/examples_test.go b/firestore/examples_test.go index 4ab7eb9f1c59..70bec366016d 100644 --- a/firestore/examples_test.go +++ b/firestore/examples_test.go @@ -502,7 +502,10 @@ func ExampleQuery_FindNearest() { // q := client.Collection("descriptions"). - FindNearest("Embedding", []float32{1, 2, 3}, 5, firestore.DistanceMeasureDotProduct, nil) + FindNearest("Embedding", []float32{1, 2, 3}, 5, firestore.DistanceMeasureDotProduct, &firestore.FindNearestOptions{ + DistanceThreshold: firestore.Ptr(20.0), + DistanceResultField: "vector_distance", + }) iter1 := q.Documents(ctx) _ = iter1 // TODO: Use iter1. } diff --git a/firestore/fieldpath.go b/firestore/fieldpath.go index a50ead70d6dc..3bb8ac666a6a 100644 --- a/firestore/fieldpath.go +++ b/firestore/fieldpath.go @@ -28,6 +28,8 @@ import ( "cloud.google.com/go/internal/fields" ) +const invalidRunes = "~*/[]" + // A FieldPath is a non-empty sequence of non-empty fields that reference a value. // // A FieldPath value should only be necessary if one of the field names contains @@ -54,9 +56,8 @@ type FieldPath []string // including attempts to quote field path compontents. So "a.`b.c`.d" is parsed into // four parts, "a", "`b", "c`" and "d". func parseDotSeparatedString(s string) (FieldPath, error) { - const invalidRunes = "~*/[]" if strings.ContainsAny(s, invalidRunes) { - return nil, fmt.Errorf("firestore: %q contains an invalid rune (one of %s)", s, invalidRunes) + return nil, errInvalidRunesField(s) } fp := FieldPath(strings.Split(s, ".")) if err := fp.validate(); err != nil { diff --git a/firestore/integration_test.go b/firestore/integration_test.go index b7fe0ab6e4c4..2ada3830a8fd 100644 --- a/firestore/integration_test.go +++ b/firestore/integration_test.go @@ -3218,6 +3218,7 @@ func TestIntegration_FindNearest(t *testing.T) { cancel() }) queryField := "EmbeddedField64" + resultField := "vector_distance" indexNames := createVectorIndexes(adminCtx, t, wantDBPath, []vectorIndex{ { fieldPath: queryField, @@ -3229,34 +3230,46 @@ func TestIntegration_FindNearest(t *testing.T) { }) type coffeeBean struct { - ID string + ID int EmbeddedField64 Vector64 EmbeddedField32 Vector32 Float32s []float32 // When querying, saving and retrieving, this should be retrieved as []float32 and not Vector32 } beans := []coffeeBean{ - { - ID: "Robusta", + { // Euclidean Distance from {1, 2, 3} = 0 + ID: 0, EmbeddedField64: []float64{1, 2, 3}, EmbeddedField32: []float32{1, 2, 3}, Float32s: []float32{1, 2, 3}, }, - { - ID: "Excelsa", + { // Euclidean Distance from {1, 2, 3} = 5.19 + ID: 1, EmbeddedField64: []float64{4, 5, 6}, EmbeddedField32: []float32{4, 5, 6}, Float32s: []float32{4, 5, 6}, }, + { // Euclidean Distance from {1, 2, 3} = 10.39 + ID: 2, + EmbeddedField64: []float64{7, 8, 9}, + EmbeddedField32: []float32{7, 8, 9}, + Float32s: []float32{7, 8, 9}, + }, + { // Euclidean Distance from {1, 2, 3} = 15.58 + ID: 3, + EmbeddedField64: []float64{10, 11, 12}, + EmbeddedField32: []float32{10, 11, 12}, + Float32s: []float32{10, 11, 12}, + }, { - ID: "Arabica", + // Euclidean Distance from {1, 2, 3} = 370.42 + ID: 4, EmbeddedField64: []float64{100, 200, 300}, // too far from query vector. not within findNearest limit EmbeddedField32: []float32{100, 200, 300}, Float32s: []float32{100, 200, 300}, }, - { - ID: "Liberica", + ID: 5, EmbeddedField64: []float64{1, 2}, // Not enough dimensions as compared to query vector. EmbeddedField32: []float32{1, 2}, Float32s: []float32{1, 2}, @@ -3277,27 +3290,55 @@ func TestIntegration_FindNearest(t *testing.T) { h.mustCreate(doc, beans[i]) } - // Query documents with a vector field - vectorQuery := collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil) - - iter := vectorQuery.Documents(ctx) - gotDocs, err := iter.GetAll() - if err != nil { - t.Fatalf("GetAll: %+v", err) - } + for _, tc := range []struct { + desc string + vq VectorQuery + wantBeans []coffeeBean + wantResField string + }{ + { + desc: "FindNearest without threshold without resultField", + vq: collRef.FindNearest(queryField, []float64{1, 2, 3}, 2, DistanceMeasureEuclidean, nil), + wantBeans: beans[:2], + }, + { + desc: "FindNearest threshold and resultField", + vq: collRef.FindNearest(queryField, []float64{1, 2, 3}, 3, DistanceMeasureEuclidean, &FindNearestOptions{ + DistanceThreshold: Ptr(20.0), + DistanceResultField: resultField, + }), + wantBeans: beans[:3], + wantResField: resultField, + }, + } { + t.Run(tc.desc, func(t *testing.T) { + iter := tc.vq.Documents(ctx) + gotDocs, err := iter.GetAll() + if err != nil { + t.Fatalf("GetAll: %+v", err) + } - if len(gotDocs) != 2 { - t.Fatalf("Expected 2 results, got %d", len(gotDocs)) - } + if len(gotDocs) != len(tc.wantBeans) { + t.Fatalf("Expected %v results, got %d", len(tc.wantBeans), len(gotDocs)) + } - for i, doc := range gotDocs { - gotBean := coffeeBean{} - err := doc.DataTo(&gotBean) - if err != nil { - t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err) - } - if beans[i].ID != gotBean.ID { - t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID) - } + for i, doc := range gotDocs { + var gotBean coffeeBean + if len(tc.wantResField) != 0 { + _, ok := doc.Data()[tc.wantResField] + if !ok { + t.Errorf("Expected %v field to exist in %v", tc.wantResField, doc.Data()) + } + } + err := doc.DataTo(&gotBean) + if err != nil { + t.Errorf("#%v: DataTo: %+v", doc.Ref.ID, err) + continue + } + if tc.wantBeans[i].ID != gotBean.ID { + t.Errorf("#%v: want: %v, got: %v", i, beans[i].ID, gotBean.ID) + } + } + }) } } diff --git a/firestore/query.go b/firestore/query.go index 9e77e2a4f6ae..021747e46723 100644 --- a/firestore/query.go +++ b/firestore/query.go @@ -33,9 +33,15 @@ import ( ) var ( - errMetricsBeforeEnd = errors.New("firestore: ExplainMetrics are available only after the iterator reaches the end") + errMetricsBeforeEnd = errors.New("firestore: ExplainMetrics are available only after the iterator reaches the end") + errInvalidVector = errors.New("firestore: queryVector must be Vector32 or Vector64") + errMalformedVectorQuery = errors.New("firestore: Malformed VectorQuery. Use FindNearest or FindNearestPath to create VectorQuery") ) +func errInvalidRunesField(field string) error { + return fmt.Errorf("firestore: %q contains an invalid rune (one of %s)", field, invalidRunes) +} + // Query represents a Firestore query. // // Query values are immutable. Each Query method creates @@ -517,9 +523,27 @@ const ( DistanceMeasureDotProduct DistanceMeasure = DistanceMeasure(pb.StructuredQuery_FindNearest_DOT_PRODUCT) ) +// Ptr returns a pointer to its argument. +// It can be used to initialize pointer fields: +// +// findNearestOptions.DistanceThreshold = firestore.Ptr[float64](0.1) +func Ptr[T any](t T) *T { return &t } + // FindNearestOptions are options for a FindNearest vector query. -// At present, there are no options. type FindNearestOptions struct { + // DistanceThreshold specifies a threshold for which no less similar documents + // will be returned. The behavior of the specified [DistanceMeasure] will + // affect the meaning of the distance threshold. Since [DistanceMeasureDotProduct] + // distances increase when the vectors are more similar, the comparison is inverted. + // For [DistanceMeasureEuclidean], [DistanceMeasureCosine]: WHERE distance <= distanceThreshold + // For [DistanceMeasureDotProduct]: WHERE distance >= distance_threshold + DistanceThreshold *float64 + + // DistanceResultField specifies name of the document field to output the result of + // the vector distance calculation. + // If the field already exists in the document, its value get overwritten with the distance calculation. + // Otherwise, a new field gets added to the document. + DistanceResultField string } // VectorQuery represents a query that uses [Query.FindNearest] or [Query.FindNearestPath]. @@ -582,7 +606,7 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit case []float64: fnvq = vectorToProtoValue(v) default: - vq.q.err = errors.New("firestore: queryVector must be Vector32 or Vector64") + vq.q.err = errInvalidVector return vq } @@ -592,6 +616,13 @@ func (q Query) FindNearestPath(vectorFieldPath FieldPath, queryVector any, limit Limit: &wrapperspb.Int32Value{Value: trunc32(limit)}, DistanceMeasure: pb.StructuredQuery_FindNearest_DistanceMeasure(measure), } + + if options != nil { + if options.DistanceThreshold != nil { + vq.q.findNearest.DistanceThreshold = &wrapperspb.DoubleValue{Value: *options.DistanceThreshold} + } + vq.q.findNearest.DistanceResultField = *&options.DistanceResultField + } return vq } diff --git a/firestore/query_test.go b/firestore/query_test.go index 5d369198546c..2eb2119c790c 100644 --- a/firestore/query_test.go +++ b/firestore/query_test.go @@ -25,6 +25,7 @@ import ( pb "cloud.google.com/go/firestore/apiv1/firestorepb" "cloud.google.com/go/internal/pretty" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/reflect/protoreflect" "google.golang.org/protobuf/testing/protocmp" tspb "google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/wrapperspb" @@ -1747,12 +1748,16 @@ func TestQueryRunOptionsAndGetAllWithOptions(t *testing.T) { t.Fatal(err) } } - func TestFindNearest(t *testing.T) { ctx := context.Background() c, srv, cleanup := newMock(t) - defer cleanup() + t.Cleanup(func() { cleanup() }) + collName := "C" + limit := 2 + threshold := float64(24) + resultField := "res" + vectorField := "path" const dbPath = "projects/projectID/databases/(default)" mapFields := map[string]*pb.Value{ typeKey: {ValueType: &pb.Value_StringValue{StringValue: typeValVector}}, @@ -1770,55 +1775,131 @@ func TestFindNearest(t *testing.T) { } wantPBDocs := []*pb.Document{ { - Name: dbPath + "/documents/C/a", + Name: dbPath + "/documents/" + collName + "/a", CreateTime: aTimestamp, UpdateTime: aTimestamp, Fields: map[string]*pb.Value{"EmbeddedField": mapval(mapFields)}, }, } + wantQueryVector := &pb.Value{ + ValueType: &pb.Value_MapValue{ + MapValue: &pb.MapValue{ + Fields: map[string]*pb.Value{ + "__type__": { + ValueType: &pb.Value_StringValue{StringValue: "__vector__"}, + }, + "value": { + ValueType: &pb.Value_ArrayValue{ + ArrayValue: &pb.ArrayValue{ + Values: []*pb.Value{ + { + ValueType: &pb.Value_DoubleValue{DoubleValue: 5}, + }, + { + ValueType: &pb.Value_DoubleValue{DoubleValue: 6}, + }, + { + ValueType: &pb.Value_DoubleValue{DoubleValue: 7}, + }, + }, + }, + }, + }, + }, + }, + }, + } + wantReq := pb.RunQueryRequest{ + Parent: fmt.Sprintf("%v/documents", dbPath), + QueryType: &pb.RunQueryRequest_StructuredQuery{ + StructuredQuery: &pb.StructuredQuery{ + From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: collName}}, + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: &pb.StructuredQuery_FieldReference{FieldPath: vectorField}, + QueryVector: wantQueryVector, + Limit: &wrapperspb.Int32Value{Value: int32(limit)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + }, + }, + }, + } + + wantReqThresholdField := pb.RunQueryRequest{ + Parent: fmt.Sprintf("%v/documents", dbPath), + QueryType: &pb.RunQueryRequest_StructuredQuery{ + StructuredQuery: &pb.StructuredQuery{ + From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: collName}}, + FindNearest: &pb.StructuredQuery_FindNearest{ + VectorField: &pb.StructuredQuery_FieldReference{FieldPath: vectorField}, + QueryVector: wantQueryVector, + Limit: &wrapperspb.Int32Value{Value: int32(limit)}, + DistanceMeasure: pb.StructuredQuery_FindNearest_EUCLIDEAN, + DistanceThreshold: &wrapperspb.DoubleValue{Value: float64(threshold)}, + DistanceResultField: resultField, + }, + }, + }, + } testcases := []struct { - desc string - path string - queryVector interface{} - wantErr bool + desc string + vQuery VectorQuery + wantReq protoreflect.ProtoMessage + wantErr error }{ { - desc: "Invalid path", - path: "path*", - wantErr: true, + desc: "Invalid path", + vQuery: c.Collection(collName). + FindNearest("path*", nil, limit, DistanceMeasureEuclidean, nil), + wantErr: errInvalidRunesField("path*"), }, { - desc: "Valid path", - path: "path", - queryVector: []float64{5, 6, 7}, - wantErr: false, + desc: "Invalid vector type", + vQuery: c.Collection(collName). + FindNearest("path", "abcd", limit, DistanceMeasureEuclidean, nil), + wantErr: errInvalidVector, }, { - desc: "Invalid vector type", - path: "path", - queryVector: "abcd", - wantErr: true, + desc: "Valid path with valid vector type []float64", + vQuery: c.Collection(collName). + FindNearest("path", []float64{5, 6, 7}, limit, DistanceMeasureEuclidean, nil), + wantReq: &wantReq, }, { - desc: "Valid vector type", - path: "path", - queryVector: []float32{5, 6, 7}, - wantErr: false, + desc: "Valid path with valid vector type []float32", + vQuery: c.Collection(collName). + FindNearest("path", []float32{5, 6, 7}, limit, DistanceMeasureEuclidean, nil), + wantReq: &wantReq, + }, + { + desc: "Valid path with valid vector type WithDistanceResultField and WithDistanceThreshold ", + vQuery: c.Collection(collName). + FindNearest("path", []float32{5, 6, 7}, limit, DistanceMeasureEuclidean, &FindNearestOptions{ + DistanceThreshold: Ptr[float64](threshold), + DistanceResultField: resultField, + }), + wantReq: &wantReqThresholdField, }, } for _, tc := range testcases { - srv.reset() - srv.addRPC(nil, []interface{}{ - &pb.RunQueryResponse{Document: wantPBDocs[0]}, + t.Run(tc.desc, func(t *testing.T) { + srv.reset() + if tc.wantErr == nil { + srv.addRPC(tc.wantReq, []interface{}{ + &pb.RunQueryResponse{Document: wantPBDocs[0]}, + }) + } + _, gotErr := tc.vQuery.Documents(ctx).GetAll() + if !errorsMatch(gotErr, tc.wantErr) { + t.Fatalf("got %v, want %v", gotErr, tc.wantErr) + } }) - vQuery := c.Collection("C").FindNearest(tc.path, tc.queryVector, 2, DistanceMeasureEuclidean, nil) + } +} - _, err := vQuery.Documents(ctx).GetAll() - if err == nil && tc.wantErr { - t.Fatalf("%s: got nil wanted error", tc.desc) - } else if err != nil && !tc.wantErr { - t.Fatalf("%s: got %v, want nil", tc.desc, err) - } +func errorsMatch(got, want error) bool { + if got == nil || want == nil { + return got == want } + return strings.Contains(got.Error(), want.Error()) }