Skip to content

Commit a441ab1

Browse files
shueybubblesadrianosela
authored andcommitted
Fix: Support nullable types in Always Encrypted (microsoft#179)
* preserve type information for Valuer parameters * support uniqueidentifier in AE * update readme
1 parent dea2c61 commit a441ab1

File tree

6 files changed

+90
-6
lines changed

6 files changed

+90
-6
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,8 @@ If the correct key provider is included in your application, decryption of encry
427427

428428
Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`.
429429

430-
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`.
430+
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values.
431431
https://github.com/microsoft/go-mssqldb/issues/129
432-
https://github.com/microsoft/go-mssqldb/issues/130
433432
434433
435434
### Local certificate AE key provider

alwaysencrypted_test.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"crypto/rand"
66
"database/sql"
7+
"database/sql/driver"
78
"fmt"
89
"math/big"
910
"strings"
@@ -65,8 +66,10 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
6566
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionRandomized, dt},
6667
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
6768
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
68-
// TODO: The driver throws away type information about Valuer implementations and sends nil as nvarchar(1). Fix that.
69-
// {"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
69+
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
70+
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
71+
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
72+
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
7073
}
7174
for _, test := range providerTests {
7275
// turn off key caching
@@ -230,13 +233,19 @@ func comparisonValueFromObject(object interface{}) string {
230233
case time.Time:
231234
return civil.DateTimeOf(v).String()
232235
//return v.Format(time.RFC3339)
233-
case fmt.Stringer:
234-
return v.String()
235236
case bool:
236237
if v == true {
237238
return "1"
238239
}
239240
return "0"
241+
case driver.Valuer:
242+
val, _ := v.Value()
243+
if val == nil {
244+
return "<nil>"
245+
}
246+
return comparisonValueFromObject(val)
247+
case fmt.Stringer:
248+
return v.String()
240249
default:
241250
return fmt.Sprintf("%v", v)
242251
}

mssql.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -982,7 +982,37 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
982982
res.ti.Size = 0
983983
return
984984
}
985+
switch valuer := val.(type) {
986+
case UniqueIdentifier:
987+
case NullUniqueIdentifier:
988+
default:
989+
break
990+
case driver.Valuer:
991+
// If the value has a non-nil value, call MakeParam on its Value
992+
val, e := driver.DefaultParameterConverter.ConvertValue(valuer)
993+
if e != nil {
994+
err = e
995+
return
996+
}
997+
if val != nil {
998+
return s.makeParam(val)
999+
}
1000+
}
9851001
switch val := val.(type) {
1002+
case UniqueIdentifier:
1003+
res.ti.TypeId = typeGuid
1004+
res.ti.Size = 16
1005+
guid, _ := val.Value()
1006+
res.buffer = guid.([]byte)
1007+
case NullUniqueIdentifier:
1008+
res.ti.TypeId = typeGuid
1009+
res.ti.Size = 16
1010+
if val.Valid {
1011+
guid, _ := val.Value()
1012+
res.buffer = guid.([]byte)
1013+
} else {
1014+
res.buffer = []byte{}
1015+
}
9861016
case int:
9871017
res.ti.TypeId = typeIntN
9881018
// Rather than guess if the caller intends to pass a 32bit int from a 64bit app based on the
@@ -1021,6 +1051,10 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
10211051
res.ti.TypeId = typeIntN
10221052
res.ti.Size = 8
10231053
res.buffer = []byte{}
1054+
case sql.NullInt32:
1055+
res.ti.TypeId = typeIntN
1056+
res.ti.Size = 4
1057+
res.buffer = []byte{}
10241058
case byte:
10251059
res.ti.TypeId = typeIntN
10261060
res.buffer = []byte{val}

mssql_go19.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ func convertInputParameter(val interface{}) (interface{}, error) {
7575
// return nil
7676
case float32:
7777
return val, nil
78+
case driver.Valuer:
79+
return val, nil
7880
default:
7981
return driver.DefaultParameterConverter.ConvertValue(v)
8082
}

queries_go19_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"reflect"
1212
"regexp"
13+
"strings"
1314
"testing"
1415
"time"
1516

@@ -31,6 +32,32 @@ func TestOutputParam(t *testing.T) {
3132
ctx, cancel := context.WithCancel(context.Background())
3233
defer cancel()
3334

35+
t.Run("varchar(max) to sql.NullString", func(t *testing.T) {
36+
sqltextcreate := `CREATE PROCEDURE [GetTask]
37+
@strparam varchar(max) = NULL OUTPUT
38+
AS
39+
SELECT @strparam = REPLICATE('a', 8000)
40+
RETURN 0`
41+
sqltextdrop := `drop procedure GetTask`
42+
sqltextrun := `GetTask`
43+
_, _ = db.ExecContext(ctx, sqltextdrop)
44+
_, err = db.ExecContext(ctx, sqltextcreate)
45+
if err != nil {
46+
t.Fatal(err)
47+
}
48+
defer db.ExecContext(ctx, sqltextdrop)
49+
nullstr := sql.NullString{}
50+
_, err := db.ExecContext(ctx, sqltextrun,
51+
sql.Named("strparam", sql.Out{Dest: &nullstr}),
52+
)
53+
if err != nil {
54+
t.Error(err)
55+
}
56+
defer db.ExecContext(ctx, sqltextdrop)
57+
if nullstr.String != strings.Repeat("a", 8000) {
58+
t.Error("Got incorrect NullString of length:", len(nullstr.String))
59+
}
60+
})
3461
t.Run("sp with rows", func(t *testing.T) {
3562
sqltextcreate := `
3663
CREATE PROCEDURE spwithrows

queries_test.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,19 @@ func TestSelect(t *testing.T) {
198198
}
199199
})
200200
})
201+
t.Run("scan into sql.NullString", func(t *testing.T) {
202+
row := conn.QueryRow("SELECT REPLICATE('a', 8000)")
203+
var out sql.NullString
204+
err := row.Scan(&out)
205+
if err != nil {
206+
t.Error("Scan to NullString failed", err.Error())
207+
return
208+
}
209+
210+
if out.String != strings.Repeat("a", 8000) {
211+
t.Error("got back a string with count:", len(out.String))
212+
}
213+
})
201214
}
202215

203216
func TestSelectDateTimeOffset(t *testing.T) {

0 commit comments

Comments
 (0)