Skip to content

Commit 9372c38

Browse files
authored
feat(interceptor): add BegixTx callback (#28)
1 parent 4bfec04 commit 9372c38

File tree

4 files changed

+53
-14
lines changed

4 files changed

+53
-14
lines changed

Makefile

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
11
.POSIX:
22
.SUFFIXES:
33

4-
fmt:
5-
@golangci-lint fmt
4+
bench:
5+
@go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem
66

7-
gen:
8-
@go generate ./...
7+
clean:
8+
@rm -rf tests/coverdata tests/coverage.out tests/test.sqlite
99

1010
deps:
1111
@go mod tidy
1212
@cd tests && go mod tidy
1313

14+
fmt:
15+
@golangci-lint fmt
16+
17+
gen:
18+
@go generate ./...
19+
1420
lint:
1521
@golangci-lint run
1622

17-
test:
18-
@rm -rf tests/coverdata tests/coverage.out tests/test.sqlite && mkdir tests/coverdata
23+
test: test/unit test/integration
24+
25+
test/unit: clean
26+
@mkdir -p tests/coverdata
1927
@go test -race -shuffle=on -cover . -args -test.gocoverdir=$$PWD/tests/coverdata
28+
29+
test/integration: clean
30+
@mkdir -p tests/coverdata
2031
@$(CONTAINER_RUNNER) compose --file=tests/compose.yaml up --detach
2132
@go test -v -race -coverpkg=go-simpler.org/queries ./tests -args -test.gocoverdir=$$PWD/tests/coverdata
2233
@$(CONTAINER_RUNNER) compose --file=tests/compose.yaml down
2334
@go tool covdata textfmt -i=tests/coverdata -o=tests/coverage.out
2435

2536
test/cover: test
2637
@go tool cover -html=tests/coverage.out
27-
28-
bench:
29-
@go test -run='^$$' -bench=. -cpuprofile=profile.cpu -memprofile=profile.mem

interceptor.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ var (
2626
// Otherwise, it prepares a [driver.Stmt] using [driver.ConnPrepareContext], executes it, and closes it.
2727
// In such cases, you may want to implement both the PrepareContext and ExecContext/QueryContext callbacks,
2828
// even if you don't prepare statements manually via [sql.DB.PrepareContext].
29-
// TODO: provide an example of such an implementation.
3029
//
3130
// [go-sql-driver/mysql]: https://github.com/go-sql-driver/mysql
3231
type Interceptor struct {
@@ -49,6 +48,10 @@ type Interceptor struct {
4948
// PrepareContext is a callback for [sql.DB.PrepareContext] and [sql.Tx.PrepareContext].
5049
// The implementation must call preparer.ConnPrepareContext(ctx, query) and return the result.
5150
PrepareContext func(ctx context.Context, query string, preparer driver.ConnPrepareContext) (driver.Stmt, error)
51+
52+
// BeginTx is a callback for [sql.DB.BeginTx].
53+
// The implementation must call beginner.BeginTx(ctx, opts) and return the result.
54+
BeginTx func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error)
5255
}
5356

5457
// Open implements [driver.Driver].
@@ -134,6 +137,9 @@ func (c wrappedConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver
134137
if !ok {
135138
panic("queries: driver does not implement driver.ConnBeginTx")
136139
}
140+
if c.interceptor.BeginTx != nil {
141+
return c.interceptor.BeginTx(ctx, opts, beginner)
142+
}
137143
return beginner.BeginTx(ctx, opts)
138144
}
139145

interceptor_test.go

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ func TestInterceptor(t *testing.T) {
1818
var execCalled bool
1919
var queryCalled bool
2020
var prepareCalled bool
21+
var beginTxCalled bool
2122

2223
interceptor := queries.Interceptor{
2324
Driver: mockDriver{conn: spyConn{}},
@@ -33,9 +34,13 @@ func TestInterceptor(t *testing.T) {
3334
prepareCalled = true
3435
return preparer.PrepareContext(ctx, query)
3536
},
37+
BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) {
38+
beginTxCalled = true
39+
return beginner.BeginTx(ctx, opts)
40+
},
3641
}
3742

38-
driverName := t.Name() + "_interceptor"
43+
driverName := t.Name()
3944
sql.Register(driverName, interceptor)
4045

4146
db, err := sql.Open(driverName, "")
@@ -53,6 +58,10 @@ func TestInterceptor(t *testing.T) {
5358
_, err = db.PrepareContext(ctx, "")
5459
assert.IsErr[E](t, err, errCalled)
5560
assert.Equal[E](t, prepareCalled, true)
61+
62+
_, err = db.BeginTx(ctx, nil)
63+
assert.IsErr[E](t, err, errCalled)
64+
assert.Equal[E](t, beginTxCalled, true)
5665
}
5766

5867
func TestInterceptor_passthrough(t *testing.T) {
@@ -62,7 +71,7 @@ func TestInterceptor_passthrough(t *testing.T) {
6271
Driver: mockDriver{conn: spyConn{}},
6372
}
6473

65-
driverName := t.Name() + "_interceptor"
74+
driverName := t.Name()
6675
sql.Register(driverName, interceptor)
6776

6877
db, err := sql.Open(driverName, "")
@@ -77,6 +86,9 @@ func TestInterceptor_passthrough(t *testing.T) {
7786

7887
_, err = db.PrepareContext(ctx, "")
7988
assert.IsErr[E](t, err, errCalled)
89+
90+
_, err = db.BeginTx(ctx, nil)
91+
assert.IsErr[E](t, err, errCalled)
8092
}
8193

8294
func TestInterceptor_unimplemented(t *testing.T) {
@@ -86,7 +98,7 @@ func TestInterceptor_unimplemented(t *testing.T) {
8698
Driver: mockDriver{conn: unimplementedConn{}},
8799
}
88100

89-
driverName := t.Name() + "_interceptor"
101+
driverName := t.Name()
90102
sql.Register(driverName, interceptor)
91103

92104
db, err := sql.Open(driverName, "")
@@ -113,7 +125,7 @@ func TestInterceptor_driver(t *testing.T) {
113125
mdriver := mockDriver{}
114126
interceptor := queries.Interceptor{Driver: mdriver}
115127

116-
driverName := t.Name() + "_interceptor"
128+
driverName := t.Name()
117129
sql.Register(driverName, interceptor)
118130

119131
db, err := sql.Open(driverName, "")
@@ -148,3 +160,7 @@ func (spyConn) QueryContext(context.Context, string, []driver.NamedValue) (drive
148160
func (spyConn) PrepareContext(context.Context, string) (driver.Stmt, error) {
149161
return nil, errCalled
150162
}
163+
164+
func (spyConn) BeginTx(context.Context, driver.TxOptions) (driver.Tx, error) {
165+
return nil, errCalled
166+
}

tests/integration_test.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ func TestIntegration(t *testing.T) {
9999
var execCalls int
100100
var queryCalls int
101101
var prepareCalls int
102+
var beginTxCalls int
102103

103104
interceptor := queries.Interceptor{
104105
Driver: driverIface,
@@ -117,6 +118,11 @@ func TestIntegration(t *testing.T) {
117118
t.Logf("PrepareContext: %s", query)
118119
return preparer.PrepareContext(ctx, query)
119120
},
121+
BeginTx: func(ctx context.Context, opts driver.TxOptions, beginner driver.ConnBeginTx) (driver.Tx, error) {
122+
beginTxCalls++
123+
t.Log("BeginTx")
124+
return beginner.BeginTx(ctx, opts)
125+
},
120126
}
121127

122128
driverName += "+interceptor"
@@ -187,14 +193,17 @@ func TestIntegration(t *testing.T) {
187193
assert.Equal[E](t, execCalls, 3)
188194
assert.Equal[E](t, queryCalls, 5*2)
189195
assert.Equal[E](t, prepareCalls, 1)
196+
assert.Equal[E](t, beginTxCalls, 1)
190197
case *mssqldb.Driver: // always uses PrepareContext.
191198
assert.Equal[E](t, execCalls, 0)
192199
assert.Equal[E](t, queryCalls, 0)
193200
assert.Equal[E](t, prepareCalls, 3+5*2)
201+
assert.Equal[E](t, beginTxCalls, 1)
194202
default:
195203
assert.Equal[E](t, execCalls, 3)
196204
assert.Equal[E](t, queryCalls, 5*2)
197205
assert.Equal[E](t, prepareCalls, 0)
206+
assert.Equal[E](t, beginTxCalls, 1)
198207
}
199208
})
200209
}

0 commit comments

Comments
 (0)