diff --git a/go/adbc/driver/snowflake/connection.go b/go/adbc/driver/snowflake/connection.go index 5abf90476c..973fdaa291 100644 --- a/go/adbc/driver/snowflake/connection.go +++ b/go/adbc/driver/snowflake/connection.go @@ -314,22 +314,22 @@ func (c *connectionImpl) GetObjects(ctx context.Context, depth adbc.ObjectDepth, // just as done by the snowflake JDBC driver. In those cases we don't need to propagate // the current session database/schema. if depth == adbc.ObjectDepthColumns || depth == adbc.ObjectDepthAll { - // the connection that is used is not the same connection context where the database may have been set - // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate - if !isNilOrEmpty(catalog) { - _, e := conn.ExecContext(context.Background(), fmt.Sprintf("USE DATABASE %s;", quoteTblName(*catalog)), nil) - if e != nil { - return nil, errToAdbcErr(adbc.StatusIO, e) - } + dbname, err := c.GetCurrentCatalog() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } - // the connection that is used is not the same connection context where the schema may have been set - // if the caller called SetCurrentDbSchema() so need to ensure the schema context is appropriate - if !isNilOrEmpty(dbSchema) { - _, e2 := conn.ExecContext(context.Background(), fmt.Sprintf("USE SCHEMA %s;", quoteTblName(*dbSchema)), nil) - if e2 != nil { - return nil, errToAdbcErr(adbc.StatusIO, e2) - } + schemaname, err := c.GetCurrentDbSchema() + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) + } + + // the connection that is used is not the same connection context where the database may have been set + // if the caller called SetCurrentCatalog() so need to ensure the database context is appropriate + multiCtx, _ := gosnowflake.WithMultiStatement(ctx, 2) + _, err = conn.ExecContext(multiCtx, fmt.Sprintf("USE DATABASE %s; USE SCHEMA %s;", quoteTblName(dbname), quoteTblName(schemaname))) + if err != nil { + return nil, errToAdbcErr(adbc.StatusIO, err) } } diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index 895015ffd7..c67389ca14 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -1215,15 +1215,15 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() { [ { "col_int64": 1, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key1\",\n \"value\": 1\n }\n ]\n}" + "col_map": "{\n \"key1\": 1\n}" }, { "col_int64": 2, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key2\",\n \"value\": 2\n }\n ]\n}" + "col_map": "{\n \"key2\": 2\n}" }, { "col_int64": 3, - "col_map": "{\n \"key_value\": [\n {\n \"key\": \"key3\",\n \"value\": 3\n }\n ]\n}" + "col_map": "{\n \"key3\": 3\n}" } ] `))) @@ -2161,6 +2161,9 @@ func (suite *SnowflakeTests) TestGetSetClientConfigFile() { func (suite *SnowflakeTests) TestGetObjectsWithNilCatalog() { // this test demonstrates calling GetObjects with the catalog depth and a nil catalog - _, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) + rdr, err := suite.cnxn.GetObjects(suite.ctx, adbc.ObjectDepthCatalogs, nil, nil, nil, nil, nil) suite.NoError(err) + // test suite validates memory allocator so we need to make sure we call + // release on the result reader + rdr.Release() }