diff --git a/dialect/pgdialect/append.go b/dialect/pgdialect/append.go index a60bf5de2..75798b385 100644 --- a/dialect/pgdialect/append.go +++ b/dialect/pgdialect/append.go @@ -27,6 +27,9 @@ var ( float64Type = reflect.TypeOf((*float64)(nil)).Elem() sliceFloat64Type = reflect.TypeOf([]float64(nil)) + + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + sliceTimeType = reflect.TypeOf([]time.Time(nil)) ) func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte { @@ -93,6 +96,8 @@ func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc { return appendInt64SliceValue case float64Type: return appendFloat64SliceValue + case timeType: + return appendTimeSliceValue } } @@ -308,6 +313,32 @@ func arrayAppendString(b []byte, s string) []byte { return b } +func appendTimeSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte { + ts := v.Convert(sliceTimeType).Interface().([]time.Time) + return appendTimeSlice(fmter, b, ts) +} + +func appendTimeSlice(fmter schema.Formatter, b []byte, ts []time.Time) []byte { + if ts == nil { + return dialect.AppendNull(b) + } + b = append(b, '\'') + b = append(b, '{') + for _, t := range ts { + b = append(b, '"') + b = t.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00") + b = append(b, '"') + b = append(b, ',') + } + if len(ts) > 0 { + b[len(b)-1] = '}' // Replace trailing comma. + } else { + b = append(b, '}') + } + b = append(b, '\'') + return b +} + //------------------------------------------------------------------------------ var mapStringStringType = reflect.TypeOf(map[string]string(nil)) diff --git a/go.work.sum b/go.work.sum index 05eb5544f..1e0326e1e 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,7 +1,12 @@ +github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4= github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/grpc-ecosystem/grpc-gateway v1.16.0 h1:gmcG1KaJ57LophUzW0Hy8NmPhnMZb4M0+kPpLofRdBo= +github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= +github.com/jackc/pgproto3 v1.1.0 h1:FYYE4yRw+AgI8wXIinMlNjBbp/UitDJwfj5LqqewP1A= github.com/mattn/go-sqlite3 v1.14.15/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/urfave/cli v1.22.1 h1:+mkCCcOFKPnCmVYVcURKps1Xe+3zP90gSYGNfRkjoIY= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= diff --git a/internal/dbtest/pg_test.go b/internal/dbtest/pg_test.go index fbcb1100d..4521f11b3 100644 --- a/internal/dbtest/pg_test.go +++ b/internal/dbtest/pg_test.go @@ -463,6 +463,59 @@ func TestPostgresTimetz(t *testing.T) { require.NotZero(t, tm) } +func TestPostgresTimeArray(t *testing.T) { + type Model struct { + ID int64 `bun:",pk,autoincrement"` + Array1 []time.Time `bun:",array"` + Array2 *[]time.Time `bun:",array"` + Array3 *[]time.Time `bun:",array"` + } + db := pg(t) + defer db.Close() + + _, err := db.NewDropTable().Model((*Model)(nil)).IfExists().Exec(ctx) + require.NoError(t, err) + + _, err = db.NewCreateTable().Model((*Model)(nil)).Exec(ctx) + require.NoError(t, err) + + time1 := time.Now() + time2 := time.Now().Add(time.Hour) + time3 := time.Now().AddDate(0, 0, 1) + + model1 := &Model{ + ID: 123, + Array1: []time.Time{time1, time2, time3}, + Array2: &[]time.Time{time1, time2, time3}, + } + _, err = db.NewInsert().Model(model1).Exec(ctx) + require.NoError(t, err) + + model2 := new(Model) + err = db.NewSelect().Model(model2).Scan(ctx) + require.NoError(t, err) + require.Equal(t, len(model1.Array1), len(model2.Array1)) + + var times []time.Time + err = db.NewSelect().Model((*Model)(nil)). + Column("array1"). + Scan(ctx, pgdialect.Array(×)) + require.NoError(t, err) + require.Equal(t, len(times), len(model1.Array1)) + + err = db.NewSelect().Model((*Model)(nil)). + Column("array2"). + Scan(ctx, pgdialect.Array(×)) + require.NoError(t, err) + require.Equal(t, 3, len(*model1.Array2)) + + err = db.NewSelect().Model((*Model)(nil)). + Column("array3"). + Scan(ctx, pgdialect.Array(×)) + require.NoError(t, err) + require.Nil(t, times) +} + func TestPostgresOnConflictDoUpdate(t *testing.T) { type Model struct { ID int64 `bun:",pk,autoincrement"`