Skip to content

Commit 82746cc

Browse files
committed
Implement BSON Marshaler support
1 parent b710ba4 commit 82746cc

File tree

6 files changed

+130
-2
lines changed

6 files changed

+130
-2
lines changed

go.mod

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
module github.com/deckarep/golang-set/v2
22

33
go 1.18
4+
5+
require go.mongodb.org/mongo-driver v1.16.0

go.sum

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
2+
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
3+
go.mongodb.org/mongo-driver v1.16.0 h1:tpRsfBJMROVHKpdGyc1BBEzzjDUWjItxbVSZ8Ls4BQ4=
4+
go.mongodb.org/mongo-driver v1.16.0/go.mod h1:oB6AhJQvFQL4LEHyXi6aJzQJtBiTQHiAd83l0GdFaiw=

set.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,15 @@ type Set[T comparable] interface {
192192
MarshalJSON() ([]byte, error)
193193

194194
// UnmarshalJSON will unmarshal a JSON-based byte slice into a full Set datastructure.
195-
// For this to work, set subtypes must implemented the Marshal/Unmarshal interface.
195+
// For this to work, set subtypes must implement the Marshal/Unmarshal interface.
196196
UnmarshalJSON(b []byte) error
197+
198+
// MarshalBSON will marshal the set into a BSON-based representation.
199+
MarshalBSON() ([]byte, error)
200+
201+
// UnmarshalBSON will unmarshal a BSON-based byte slice into a full Set datastructure.
202+
// For this to work, set subtypes must implement the Marshal/Unmarshal interface.
203+
UnmarshalBSON(b []byte) error
197204
}
198205

199206
// NewSet creates and returns a new set with the given elements.

threadsafe.go

+17-1
Original file line numberDiff line numberDiff line change
@@ -291,9 +291,25 @@ func (t *threadSafeSet[T]) MarshalJSON() ([]byte, error) {
291291
}
292292

293293
func (t *threadSafeSet[T]) UnmarshalJSON(p []byte) error {
294-
t.RLock()
294+
t.Lock()
295295
err := t.uss.UnmarshalJSON(p)
296+
t.Unlock()
297+
298+
return err
299+
}
300+
301+
func (t *threadSafeSet[T]) MarshalBSON() ([]byte, error) {
302+
t.RLock()
303+
b, err := t.uss.MarshalBSON()
296304
t.RUnlock()
297305

306+
return b, err
307+
}
308+
309+
func (t *threadSafeSet[T]) UnmarshalBSON(p []byte) error {
310+
t.Lock()
311+
err := t.uss.UnmarshalBSON(p)
312+
t.Unlock()
313+
298314
return err
299315
}

threadsafe_test.go

+78
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232
"sync"
3333
"sync/atomic"
3434
"testing"
35+
36+
"go.mongodb.org/mongo-driver/bson"
3537
)
3638

3739
const N = 1000
@@ -625,3 +627,79 @@ func Test_MarshalJSON(t *testing.T) {
625627
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
626628
}
627629
}
630+
631+
func Test_UnmarshalBSON(t *testing.T) {
632+
tp, s, initErr := bson.MarshalValue(
633+
bson.A{"1", "2", "3", "test"},
634+
)
635+
636+
if initErr != nil {
637+
t.Errorf("Init Error should be nil: %v", initErr)
638+
639+
return
640+
}
641+
642+
if tp != bson.TypeArray {
643+
t.Errorf("Encoded Type should be bson.Array, got: %v", tp)
644+
645+
return
646+
}
647+
648+
expected := NewSet("1", "2", "3", "test")
649+
actual := NewSet[string]()
650+
err := bson.UnmarshalValue(bson.TypeArray, s, actual)
651+
if err != nil {
652+
t.Errorf("Error should be nil: %v", err)
653+
}
654+
655+
if !expected.Equal(actual) {
656+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
657+
}
658+
}
659+
func TestThreadUnsafeSet_UnmarshalBSON(t *testing.T) {
660+
tp, s, initErr := bson.MarshalValue(
661+
bson.A{int64(1), int64(2), int64(3)},
662+
)
663+
664+
if initErr != nil {
665+
t.Errorf("Init Error should be nil: %v", initErr)
666+
667+
return
668+
}
669+
670+
if tp != bson.TypeArray {
671+
t.Errorf("Encoded Type should be bson.Array, got: %v", tp)
672+
673+
return
674+
}
675+
676+
expected := NewThreadUnsafeSet[int64](1, 2, 3)
677+
actual := NewThreadUnsafeSet[int64]()
678+
err := actual.UnmarshalBSON([]byte(s))
679+
if err != nil {
680+
t.Errorf("Error should be nil: %v", err)
681+
}
682+
if !expected.Equal(actual) {
683+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
684+
}
685+
}
686+
func Test_MarshalBSON(t *testing.T) {
687+
expected := NewSet("1", "test")
688+
689+
_, b, err := bson.MarshalValue(
690+
NewSet("1", "test"),
691+
)
692+
if err != nil {
693+
t.Errorf("Error should be nil: %v", err)
694+
}
695+
696+
actual := NewSet[string]()
697+
err = bson.UnmarshalValue(bson.TypeArray, b, actual)
698+
if err != nil {
699+
t.Errorf("Error should be nil: %v", err)
700+
}
701+
702+
if !expected.Equal(actual) {
703+
t.Errorf("Expected no difference, got: %v", expected.Difference(actual))
704+
}
705+
}

threadunsafe.go

+21
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ import (
2929
"encoding/json"
3030
"fmt"
3131
"strings"
32+
33+
"go.mongodb.org/mongo-driver/bson"
3234
)
3335

3436
type threadUnsafeSet[T comparable] map[T]struct{}
@@ -328,3 +330,22 @@ func (s threadUnsafeSet[T]) UnmarshalJSON(b []byte) error {
328330

329331
return nil
330332
}
333+
334+
// MarshalBSON creates a BSON array from the set.
335+
func (s threadUnsafeSet[T]) MarshalBSON() ([]byte, error) {
336+
_, data, err := bson.MarshalValue(s.ToSlice())
337+
338+
return data, err
339+
}
340+
341+
// UnmarshalBSON recreates a set from a BSON array.
342+
func (s threadUnsafeSet[T]) UnmarshalBSON(b []byte) error {
343+
var i []T
344+
err := bson.UnmarshalValue(bson.TypeArray, b, &i)
345+
if err != nil {
346+
return err
347+
}
348+
s.Append(i...)
349+
350+
return nil
351+
}

0 commit comments

Comments
 (0)