Skip to content

Commit b3a8513

Browse files
fix: support nullable types for bulkcopy (#192)
* fix: support nullable types for bulkcopy * Add test cases for all nullable types * Fix test cases * Add bulkcopy test for invalid nullable types * Add case in convertInputParameter to bypass uniqueidentifier type * Add test cases for invalid nullable test * Revert bypass change
1 parent ada30cb commit b3a8513

File tree

2 files changed

+151
-0
lines changed

2 files changed

+151
-0
lines changed

bulkcopy.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package mssql
33
import (
44
"bytes"
55
"context"
6+
"database/sql/driver"
67
"encoding/binary"
78
"fmt"
89
"math"
@@ -318,6 +319,19 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error)
318319
res.ti.Size = col.ti.Size
319320
res.ti.TypeId = col.ti.TypeId
320321

322+
switch valuer := val.(type) {
323+
case driver.Valuer:
324+
var e error
325+
val, e = driver.DefaultParameterConverter.ConvertValue(valuer)
326+
if e != nil {
327+
err = e
328+
return
329+
}
330+
if val != nil {
331+
return b.makeParam(val, col)
332+
}
333+
}
334+
321335
if val == nil {
322336
res.ti.Size = 0
323337
return

bulkcopy_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,103 @@ import (
1414
"time"
1515
)
1616

17+
func TestBulkcopyWithInvalidNullableType(t *testing.T) {
18+
// Arrange
19+
tableName := "#table_test"
20+
columns := []string{
21+
"test_nullfloat",
22+
"test_nullstring",
23+
"test_nullbyte",
24+
"test_nullbool",
25+
"test_nullint64",
26+
"test_nullint32",
27+
"test_nullint16",
28+
"test_nulltime",
29+
"test_nulluniqueidentifier",
30+
}
31+
values := []interface{}{
32+
sql.NullFloat64{Valid: false},
33+
sql.NullString{Valid: false},
34+
sql.NullByte{Valid: false},
35+
sql.NullBool{Valid: false},
36+
sql.NullInt64{Valid: false},
37+
sql.NullInt32{Valid: false},
38+
sql.NullInt16{Valid: false},
39+
sql.NullTime{Valid: false},
40+
NullUniqueIdentifier{Valid: false},
41+
}
42+
43+
pool, logger := open(t)
44+
defer pool.Close()
45+
defer logger.StopLogging()
46+
47+
ctx, cancel := context.WithCancel(context.Background())
48+
defer cancel()
49+
50+
conn, err := pool.Conn(ctx)
51+
if err != nil {
52+
t.Fatal("failed to pull connection from pool", err)
53+
}
54+
defer conn.Close()
55+
56+
err = setupNullableTypeTable(ctx, t, conn, tableName)
57+
if err != nil {
58+
t.Error("Setup table failed: ", err)
59+
return
60+
}
61+
62+
stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))
63+
if err != nil {
64+
t.Fatal(err)
65+
}
66+
defer stmt.Close()
67+
68+
_, err = stmt.Exec(values...)
69+
if err != nil {
70+
t.Fatal("AddRow failed: ", err.Error())
71+
}
72+
73+
result, err := stmt.Exec()
74+
if err != nil {
75+
t.Fatal("bulkcopy failed: ", err.Error())
76+
}
77+
78+
insertedRowCount, _ := result.RowsAffected()
79+
if insertedRowCount == 0 {
80+
t.Fatal("0 row inserted!")
81+
}
82+
83+
//data verification
84+
rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName)
85+
if err != nil {
86+
t.Fatal(err)
87+
}
88+
defer rows.Close()
89+
for rows.Next() {
90+
91+
ptrs := make([]interface{}, len(columns))
92+
container := make([]interface{}, len(columns))
93+
for i := range ptrs {
94+
ptrs[i] = &container[i]
95+
}
96+
if err := rows.Scan(ptrs...); err != nil {
97+
t.Fatal(err)
98+
}
99+
for i, c := range columns {
100+
if !compareValue(container[i], nil) {
101+
v := container[i]
102+
if s, ok := v.([]uint8); ok {
103+
v = string(s)
104+
}
105+
t.Errorf("columns %s : expected: %T %v, got: %T %v\n", c, nil, nil, container[i], v)
106+
}
107+
}
108+
}
109+
if err := rows.Err(); err != nil {
110+
t.Error(err)
111+
}
112+
}
113+
17114
func TestBulkcopy(t *testing.T) {
18115
// TDS level Bulk Insert is not supported on Azure SQL Server.
19116
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
@@ -69,6 +166,14 @@ func TestBulkcopy(t *testing.T) {
69166
{"test_geom", geom, string(geom)},
70167
{"test_uniqueidentifier", uid, string(uid)},
71168
{"test_nulluniqueidentifier", nil, nil},
169+
{"test_nullfloat", sql.NullFloat64{64, true}, 64.0},
170+
{"test_nullstring", sql.NullString{"abcdefg", true}, "abcdefg"},
171+
{"test_nullbyte", sql.NullByte{0x01, true}, 1},
172+
{"test_nullbool", sql.NullBool{true, true}, true},
173+
{"test_nullint64", sql.NullInt64{9223372036854775807, true}, 9223372036854775807},
174+
{"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647},
175+
{"test_nullint16", sql.NullInt16{32767, true}, 32767},
176+
{"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
72177
// {"test_smallmoney", 1234.56, nil},
73178
// {"test_money", 1234.56, nil},
74179
{"test_decimal_18_0", 1234.0001, "1234"},
@@ -223,6 +328,30 @@ func compareValue(a interface{}, expected interface{}) bool {
223328
}
224329
}
225330

331+
func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
332+
tablesql := `CREATE TABLE ` + tableName + ` (
333+
[id] [int] IDENTITY(1,1) NOT NULL,
334+
[test_nullfloat] [float] NULL,
335+
[test_nullstring] [nvarchar](50) NULL,
336+
[test_nullbyte] [tinyint] NULL,
337+
[test_nullbool] [bit] NULL,
338+
[test_nullint64] [bigint] NULL,
339+
[test_nullint32] [int] NULL,
340+
[test_nullint16] [smallint] NULL,
341+
[test_nulltime] [datetime] NULL,
342+
[test_nulluniqueidentifier] [uniqueidentifier] NULL,
343+
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
344+
(
345+
[id] ASC
346+
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
347+
) ON [PRIMARY];`
348+
_, err = conn.ExecContext(ctx, tablesql)
349+
if err != nil {
350+
t.Fatal("tablesql failed:", err)
351+
}
352+
return
353+
}
354+
226355
func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
227356
tablesql := `CREATE TABLE ` + tableName + ` (
228357
[id] [int] IDENTITY(1,1) NOT NULL,
@@ -290,6 +419,14 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str
290419
[test_int16nvarchar] [varchar](4) NULL,
291420
[test_int8nvarchar] [varchar](3) NULL,
292421
[test_intnvarchar] [varchar](4) NULL,
422+
[test_nullfloat] [float] NULL,
423+
[test_nullstring] [nvarchar](50) NULL,
424+
[test_nullbyte] [tinyint] NULL,
425+
[test_nullbool] [bit] NULL,
426+
[test_nullint64] [bigint] NULL,
427+
[test_nullint32] [int] NULL,
428+
[test_nullint16] [smallint] NULL,
429+
[test_nulltime] [datetime] NULL,
293430
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
294431
(
295432
[id] ASC

0 commit comments

Comments
 (0)