Skip to content

Commit

Permalink
[PECO-1962] Support positional query parameters (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
kravets-levko authored Oct 22, 2024
1 parent 909d73f commit 1e9d6ac
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 49 deletions.
7 changes: 6 additions & 1 deletion connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,11 @@ func invalidOperationState(ctx context.Context, opStatus *cli_service.TGetOperat
func (c *conn) executeStatement(ctx context.Context, query string, args []driver.NamedValue) (*cli_service.TExecuteStatementResp, error) {
ctx = driverctx.NewContextWithConnId(ctx, c.id)

parameters, err := convertNamedValuesToSparkParams(args)
if err != nil {
return nil, err
}

req := cli_service.TExecuteStatementReq{
SessionHandle: c.session.SessionHandle,
Statement: query,
Expand All @@ -284,7 +289,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
MaxRows: int64(c.cfg.MaxRows),
},
CanDecompressLZ4Result_: &c.cfg.UseLz4Compression,
Parameters: convertNamedValuesToSparkParams(args),
Parameters: parameters,
}

if c.cfg.UseArrowBatches {
Expand Down
8 changes: 7 additions & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,17 @@ Use the driverctx package under driverctx/ctx.go to add callbacks to the query c
Passing parameters to a query is supported when run against servers with version DBR 14.1.
// Named parameters:
p := dbsql.Parameter{Name: "p_bool", Value: true},
rows, err1 := db.QueryContext(ctx, `select * from sometable where condition=:p_bool`,dbsql.Parameter{Name: "p_bool", Value: true})
rows, err := db.QueryContext(ctx, `select * from sometable where condition=:p_bool`,dbsql.Parameter{Name: "p_bool", Value: true})
// Positional parameters - both `dbsql.Parameter` and plain values can be used:
rows, err := db.Query(`select *, ? from sometable where field=?`,dbsql.Parameter{Value: "123.456"}, "another parameter")
For complex types, you can specify the SQL type using the dbsql.Parameter type field. If this field is set, the value field MUST be set to a string.
Please note that named and positional parameters cannot be used together in the single query.
# Staging Ingestion
The Go driver now supports staging operations. In order to use a staging operation, you first must update the context with a list of folders that you are allowing the driver to access.
Expand Down
11 changes: 6 additions & 5 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ import (
// Error messages
const (
// Driver errors
ErrNotImplemented = "not implemented"
ErrTransactionsNotSupported = "transactions are not supported"
ErrReadQueryStatus = "could not read query status"
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"
ErrParametersNotSupported = "query parameters are not supported by this server"
ErrNotImplemented = "not implemented"
ErrTransactionsNotSupported = "transactions are not supported"
ErrReadQueryStatus = "could not read query status"
ErrSentinelTimeout = "sentinel timed out waiting for operation to complete"
ErrParametersNotSupported = "query parameters are not supported by this server"
ErrMixedNamedAndPositionalParameters = "named and positional parameters cannot be used simultaneously"

// Request error messages (connection, authentication, network error)
ErrCloseConnection = "failed to close connection"
Expand Down
99 changes: 70 additions & 29 deletions examples/parameters/main.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package main

import (
"context"
"database/sql"
"fmt"
"log"
Expand All @@ -12,6 +11,74 @@ import (
"github.com/joho/godotenv"
)

func queryWithNamedParameters(db *sql.DB) {
var p_bool bool
var p_int int
var p_double float64
var p_float float32
var p_date string

err := db.QueryRow(`
SELECT
:p_bool AS col_bool,
:p_int AS col_int,
:p_double AS col_double,
:p_float AS col_float,
:p_date AS col_date
`,
dbsql.Parameter{Name: "p_bool", Value: true},
dbsql.Parameter{Name: "p_int", Value: int(1234)},
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"},
).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err != nil {
if err == sql.ErrNoRows {
fmt.Println("not found")
return
} else {
fmt.Printf("err: %v\n", err)
}
} else {
fmt.Println(p_bool, p_int, p_double, p_float, p_date)
}
}

func queryWithPositionalParameters(db *sql.DB) {
var p_bool bool
var p_int int
var p_double float64
var p_float float32
var p_date string

err := db.QueryRow(`
SELECT
? AS col_bool,
? AS col_int,
? AS col_double,
? AS col_float,
? AS col_date
`,
true,
int(1234),
"3.14",
dbsql.Parameter{Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.Parameter{Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"},
).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err != nil {
if err == sql.ErrNoRows {
fmt.Println("not found")
return
} else {
fmt.Printf("err: %v\n", err)
}
} else {
fmt.Println(p_bool, p_int, p_double, p_float, p_date)
}
}

func main() {
// Opening a driver typically will not attempt to connect to the database.
err := godotenv.Load()
Expand All @@ -36,33 +103,7 @@ func main() {
}
db := sql.OpenDB(connector)
defer db.Close()
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
// defer cancel()
ctx := context.Background()
var p_bool bool
var p_int int
var p_double float64
var p_float float32
var p_date string
err1 := db.QueryRowContext(ctx, `SELECT
:p_bool AS col_bool,
:p_int AS col_int,
:p_double AS col_double,
:p_float AS col_float,
:p_date AS col_date`,
dbsql.Parameter{Name: "p_bool", Value: true},
dbsql.Parameter{Name: "p_int", Value: int(1234)},
dbsql.Parameter{Name: "p_double", Type: dbsql.SqlDouble, Value: "3.14"},
dbsql.Parameter{Name: "p_float", Type: dbsql.SqlFloat, Value: "3.14"},
dbsql.Parameter{Name: "p_date", Type: dbsql.SqlDate, Value: "2017-07-23 00:00:00"}).Scan(&p_bool, &p_int, &p_double, &p_float, &p_date)

if err1 != nil {
if err1 == sql.ErrNoRows {
fmt.Println("not found")
return
} else {
fmt.Printf("err: %v\n", err1)
}
}

queryWithNamedParameters(db)
queryWithPositionalParameters(db)
}
53 changes: 43 additions & 10 deletions parameter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package dbsql

import (
"database/sql/driver"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/stretchr/testify/require"
"strconv"
"testing"
"time"
Expand All @@ -21,7 +23,7 @@ func TestParameter_Inference(t *testing.T) {
{Name: "", Value: nil},
{Name: "", Value: Parameter{Value: float64Ptr(6.2), Type: SqlUnkown}},
}
parameters := convertNamedValuesToSparkParams(values[:])
parameters, _ := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, strconv.FormatFloat(float64(5.1), 'f', -1, 64), *parameters[0].Value.StringValue)
assert.NotNil(t, parameters[1].Value.StringValue)
assert.Equal(t, string("TIMESTAMP"), *parameters[1].Type)
Expand All @@ -34,14 +36,45 @@ func TestParameter_Inference(t *testing.T) {
assert.Equal(t, &cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, parameters[6].Value)
})
}
func TestParameters_Names(t *testing.T) {
t.Run("Should infer types correctly", func(t *testing.T) {
values := [2]driver.NamedValue{{Name: "1", Value: int(26)}, {Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}}}
parameters := convertNamedValuesToSparkParams(values[:])
assert.Equal(t, string("1"), *parameters[0].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
assert.Equal(t, string("2"), *parameters[1].Name)
assert.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
assert.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)

func TestParameters_ConvertToSpark(t *testing.T) {
t.Run("Should convert names parameters", func(t *testing.T) {
values := [2]driver.NamedValue{
{Name: "1", Value: int(26)},
{Name: "", Value: Parameter{Name: "2", Type: SqlDecimal, Value: "6.2"}},
}
parameters, err := convertNamedValuesToSparkParams(values[:])
require.NoError(t, err)
require.Equal(t, string("1"), *parameters[0].Name)
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
require.Equal(t, string("2"), *parameters[1].Name)
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
require.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
})

t.Run("Should convert positional parameters", func(t *testing.T) {
values := [2]driver.NamedValue{
{Value: int(26)},
{Name: "", Value: Parameter{Type: SqlDecimal, Value: "6.2"}},
}
parameters, err := convertNamedValuesToSparkParams(values[:])
require.NoError(t, err)
require.Nil(t, parameters[0].Name)
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("26")}, *parameters[0].Value)
require.Nil(t, parameters[1].Name)
require.Equal(t, cli_service.TSparkParameterValue{StringValue: strPtr("6.2")}, *parameters[1].Value)
require.Equal(t, string("DECIMAL(2,1)"), *parameters[1].Type)
})

t.Run("Should error out when named and positional parameters are mixed", func(t *testing.T) {
values := [4]driver.NamedValue{
{Name: "a", Value: int(26)},
{Name: "", Value: Parameter{Type: SqlDecimal, Value: "6.2"}},
{Value: "test"},
{Name: "b", Value: Parameter{Type: SqlDouble, Value: 123.456}},
}
_, err := convertNamedValuesToSparkParams(values[:])
require.Error(t, err)
require.Equal(t, err.Error(), dbsqlerr.ErrMixedNamedAndPositionalParameters)
})
}
26 changes: 23 additions & 3 deletions parameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"strings"
"time"

dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/pkg/errors"
)

type Parameter struct {
Expand Down Expand Up @@ -162,10 +164,14 @@ func inferType(param *Parameter) {
}
}

func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.TSparkParameter {
func convertNamedValuesToSparkParams(values []driver.NamedValue) ([]*cli_service.TSparkParameter, error) {
var sparkParams []*cli_service.TSparkParameter

sqlParams := valuesToParameters(values)

hasNamedParams := false
hasPositionalParams := false

inferTypes(sqlParams)
for i := range sqlParams {
sqlParam := sqlParams[i]
Expand All @@ -183,10 +189,24 @@ func convertNamedValuesToSparkParams(values []driver.NamedValue) []*cli_service.
} else {
sparkParamType = sqlParam.Type.String()
}
sparkParam := cli_service.TSparkParameter{Name: &sqlParam.Name, Type: &sparkParamType, Value: sparkValue}

var sparkParamName *string
if sqlParam.Name != "" {
sparkParamName = &sqlParam.Name
hasNamedParams = true
} else {
sparkParamName = nil
hasPositionalParams = true
}

if hasNamedParams && hasPositionalParams {
return nil, errors.New(dbsqlerr.ErrMixedNamedAndPositionalParameters)
}

sparkParam := cli_service.TSparkParameter{Name: sparkParamName, Type: &sparkParamType, Value: sparkValue}
sparkParams = append(sparkParams, &sparkParam)
}
return sparkParams
return sparkParams, nil
}

func inferDecimalType(d string) (t string) {
Expand Down

0 comments on commit 1e9d6ac

Please # to comment.