Skip to content

Commit a64b5da

Browse files
luky116springliao
authored andcommitted
optimize datasource usage of driver (#207)
optimize: optimize datasource usage of driver
1 parent 24e68b8 commit a64b5da

6 files changed

Lines changed: 186 additions & 175 deletions

File tree

pkg/datasource/sql/at.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import (
3636
)
3737

3838
func init() {
39-
datasource.RegisterResourceManager(branch.BranchTypeAT, &ATSourceManager{basic: &datasource.BasicSourceManager{}})
39+
datasource.RegisterResourceManager(branch.BranchTypeAT, &ATSourceManager{basic: datasource.NewBasicSourceManager()})
4040
}
4141

4242
type ATSourceManager struct {
@@ -267,7 +267,7 @@ func (w *asyncATWorker) branchCommit(ctx context.Context, req message.BranchComm
267267

268268
select {
269269
case w.commitQueue <- phaseCtx:
270-
case <- ctx.Done():
270+
case <-ctx.Done():
271271
}
272272

273273
return

pkg/datasource/sql/datasource/datasource_manager.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ type BasicSourceManager struct {
9797
tableMetaCache map[string]*entry
9898
}
9999

100+
func NewBasicSourceManager() *BasicSourceManager {
101+
return &BasicSourceManager{
102+
tableMetaCache: make(map[string]*entry, 0),
103+
}
104+
}
105+
100106
// Commit a branch transaction
101107
// TODO wait finish
102108
func (dm *BasicSourceManager) BranchCommit(ctx context.Context, req message.BranchCommitRequest) (branch.BranchStatus, error) {

pkg/datasource/sql/driver.go

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,186 @@
1818
package sql
1919

2020
import (
21+
"context"
22+
"database/sql"
2123
"database/sql/driver"
24+
"errors"
25+
"fmt"
26+
"reflect"
27+
"strings"
28+
"unsafe"
29+
30+
"github.com/seata/seata-go/pkg/common/log"
31+
32+
"github.com/go-sql-driver/mysql"
33+
"github.com/seata/seata-go-datasource/sql/datasource"
34+
"github.com/seata/seata-go-datasource/sql/types"
35+
"github.com/seata/seata-go/pkg/protocol/branch"
36+
)
37+
38+
const (
39+
SeataMySQLDriver = "seata-mysql"
2240
)
2341

42+
func init() {
43+
sql.Register(SeataMySQLDriver, &SeataDriver{
44+
target: mysql.MySQLDriver{},
45+
})
46+
}
47+
2448
type SeataDriver struct {
2549
target driver.Driver
2650
}
2751

2852
func (d *SeataDriver) Open(name string) (driver.Conn, error) {
29-
return d.target.Open(name)
53+
conn, err := d.target.Open(name)
54+
if err != nil {
55+
log.Errorf("open connection: %w", err)
56+
return nil, err
57+
}
58+
59+
v := reflect.ValueOf(conn)
60+
if v.Kind() == reflect.Ptr {
61+
v = v.Elem()
62+
}
63+
64+
field := v.FieldByName("connector")
65+
66+
connector, _ := GetUnexportedField(field).(driver.Connector)
67+
68+
dbType := types.ParseDBType(d.getTargetDriverName())
69+
if dbType == types.DBType_Unknown {
70+
return nil, errors.New("unsupport conn type")
71+
}
72+
73+
c, err := d.OpenConnector(name)
74+
if err != nil {
75+
log.Errorf("open connector: %w", err)
76+
return nil, fmt.Errorf("open connector error: %v", err.Error())
77+
}
78+
79+
proxy, err := registerResource(connector, dbType, sql.OpenDB(c), name)
80+
if err != nil {
81+
log.Errorf("register resource: %w", err)
82+
return nil, err
83+
}
84+
85+
SetUnexportedField(field, proxy)
86+
return conn, nil
87+
}
88+
89+
func (d *SeataDriver) OpenConnector(dataSourceName string) (driver.Connector, error) {
90+
if driverCtx, ok := d.target.(driver.DriverContext); ok {
91+
return driverCtx.OpenConnector(dataSourceName)
92+
}
93+
return &dsnConnector{dsn: dataSourceName, driver: d.target}, nil
94+
}
95+
96+
func (d *SeataDriver) getTargetDriverName() string {
97+
return strings.ReplaceAll(SeataMySQLDriver, "seata-", "")
98+
}
99+
100+
type dsnConnector struct {
101+
dsn string
102+
driver driver.Driver
103+
}
104+
105+
func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
106+
return t.driver.Open(t.dsn)
107+
}
108+
109+
func (t *dsnConnector) Driver() driver.Driver {
110+
return t.driver
111+
}
112+
113+
func registerResource(connector driver.Connector, dbType types.DBType, db *sql.DB,
114+
dataSourceName string, opts ...seataOption) (driver.Connector, error) {
115+
116+
conf := loadConfig()
117+
for i := range opts {
118+
opts[i](conf)
119+
}
120+
121+
if err := conf.validate(); err != nil {
122+
log.Errorf("invalid conf: %w", err)
123+
return connector, err
124+
}
125+
126+
options := []dbOption{
127+
withGroupID(conf.GroupID),
128+
withResourceID(parseResourceID(dataSourceName)),
129+
withConf(conf),
130+
withTarget(db),
131+
withDBType(dbType),
132+
}
133+
134+
res, err := newResource(options...)
135+
if err != nil {
136+
log.Errorf("create new resource: %w", err)
137+
return nil, err
138+
}
139+
140+
if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil {
141+
log.Errorf("regisiter resource: %w", err)
142+
return nil, err
143+
}
144+
145+
return &seataConnector{
146+
res: res,
147+
target: connector,
148+
conf: conf,
149+
}, nil
150+
}
151+
152+
type (
153+
seataOption func(cfg *seataServerConfig)
154+
155+
// seataServerConfig
156+
seataServerConfig struct {
157+
// GroupID
158+
GroupID string `yaml:"groupID"`
159+
// BranchType
160+
BranchType branch.BranchType
161+
// Endpoints
162+
Endpoints []string `yaml:"endpoints" json:"endpoints"`
163+
}
164+
)
165+
166+
func (c *seataServerConfig) validate() error {
167+
return nil
168+
}
169+
170+
// loadConfig
171+
// TODO wait finish
172+
func loadConfig() *seataServerConfig {
173+
// 先设置默认配置
174+
175+
// 从默认文件获取
176+
return &seataServerConfig{
177+
GroupID: "DEFAULT_GROUP",
178+
BranchType: branch.BranchTypeAT,
179+
Endpoints: []string{"127.0.0.1:8888"},
180+
}
181+
}
182+
183+
func parseResourceID(dsn string) string {
184+
i := strings.Index(dsn, "?")
185+
186+
res := dsn
187+
188+
if i > 0 {
189+
res = dsn[:i]
190+
}
191+
192+
return strings.ReplaceAll(res, ",", "|")
193+
}
194+
195+
func GetUnexportedField(field reflect.Value) interface{} {
196+
return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface()
197+
}
198+
199+
func SetUnexportedField(field reflect.Value, value interface{}) {
200+
reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).
201+
Elem().
202+
Set(reflect.ValueOf(value))
30203
}

pkg/datasource/sql/sql.go

Lines changed: 0 additions & 169 deletions
This file was deleted.

0 commit comments

Comments
 (0)