Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

BulkLoad: Fix some types #239

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 76 additions & 3 deletions bulkcopy.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ func (b *Bulk) sendBulkCommand(ctx context.Context) (err error) {
}
}
if bulkCol != nil {
// Note that for INSERT BULK operations, XMLTYPE is to be sent as NVARCHAR(N) or NVARCHAR(MAX) data type.
// An error is produced if XMLTYPE is specified.
//
// https://learn.microsoft.com/openspecs/windows_protocols/ms-tds/ab4a7d62-cd1f-4db1-b67d-ecae58f493e3
if bulkCol.ti.TypeId == typeXml {
bulkCol.ti.TypeId = typeNVarChar
}

if bulkCol.ti.TypeId == typeUdt {
//send udt as binary
Expand Down Expand Up @@ -535,7 +542,74 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
err = fmt.Errorf("mssql: invalid type for time column: %T %s", val, val)
return
}
// case typeMoney, typeMoney4, typeMoneyN:
case typeMoney, typeMoney4, typeMoneyN:
var intvalue int64

string2Int64 := func(str string) (int64, error) {
// Split on decimal point
parts := strings.Split(str, ".")
if len(parts) > 2 {
return 0, fmt.Errorf("invalid money format")
}

// Handle the decimal places
if len(parts) == 2 {
// Pad or truncate decimal places to exactly 4 digits
decimal := parts[1]
if len(decimal) > 4 {
decimal = decimal[:4] // truncate to 4 decimal places
} else {
decimal = decimal + strings.Repeat("0", 4-len(decimal)) // pad with zeros
}
str = parts[0] + decimal
} else {
// No decimal point, append 4 zeros
str = str + "0000"
}

return strconv.ParseInt(str, 10, 64)
}

switch val := val.(type) {
case int:
intvalue = int64(val)
case int64:
intvalue = val
case []byte:
intvalue, err = string2Int64(string(val))
if err != nil {
return res, fmt.Errorf("mssql: invalid money string format: %s", string(val))
}
case string:
intvalue, err = string2Int64(val)
if err != nil {
return res, fmt.Errorf("mssql: invalid money string format: %s", val)
}
default:
err = fmt.Errorf("mssql: invalid type for money column: %T %s", val, val)
return
}

res.buffer = make([]byte, res.ti.Size)

// smallmoney is a 4-byte integer stored as value * 10^4.
// money is an 8-byte integer stored as value * 10^4.
//
// https://learn.microsoft.com/openspecs/windows_protocols/ms-tds/1266679d-cd6e-492a-b2b2-3a9ba004196d
switch col.ti.Size {
case 4:
binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
case 8:
// The 8-byte signed integer is represented in the following sequence:
// - One 4-byte integer that represents the more significant half.
// - One 4-byte integer that represents the less significant half.
//
// https://learn.microsoft.com/openspecs/windows_protocols/ms-tds/1266679d-cd6e-492a-b2b2-3a9ba004196d
binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(intvalue>>32))
binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(intvalue&0xFFFFFFFF))
default:
err = fmt.Errorf("mssql: invalid size of column %d", col.ti.Size)
}
case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
prec := col.ti.Prec
scale := col.ti.Scale
Expand Down Expand Up @@ -599,7 +673,7 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
buf[i] = ub[j]
}
res.buffer = buf
case typeBigVarBin, typeBigBinary:
case typeBigVarBin, typeBigBinary, typeImage:
switch val := val.(type) {
case []byte:
res.ti.Size = len(val)
Expand All @@ -617,7 +691,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
err = fmt.Errorf("mssql: invalid type for Guid column: %T %s", val, val)
return
}

