From fa20f11bcac8ac9116df2098d9ecd10d5f7af1d8 Mon Sep 17 00:00:00 2001 From: Lokesh Kumar Date: Wed, 11 Jun 2025 22:50:21 +0200 Subject: [PATCH 1/2] GODRIVER-3472: Add support for unmarshaling BSON Vector binary values into slices --- bson/slice_codec.go | 48 +++++++++- bson/vector.go | 66 ++++++++++++++ bson/vector_unmarshal_test.go | 159 ++++++++++++++++++++++++++++++++++ 3 files changed, 271 insertions(+), 2 deletions(-) create mode 100644 bson/vector_unmarshal_test.go diff --git a/bson/slice_codec.go b/bson/slice_codec.go index c8719dcc18..ffe17f01db 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -19,6 +19,43 @@ type sliceCodec struct { encodeNilAsEmpty bool } +// decodeVectorBinary handles decoding of BSON Vector binary (subtype 9) into slices. +// It returns errNotAVectorBinary if the binary data is not a Vector binary. +// The method supports decoding into []int8 and []float32 slices. +func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) error { + elemType := val.Type().Elem() + + if elemType != TInt8 && elemType != TFloat32 { + return errNotAVectorBinary + } + + data, subtype, err := vr.ReadBinary() + if err != nil { + return err + } + + if subtype != TypeBinaryVector { + return errNotAVectorBinary + } + + switch elemType { + case TInt8: + int8Slice, err := DecodeVectorInt8(data) + if err != nil { + return err + } + val.Set(reflect.ValueOf(int8Slice)) + case TFloat32: + float32Slice, err := DecodeVectorFloat32(data) + if err != nil { + return err + } + val.Set(reflect.ValueOf(float32Slice)) + } + + return nil +} + // EncodeValue is the ValueEncoder for slice types. func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect.Value) error { if !val.IsValid() || val.Kind() != reflect.Slice { @@ -29,8 +66,9 @@ func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. return vw.WriteNull() } - // If we have a []byte we want to treat it as a binary instead of as an array. - if val.Type().Elem() == tByte { + // Treat []byte as binary data, but skip for []int8 since it's a different type + // even though byte is an alias for uint8 which has the same underlying type as int8 + if val.Type().Elem() == tByte && val.Type() != reflect.TypeOf([]int8{}) { byteSlice := make([]byte, val.Len()) reflect.Copy(reflect.ValueOf(byteSlice), val) return vw.WriteBinary(byteSlice) @@ -99,6 +137,12 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } + if vr.Type() == TypeBinary { + if err := sc.decodeVectorBinary(vr, val); err != errNotAVectorBinary { + return err + } + } + switch vrType := vr.Type(); vrType { case TypeArray: case TypeNull: diff --git a/bson/vector.go b/bson/vector.go index 31a10bd5be..d61ef02076 100644 --- a/bson/vector.go +++ b/bson/vector.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math" + "reflect" ) // BSON binary vector types as described in https://bsonspec.org/spec.html. @@ -25,6 +26,14 @@ var ( errInsufficientVectorData = errors.New("insufficient data") errNonZeroVectorPadding = errors.New("padding must be 0") errVectorPaddingTooLarge = errors.New("padding cannot be larger than 7") + errNotAVectorBinary = errors.New("not a vector binary") +) + +var ( + // TInt8 is the reflect.Type for int8 + TInt8 = reflect.TypeOf(int8(0)) + // TFloat32 is the reflect.Type for float32 + TFloat32 = reflect.TypeOf(float32(0)) ) type vectorTypeError struct { @@ -266,3 +275,60 @@ func newBitVector(b []byte) (Vector, error) { } return NewPackedBitVector(b[1:], b[0]) } + +// DecodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice. +// The binary data should be in the format: [ ] +// For int8 vectors, the vector type is 0x01. +func DecodeVectorInt8(data []byte) ([]int8, error) { + if len(data) < 2 { + return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != 0x01 { // Int8Vector + return nil, errors.New("invalid vector type: expected int8 vector (0x01)") + } + + if padding := data[1]; padding != 0 { + return nil, errors.New("invalid vector: padding byte must be 0") + } + values := make([]int8, 0, len(data)-2) + for i := 2; i < len(data); i++ { + values = append(values, int8(data[i])) + } + + return values, nil +} + +// DecodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice. +// The binary data should be in the format: [ ] +// For float32 vectors, the vector type is 0x02 and data must be a multiple of 4 bytes. +func DecodeVectorFloat32(data []byte) ([]float32, error) { + if len(data) < 2 { + return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != 0x02 { // Float32Vector + return nil, errors.New("invalid vector type: expected float32 vector (0x02)") + } + + if padding := data[1]; padding != 0 { + return nil, errors.New("invalid vector: padding byte must be 0") + } + floatData := data[2:] + if len(floatData)%4 != 0 { + return nil, errors.New("invalid float32 vector: data length must be a multiple of 4") + } + + values := make([]float32, 0, len(floatData)/4) + for i := 0; i < len(floatData); i += 4 { + if i+4 > len(floatData) { + return nil, errors.New("invalid float32 vector: truncated data") + } + bits := binary.LittleEndian.Uint32(floatData[i : i+4]) + values = append(values, math.Float32frombits(bits)) + } + + return values, nil +} diff --git a/bson/vector_unmarshal_test.go b/bson/vector_unmarshal_test.go new file mode 100644 index 0000000000..cf2658a606 --- /dev/null +++ b/bson/vector_unmarshal_test.go @@ -0,0 +1,159 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package bson + +import ( + "encoding/binary" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +// Helper function to create a BSON document with a vector binary field (subtype 0x09) +func createBSONWithBinary(data []byte) []byte { + // Document format: {"v": BinData(subtype, data)} + buf := make([]byte, 0, 32+len(data)) + + buf = append(buf, 0x00, 0x00, 0x00, 0x00) // Length placeholder + buf = append(buf, 0x05) // Binary type + buf = append(buf, 'v', 0x00) // Field name "v" + + buf = append(buf, + byte(len(data)), // Length of binary data + 0x00, 0x00, 0x00, // 4-byte length (little endian) + 0x09, // Binary subtype for Vector + ) + buf = append(buf, data...) + buf = append(buf, 0x00) + + docLen := len(buf) + buf[0] = byte(docLen) + buf[1] = byte(docLen >> 8) + buf[2] = byte(docLen >> 16) + buf[3] = byte(docLen >> 24) + + return buf +} + +func TestVectorBackwardCompatibility(t *testing.T) { + t.Parallel() + + t.Run("unmarshal to Vector field", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{ + 0x03, // int8 vector type (0x03 is Int8Vector) + 0x00, // padding + 0x01, 0x02, 0x03, 0x04, // int8 values + } + + doc := createBSONWithBinary(vectorData) + + var result struct { + V Vector + } + err := Unmarshal(doc, &result) + require.NoError(t, err) + + require.Equal(t, Int8Vector, result.V.Type()) + int8Data, ok := result.V.Int8OK() + require.True(t, ok, "expected int8 vector") + require.Equal(t, []int8{1, 2, 3, 4}, int8Data) + }) +} + +func TestUnmarshalVectorToSlices(t *testing.T) { + t.Parallel() + + t.Run("int8 vector to []int8", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{ + 0x01, // int8 vector type + 0x00, // padding + 0x01, 0x02, 0x03, 0x04, // int8 values + } + + bsonData := createBSONWithBinary(vectorData) + + var result struct{ V []int8 } + err := Unmarshal(bsonData, &result) + require.NoError(t, err) + require.Equal(t, []int8{1, 2, 3, 4}, result.V) + }) + + t.Run("float32 vector to []float32", func(t *testing.T) { + t.Parallel() + + vectorData := make([]byte, 2+4*4) // type + padding + 4 float32s + vectorData[0] = 0x02 // float32 vector type + vectorData[1] = 0x00 // padding + + binary.LittleEndian.PutUint32(vectorData[2:], math.Float32bits(1.0)) + binary.LittleEndian.PutUint32(vectorData[6:], math.Float32bits(2.0)) + binary.LittleEndian.PutUint32(vectorData[10:], math.Float32bits(3.0)) + binary.LittleEndian.PutUint32(vectorData[14:], math.Float32bits(4.0)) + + bsonData := createBSONWithBinary(vectorData) + + var result struct{ V []float32 } + err := Unmarshal(bsonData, &result) + require.NoError(t, err) + require.InEpsilonSlice(t, []float32{1.0, 2.0, 3.0, 4.0}, result.V, 0.0001) + }) + + t.Run("invalid vector type to slice", func(t *testing.T) { + t.Parallel() + + vectorData := []byte{ + 0x10, // packed bit vector type (unsupported for direct unmarshaling) + 0x00, // padding + 0x01, 0x02, // some data + } + bsonData := createBSONWithBinary(vectorData) + + t.Run("to []int8", func(t *testing.T) { + t.Parallel() + var result struct{ V []int8 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vector type: expected int8 vector (0x01)") + }) + + t.Run("to []float32", func(t *testing.T) { + t.Parallel() + var result struct{ V []float32 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vector type: expected float32 vector (0x02)") + }) + }) + + t.Run("invalid binary data", func(t *testing.T) { + t.Parallel() + + invalidData := []byte{0x01, 0x01, 0x02, 0x03} + bsonData := createBSONWithBinary(invalidData) + + t.Run("to []int8", func(t *testing.T) { + t.Parallel() + var result struct{ V []int8 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vector: padding byte must be 0") + }) + + t.Run("to []float32", func(t *testing.T) { + t.Parallel() + var result struct{ V []float32 } + err := Unmarshal(bsonData, &result) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid vector type: expected float32 vector (0x02)") + }) + }) +} From 656f6124a99c507861f6ba3804222094339e382a Mon Sep 17 00:00:00 2001 From: Lokesh Kumar Date: Sun, 15 Jun 2025 20:51:47 +0200 Subject: [PATCH 2/2] GODRIVER-3472: Address code review improvements --- bson/primitive_codecs_test.go | 48 +++++++++ bson/slice_codec.go | 90 +++++++++++++--- bson/types.go | 2 + bson/vector.go | 65 ------------ bson/vector_unmarshal_test.go | 57 +++++------ mongo/mongo_test.go | 186 ++++++++++++++++++++++++++++++++++ 6 files changed, 339 insertions(+), 109 deletions(-) diff --git a/bson/primitive_codecs_test.go b/bson/primitive_codecs_test.go index 6071ea02f9..5ff03d302c 100644 --- a/bson/primitive_codecs_test.go +++ b/bson/primitive_codecs_test.go @@ -8,6 +8,7 @@ package bson import ( "bytes" + "encoding/binary" "encoding/json" "errors" "fmt" @@ -1116,3 +1117,50 @@ func compareDecimal128(d1, d2 Decimal128) bool { return true } + +func TestSliceCodec(t *testing.T) { + t.Run("[]byte is treated as binary data", func(t *testing.T) { + type testStruct struct { + B []byte `bson:"b"` + } + + testData := testStruct{B: []byte{0x01, 0x02, 0x03}} + data, err := Marshal(testData) + assert.Nil(t, err, "Marshal error: %v", err) + var doc D + err = Unmarshal(data, &doc) + assert.Nil(t, err, "Unmarshal error: %v", err) + + offset := 4 + 1 + 2 + length := int32(binary.LittleEndian.Uint32(data[offset:])) + offset += 4 // Skip length + subtype := data[offset] + offset++ // Skip subtype + dataBytes := data[offset : offset+int(length)] + + assert.Equal(t, byte(0x00), subtype, "Expected binary subtype 0x00") + assert.Equal(t, []byte{0x01, 0x02, 0x03}, dataBytes, "Binary data mismatch") + }) + + t.Run("[]int8 is not treated as binary data", func(t *testing.T) { + type testStruct struct { + I []int8 `bson:"i"` + } + testData := testStruct{I: []int8{1, 2, 3}} + data, err := Marshal(testData) + assert.Nil(t, err, "Marshal error: %v", err) + + offset := 4 // Skip document length + assert.Equal(t, byte(0x04), data[offset], "Expected array type (0x04), got: 0x%02x", data[offset]) + + var result struct { + I []int32 `bson:"i"` + } + err = Unmarshal(data, &result) + assert.Nil(t, err, "Unmarshal result error: %v", err) + assert.Equal(t, 3, len(result.I), "Expected array length 3") + assert.Equal(t, int32(1), result.I[0], "Array element 0 mismatch") + assert.Equal(t, int32(2), result.I[1], "Array element 1 mismatch") + assert.Equal(t, int32(3), result.I[2], "Array element 2 mismatch") + }) +} diff --git a/bson/slice_codec.go b/bson/slice_codec.go index ffe17f01db..b834762533 100644 --- a/bson/slice_codec.go +++ b/bson/slice_codec.go @@ -7,8 +7,10 @@ package bson import ( + "encoding/binary" "errors" "fmt" + "math" "reflect" ) @@ -25,7 +27,7 @@ type sliceCodec struct { func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) error { elemType := val.Type().Elem() - if elemType != TInt8 && elemType != TFloat32 { + if elemType != tInt8 && elemType != tFloat32 { return errNotAVectorBinary } @@ -39,14 +41,14 @@ func (sc *sliceCodec) decodeVectorBinary(vr ValueReader, val reflect.Value) erro } switch elemType { - case TInt8: - int8Slice, err := DecodeVectorInt8(data) + case tInt8: + int8Slice, err := decodeVectorInt8(data) if err != nil { return err } val.Set(reflect.ValueOf(int8Slice)) - case TFloat32: - float32Slice, err := DecodeVectorFloat32(data) + case tFloat32: + float32Slice, err := decodeVectorFloat32(data) if err != nil { return err } @@ -66,8 +68,9 @@ func (sc *sliceCodec) EncodeValue(ec EncodeContext, vw ValueWriter, val reflect. return vw.WriteNull() } - // Treat []byte as binary data, but skip for []int8 since it's a different type - // even though byte is an alias for uint8 which has the same underlying type as int8 + // Treat []byte as binary data, but skip for []int8 since it's a different type. + // Even though byte is an alias for uint8 which has the same underlying type as int8, + // we want to maintain the semantic difference between []byte (binary data) and []int8 (array of integers). if val.Type().Elem() == tByte && val.Type() != reflect.TypeOf([]int8{}) { byteSlice := make([]byte, val.Len()) reflect.Copy(reflect.ValueOf(byteSlice), val) @@ -137,12 +140,6 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return ValueDecoderError{Name: "SliceDecodeValue", Kinds: []reflect.Kind{reflect.Slice}, Received: val} } - if vr.Type() == TypeBinary { - if err := sc.decodeVectorBinary(vr, val); err != errNotAVectorBinary { - return err - } - } - switch vrType := vr.Type(); vrType { case TypeArray: case TypeNull: @@ -156,6 +153,14 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return fmt.Errorf("cannot decode document into %s", val.Type()) } case TypeBinary: + err := sc.decodeVectorBinary(vr, val) + if err == nil { + return nil + } + if err != errNotAVectorBinary { + return err + } + if val.Type().Elem() != tByte { return fmt.Errorf("SliceDecodeValue can only decode a binary into a byte array, got %v", vrType) } @@ -215,3 +220,62 @@ func (sc *sliceCodec) DecodeValue(dc DecodeContext, vr ValueReader, val reflect. return nil } + +// decodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice. +// The binary data should be in the format: [ ] +// For int8 vectors, the vector type is Int8Vector (0x03). +func decodeVectorInt8(data []byte) ([]int8, error) { + if len(data) < 2 { + return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != Int8Vector { + return nil, fmt.Errorf("invalid vector type: expected int8 vector (0x%02x), got 0x%02x", Int8Vector, vectorType) + } + + if padding := data[1]; padding != 0 { + return nil, fmt.Errorf("invalid vector: padding byte must be 0") + } + + values := make([]int8, 0, len(data)-2) + for i := 2; i < len(data); i++ { + values = append(values, int8(data[i])) + } + + return values, nil +} + +// decodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice. +// The binary data should be in the format: [ ] +// For float32 vectors, the vector type is Float32Vector (0x27) and data must be a multiple of 4 bytes. +func decodeVectorFloat32(data []byte) ([]float32, error) { + if len(data) < 2 { + return nil, fmt.Errorf("insufficient bytes to decode vector: expected at least 2 bytes") + } + + vectorType := data[0] + if vectorType != Float32Vector { + return nil, fmt.Errorf("invalid vector type: expected float32 vector (0x%02x), got 0x%02x", Float32Vector, vectorType) + } + + if padding := data[1]; padding != 0 { + return nil, fmt.Errorf("invalid vector: padding byte must be 0") + } + + floatData := data[2:] + if len(floatData)%4 != 0 { + return nil, fmt.Errorf("invalid float32 vector: data length must be a multiple of 4") + } + + values := make([]float32, 0, len(floatData)/4) + for i := 0; i < len(floatData); i += 4 { + if i+4 > len(floatData) { + return nil, fmt.Errorf("invalid float32 vector: truncated data") + } + bits := binary.LittleEndian.Uint32(floatData[i : i+4]) + values = append(values, math.Float32frombits(bits)) + } + + return values, nil +} diff --git a/bson/types.go b/bson/types.go index c2883aa4ef..91bd9e32fb 100644 --- a/bson/types.go +++ b/bson/types.go @@ -77,7 +77,9 @@ const ( ) var tBool = reflect.TypeOf(false) +var tFloat32 = reflect.TypeOf(float32(0)) var tFloat64 = reflect.TypeOf(float64(0)) +var tInt8 = reflect.TypeOf(int8(0)) var tInt32 = reflect.TypeOf(int32(0)) var tInt64 = reflect.TypeOf(int64(0)) var tString = reflect.TypeOf("") diff --git a/bson/vector.go b/bson/vector.go index d61ef02076..f0735f806f 100644 --- a/bson/vector.go +++ b/bson/vector.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "math" - "reflect" ) // BSON binary vector types as described in https://bsonspec.org/spec.html. @@ -29,13 +28,6 @@ var ( errNotAVectorBinary = errors.New("not a vector binary") ) -var ( - // TInt8 is the reflect.Type for int8 - TInt8 = reflect.TypeOf(int8(0)) - // TFloat32 is the reflect.Type for float32 - TFloat32 = reflect.TypeOf(float32(0)) -) - type vectorTypeError struct { Method string Type byte @@ -275,60 +267,3 @@ func newBitVector(b []byte) (Vector, error) { } return NewPackedBitVector(b[1:], b[0]) } - -// DecodeVectorInt8 decodes a BSON Vector binary value (subtype 9) into a []int8 slice. -// The binary data should be in the format: [ ] -// For int8 vectors, the vector type is 0x01. -func DecodeVectorInt8(data []byte) ([]int8, error) { - if len(data) < 2 { - return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes") - } - - vectorType := data[0] - if vectorType != 0x01 { // Int8Vector - return nil, errors.New("invalid vector type: expected int8 vector (0x01)") - } - - if padding := data[1]; padding != 0 { - return nil, errors.New("invalid vector: padding byte must be 0") - } - values := make([]int8, 0, len(data)-2) - for i := 2; i < len(data); i++ { - values = append(values, int8(data[i])) - } - - return values, nil -} - -// DecodeVectorFloat32 decodes a BSON Vector binary value (subtype 9) into a []float32 slice. -// The binary data should be in the format: [ ] -// For float32 vectors, the vector type is 0x02 and data must be a multiple of 4 bytes. -func DecodeVectorFloat32(data []byte) ([]float32, error) { - if len(data) < 2 { - return nil, errors.New("insufficient bytes to decode vector: expected at least 2 bytes") - } - - vectorType := data[0] - if vectorType != 0x02 { // Float32Vector - return nil, errors.New("invalid vector type: expected float32 vector (0x02)") - } - - if padding := data[1]; padding != 0 { - return nil, errors.New("invalid vector: padding byte must be 0") - } - floatData := data[2:] - if len(floatData)%4 != 0 { - return nil, errors.New("invalid float32 vector: data length must be a multiple of 4") - } - - values := make([]float32, 0, len(floatData)/4) - for i := 0; i < len(floatData); i += 4 { - if i+4 > len(floatData) { - return nil, errors.New("invalid float32 vector: truncated data") - } - bits := binary.LittleEndian.Uint32(floatData[i : i+4]) - values = append(values, math.Float32frombits(bits)) - } - - return values, nil -} diff --git a/bson/vector_unmarshal_test.go b/bson/vector_unmarshal_test.go index cf2658a606..01ca9448ac 100644 --- a/bson/vector_unmarshal_test.go +++ b/bson/vector_unmarshal_test.go @@ -7,8 +7,7 @@ package bson import ( - "encoding/binary" - "math" + "fmt" "testing" "github.com/stretchr/testify/require" @@ -73,38 +72,25 @@ func TestUnmarshalVectorToSlices(t *testing.T) { t.Run("int8 vector to []int8", func(t *testing.T) { t.Parallel() - vectorData := []byte{ - 0x01, // int8 vector type - 0x00, // padding - 0x01, 0x02, 0x03, 0x04, // int8 values - } - - bsonData := createBSONWithBinary(vectorData) - + doc := D{{"v", NewVector([]int8{-2, 1, 2, 3, 4})}} + bsonData, err := Marshal(doc) + require.NoError(t, err) var result struct{ V []int8 } - err := Unmarshal(bsonData, &result) + err = Unmarshal(bsonData, &result) require.NoError(t, err) - require.Equal(t, []int8{1, 2, 3, 4}, result.V) + require.Equal(t, []int8{-2, 1, 2, 3, 4}, result.V) }) t.Run("float32 vector to []float32", func(t *testing.T) { t.Parallel() - vectorData := make([]byte, 2+4*4) // type + padding + 4 float32s - vectorData[0] = 0x02 // float32 vector type - vectorData[1] = 0x00 // padding - - binary.LittleEndian.PutUint32(vectorData[2:], math.Float32bits(1.0)) - binary.LittleEndian.PutUint32(vectorData[6:], math.Float32bits(2.0)) - binary.LittleEndian.PutUint32(vectorData[10:], math.Float32bits(3.0)) - binary.LittleEndian.PutUint32(vectorData[14:], math.Float32bits(4.0)) - - bsonData := createBSONWithBinary(vectorData) - + doc := D{{"v", NewVector([]float32{1.1, 2.2, 3.3, 4.4})}} + bsonData, err := Marshal(doc) + require.NoError(t, err) var result struct{ V []float32 } - err := Unmarshal(bsonData, &result) + err = Unmarshal(bsonData, &result) require.NoError(t, err) - require.InEpsilonSlice(t, []float32{1.0, 2.0, 3.0, 4.0}, result.V, 0.0001) + require.InDeltaSlice(t, []float32{1.1, 2.2, 3.3, 4.4}, result.V, 0.001) }) t.Run("invalid vector type to slice", func(t *testing.T) { @@ -119,10 +105,14 @@ func TestUnmarshalVectorToSlices(t *testing.T) { t.Run("to []int8", func(t *testing.T) { t.Parallel() + + vectorData := []byte{0x10, 0x00} // Invalid vector type + bsonData := createBSONWithBinary(vectorData) + var result struct{ V []int8 } err := Unmarshal(bsonData, &result) require.Error(t, err) - require.Contains(t, err.Error(), "invalid vector type: expected int8 vector (0x01)") + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected int8 vector (0x%02x)", Int8Vector)) }) t.Run("to []float32", func(t *testing.T) { @@ -130,30 +120,35 @@ func TestUnmarshalVectorToSlices(t *testing.T) { var result struct{ V []float32 } err := Unmarshal(bsonData, &result) require.Error(t, err) - require.Contains(t, err.Error(), "invalid vector type: expected float32 vector (0x02)") + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected float32 vector (0x%02x)", Float32Vector)) }) }) t.Run("invalid binary data", func(t *testing.T) { t.Parallel() - invalidData := []byte{0x01, 0x01, 0x02, 0x03} - bsonData := createBSONWithBinary(invalidData) + vectorData := []byte{0x01, 0x00, 0x01, 0x02, 0x03, 0x04} + bsonData := createBSONWithBinary(vectorData) t.Run("to []int8", func(t *testing.T) { t.Parallel() + var result struct{ V []int8 } err := Unmarshal(bsonData, &result) require.Error(t, err) - require.Contains(t, err.Error(), "invalid vector: padding byte must be 0") + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected int8 vector (0x%02x)", Int8Vector)) }) t.Run("to []float32", func(t *testing.T) { t.Parallel() + + vectorData := []byte{0x01, 0x00} + bsonData := createBSONWithBinary(vectorData) + var result struct{ V []float32 } err := Unmarshal(bsonData, &result) require.Error(t, err) - require.Contains(t, err.Error(), "invalid vector type: expected float32 vector (0x02)") + require.Contains(t, err.Error(), fmt.Sprintf("invalid vector type: expected float32 vector (0x%02x)", Float32Vector)) }) }) } diff --git a/mongo/mongo_test.go b/mongo/mongo_test.go index 96be905cb5..3270ee638e 100644 --- a/mongo/mongo_test.go +++ b/mongo/mongo_test.go @@ -7,6 +7,7 @@ package mongo import ( + "context" "errors" "fmt" "reflect" @@ -652,3 +653,188 @@ type bvMarsh struct { func (b bvMarsh) MarshalBSONValue() (byte, []byte, error) { return byte(b.t), b.data, b.err } + +func TestVectorIntegration(t *testing.T) { + t.Run("roundtrip int8 vector", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []int8 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": "test_int8"}, + }}) + + expected := vectorDoc{ + ID: "test_int8", + Vec: []int8{-2, -1, 0, 1, 2}, + } + + _, err := coll.InsertOne(ctx, expected) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": "test_int8"}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if !reflect.DeepEqual(expected.Vec, result.Vec) { + t.Errorf("vector data does not match. Expected %v, got %v", expected.Vec, result.Vec) + } + }) + + t.Run("roundtrip float32 vector", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []float32 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": "test_float32"}, + }}) + expected := vectorDoc{ + ID: "test_float32", + Vec: []float32{-1.1, 0.0, 0.5, 1.1, 2.2}, + } + + _, err := coll.InsertOne(ctx, expected) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": "test_float32"}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if len(expected.Vec) != len(result.Vec) { + t.Fatalf("vector length mismatch: expected %d, got %d", len(expected.Vec), len(result.Vec)) + } + for i := range expected.Vec { + if diff := expected.Vec[i] - result.Vec[i]; diff < -0.0001 || diff > 0.0001 { + t.Errorf("vector element %d mismatch: expected %v, got %v", i, expected.Vec[i], result.Vec[i]) + } + } + }) + + t.Run("bson.NewVector with int8", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []int8 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + testID := "test_new_vector_int8" + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": testID}, + }}) + + expected := []int8{-2, -1, 0, 1, 2} + + _, err := coll.InsertOne(ctx, bson.D{ + {Key: "_id", Value: testID}, + {Key: "v", Value: bson.NewVector(expected)}, + }) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": testID}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if !reflect.DeepEqual(expected, result.Vec) { + t.Errorf("vector data does not match. Expected %v, got %v", expected, result.Vec) + } + }) + + t.Run("bson.NewVector with float32", func(t *testing.T) { + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + type vectorDoc struct { + ID string `bson:"_id"` + Vec []float32 `bson:"v"` + } + + ctx := context.Background() + client := setupClient() + defer func() { + _ = client.Disconnect(ctx) + }() + + db := client.Database("test") + coll := db.Collection("vector_test") + + testID := "test_new_vector_float32" + _, _ = coll.DeleteMany(ctx, bson.M{"$or": []bson.M{ + {"_id": testID}, + }}) + + expected := []float32{-1.1, 0.0, 0.5, 1.1, 2.2} + + _, err := coll.InsertOne(ctx, bson.D{ + {Key: "_id", Value: testID}, + {Key: "v", Value: bson.NewVector(expected)}, + }) + if err != nil { + t.Fatalf("InsertOne error: %v", err) + } + + var result vectorDoc + err = coll.FindOne(ctx, bson.M{"_id": testID}).Decode(&result) + if err != nil { + t.Fatalf("FindOne error: %v", err) + } + + if len(expected) != len(result.Vec) { + t.Fatalf("vector length mismatch: expected %d, got %d", len(expected), len(result.Vec)) + } + for i := range expected { + if diff := expected[i] - result.Vec[i]; diff < -0.0001 || diff > 0.0001 { + t.Errorf("vector element %d mismatch: expected %v, got %v", i, expected[i], result.Vec[i]) + } + } + }) +}