Skip to content

Commit c72742c

Browse files
pymqvmihailenco
authored andcommittedMar 21, 2020
add possibility to override json library for jsonb fields
1 parent b0e9c54 commit c72742c

8 files changed

+103
-13
lines changed
 

‎conv_test.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@ import (
1010
"testing"
1111
"time"
1212

13-
"github.com/segmentio/encoding/json"
14-
1513
"github.com/go-pg/pg/v9"
1614
"github.com/go-pg/pg/v9/orm"
15+
"github.com/go-pg/pg/v9/pgjson"
1716
"github.com/go-pg/pg/v9/types"
1817
)
1918

@@ -24,11 +23,11 @@ func (m *JSONMap) Scan(b interface{}) error {
2423
*m = nil
2524
return nil
2625
}
27-
return json.Unmarshal(b.([]byte), m)
26+
return pgjson.Unmarshal(b.([]byte), m)
2827
}
2928

3029
func (m JSONMap) Value() (driver.Value, error) {
31-
b, err := json.Marshal(m)
30+
b, err := pgjson.Marshal(m)
3231
if err != nil {
3332
return nil, err
3433
}

‎db_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ func TestAnynomousStructField(t *testing.T) {
148148

149149
var st MyStruct
150150
_, err := db.Query(&st, "SELECT ARRAY[1,2,3,4] AS ints")
151-
wanted := `json: cannot unmarshal "1" into Go value of type pg_test.MyInt`
152-
if err.Error() != wanted {
151+
if !strings.Contains(err.Error(), "json: cannot unmarshal") {
153152
t.Fatal(err)
154153
}
155154
}

‎orm/table.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/vmihailenco/tagparser"
1717

1818
"github.com/go-pg/pg/v9/internal"
19+
"github.com/go-pg/pg/v9/pgjson"
1920
"github.com/go-pg/pg/v9/types"
2021
"github.com/go-pg/zerochecker"
2122
)
@@ -1005,7 +1006,7 @@ func scanJSONValue(v reflect.Value, rd types.Reader, n int) error {
10051006
return nil
10061007
}
10071008

1008-
dec := json.NewDecoder(rd)
1009+
dec := pgjson.NewDecoder(rd)
10091010
dec.UseNumber()
10101011
return dec.Decode(v.Addr().Interface())
10111012
}

‎pgjson/json.go

+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package pgjson
2+
3+
import (
4+
"encoding/json"
5+
"io"
6+
7+
json2 "github.com/segmentio/encoding/json"
8+
)
9+
10+
var _ Provider = (*StdProvider)(nil)
11+
12+
type StdProvider struct{}
13+
14+
func (StdProvider) Marshal(v interface{}) ([]byte, error) {
15+
return json.Marshal(v)
16+
}
17+
18+
func (StdProvider) Unmarshal(data []byte, v interface{}) error {
19+
return json.Unmarshal(data, v)
20+
}
21+
22+
func (StdProvider) NewEncoder(w io.Writer) Encoder {
23+
return json.NewEncoder(w)
24+
}
25+
26+
func (StdProvider) NewDecoder(r io.Reader) Decoder {
27+
return json.NewDecoder(r)
28+
}
29+
30+
var _ Provider = (*SegmentioProvider)(nil)
31+
32+
type SegmentioProvider struct{}
33+
34+
func (SegmentioProvider) Marshal(v interface{}) ([]byte, error) {
35+
return json2.Marshal(v)
36+
}
37+
38+
func (SegmentioProvider) Unmarshal(data []byte, v interface{}) error {
39+
return json2.Unmarshal(data, v)
40+
}
41+
42+
func (SegmentioProvider) NewEncoder(w io.Writer) Encoder {
43+
return json2.NewEncoder(w)
44+
}
45+
46+
func (SegmentioProvider) NewDecoder(r io.Reader) Decoder {
47+
return json2.NewDecoder(r)
48+
}

‎pgjson/provider.go

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package pgjson
2+
3+
import (
4+
"io"
5+
)
6+
7+
var provider Provider = StdProvider{}
8+
9+
func SetProvider(p Provider) {
10+
provider = p
11+
}
12+
13+
type Provider interface {
14+
Marshal(v interface{}) ([]byte, error)
15+
Unmarshal(data []byte, v interface{}) error
16+
NewEncoder(w io.Writer) Encoder
17+
NewDecoder(r io.Reader) Decoder
18+
}
19+
20+
type Decoder interface {
21+
Decode(v interface{}) error
22+
UseNumber()
23+
}
24+
25+
type Encoder interface {
26+
Encode(v interface{}) error
27+
}
28+
29+
func Marshal(v interface{}) ([]byte, error) {
30+
return provider.Marshal(v)
31+
}
32+
33+
func Unmarshal(data []byte, v interface{}) error {
34+
return provider.Unmarshal(data, v)
35+
}
36+
37+
func NewEncoder(w io.Writer) Encoder {
38+
return provider.NewEncoder(w)
39+
}
40+
41+
func NewDecoder(r io.Reader) Decoder {
42+
return provider.NewDecoder(r)
43+
}

‎types/append_jsonb_test.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ import (
44
"bytes"
55
"testing"
66

7-
"github.com/segmentio/encoding/json"
8-
7+
"github.com/go-pg/pg/v9/pgjson"
98
"github.com/go-pg/pg/v9/types"
109
)
1110

@@ -30,7 +29,7 @@ func TestAppendJSONB(t *testing.T) {
3029
}
3130

3231
func BenchmarkAppendJSONB(b *testing.B) {
33-
bytes, err := json.Marshal(jsonbTests)
32+
bytes, err := pgjson.Marshal(jsonbTests)
3433
if err != nil {
3534
b.Fatal(err)
3635
}

‎types/append_value.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ import (
99
"sync"
1010
"time"
1111

12-
"github.com/segmentio/encoding/json"
1312
"github.com/vmihailenco/bufpool"
1413

1514
"github.com/go-pg/pg/v9/internal"
15+
"github.com/go-pg/pg/v9/pgjson"
1616
)
1717

1818
var driverValuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
@@ -196,7 +196,7 @@ func appendJSONValue(b []byte, v reflect.Value, flags int) []byte {
196196
buf := internal.GetBuffer()
197197
defer internal.PutBuffer(buf)
198198

199-
if err := json.NewEncoder(buf).Encode(v.Interface()); err != nil {
199+
if err := pgjson.NewEncoder(buf).Encode(v.Interface()); err != nil {
200200
return AppendError(b, err)
201201
}
202202

‎types/scan_value.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"github.com/segmentio/encoding/json"
1313

1414
"github.com/go-pg/pg/v9/internal"
15+
"github.com/go-pg/pg/v9/pgjson"
1516
)
1617

1718
var valueScannerType = reflect.TypeOf((*ValueScanner)(nil)).Elem()
@@ -283,7 +284,7 @@ func scanJSONValue(v reflect.Value, rd Reader, n int) error {
283284
return nil
284285
}
285286

286-
dec := json.NewDecoder(rd)
287+
dec := pgjson.NewDecoder(rd)
287288
return dec.Decode(v.Addr().Interface())
288289
}
289290

0 commit comments

Comments
 (0)