default:
err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
}
Expand Down
39 changes: 32 additions & 7 deletions bulkcopy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ func testBulkcopy(t *testing.T, guidConversion bool) {
{"test_nchar", "abcdefg ", nil},
{"test_text", "abcdefg", nil},
{"test_ntext", "abcdefg", nil},
{"test_textn", nil, nil},
{"test_ntextn", nil, nil},
{"test_float", 1234.56, nil},
{"test_floatn", 1234.56, nil},
{"test_real", 1234.56, nil},
Expand Down Expand Up @@ -174,8 +176,10 @@ func testBulkcopy(t *testing.T, guidConversion bool) {
{"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647},
{"test_nullint16", sql.NullInt16{32767, true}, 32767},
{"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
// {"test_smallmoney", 1234.56, nil},
// {"test_money", 1234.56, nil},
{"test_smallmoney", []byte("1234.5600"), nil},
{"test_smallmoneyn", nil, nil},
{"test_money", []byte("1234.5600"), nil},
{"test_moneyn", nil, nil},
{"test_decimal_18_0", 1234.0001, "1234"},
{"test_decimal_9_2", -1234.560001, "-1234.56"},
{"test_decimal_20_0", 1234, "1234"},
Expand All @@ -187,12 +191,21 @@ func testBulkcopy(t *testing.T, guidConversion bool) {
{"test_varbinary_max", bin, nil},
{"test_binary", []byte("1"), nil},
{"test_binary_16", bin, nil},
{"test_varbinaryn", nil, nil},
{"test_varbinary_16n", nil, nil},
{"test_varbinary_maxn", nil, nil},
{"test_binaryn", nil, nil},
{"test_binary_16n", nil, nil},
{"test_intvarchar", 1234, "1234"},
{"test_int64nvarchar", int64(123456), "123456"},
{"test_int32nvarchar", int32(12345), "12345"},
{"test_int16nvarchar", int16(1234), "1234"},
{"test_int8nvarchar", int8(12), "12"},
{"test_intnvarchar", 1234, "1234"},
{"test_image", []byte("1"), nil},
{"test_imagen", nil, nil},
{"test_xml", "<root><child>value</child></root>", nil},
{"test_xmln", nil, nil},
}

columns := make([]string, len(testValues))
Expand Down Expand Up @@ -373,8 +386,10 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_varchar_max_nil] [varchar](max) NULL,
[test_char] [char](10) NULL,
[test_nchar] [nchar](10) NULL,
[test_text] [text] NULL,
[test_ntext] [ntext] NULL,
[test_text] [text] NOT NULL,
[test_ntext] [ntext] NOT NULL,
[test_textn] [text] NULL,
[test_ntextn] [ntext] NULL,
[test_float] [float] NOT NULL,
[test_floatn] [float] NULL,
[test_real] [real] NULL,
Expand All @@ -394,8 +409,10 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_date_2] [date] NULL,
[test_time] [time](7) NULL,
[test_time_2] [time](7) NULL,
[test_smallmoney] [smallmoney] NULL,
[test_money] [money] NULL,
[test_smallmoney] [smallmoney] NOT NULL,
[test_smallmoneyn] [smallmoney] NULL,
[test_money] [money] NOT NULL,
[test_moneyn] [money] NULL,
[test_tinyint] [tinyint] NULL,
[test_smallint] [smallint] NOT NULL,
[test_smallintn] [smallint] NULL,
Expand All @@ -406,7 +423,8 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_intf32] [int] NULL,
[test_geom] [geometry] NULL,
[test_geog] [geography] NULL,
[text_xml] [xml] NULL,
[test_xml] [xml] NOT NULL,
[test_xmln] [xml] NULL,
[test_uniqueidentifier] [uniqueidentifier] NULL,
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
[test_decimal_18_0] [decimal](18, 0) NULL,
Expand All @@ -421,6 +439,11 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_varbinary_max] VARBINARY(max) NOT NULL,
[test_binary] BINARY NOT NULL,
[test_binary_16] BINARY(16) NOT NULL,
[test_varbinaryn] VARBINARY NULL,
[test_varbinary_16n] VARBINARY(16) NULL,
[test_varbinary_maxn] VARBINARY(max) NULL,
[test_binaryn] BINARY NULL,
[test_binary_16n] BINARY(16) NULL,
[test_intvarchar] [varchar](4) NULL,
[test_int64nvarchar] [varchar](6) NULL,
[test_int32nvarchar] [varchar](5) NULL,
Expand All @@ -435,6 +458,8 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
[test_nullint32] [int] NULL,
[test_nullint16] [smallint] NULL,
[test_nulltime] [datetime] NULL,
[test_image] [image] NOT NULL,
[test_imagen] [image] NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
Expand Down
29 changes: 27 additions & 2 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,16 @@ func writeVarLen(w io.Writer, ti *typeInfo, out bool, encoding msdsn.EncodeParam
if err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)); err != nil {
return
}
if err = writeCollation(w, ti.Collation); err != nil {
return

// COLLATION occurs only if the type is BIGCHARTYPE, BIGVARCHARTYPE, TEXTTYPE, NTEXTTYPE,
// NCHARTYPE, or NVARCHARTYPE.
//
// https://learn.microsoft.com/openspecs/windows_protocols/ms-tds/cbe9c510-eae6-4b1f-9893-a098944d430a
switch ti.TypeId {
case typeText, typeNText:
if err = writeCollation(w, ti.Collation); err != nil {
return
}
}
ti.Writer = writeLongLenType
default:
Expand Down Expand Up @@ -578,6 +586,21 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{}
panic("shoulnd't get here")
}
func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
if buf == nil {
// According to the documentation, we MUST NOT specify the text pointer and timestamp when the value is NULL.
Copy link
Collaborator

@shueybubbles shueybubbles Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be worthwhile to have a non-bulkinsert test case for this? #Resolved

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shueybubbles I don’t see any insert tests anywhere…
Besides, this code won’t run in non-bulkinsert scenario because:

func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
	if val == nil {
		res.ti.TypeId = typeNull
		res.buffer = nil
		res.ti.Size = 0
		return
	}
        ...

And then

func writeTypeInfo(w io.Writer, ti *typeInfo, out bool, encoding msdsn.EncodeParameters) (err error) {
	err = binary.Write(w, binary.LittleEndian, ti.TypeId)
	if err != nil {
		return
	}
	switch ti.TypeId {
	case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
		typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
		// those are fixed length
		// https://msdn.microsoft.com/en-us/library/dd341171.aspx
		ti.Writer = writeFixedType
	...

//
// https://learn.microsoft.com/openspecs/windows_protocols/ms-tds/3840ef93-3b10-4aca-9fd1-a210b8bb6d0c
//
// However, this approach fails with the error:
// "Expected the text length in data stream for bulk copy of text, ntext, or image data."
//
// But we can insert NULL successfully by setting the text pointer length to zero
// (without writing any additional bytes).
// Since there's no clear way to follow the documentation exactly, let's use this solution.
err = binary.Write(w, binary.LittleEndian, byte(0x00))
return
}

//textptr
err = binary.Write(w, binary.LittleEndian, byte(0x10))
if err != nil {
Expand Down Expand Up @@ -1318,6 +1341,8 @@ func makeDecl(ti typeInfo) string {
return "ntext"
case typeUdt:
return ti.UdtInfo.TypeName
case typeImage:
return "image"
case typeGuid:
return "uniqueidentifier"
case typeTvp:
Expand Down
Loading