Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

James.lawrence/duckdbtypes #365

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions appender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ import (
_ "time/tzdata"

"github.com/go-viper/mapstructure/v2"
"github.com/google/uuid"
"github.com/marcboeker/go-duckdb/duckdbtypes"
"github.com/marcboeker/go-duckdb/internal/uuidx"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -56,7 +57,7 @@ type mixedStruct struct {
L []int32
}
C struct {
L Map
L duckdbtypes.Map
}
}

Expand Down Expand Up @@ -228,7 +229,7 @@ func TestAppenderArray(t *testing.T) {
c, con, a := prepareAppender(t, `CREATE TABLE test (string_array VARCHAR[3])`)

count := 10
expected := Composite[[3]string]{[3]string{"a", "b", "c"}}
expected := duckdbtypes.NewComposite([3]string{"a", "b", "c"})
for i := 0; i < count; i++ {
require.NoError(t, a.AppendRow([]string{"a", "b", "c"}))
require.NoError(t, a.AppendRow(expected.Get()))
Expand All @@ -241,7 +242,7 @@ func TestAppenderArray(t *testing.T) {

i := 0
for res.Next() {
var r Composite[[3]string]
var r duckdbtypes.Composite[[3]string]
require.NoError(t, res.Scan(&r))
require.Equal(t, expected, r)
i++
Expand Down Expand Up @@ -481,11 +482,11 @@ func TestAppenderUUID(t *testing.T) {
t.Parallel()
c, con, a := prepareAppender(t, `CREATE TABLE test (id UUID)`)

id := UUID(uuid.New())
otherId := UUID(uuid.New())
id := duckdbtypes.UUID(uuidx.Random())
otherId := duckdbtypes.UUID(uuidx.Random())
require.NoError(t, a.AppendRow(id))
require.NoError(t, a.AppendRow(&otherId))
require.NoError(t, a.AppendRow((*UUID)(nil)))
require.NoError(t, a.AppendRow((*duckdbtypes.UUID)(nil)))
require.NoError(t, a.AppendRow(nil))
require.NoError(t, a.Flush())

Expand All @@ -496,11 +497,11 @@ func TestAppenderUUID(t *testing.T) {
i := 0
for res.Next() {
if i == 0 {
var r UUID
var r duckdbtypes.UUID
require.NoError(t, res.Scan(&r))
require.Equal(t, id, r)
} else {
var r *UUID
var r *duckdbtypes.UUID
require.NoError(t, res.Scan(&r))
if i == 1 {
require.Equal(t, otherId, *r)
Expand Down Expand Up @@ -733,8 +734,8 @@ func TestAppenderDecimal(t *testing.T) {
)`)

require.NoError(t, a.AppendRow(nil))
require.NoError(t, a.AppendRow(Decimal{Width: uint8(4), Value: big.NewInt(1), Scale: 3}))
require.NoError(t, a.AppendRow(Decimal{Width: uint8(4), Value: big.NewInt(2), Scale: 3}))
require.NoError(t, a.AppendRow(duckdbtypes.Decimal{Width: uint8(4), Value: big.NewInt(1), Scale: 3}))
require.NoError(t, a.AppendRow(duckdbtypes.Decimal{Width: uint8(4), Value: big.NewInt(2), Scale: 3}))
require.NoError(t, a.Flush())

// Verify results.
Expand Down Expand Up @@ -886,8 +887,8 @@ func prepareNestedData(rowCount int) []nestedDataRow {
{[]int32{1, 2, 3}},
},
C: struct {
L Map
}{L: Map{"foo": int32(1), "bar": int32(2)}},
L duckdbtypes.Map
}{L: duckdbtypes.Map{"foo": int32(1), "bar": int32(2)}},
}

rowsToAppend := make([]nestedDataRow, rowCount)
Expand Down
4 changes: 3 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"errors"
"math/big"
"unsafe"

"github.com/marcboeker/go-duckdb/duckdbtypes"
)

// Conn holds a connection to a DuckDB database.
Expand All @@ -25,7 +27,7 @@ type Conn struct {
// CheckNamedValue implements the driver.NamedValueChecker interface.
func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
switch nv.Value.(type) {
case *big.Int, Interval:
case *big.Int, duckdbtypes.Interval:
return nil
}
return driver.ErrSkip
Expand Down
11 changes: 6 additions & 5 deletions duckdb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"testing"
"time"

"github.com/marcboeker/go-duckdb/duckdbtypes"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -281,7 +282,7 @@ func TestJSON(t *testing.T) {
db := openDB(t)

t.Run("SELECT an empty JSON", func(t *testing.T) {
var res Composite[map[string]any]
var res duckdbtypes.Composite[map[string]any]
require.NoError(t, db.QueryRow(`SELECT '{}'::JSON`).Scan(&res))
require.Empty(t, res.Get())
})
Expand All @@ -300,7 +301,7 @@ func TestJSON(t *testing.T) {
})

t.Run("SELECT a JSON array", func(t *testing.T) {
var res Composite[[]any]
var res duckdbtypes.Composite[[]any]
require.NoError(t, db.QueryRow(`SELECT json_array('foo', 'bar')`).Scan(&res))
require.Len(t, res.Get(), 2)
require.Equal(t, "foo", res.Get()[0])
Expand Down Expand Up @@ -415,7 +416,7 @@ func TestTypeNamesAndScanTypes(t *testing.T) {
// DUCKDB_TYPE_INTERVAL
{
sql: "SELECT INTERVAL 15 MINUTES AS col",
value: Interval{Micros: 15 * 60 * 1000000},
value: duckdbtypes.Interval{Micros: 15 * 60 * 1000000},
typeName: "INTERVAL",
},
// DUCKDB_TYPE_HUGEINT
Expand All @@ -439,7 +440,7 @@ func TestTypeNamesAndScanTypes(t *testing.T) {
// DUCKDB_TYPE_DECIMAL
{
sql: "SELECT 31::DECIMAL(30,17) AS col",
value: Decimal{Value: big.NewInt(3100000000000000000), Width: 30, Scale: 17},
value: duckdbtypes.Decimal{Value: big.NewInt(3100000000000000000), Width: 30, Scale: 17},
typeName: "DECIMAL(30,17)",
},
// DUCKDB_TYPE_TIMESTAMP_S
Expand Down Expand Up @@ -480,7 +481,7 @@ func TestTypeNamesAndScanTypes(t *testing.T) {
// DUCKDB_TYPE_MAP
{
sql: "SELECT map([1, 5], ['a', 'e']) AS col",
value: Map{int32(1): "a", int32(5): "e"},
value: duckdbtypes.Map{int32(1): "a", int32(5): "e"},
typeName: "MAP(INTEGER, VARCHAR)",
},
// DUCKDB_TYPE_ARRAY
Expand Down
20 changes: 20 additions & 0 deletions duckdbtypes/composite.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package duckdbtypes

import "github.com/go-viper/mapstructure/v2"

func NewComposite[T any](v T) Composite[T] {
return Composite[T]{v}
}

// Use as the `Scanner` type for any composite types (maps, lists, structs)
type Composite[T any] struct {
t T
}

func (s Composite[T]) Get() T {
return s.t
}

func (s *Composite[T]) Scan(v any) error {
return mapstructure.Decode(v, &s.t)
}
52 changes: 52 additions & 0 deletions duckdbtypes/decimal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package duckdbtypes

import (
"fmt"
"math/big"
"strings"
)

type Decimal struct {
Width uint8
Scale uint8
Value *big.Int
}

func (d *Decimal) Float64() float64 {
scale := big.NewInt(int64(d.Scale))
factor := new(big.Float).SetInt(new(big.Int).Exp(big.NewInt(10), scale, nil))
value := new(big.Float).SetInt(d.Value)
value.Quo(value, factor)
f, _ := value.Float64()
return f
}

func (d *Decimal) String() string {
// Get the sign, and return early if zero
if d.Value.Sign() == 0 {
return "0"
}

// Remove the sign from the string integer value
var signStr string
scaleless := d.Value.String()
if d.Value.Sign() < 0 {
signStr = "-"
scaleless = scaleless[1:]
}

// Remove all zeros from the right side
zeroTrimmed := strings.TrimRightFunc(scaleless, func(r rune) bool { return r == '0' })
scale := int(d.Scale) - (len(scaleless) - len(zeroTrimmed))

// If the string is still bigger than the scale factor, output it without a decimal point
if scale <= 0 {
return signStr + zeroTrimmed + strings.Repeat("0", -1*scale)
}

// Pad a number with 0.0's if needed
if len(zeroTrimmed) <= scale {
return fmt.Sprintf("%s0.%s%s", signStr, strings.Repeat("0", scale-len(zeroTrimmed)), zeroTrimmed)
}
return signStr + zeroTrimmed[:len(zeroTrimmed)-scale] + "." + zeroTrimmed[len(zeroTrimmed)-scale:]
}
68 changes: 68 additions & 0 deletions duckdbtypes/duckdbtypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package duckdbtypes

import "reflect"

var kindTypes map[reflect.Kind]reflect.Type

func toInterface(dst reflect.Value, t reflect.Type) (interface{}, bool) {
nextDst := dst.Convert(t)
return nextDst.Interface(), dst.Type() != nextDst.Type()
}

// GetAssignToDstType attempts to convert dst to something AssignTo can assign
// to. If dst is a pointer to pointer it allocates a value and returns the
// dereferences pointer. If dst is a named type such as *Foo where Foo is type
// Foo int16, it converts dst to *int16.
//
// GetAssignToDstType returns the converted dst and a bool representing if any
// change was made.
func GetAssignToDstType(dst interface{}) (interface{}, bool) {
dstPtr := reflect.ValueOf(dst)

// AssignTo dst must always be a pointer
if dstPtr.Kind() != reflect.Ptr {
return nil, false
}

dstVal := dstPtr.Elem()

// if dst is a pointer to pointer, allocate space try again with the dereferenced pointer
if dstVal.Kind() == reflect.Ptr {
dstVal.Set(reflect.New(dstVal.Type().Elem()))
return dstVal.Interface(), true
}

// if dst is pointer to a base type that has been renamed
if baseValType, ok := kindTypes[dstVal.Kind()]; ok {
return toInterface(dstPtr, reflect.PtrTo(baseValType))
}

if dstVal.Kind() == reflect.Slice {
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
return toInterface(dstPtr, reflect.PtrTo(reflect.SliceOf(baseElemType)))
}
}

if dstVal.Kind() == reflect.Array {
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)))
}
}

if dstVal.Kind() == reflect.Struct {
if dstVal.Type().NumField() == 1 && dstVal.Type().Field(0).Anonymous {
dstPtr = dstVal.Field(0).Addr()
nested := dstVal.Type().Field(0).Type
if nested.Kind() == reflect.Array {
if baseElemType, ok := kindTypes[nested.Elem().Kind()]; ok {
return toInterface(dstPtr, reflect.PtrTo(reflect.ArrayOf(nested.Len(), baseElemType)))
}
}
if _, ok := kindTypes[nested.Kind()]; ok && dstPtr.CanInterface() {
return dstPtr.Interface(), true
}
}
}

return nil, false
}
7 changes: 7 additions & 0 deletions duckdbtypes/interval.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package duckdbtypes

type Interval struct {
Days int32 `json:"days"`
Months int32 `json:"months"`
Micros int64 `json:"micros"`
}
15 changes: 15 additions & 0 deletions duckdbtypes/map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package duckdbtypes

import "fmt"

type Map map[any]any

func (m *Map) Scan(v any) error {
data, ok := v.(Map)
if !ok {
return fmt.Errorf("invalid type `%T` for scanning `Map`, expected `Map`", data)
}

*m = data
return nil
}
67 changes: 67 additions & 0 deletions duckdbtypes/uuid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package duckdbtypes

import (
"encoding/hex"
"fmt"

"github.com/marcboeker/go-duckdb/internal/uuidx"
)

type UUID [uuidx.ByteLength]byte

func (t *UUID) AssignTo(dst any) error {
switch v := dst.(type) {
case *[16]byte:
*v = *t
return nil
case *[]byte:
*v = make([]byte, 16)
copy(*v, t[:])
return nil
case *string:
*v = t.String()
return nil
default:
if nextDst, retry := GetAssignToDstType(v); retry {
return t.AssignTo(nextDst)
}
}

return fmt.Errorf("cannot assign %v into %T", t, dst)
}

func (u *UUID) Scan(v any) error {
switch val := v.(type) {
case []byte:
if len(val) != uuidx.ByteLength {
return u.Scan(string(val))
}
copy(u[:], val[:])
return nil
case string:
id, err := uuidx.Parse(val)
if err != nil {
return err
}
copy(u[:], id[:])
return nil
default:
return fmt.Errorf("invalid UUID value type: %T", val)
}
}

func (u *UUID) String() string {
buf := make([]byte, 36)

hex.Encode(buf, u[:4])
buf[8] = '-'
hex.Encode(buf[9:13], u[4:6])
buf[13] = '-'
hex.Encode(buf[14:18], u[6:8])
buf[18] = '-'
hex.Encode(buf[19:23], u[8:10])
buf[23] = '-'
hex.Encode(buf[24:], u[10:])

return string(buf)
}
Loading