|
18 | 18 | package sql |
19 | 19 |
|
20 | 20 | import ( |
| 21 | + "context" |
| 22 | + "database/sql" |
21 | 23 | "database/sql/driver" |
| 24 | + "errors" |
| 25 | + "fmt" |
| 26 | + "reflect" |
| 27 | + "strings" |
| 28 | + "unsafe" |
| 29 | + |
| 30 | + "github.com/seata/seata-go/pkg/common/log" |
| 31 | + |
| 32 | + "github.com/go-sql-driver/mysql" |
| 33 | + "github.com/seata/seata-go-datasource/sql/datasource" |
| 34 | + "github.com/seata/seata-go-datasource/sql/types" |
| 35 | + "github.com/seata/seata-go/pkg/protocol/branch" |
| 36 | +) |
| 37 | + |
| 38 | +const ( |
| 39 | + SeataMySQLDriver = "seata-mysql" |
22 | 40 | ) |
23 | 41 |
|
| 42 | +func init() { |
| 43 | + sql.Register(SeataMySQLDriver, &SeataDriver{ |
| 44 | + target: mysql.MySQLDriver{}, |
| 45 | + }) |
| 46 | +} |
| 47 | + |
24 | 48 | type SeataDriver struct { |
25 | 49 | target driver.Driver |
26 | 50 | } |
27 | 51 |
|
28 | 52 | func (d *SeataDriver) Open(name string) (driver.Conn, error) { |
29 | | - return d.target.Open(name) |
| 53 | + conn, err := d.target.Open(name) |
| 54 | + if err != nil { |
| 55 | + log.Errorf("open connection: %w", err) |
| 56 | + return nil, err |
| 57 | + } |
| 58 | + |
| 59 | + v := reflect.ValueOf(conn) |
| 60 | + if v.Kind() == reflect.Ptr { |
| 61 | + v = v.Elem() |
| 62 | + } |
| 63 | + |
| 64 | + field := v.FieldByName("connector") |
| 65 | + |
| 66 | + connector, _ := GetUnexportedField(field).(driver.Connector) |
| 67 | + |
| 68 | + dbType := types.ParseDBType(d.getTargetDriverName()) |
| 69 | + if dbType == types.DBType_Unknown { |
| 70 | + return nil, errors.New("unsupport conn type") |
| 71 | + } |
| 72 | + |
| 73 | + c, err := d.OpenConnector(name) |
| 74 | + if err != nil { |
| 75 | + log.Errorf("open connector: %w", err) |
| 76 | + return nil, fmt.Errorf("open connector error: %v", err.Error()) |
| 77 | + } |
| 78 | + |
| 79 | + proxy, err := registerResource(connector, dbType, sql.OpenDB(c), name) |
| 80 | + if err != nil { |
| 81 | + log.Errorf("register resource: %w", err) |
| 82 | + return nil, err |
| 83 | + } |
| 84 | + |
| 85 | + SetUnexportedField(field, proxy) |
| 86 | + return conn, nil |
| 87 | +} |
| 88 | + |
| 89 | +func (d *SeataDriver) OpenConnector(dataSourceName string) (driver.Connector, error) { |
| 90 | + if driverCtx, ok := d.target.(driver.DriverContext); ok { |
| 91 | + return driverCtx.OpenConnector(dataSourceName) |
| 92 | + } |
| 93 | + return &dsnConnector{dsn: dataSourceName, driver: d.target}, nil |
| 94 | +} |
| 95 | + |
| 96 | +func (d *SeataDriver) getTargetDriverName() string { |
| 97 | + return strings.ReplaceAll(SeataMySQLDriver, "seata-", "") |
| 98 | +} |
| 99 | + |
| 100 | +type dsnConnector struct { |
| 101 | + dsn string |
| 102 | + driver driver.Driver |
| 103 | +} |
| 104 | + |
| 105 | +func (t *dsnConnector) Connect(_ context.Context) (driver.Conn, error) { |
| 106 | + return t.driver.Open(t.dsn) |
| 107 | +} |
| 108 | + |
| 109 | +func (t *dsnConnector) Driver() driver.Driver { |
| 110 | + return t.driver |
| 111 | +} |
| 112 | + |
| 113 | +func registerResource(connector driver.Connector, dbType types.DBType, db *sql.DB, |
| 114 | + dataSourceName string, opts ...seataOption) (driver.Connector, error) { |
| 115 | + |
| 116 | + conf := loadConfig() |
| 117 | + for i := range opts { |
| 118 | + opts[i](conf) |
| 119 | + } |
| 120 | + |
| 121 | + if err := conf.validate(); err != nil { |
| 122 | + log.Errorf("invalid conf: %w", err) |
| 123 | + return connector, err |
| 124 | + } |
| 125 | + |
| 126 | + options := []dbOption{ |
| 127 | + withGroupID(conf.GroupID), |
| 128 | + withResourceID(parseResourceID(dataSourceName)), |
| 129 | + withConf(conf), |
| 130 | + withTarget(db), |
| 131 | + withDBType(dbType), |
| 132 | + } |
| 133 | + |
| 134 | + res, err := newResource(options...) |
| 135 | + if err != nil { |
| 136 | + log.Errorf("create new resource: %w", err) |
| 137 | + return nil, err |
| 138 | + } |
| 139 | + |
| 140 | + if err = datasource.GetDataSourceManager(conf.BranchType).RegisterResource(res); err != nil { |
| 141 | + log.Errorf("regisiter resource: %w", err) |
| 142 | + return nil, err |
| 143 | + } |
| 144 | + |
| 145 | + return &seataConnector{ |
| 146 | + res: res, |
| 147 | + target: connector, |
| 148 | + conf: conf, |
| 149 | + }, nil |
| 150 | +} |
| 151 | + |
| 152 | +type ( |
| 153 | + seataOption func(cfg *seataServerConfig) |
| 154 | + |
| 155 | + // seataServerConfig |
| 156 | + seataServerConfig struct { |
| 157 | + // GroupID |
| 158 | + GroupID string `yaml:"groupID"` |
| 159 | + // BranchType |
| 160 | + BranchType branch.BranchType |
| 161 | + // Endpoints |
| 162 | + Endpoints []string `yaml:"endpoints" json:"endpoints"` |
| 163 | + } |
| 164 | +) |
| 165 | + |
| 166 | +func (c *seataServerConfig) validate() error { |
| 167 | + return nil |
| 168 | +} |
| 169 | + |
| 170 | +// loadConfig |
| 171 | +// TODO wait finish |
| 172 | +func loadConfig() *seataServerConfig { |
| 173 | + // 先设置默认配置 |
| 174 | + |
| 175 | + // 从默认文件获取 |
| 176 | + return &seataServerConfig{ |
| 177 | + GroupID: "DEFAULT_GROUP", |
| 178 | + BranchType: branch.BranchTypeAT, |
| 179 | + Endpoints: []string{"127.0.0.1:8888"}, |
| 180 | + } |
| 181 | +} |
| 182 | + |
| 183 | +func parseResourceID(dsn string) string { |
| 184 | + i := strings.Index(dsn, "?") |
| 185 | + |
| 186 | + res := dsn |
| 187 | + |
| 188 | + if i > 0 { |
| 189 | + res = dsn[:i] |
| 190 | + } |
| 191 | + |
| 192 | + return strings.ReplaceAll(res, ",", "|") |
| 193 | +} |
| 194 | + |
| 195 | +func GetUnexportedField(field reflect.Value) interface{} { |
| 196 | + return reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface() |
| 197 | +} |
| 198 | + |
| 199 | +func SetUnexportedField(field reflect.Value, value interface{}) { |
| 200 | + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())). |
| 201 | + Elem(). |
| 202 | + Set(reflect.ValueOf(value)) |
30 | 203 | } |
0 commit comments