diff --git a/buf.go b/buf.go index 68663421..be2df815 100644 --- a/buf.go +++ b/buf.go @@ -45,6 +45,7 @@ type tdsBuffer struct { wpos int wPacketSeq byte wPacketType packetType + wSpid uint16 // Read fields. rbuf []byte @@ -52,11 +53,14 @@ type tdsBuffer struct { rsize int final bool rPacketType packetType + rSpid uint16 // afterFirst is assigned to right after tdsBuffer is created and // before the first use. It is executed after the first packet is // written and then removed. afterFirst func() + + serverConn *tdsSession } func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { @@ -86,6 +90,7 @@ func (w *tdsBuffer) flush() (err error) { // Write packet size. w.wbuf[0] = byte(w.wPacketType) binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) + binary.BigEndian.PutUint16(w.wbuf[4:], w.wSpid) w.wbuf[6] = w.wPacketSeq // Write packet into underlying transport. @@ -169,12 +174,14 @@ func (r *tdsBuffer) readNextPacket() error { PacketNo: buf[6], Pad: buf[7], } + if int(h.Size) > r.packetSize { return errors.New("invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { return errors.New("invalid packet size, it is shorter than header size") } + _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) //s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size]) //fmt.Print(s) @@ -183,14 +190,58 @@ func (r *tdsBuffer) readNextPacket() error { } r.rpos = headerSize r.rsize = int(h.Size) - r.final = h.Status != 0 + r.final = h.Status&0x1 != 0 r.rPacketType = h.PacketType + r.rSpid = h.Spid + + if r.serverConn != nil { + _, err := r.serverConn.buf.Write(r.rbuf[r.rpos:r.rsize]) + if err != nil { + return err + } + + if r.final { + r.serverConn.buf.wSpid = h.Spid + if err := r.serverConn.buf.FinishPacket(); err != nil { + return err + } + } else { + if err := r.serverConn.buf.flush(); err != nil { + return err + } + } + } return nil } -func (r *tdsBuffer) BeginRead() (packetType, error) { +type beginReadConfig struct { + id string +} + +type beginRedOption func(beginReadConfig) + +func withFallbackID(id string) beginRedOption { return func(brc beginReadConfig) { brc.id = id } } + +func (r *tdsBuffer) BeginRead(opts ...beginRedOption) (packetType, error) { + conf := beginReadConfig{ + id: "UNSET", + } + for _, opt := range opts { + opt(conf) + } + + // var a string + if r.serverConn != nil { + conf.id = r.serverConn.id + // fmt.Printf("BeginRead with serverconn: %s %d\n", r.serverConn.id, r.rPacketType) + r.serverConn.buf.BeginPacket(r.rPacketType, false) + } else { + // fmt.Printf("BeginRead no serverconn: %s %d\n", conf.id, r.rPacketType) + } + err := r.readNextPacket() if err != nil { + // fmt.Printf("BeginRead error: %s: %s\n", conf.id, err) return 0, err } return r.rPacketType, nil diff --git a/go.sum b/go.sum index b9445a8d..5fd0bc5a 100644 --- a/go.sum +++ b/go.sum @@ -13,8 +13,6 @@ github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2/go.mod h1:wP83 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= -github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= diff --git a/lastinsertid_example_test.go b/lastinsertid_example_test.go deleted file mode 100644 index 260b44ec..00000000 --- a/lastinsertid_example_test.go +++ /dev/null @@ -1,60 +0,0 @@ -//go:build go1.10 -// +build go1.10 - -package mssql_test - -import ( - "database/sql" - "log" -) - -// This example shows the usage of Connector type -func ExampleLastInsertId() { - - connString := makeConnURL().String() - - db, err := sql.Open("sqlserver", connString) - if err != nil { - log.Fatal("Open connection failed:", err.Error()) - } - defer db.Close() - - // Create table - _, err = db.Exec("create table foo (bar int identity, baz int unique);") - if err != nil { - log.Fatal(err) - } - defer db.Exec("if object_id('foo', 'U') is not null drop table foo;") - - // Attempt to retrieve scope identity using LastInsertId - res, err := db.Exec("insert into foo (baz) values (1)") - if err != nil { - log.Fatal(err) - } - n, err := res.LastInsertId() - if err != nil { - log.Print(err) - // Gets error: LastInsertId is not supported. Please use the OUTPUT clause or add `select ID = convert(bigint, SCOPE_IDENTITY())` to the end of your query. - } - log.Printf("LastInsertId: %d\n", n) - - // Retrieve scope identity by adding 'select ID = convert(bigint, SCOPE_IDENTITY())' to the end of the query - rows, err := db.Query("insert into foo (baz) values (10); select ID = convert(bigint, SCOPE_IDENTITY())") - if err != nil { - log.Fatal(err) - } - defer rows.Close() - var lastInsertId1 int64 - for rows.Next() { - rows.Scan(&lastInsertId1) - log.Printf("LastInsertId from SCOPE_IDENTITY(): %d\n", lastInsertId1) - } - - // Retrieve scope identity by 'output inserted`` - var lastInsertId2 int64 - err = db.QueryRow("insert into foo (baz) output inserted.bar values (100)").Scan(&lastInsertId2) - if err != nil { - log.Fatal(err) - } - log.Printf("LastInsertId from output inserted: %d\n", lastInsertId2) -} diff --git a/mssql.go b/mssql.go index 2c8d00e5..4d733ad1 100644 --- a/mssql.go +++ b/mssql.go @@ -239,6 +239,10 @@ type Conn struct { outs outputs } +func (c *Conn) Spid() uint16 { + return c.sess.buf.rSpid +} + type outputs struct { params map[string]interface{} returnStatus *ReturnStatus diff --git a/proxy.go b/proxy.go new file mode 100644 index 00000000..ced41035 --- /dev/null +++ b/proxy.go @@ -0,0 +1,1468 @@ +package mssql + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "os" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +const ( + defaultServerProgName = "GO MSSQL Server" + defaultServerVerion = "v15.0.0" +) + +type Client struct { + Conn *Conn + + border0DebugLogs bool +} + +type Server struct { + ConnTimeout time.Duration + PacketSize uint16 + Logger ContextLogger + Version uint32 + ProgName string + Encryption byte + + border0DebugLogs bool +} + +type ServerConfig struct { + ConnTimeout *time.Duration + PacketSize *uint16 + Logger ContextLogger + Version *string + Encryption *string + ProgName *string +} + +type ServerSession struct { + *tdsSession +} + +func NewServer(config ServerConfig) (*Server, error) { + border0DebugLogs := strings.ToLower(os.Getenv("BORDER0_MSSQL_PROXY_DEBUG")) == "true" + + server := &Server{border0DebugLogs: border0DebugLogs} + + if config.PacketSize == nil { + server.PacketSize = defaultPacketSize + } else { + server.PacketSize = *config.PacketSize + } + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + if server.PacketSize < 512 { + server.PacketSize = 512 + } else if server.PacketSize > 32767 { + server.PacketSize = 32767 + } + + if config.ConnTimeout != nil { + server.ConnTimeout = *config.ConnTimeout + } + + if config.Logger != nil { + server.Logger = config.Logger + } + + if config.Version != nil { + server.Version = getDriverVersion(*config.Version) + } else { + server.Version = getDriverVersion(defaultServerVerion) + } + + if config.ProgName != nil { + server.ProgName = *config.ProgName + } else { + server.ProgName = defaultServerProgName + } + + if config.Encryption != nil { + switch *config.Encryption { + case "strict": + server.Encryption = encryptStrict + case "required": + server.Encryption = encryptReq + case "on": + server.Encryption = encryptOn + case "off": + server.Encryption = encryptOff + default: + return nil, errors.New("invalid encryption option") + } + } else { + server.Encryption = encryptNotSup + } + + return server, nil +} + +func (s *Server) ReadLogin(conn net.Conn) (*ServerSession, *login, map[uint8][]byte, error) { + toconn := newTimeoutConn(conn, s.ConnTimeout) + inbuf := newTdsBuffer(s.PacketSize, toconn) + + loginOptions, login, err := s.handshake(inbuf) + if err != nil { + return nil, nil, loginOptions, err + } + + sess := ServerSession{&tdsSession{ + buf: inbuf, + logger: s.Logger, + id: conn.RemoteAddr().String(), + + // FIXME: REMOVE + border0DebugLogs: s.border0DebugLogs, + }} + + return &sess, &login, loginOptions, nil +} + +func (s *tdsSession) ReadCommand() (packetType, error) { + var buf []byte + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Start ReadCommand: %s rpos=%d, rsize=%d\n", s.id, s.buf.rpos, s.buf.rsize) + } + for { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("ReadCommand: %s rpos=%d, rsize=%d\n", s.id, s.buf.rpos, s.buf.rsize) + } + _, err := s.buf.BeginRead(withFallbackID(s.id)) + if err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("ReadCommand: %s error %v\n", s.id, err) + } + return 0, err + } + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Got data: %s %d\n", s.id, s.buf.rsize) + } + + bytes := make([]byte, s.buf.rsize-s.buf.rpos) + s.buf.ReadFull(bytes) + buf = append(buf, bytes...) + + if s.buf.final { + copy(s.buf.rbuf, buf) + s.buf.rsize = len(buf) + s.buf.rpos = 0 + return s.buf.rPacketType, nil + } + } +} + +func (s *Server) handshake(r *tdsBuffer) (map[uint8][]byte, login, error) { + var login login + + loginOptions, err := s.readPrelogin(r) + if err != nil { + return loginOptions, login, err + } + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Client -> Proxy: revieved prelogin options %+v\n", loginOptions) + } + + err = s.writePrelogin(r) + if err != nil { + return loginOptions, login, err + } + + login, err = s.readLogin(r) + if err != nil { + return loginOptions, login, err + } + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Client -> Proxy: revieved login packet %+v\n", login) + } + + return loginOptions, login, nil +} + +func (s *Server) readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) { + packet_type, err := r.BeginRead() + if err != nil { + return nil, err + } + struct_buf, err := io.ReadAll(r) + if err != nil { + return nil, err + } + if packet_type != packPrelogin { + return nil, errors.New("invalid request, expected pre-login packet") + } + if len(struct_buf) == 0 { + return nil, errors.New("invalid empty PRELOGIN request, it must contain at least one byte") + } + + offset := 0 + results := map[uint8][]byte{} + for { + // read prelogin option + plOption, err := readPreloginOption(struct_buf, offset) + if err != nil { + return results, err + } + + if plOption.token == preloginTERMINATOR { + break + } + + // read prelogin option data + value, err := readPreloginOptionData(plOption, struct_buf) + if err != nil { + return results, err + } + results[plOption.token] = value + + offset += preloginOptionSize + } + + return results, nil +} + +func (s *Server) writePrelogin(r *tdsBuffer) error { + fields := s.preparePreloginResponseFields() + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy -> Client: returned prelogin response %+v\n", fields) + } + + if err := writePrelogin(packReply, r, fields); err != nil { + return err + } + + return nil +} + +func (s *Server) preparePreloginResponseFields() map[uint8][]byte { + s.Version = getDriverVersion("v15.0.4430.0") + fields := map[uint8][]byte{ + // 4 bytes for version and 2 bytes for minor version + preloginVERSION: {byte(s.Version >> 24), byte(s.Version >> 16), byte(s.Version >> 8), byte(s.Version), 0, 0}, + preloginENCRYPTION: {s.Encryption}, + preloginINSTOPT: {0}, + // preloginTHREADID: {0, 0, 0, 0}, + preloginTHREADID: {}, + preloginMARS: {0}, // MARS disabled + } + + return fields +} + +func (s *Server) readLogin(r *tdsBuffer) (login, error) { + var login login + packet_type, err := r.BeginRead() + if err != nil { + return login, err + } + + if packet_type != packLogin7 { + return login, errors.New("invalid request, expected login packet") + } + + struct_buf, err := io.ReadAll(r) + if err != nil { + return login, err + } + + if len(struct_buf) == 0 { + return login, errors.New("invalid empty login request, it must contain at least one byte") + } + + var loginHeader loginHeader + if err := binary.Read(bytes.NewReader(struct_buf), binary.LittleEndian, &loginHeader); err != nil { + return login, fmt.Errorf("failed to read login packet: %w", err) + } + + login.TDSVersion = loginHeader.TDSVersion + login.ClientProgVer = loginHeader.ClientProgVer + login.ClientPID = loginHeader.ClientPID + login.ConnectionID = loginHeader.ConnectionID + login.OptionFlags1 = loginHeader.OptionFlags1 + login.OptionFlags2 = loginHeader.OptionFlags2 + login.TypeFlags = loginHeader.TypeFlags + login.OptionFlags3 = loginHeader.OptionFlags3 + login.ClientTimeZone = loginHeader.ClientTimeZone + login.ClientLCID = loginHeader.ClientLCID + login.ClientID = loginHeader.ClientID + + login.HostName, err = readLoginFieldString(struct_buf, loginHeader.HostNameOffset, loginHeader.HostNameLength) + if err != nil { + return login, fmt.Errorf("failed to read hostname: %w", err) + } + login.UserName, err = readLoginFieldString(struct_buf, loginHeader.UserNameOffset, loginHeader.UserNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.AppName, err = readLoginFieldString(struct_buf, loginHeader.AppNameOffset, loginHeader.AppNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.ServerName, err = readLoginFieldString(struct_buf, loginHeader.ServerNameOffset, loginHeader.ServerNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.CtlIntName, err = readLoginFieldString(struct_buf, loginHeader.CtlIntNameOffset, loginHeader.CtlIntNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Language, err = readLoginFieldString(struct_buf, loginHeader.LanguageOffset, loginHeader.LanguageLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Database, err = readLoginFieldString(struct_buf, loginHeader.DatabaseOffset, loginHeader.DatabaseLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Get dbname: %s\n", login.Database) + } + login.SSPI, err = readLoginFieldBytes(struct_buf, loginHeader.SSPIOffset, loginHeader.SSPILength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.AtchDBFile, err = readLoginFieldString(struct_buf, loginHeader.AtchDBFileOffset, loginHeader.AtchDBFileLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.ChangePassword, err = readLoginFieldString(struct_buf, loginHeader.ChangePasswordOffset, loginHeader.ChangePasswordLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + + // Read FeatureExt if present + if loginHeader.OptionFlags3&0x10 != 0 && loginHeader.ExtensionOffset != 0 { + if int(loginHeader.ExtensionOffset)+4 > len(struct_buf) { + return login, fmt.Errorf("cannot read ibFeatureExtLong at offset %d", loginHeader.ExtensionOffset) + } + extPtr := binary.LittleEndian.Uint32(struct_buf[loginHeader.ExtensionOffset : loginHeader.ExtensionOffset+4]) + extStart := int(extPtr) + if extStart+1 > len(struct_buf) { + return login, fmt.Errorf("invalid FeatureExt pointer: not enough bytes at offset %d", extStart) + } + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: starting FeatureExt parse from offset %d\n", extStart) + } + + reader := bytes.NewReader(struct_buf[extStart:]) + for { + featureID, err := reader.ReadByte() + if err != nil { + return login, fmt.Errorf("failed to read FeatureID: %w", err) + } + if featureID == 0xFF { + break + } + + var featureDataLen uint32 + if err := binary.Read(reader, binary.LittleEndian, &featureDataLen); err != nil { + return login, fmt.Errorf("failed to read FeatureDataLength for FeatureID 0x%X: %w", featureID, err) + } + + if int(featureDataLen) > reader.Len() { + return login, fmt.Errorf("declared FeatureDataLength (%d) exceeds remaining buffer (%d) for FeatureID 0x%X", featureDataLen, reader.Len(), featureID) + } + + data := make([]byte, featureDataLen) + if _, err := io.ReadFull(reader, data); err != nil { + return login, fmt.Errorf("failed to read FeatureData for FeatureID 0x%X: %w", featureID, err) + } + + switch featureID { + case 0x01: + login.FeatureExt.Add(&featureExtSessionRecovery{}) + case 0x04: + login.FeatureExt.Add(&featureExtColumnEncryption{ + version: data[0], + }) + case 0x5: + login.FeatureExt.Add(&featureExtGlobalTransactions{}) + case 0x09: + login.FeatureExt.Add(&featureExtDataClassification{ + version: data[0], + }) + case 0x0A: + login.FeatureExt.Add(&featureExtUTF8Support{}) + case 0x0B: + login.FeatureExt.Add(&featureExtAzureSQLDNSCaching{}) + default: + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: unknown FeatureExt ID 0x%02X, length %d\n", featureID, featureDataLen) + } + return login, fmt.Errorf("unknown FeatureExt ID 0x%02X", featureID) + } + if s.border0DebugLogs { + fmt.Printf("FeatureExt: ID=0x%02X, Len=%d, Data=% X\n", featureID, featureDataLen, data) + } + } + } + + if s.border0DebugLogs { + fmt.Printf("\n--- TDS LOGIN PACKET SUMMARY ---\n") + fmt.Printf("TDS Version: 0x%08X\n", login.TDSVersion) + fmt.Printf("Client PID: %d\n", login.ClientPID) + fmt.Printf("OptionFlags1: 0x%02X\n", login.OptionFlags1) + fmt.Printf("OptionFlags2: 0x%02X\n", login.OptionFlags2) + fmt.Printf("OptionFlags3: 0x%02X\n", login.OptionFlags3) + fmt.Printf("Client Time Zone: %d\n", login.ClientTimeZone) + fmt.Printf("Client LCID: 0x%08X\n", login.ClientLCID) + fmt.Printf("Client ID: % X\n", login.ClientID[:]) + fmt.Printf("HostName: %s\n", login.HostName) + fmt.Printf("UserName: %s\n", login.UserName) + fmt.Printf("AppName: %s\n", login.AppName) + fmt.Printf("ServerName: %s\n", login.ServerName) + fmt.Printf("CtlIntName: %s\n", login.CtlIntName) + fmt.Printf("Language: %s\n", login.Language) + fmt.Printf("Database: %s\n", login.Database) + fmt.Printf("AttachDBFile: %s\n", login.AtchDBFile) + fmt.Printf("ChangePassword: %s\n", login.ChangePassword) + fmt.Printf("SSPI: % X\n", login.SSPI) + fmt.Printf("FeatureExt: % X\n", login.FeatureExt) + fmt.Printf("--- END LOGIN PACKET SUMMARY ---\n\n") + } + return login, nil +} + +func readLoginFieldString(b []byte, offset uint16, length uint16) (string, error) { + if len(b) < int(offset)+int(length)*2 { + return "", fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length*2, len(b)) + } + + return ucs22str(b[offset : offset+length*2]) +} + +func readLoginFieldBytes(b []byte, offset uint16, length uint16) ([]byte, error) { + if len(b) < int(offset)+int(length) { + return nil, fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length, len(b)) + } + + return b[offset : offset+length], nil +} + +func (s *Server) WriteLogin(session *ServerSession, loginTokens []tokenStruct, spid uint16) error { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing login tokens %+v\n", loginTokens) + } + loginAck := loginAckStruct{ + Interface: 1, + TDSVersion: verTDS74, + ProgName: s.ProgName, + ProgVer: s.Version, + } + + done := doneStruct{ + Status: 0, + CurCmd: 0, + RowCount: 0, + errors: []Error{}, + } + + session.buf.wSpid = spid + session.buf.BeginPacket(packReply, false) + // session.buf.Write(loginEnvBytes) + // session.buf.Write(writeLoginAck(loginAckStruct)) + // session.buf.Write(writeDone(doneStruct)) + for _, token := range loginTokens { + switch t := token.(type) { + case loginAckStruct: + if _, err := session.buf.Write(writeLoginAck(t)); err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Error writing loginAck: %v\n", err) + } + return err + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing loginAck: %+v\n", loginAck) + } + case doneStruct: + if _, err := session.buf.Write(writeDone(done)); err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Error writing doneStruct: %v\n", err) + } + return err + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing doneStruct: %+v\n", done) + } + case envChange: + data := make([]byte, 0, len(t.data)+3) + data = append(data, byte(tokenEnvChange)) + // append length of data as uint16 in little-endian order + lenBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(lenBytes, uint16(len(t.data))) + data = append(data, lenBytes...) + data = append(data, t.data...) + + if _, err := session.buf.Write(data); err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Error writing envChange: %v\n", err) + } + return err + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing envChange: %+v\n", t.data) + } + case loginToken: + data := make([]byte, 0, len(t.data)+3) + data = append(data, byte(t.token)) + // append length of data as uint16 in little-endian order + lenBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(lenBytes, uint16(len(t.data))) + data = append(data, lenBytes...) + + data = append(data, t.data...) + + if _, err := session.buf.Write(data); err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Error writing loginToken: %v\n", err) + } + return err + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing loginToken: %+v\n", t.token) + } + case featureExtAck: + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing featureExtAck: %+v\n", t) + } + // Serialize the raw feature ACK data (feature entries + terminator) + var rawBuf bytes.Buffer + if err := writeFeatureExtAck(&rawBuf, t); err != nil { + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Error building featureExtAck: %v\n", err) + } + return err + } + raw := rawBuf.Bytes() + // Write token type + if err := session.buf.WriteByte(byte(tokenFeatureExtAck)); err != nil { + return err + } + // Write length of feature data (uint32 little-endian) + // lenBuf := make([]byte, 4) + // binary.LittleEndian.PutUint32(lenBuf, uint32(len(raw))) + // if _, err := session.buf.Write(lenBuf); err != nil { + // return err + // } + // Write the feature data itself + if _, err := session.buf.Write(raw); err != nil { + return err + } + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: featureExtAck raw bytes written (%d bytes)\n", len(raw)) + } + default: + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Unknown token type: %T\n", t) + } + } + } + + // FIXME: REMOVE + if s.border0DebugLogs { + fmt.Printf("Proxy-Client: Writing loginTokens %+v\n", loginTokens) + fmt.Printf("Proxy-Client: Writing loginAck %+v\n", loginAck) + fmt.Printf("Proxy-Client: Writing doneStruct %+v\n", done) + } + + return session.buf.FinishPacket() +} + +func UCS2String(s []byte) (string, error) { + return ucs22str(s) +} + +func (c *Conn) Transport() io.ReadWriteCloser { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf.transport +} + +func (c *Conn) Buffer() *tdsBuffer { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf +} + +func (c *Conn) Session() *tdsSession { + return c.sess +} + +func (s *tdsSession) ParseHeader() (header, error) { + var h header + err := binary.Read(s.buf, binary.LittleEndian, &h) + if err != nil { + return header{}, err + } + + return h, nil +} + +func (s *tdsSession) ParseSQLBatch() ([]headerStruct, string, error) { + // fmt.Printf("ParseSQLBatch: %s payload total=%d, about to consume body at rpos=%d\n", s.id, s.buf.rsize, s.buf.rpos) + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, "", err + } + + // fmt.Printf("ParseSQLBatch: headers=%+v\n", headers) + query, err := readUcs2(s.buf, (s.buf.rsize-s.buf.rpos)/2) + if err != nil { + return nil, "", err + } + // fmt.Printf("ParseSQLBatch: query=%s\n", query) + + return headers, query, nil +} + +func (s *tdsSession) ParseTransMgrReq() ([]headerStruct, uint16, isoLevel, string, string, uint8, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var rqtype uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &rqtype); err != nil { + return nil, 0, 0, "", "", 0, err + } + + switch rqtype { + case tmBeginXact: + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + return headers, rqtype, isolationLevel, name, "", 0, nil + case tmCommitXact, tmRollbackXact: + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var flags uint8 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, 0, 0, "", "", 0, err + } + + var newname string + if flags&fBeginXact != 0 { + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + newname, err = readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + } + + return headers, rqtype, 0, name, newname, flags, nil + default: + return nil, 0, 0, "", "", 0, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } +} + +// // readParamMeta consumes the parameter's type metadata and returns a filled typeInfo. +// // For legacy TEXT and NTEXT, it skips max-length and collation before using the simple reader. +// // For IMAGE, it skips only max-length. All other types use the normal readTypeInfo path. +// func readParamMeta(b *tdsBuffer, enc msdsn.EncodeParameters) typeInfo { +// // Read the type token +// typeByte := b.byte() +// fmt.Printf("readParamMeta: typeByte %d\n", typeByte) + +// switch typeByte { +// case typeText, typeNText: +// // Skip max length (2 bytes) +// var maxLen uint16 +// if err := binary.Read(b, binary.LittleEndian, &maxLen); err != nil { +// panic(err) +// } +// // Skip collation (5 bytes) +// if _, err := io.ReadFull(b, make([]byte, 5)); err != nil { +// panic(err) +// } +// return typeInfo{TypeId: typeByte, Reader: readSimpleParam} + +// case typeImage: +// // Skip max length (2 bytes) +// var maxLen uint16 +// if err := binary.Read(b, binary.LittleEndian, &maxLen); err != nil { +// panic(err) +// } +// return typeInfo{TypeId: typeByte, Reader: readSimpleParam} + +// default: +// // All other types: consume full metadata and field-specific reader via readTypeInfo +// return readTypeInfo(b, typeByte, nil, enc) +// } +// } + +func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []interface{}, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + // fmt.Printf("ParseRPC: %s payload total=%d, about to consume body at rpos=%d\n", s.id, s.buf.rsize, s.buf.rpos) + var nameLength uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &nameLength); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + var proc procId + var idswitch uint16 = 0xffff + if nameLength == idswitch { + if err := binary.Read(s.buf, binary.LittleEndian, &proc.id); err != nil { + return nil, procId{}, 0, nil, nil, err + } + } else { + proc.name, err = readUcs2(s.buf, int(nameLength)) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + } + + var flags uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + params, values, err := parseParams(s.buf, s.encoding) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + // fmt.Printf("ParseRPC: %s params done, rpos=%d, rsize=%d, final=%v\n", s.id, s.buf.rpos, s.buf.rsize, s.buf.final) + return headers, proc, flags, params, values, nil +} + +func parseParams(b *tdsBuffer, encoding msdsn.EncodeParameters) ([]param, []interface{}, error) { + var ( + params []param + values []interface{} + ) + + for { + // stop when buffer is exhausted + if b.rpos >= b.rsize { + break + } + + // dump next bytes for context + nextEnd := b.rpos + 32 + if nextEnd > b.rsize { + nextEnd = b.rsize + } + + var p param + + // name + name, err := readBVarChar(b) + if err != nil { + return nil, nil, err + } + p.Name = name + + // flags + if err := binary.Read(b, binary.LittleEndian, &p.Flags); err != nil { + return nil, nil, err + } + + // always parse type metadata to keep cursor aligned + p.ti = readParamTypeInfo(b, b.byte(), nil, encoding) + + // // OUTPUT-only: skip without consuming data + // if p.Flags¶mOutput != 0 { + // fmt.Printf("parseParams: param %q is OUTPUT-only\n", p.Name) + // params = append(params, p) + // values = append(values, nil) + // continue + // } + // // BYREF: consume and drop client-sent value to stay aligned + // if p.Flags¶mByRef != 0 { + // fmt.Printf("parseParams: about to read value for param %q at pos=%d, rsize=%d\n", p.Name, b.rpos, b.rsize) + // fmt.Printf("parseParams: param %q is BYREF\n", p.Name) + // _ = p.ti.Reader(&p.ti, b, nil) + // params = append(params, p) + // values = append(values, nil) + // continue + // } + // normal IN parameter → read value + val := p.ti.Reader(&p.ti, b, nil) + p.buffer = p.ti.Buffer + params = append(params, p) + values = append(values, val) + } + + return params, values, nil +} + +func readAllHeaders(r io.Reader) ([]headerStruct, error) { + var totalLength uint32 + err := binary.Read(r, binary.LittleEndian, &totalLength) + if err != nil { + return nil, err + } + + if totalLength < 4 { + return nil, errors.New("invalid total length") + } + + var headers []headerStruct + remainingLength := totalLength - 4 // Subtracting the length of the totalLength field + + for remainingLength > 0 { + var headerLength uint32 + err = binary.Read(r, binary.LittleEndian, &headerLength) + if err != nil { + return nil, err + } + + if headerLength < 6 || headerLength-6 > remainingLength { + return nil, errors.New("invalid header length") + } + + var hdrtype uint16 + err = binary.Read(r, binary.LittleEndian, &hdrtype) + if err != nil { + return nil, err + } + + dataLength := headerLength - 6 // Subtracting the length of the headerLength and hdrtype fields + data := make([]byte, dataLength) + _, err = io.ReadFull(r, data) + if err != nil { + return nil, err + } + + headers = append(headers, headerStruct{ + hdrtype: hdrtype, + data: data, + }) + + remainingLength -= headerLength + } + + if remainingLength != 0 { + return nil, errors.New("inconsistent header length") + } + + return headers, nil +} + +func (p *procId) Id() uint16 { + return p.id +} + +func (p *procId) Name() string { + return p.name +} + +func writeDone(d doneStruct) []byte { + data := make([]byte, 0, 12) + + // Append tokenDone and the calculated size + data = append(data, byte(tokenDone)) + + // Append Status + statusBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(statusBytes, d.Status) + data = append(data, statusBytes...) + + // Append CurCmd + curCmdBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(curCmdBytes, d.CurCmd) + data = append(data, curCmdBytes...) + + // Append RowCount + rowCountBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(rowCountBytes, d.RowCount) + data = append(data, rowCountBytes...) + + return data +} + +func writeLoginAck(l loginAckStruct) []byte { + progNameUCS2 := str2ucs2(l.ProgName) + + // Prepare the slice with preallocated size for efficiency + data := make([]byte, 0, 10+len(progNameUCS2)) + + // Append tokenLoginAck + data = append(data, byte(tokenLoginAck)) + + // Append calculated size + size := uint16(10 + len(progNameUCS2)) + sizeBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(sizeBytes, size) + data = append(data, sizeBytes...) + + // Append Interface + data = append(data, l.Interface) + + // Append TDSVersion + tdsVersionBytes := make([]byte, 4) + binary.BigEndian.PutUint32(tdsVersionBytes, l.TDSVersion) + data = append(data, tdsVersionBytes...) + + // Append ProgName Length and ProgName + data = append(data, byte(len(progNameUCS2)/2)) + data = append(data, progNameUCS2...) + + // Append ProgVer + progVerBytes := make([]byte, 4) + binary.BigEndian.PutUint32(progVerBytes, l.ProgVer) + data = append(data, progVerBytes...) + + return data +} + +func NewConnectorFromConfig(config msdsn.Config) *Connector { + return newConnector(config, driverInstanceNoProcess) +} + +func NewClient(ctx context.Context, c *Connector, dialer Dialer, database string) (*Client, error) { + border0DebugLogs := strings.ToLower(os.Getenv("BORDER0_MSSQL_PROXY_DEBUG")) == "true" + + if dialer != nil { + c.Dialer = dialer + } + + params := c.params + params.Database = database + if border0DebugLogs { + fmt.Printf("NewClient: %+v\n", params) + } + + conn, err := c.driver.connect(ctx, c, params) + if err != nil { + return nil, err + } + + if err := conn.ResetSession(ctx); err != nil { + return nil, err + } + + return &Client{ + Conn: conn, + border0DebugLogs: border0DebugLogs, + }, nil +} + +func (c *Client) Close() error { + return c.Conn.Close() +} + +func (c *Client) SendSqlBatch(ctx context.Context, serverConn *ServerSession, query string, headers []headerStruct, resetSession bool) ([]doneStruct, error) { + if err := sendSqlBatch72(c.Conn.sess.buf, query, headers, resetSession); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) SendRpc(ctx context.Context, serverConn *ServerSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { + if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession, c.Conn.sess.encoding); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) TransMgrReq(ctx context.Context, serverConn *ServerSession, headers []headerStruct, rqtype uint16, isolationLevel isoLevel, name, newname string, flags uint8, resetSession bool) ([]doneStruct, error) { + switch rqtype { + case tmBeginXact: + if err := sendBeginXact(c.Conn.sess.buf, headers, isolationLevel, name, resetSession); err != nil { + return nil, err + } + case tmCommitXact: + if err := sendCommitXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + case tmRollbackXact: + if err := sendRollbackXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]doneStruct, error) { + c.Conn.sess.buf.serverConn = sess.tdsSession + + packet_type, err := c.Conn.sess.buf.BeginRead() + if err != nil { + switch e := err.(type) { + case *net.OpError: + return nil, e + default: + return nil, &net.OpError{Op: "Read", Err: err} + } + } + + if packet_type != packReply { + return nil, StreamError{ + InnerError: fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply), + } + } + + var dones []doneStruct + var columns []columnStruct + var errs []Error + for { + token := token(c.Conn.sess.buf.byte()) + if c.border0DebugLogs { + fmt.Printf("processResponse: %s token %d\n", c.Conn.sess.id, token) + } + switch token { + case tokenReturnStatus: + parseReturnStatus(c.Conn.sess.buf) + case tokenOrder: + parseOrder(c.Conn.sess.buf) + case tokenDone, tokenDoneProc, tokenDoneInProc: + res := parseDone(c.Conn.sess.buf) + res.errors = errs + dones = append(dones, res) + if res.Status&doneSrvError != 0 { + return dones, ServerError{res.getError()} + } + + if res.Status&doneMore == 0 { + return dones, nil + } + case tokenColMetadata: + columns = parseColMetadata72(c.Conn.sess.buf, c.Conn.sess) + case tokenRow: + row := make([]interface{}, len(columns)) + err = parseRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenNbcRow: + row := make([]interface{}, len(columns)) + err = parseNbcRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenEnvChange: + processEnvChg(ctx, c.Conn.sess) + case tokenError: + err := parseError72(c.Conn.sess.buf) + errs = append(errs, err) + case tokenInfo: + parseInfo(c.Conn.sess.buf) + case tokenReturnValue: + parseReturnValue(c.Conn.sess.buf, c.Conn.sess) + case tokenSessionState: + // Read the total length of the SESSIONSTATE token (excluding TokenType and this Length field itself) + var totalLen uint32 + if err := binary.Read(c.Conn.sess.buf, binary.LittleEndian, &totalLen); err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read SESSIONSTATE length: %w", err), + } + } + + // Read SeqNo (4 bytes) + var seqNo uint32 + if err := binary.Read(c.Conn.sess.buf, binary.LittleEndian, &seqNo); err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read SESSIONSTATE SeqNo: %w", err), + } + } + + // Read Status (1 byte) + status, err := c.Conn.sess.buf.ReadByte() + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read SESSIONSTATE Status: %w", err), + } + } + fRecoverable := status&0x01 == 0x01 + + // FIXME: REMOVE + if c.border0DebugLogs { + fmt.Printf("processResponse: SESSIONSTATE received - TotalLen=%d, SeqNo=%d, fRecoverable=%v\n", totalLen, seqNo, fRecoverable) + } + + bytesLeft := int(totalLen - 5) // minus SeqNo (4) + Status (1) + for bytesLeft > 0 { + stateID, err := c.Conn.sess.buf.ReadByte() + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read StateId: %w", err), + } + } + + lenByte, err := c.Conn.sess.buf.ReadByte() + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read StateLen byte: %w", err), + } + } + bytesLeft -= 2 // 1 for stateID, 1 for lenByte + + var stateLen int + if lenByte == 0xFF { + var longLen uint32 + if err := binary.Read(c.Conn.sess.buf, binary.LittleEndian, &longLen); err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read extended StateLen: %w", err), + } + } + stateLen = int(longLen) + bytesLeft -= 4 + } else { + stateLen = int(lenByte) + } + bytesLeft -= stateLen + + stateValue := make([]byte, stateLen) + if _, err := io.ReadFull(c.Conn.sess.buf, stateValue); err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to read StateValue: %w", err), + } + } + + if c.border0DebugLogs { + fmt.Printf("SESSIONSTATE: StateID=0x%02X, Length=%d, Value=% X\n", stateID, stateLen, stateValue) + } + } + default: + return nil, StreamError{ + InnerError: fmt.Errorf("unknown token type returned: %v", token), + } + } + } +} + +func (d doneStruct) GetError() error { + n := len(d.errors) + if n == 0 { + return nil + } + + var err error + + for _, e := range d.errors { + err = errors.Join(err, e) + } + + return err +} + +func (c *Client) LoginTokens() []tokenStruct { + return c.Conn.sess.loginTokens +} + +func (c *Client) Database() string { + return c.Conn.sess.database +} + +func (c *Client) SendAttention(ctx context.Context, serverConn *ServerSession) ([]doneStruct, error) { + if err := sendAttention(c.Conn.sess.buf); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +// func readSimpleParam(ti *typeInfo, r *tdsBuffer, _ *cryptoMetadata) interface{} { +// fmt.Printf("[readSimpleParam] rpos before read = %d, rsize = %d\n", r.rpos, r.rsize) + +// // Length (int32) +// if r.rpos+4 > r.rsize { +// fmt.Printf("[readSimpleParam] not enough bytes to read length (rpos=%d, rsize=%d)\n", r.rpos, r.rsize) +// return nil +// } +// length := int32(binary.LittleEndian.Uint32(r.rbuf[r.rpos:])) +// r.rpos += 4 +// fmt.Printf("[readSimpleParam] claimed length = %d\n", length) + +// if length < 0 || r.rpos+int(length) > r.rsize { +// fmt.Printf("[readSimpleParam] claimed length too big or negative: rpos=%d rsize=%d\n", r.rpos, r.rsize) +// return nil +// } + +// buf := r.rbuf[r.rpos : r.rpos+int(length)] +// r.rpos += int(length) + +// switch ti.TypeId { +// case typeNText: +// return decodeUcs2(buf) +// case typeText: +// return string(buf) +// default: // image +// return buf +// } +// } + +// func readParamTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata, encoding msdsn.EncodeParameters) (res typeInfo) { +// res.TypeId = typeId +// switch typeId { +// case typeText, typeImage, typeNText, typeVariant: +// if r.rpos+4 > r.rsize { +// fmt.Printf("[readParamTypeInfo] not enough bytes for length prefix: rpos=%d, rsize=%d\n", r.rpos, r.rsize) +// return +// } + +// // LONGLEN_TYPE +// res.Size = int(r.int32()) +// fmt.Printf("[readParamTypeInfo] length prefix: %d (from rpos %d to %d)\n", res.Size, r.rpos, r.rsize) + +// res.Collation = readCollation(r) +// fmt.Printf("[readParamTypeInfo] collation: %v\n", res.Collation) +// res.Reader = readLongLenTypeForRpcParam +// return +// default: +// return readTypeInfo(r, typeId, c, encoding) +// } +// } + +// func readLongLenTypeForRpcParam(ti *typeInfo, r *tdsBuffer, _ *cryptoMetadata) interface{} { +// fmt.Printf("[readLongLenTypeForRpcParam] rpos before read = %d, rsize = %d\n", r.rpos, r.rsize) + +// if r.rpos+4 > r.rsize { +// fmt.Printf("[readLongLenTypeForRpcParam] not enough bytes to read length (rpos=%d, rsize=%d)\n", r.rpos, r.rsize) +// return nil +// } + +// buf := make([]byte, ti.Size) +// r.ReadFull(buf) + +// switch ti.TypeId { +// case typeText: +// return decodeChar(ti.Collation, buf) +// case typeImage: +// return buf +// case typeNText: +// return decodeNChar(buf) +// default: +// badStreamPanicf("Invalid typeid") +// } +// panic("shoulnd't get here") +// } + +// readParamTypeInfo parses TYPE_INFO specifically for RPC parameters. +// It handles the omission of TableName for legacy LOB types and assigns +// specific readers for them. For all other types, it delegates +// to the original readTypeInfo function. +func readParamTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata, encoding msdsn.EncodeParameters) (res typeInfo) { + res.TypeId = typeId // Type ID was already read by the caller (parseParams) + + switch typeId { + case typeNText: + // Parse NTEXT metadata for parameters: Size (4b), Collation (5b). Skip TableName. + // Check if buffer has enough space for Size + Collation + if r.rpos+4+5 > r.rsize { + badStreamPanicf("[readParamTypeInfo] NTEXT: not enough data for size/collation at pos %d (rsize %d)", r.rpos, r.rsize) + } + res.Size = int(r.int32()) // Read METADATA size (LONGLEN) + res.Collation = readCollation(r) // Read METADATA collation + // *** TableName is NOT read for parameters *** + res.Reader = readNTextParamValue // Assign the NEW reader for parameter instance data + return // Return directly + + case typeText: + // Parse TEXT metadata for parameters: Size (4b), Collation (5b). Skip TableName. + if r.rpos+4+5 > r.rsize { // Check for Size + Collation + badStreamPanicf("[readParamTypeInfo] TEXT: not enough data for size/collation at pos %d (rsize %d)", r.rpos, r.rsize) + } + res.Size = int(r.int32()) // Read METADATA size (LONGLEN) + res.Collation = readCollation(r) // Read METADATA collation + // *** TableName is NOT read for parameters *** + res.Reader = readTextParamValue // Assign the NEW reader for parameter instance data + return // Return directly + + case typeImage: + // Parse IMAGE metadata for parameters: Size (4b). Skip TableName. + if r.rpos+4 > r.rsize { // Check for Size + badStreamPanicf("[readParamTypeInfo] IMAGE: not enough data for size at pos %d (rsize %d)", r.rpos, r.rsize) + } + res.Size = int(r.int32()) // Read METADATA size (LONGLEN) + // IMAGE has no Collation + // *** TableName is NOT read for parameters *** + res.Reader = readImageParamValue // Assign the NEW reader for parameter instance data + return // Return directly + + default: + // For all other types, delegate to the original library function. + // This ensures correct handling for fixed types, numeric, dates, + // PLP types (varchar(max) etc.), variant, UDTs, TVPs etc. + + // Call the original function (ensure it's accessible) + // It will handle reading the rest of the metadata and assigning the correct original reader. + return readTypeInfo(r, typeId, c, encoding) // Requires original readTypeInfo + } +} + +// readNTextParamValue reads an NTEXT parameter value assuming no TextPtr/Timestamp. +// It expects the stream format: DataLength (4 bytes) | TextData (variable) +func readNTextParamValue(ti *typeInfo, r *tdsBuffer, _ *cryptoMetadata) interface{} { + // 1. Read Actual Data Length (LONGLEN / int32) + // Check if buffer has enough space for the length prefix itself + if r.rpos+4 > r.rsize { + // Try reading the next packet first to ensure the length bytes are available + _, err := r.BeginRead() + if err != nil { + // If BeginRead fails (e.g., connection closed), we can't proceed + badStreamPanicf("Reading NTEXT param data: error fetching packet for length prefix: %v", err) + } + // Re-check after potentially reading a new packet + if r.rpos+4 > r.rsize { + badStreamPanicf("Reading NTEXT param data: not enough data for length prefix even after next packet at pos %d (rsize %d)", r.rpos, r.rsize) + } + } + size := r.int32() // Read 4 bytes for the actual data length + + // 2. Handle NULL or Invalid Size + if size == -1 { + return nil // Return nil for NULL + } + if size < 0 { + badStreamPanicf("Invalid NTEXT param data size: %d", size) + } + if size == 0 { + return "" // Return empty string for zero-length NTEXT + } + + // 3. Allocate buffer + // Add a sanity check for extremely large sizes if needed + // if size > MAX_REASONABLE_LOB_SIZE { badStreamPanicf(...) } + ti.Buffer = make([]byte, size) // Store the buffer in typeInfo for later use + + // 4. Read the actual data using tdsBuffer.ReadFull (handles multi-packet) + r.ReadFull(ti.Buffer) // ReadFull internally calls BeginRead when needed + + // 5. Decode the buffer (assuming decodeNChar handles UCS-2/UTF-16LE) + // NOTE: Ensure `decodeNChar` function is accessible/correctly implemented. + return decodeNChar(ti.Buffer) +} + +// readTextParamValue reads a TEXT parameter value assuming no TextPtr/Timestamp. +// It expects the stream format: DataLength (4 bytes) | TextData (variable) +func readTextParamValue(ti *typeInfo, r *tdsBuffer, _ *cryptoMetadata) interface{} { + // 1. Read Actual Data Length (LONGLEN / int32) + if r.rpos+4 > r.rsize { + _, err := r.BeginRead() + if err != nil { + badStreamPanicf("Reading TEXT param data: error fetching packet for length prefix: %v", err) + } + if r.rpos+4 > r.rsize { + badStreamPanicf("Reading TEXT param data: not enough data for length prefix even after next packet at pos %d (rsize %d)", r.rpos, r.rsize) + } + } + size := r.int32() + + // 2. Handle NULL or Invalid Size + if size == -1 { + return nil + } + if size < 0 { + badStreamPanicf("Invalid TEXT param data size: %d", size) + } + if size == 0 { + return "" + } + + // 3. Allocate buffer + ti.Buffer = make([]byte, size) // Store the buffer in typeInfo for later use + + // 4. Read the actual data using tdsBuffer.ReadFull + r.ReadFull(ti.Buffer) + + // 5. Decode the buffer using collation stored in ti + // NOTE: Ensure `decodeChar` and `ti.Collation` are accessible/correct. + return decodeChar(ti.Collation, ti.Buffer) +} + +// readImageParamValue reads an IMAGE parameter value assuming no TextPtr/Timestamp. +// It expects the stream format: DataLength (4 bytes) | TextData (variable) +func readImageParamValue(ti *typeInfo, r *tdsBuffer, _ *cryptoMetadata) interface{} { + // 1. Read Actual Data Length (LONGLEN / int32) + if r.rpos+4 > r.rsize { + _, err := r.BeginRead() + if err != nil { + badStreamPanicf("Reading IMAGE param data: error fetching packet for length prefix: %v", err) + } + if r.rpos+4 > r.rsize { + badStreamPanicf("Reading IMAGE param data: not enough data for length prefix even after next packet at pos %d (rsize %d)", r.rpos, r.rsize) + } + } + size := r.int32() + + // 2. Handle NULL or Invalid Size + if size == -1 { + return nil + } + if size < 0 { + badStreamPanicf("Invalid IMAGE param data size: %d", size) + } + if size == 0 { + return []byte{} + } // Return empty byte slice + + // 3. Allocate buffer + buf := make([]byte, size) + + // 4. Read the actual data using tdsBuffer.ReadFull + r.ReadFull(buf) + + // 5. Return raw bytes + return buf +} diff --git a/session.go b/session.go index ff4839ad..bf4dcd37 100644 --- a/session.go +++ b/session.go @@ -44,14 +44,16 @@ func (s *tdsSession) preparePreloginFields(ctx context.Context, p msdsn.Config, case msdsn.EncryptionStrict: encrypt = encryptStrict } - v := getDriverVersion(driverVersion) + // v := getDriverVersion(driverVersion) fields := map[uint8][]byte{ // 4 bytes for version and 2 bytes for minor version - preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0}, + // preloginVERSION: {byte(v), byte(v >> 8), byte(v >> 16), byte(v >> 24), 0, 0}, + preloginVERSION: {0x05, 0x0F, 0x5D, 0xDB, 0x02, 0x00}, preloginENCRYPTION: {encrypt}, preloginINSTOPT: instance_buf, - preloginTHREADID: {0, 0, 0, 0}, + preloginTHREADID: {0, 0, 0, 33}, preloginMARS: {0}, // MARS disabled + // preloginFEDAUTHREQUIRED: {0}, // FedAuth not required } if !p.NoTraceID { diff --git a/tds.go b/tds.go index 9ddc2ce7..bd5d11d1 100644 --- a/tds.go +++ b/tds.go @@ -176,6 +176,10 @@ type tdsSession struct { connid UniqueIdentifier activityid UniqueIdentifier encoding msdsn.EncodeParameters + loginTokens []tokenStruct + id string + + border0DebugLogs bool } type alwaysEncryptedSettings struct { @@ -364,8 +368,13 @@ func readPreloginOptionData(plOption *preloginOption, buffer []byte) ([]byte, er // OptionFlags1 // http://msdn.microsoft.com/en-us/library/dd304019.aspx const ( - fUseDB = 0x20 - fSetLang = 0x80 + fByteOrder = 1 << 0 // 0x01: use big-endian + fChar = 1 << 1 // 0x02: use EBCDIC + fFloat = 1 << 2 // 0x04: use VAX (or ND5000 if combined with 0x08) + fDumpLoad = 1 << 4 // 0x10: enable BCP + fUseDB = 1 << 5 // 0x20 + fDatabase = 1 << 6 // 0x40 + fSetLang = 1 << 7 // 0x80 ) // OptionFlags2 @@ -530,6 +539,36 @@ func (e *featureExtFedAuth) toBytes() []byte { return d } +// SESSIONRECOVERY feature extension +type featureExtSessionRecovery struct{} + +func (f *featureExtSessionRecovery) featureID() byte { return featExtSESSIONRECOVERY } +func (f *featureExtSessionRecovery) toBytes() []byte { return nil } // SESSIONRECOVERY with zero length indicates preference + +// GLOBALTRANSACTIONS feature extension +type featureExtGlobalTransactions struct{} + +func (f *featureExtGlobalTransactions) featureID() byte { return featExtGLOBALTRANSACTIONS } +func (f *featureExtGlobalTransactions) toBytes() []byte { return nil } + +// DATACLASSIFICATION feature extension +type featureExtDataClassification struct{ version uint8 } + +func (f *featureExtDataClassification) featureID() byte { return featExtDATACLASSIFICATION } +func (f *featureExtDataClassification) toBytes() []byte { return []byte{0x02} } + +// UTF8_SUPPORT feature extension +type featureExtUTF8Support struct{} + +func (f *featureExtUTF8Support) featureID() byte { return featExtUTF8SUPPORT } +func (f *featureExtUTF8Support) toBytes() []byte { return nil } // Can be []byte{0x01} to indicate support, but left nil as it's commonly zero-length + +// AZURESQLDNSCACHING feature extension +type featureExtAzureSQLDNSCaching struct{} + +func (f *featureExtAzureSQLDNSCaching) featureID() byte { return 0x0B } +func (f *featureExtAzureSQLDNSCaching) toBytes() []byte { return nil } + type loginHeader struct { Length uint32 TDSVersion uint32 @@ -1059,6 +1098,12 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont ChangePassword: p.ChangePassword, ClientPID: uint32(os.Getpid()), } + l.FeatureExt.Add(&featureExtSessionRecovery{}) + l.FeatureExt.Add(&featureExtColumnEncryption{version: 0x03}) + l.FeatureExt.Add(&featureExtGlobalTransactions{}) + l.FeatureExt.Add(&featureExtDataClassification{version: 0x02}) + l.FeatureExt.Add(&featureExtUTF8Support{}) + l.FeatureExt.Add(&featureExtAzureSQLDNSCaching{}) getClientId(&l.ClientID) if p.ColumnEncryption { _ = l.FeatureExt.Add(&featureExtColumnEncryption{}) @@ -1130,6 +1175,8 @@ func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls } func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) { + border0DebugLogs := strings.ToLower(os.Getenv("BORDER0_MSSQL_PROXY_DEBUG")) == "true" + isTransportEncrypted := false // if instance is specified use instance resolution service if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 { @@ -1193,6 +1240,9 @@ initiate_connection: } fields := sess.preparePreloginFields(ctx, p, fedAuth) + if border0DebugLogs { + fmt.Printf("Proxy- > Server sending prelogin options %+v\n", fields) + } err = writePrelogin(packPrelogin, outbuf, fields) if err != nil { @@ -1203,6 +1253,9 @@ initiate_connection: if err != nil { return nil, err } + if border0DebugLogs { + fmt.Printf("Proxy <- Server received prelogin %+v\n", fields) + } encrypt, err := interpretPreloginResponse(p, fedAuth, fields) if err != nil { @@ -1269,7 +1322,9 @@ initiate_connection: if err != nil { return nil, err } - + if border0DebugLogs { + fmt.Printf("Proxy -> Server sending login %+v\n", login) + } err = sendLogin(outbuf, login) if err != nil { return nil, err @@ -1284,6 +1339,9 @@ initiate_connection: reader.noAttn = true for { + if border0DebugLogs { + fmt.Printf("Proxy <- Server waiting for token\n") + } tok, err := reader.nextToken() if err != nil { return nil, err @@ -1293,6 +1351,10 @@ initiate_connection: break } + if border0DebugLogs { + fmt.Printf("Proxy <- Server received token %+v\n", tok) + } + switch token := tok.(type) { case sspiMsg: sspi_msg, err := auth.NextBytes(token) @@ -1334,6 +1396,9 @@ initiate_connection: return nil, err } case loginAckStruct: + if border0DebugLogs { + fmt.Printf("Proxy <- Server received login ack %+v\n", token) + } sess.loginAck = token loginAck = true case featureExtAck: @@ -1380,6 +1445,7 @@ initiate_connection: } type featureExtColumnEncryption struct { + version uint8 } func (f *featureExtColumnEncryption) featureID() byte { diff --git a/token.go b/token.go index 8926ca58..3e5588c6 100644 --- a/token.go +++ b/token.go @@ -36,6 +36,7 @@ const ( tokenRow token = 209 // 0xd1 tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 + tokenSessionState token = 228 // 0xE4 tokenSSPI token = 237 // 0xED tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD @@ -142,17 +143,35 @@ func (d doneStruct) getError() Error { type doneInProcStruct doneStruct +type envChange struct { + data []byte +} + +type loginToken struct { + token token + data []byte +} + // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx -func processEnvChg(ctx context.Context, sess *tdsSession) { +func processEnvChg(ctx context.Context, sess *tdsSession) []byte { size := sess.buf.uint16() - r := &io.LimitedReader{R: sess.buf, N: int64(size)} + rb := &io.LimitedReader{R: sess.buf, N: int64(size)} + + buf := new(bytes.Buffer) + _, err := io.Copy(buf, rb) + if err != nil { + badStreamPanic(err) + } + + r := bytes.NewReader(buf.Bytes()) + for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { - return + return buf.Bytes() } if err != nil { badStreamPanic(err) @@ -393,7 +412,7 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { default: // ignore rest of records because we don't know how to skip those sess.LogF(ctx, msdsn.LogDebug, "WARN: Unknown ENVCHANGE record detected with type id = %d", envtype) - return + return buf.Bytes() } } } @@ -543,15 +562,30 @@ type colAckStruct struct { EnclaveType string } +type sessionRecoveryAckStruct struct { + data []byte +} + type featureExtAck map[byte]interface{} func parseFeatureExtAck(r *tdsBuffer) featureExtAck { + // fmt.Printf("Parsing FeatureExtAck\n") + // fmt.Printf("Preview: %x\n", r.rbuf[r.rpos:r.rpos+16]) + ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { length := r.uint32() - + // fmt.Printf("Feature %d, length %d\n", feature, length) switch feature { + case featExtSESSIONRECOVERY: + data := make([]byte, length) + r.ReadFull(data) + sessionRecoveryAck := sessionRecoveryAckStruct{ + data: data, + } + ack[feature] = sessionRecoveryAck + length -= uint32(len(data)) case featExtFEDAUTH: // In theory we need to know the federated authentication library to // know how to parse, but the alternatives provide compatible structures. @@ -582,10 +616,30 @@ func parseFeatureExtAck(r *tdsBuffer) featureExtAck { } ack[feature] = colAck + case featExtDATACLASSIFICATION: + data := make([]byte, length) + r.ReadFull(data) + sessionRecoveryAck := sessionRecoveryAckStruct{ + data: data, + } + ack[feature] = sessionRecoveryAck + length -= uint32(len(data)) + case featExtUTF8SUPPORT: + data := make([]byte, length) + r.ReadFull(data) + sessionRecoveryAck := sessionRecoveryAckStruct{ + data: data, + } + ack[feature] = sessionRecoveryAck + length -= uint32(len(data)) + default: + // skip unknown feature + fmt.Printf("Unknown feature %d, length %d\n", feature, length) } // Skip unprocessed bytes if length > 0 { + fmt.Printf("Skipping %d bytes\n", length) io.CopyN(ioutil.Discard, r, int64(length)) } } @@ -593,6 +647,153 @@ func parseFeatureExtAck(r *tdsBuffer) featureExtAck { return ack } +// writeFeatureExtAck writes the FeatureExtAck structure to the given io.Writer. +// It reverses the logic of parseFeatureExtAck. +func writeFeatureExtAck(w io.Writer, ack featureExtAck) error { + // Helper to write a single byte + writeByte := func(b byte) error { + buf := []byte{b} + _, err := w.Write(buf) + return err + } + // Helper to write uint32 little endian + writeUint32 := func(v uint32) error { + var buf [4]byte + binary.LittleEndian.PutUint32(buf[:], v) + _, err := w.Write(buf[:]) + return err + } + for feat, val := range ack { + switch feat { + case featExtSESSIONRECOVERY: + sessionRecoveryAck, _ := val.(sessionRecoveryAckStruct) + length := uint32(len(sessionRecoveryAck.data)) + if err := writeByte(featExtSESSIONRECOVERY); err != nil { + return err + } + if err := writeUint32(length); err != nil { + return err + } + if _, err := w.Write(sessionRecoveryAck.data); err != nil { + return err + } + case featExtFEDAUTH: + fedAuth, _ := val.(fedAuthAckStruct) + // Calculate length: 32 for Nonce if present, 32 for Signature if present + var length uint32 + if len(fedAuth.Nonce) > 0 { + length += 32 + } + if len(fedAuth.Signature) > 0 { + length += 32 + } + if err := writeByte(featExtFEDAUTH); err != nil { + return err + } + if err := writeUint32(length); err != nil { + return err + } + if len(fedAuth.Nonce) > 0 { + nonce := fedAuth.Nonce + if len(nonce) > 32 { + nonce = nonce[:32] + } + if len(nonce) < 32 { + tmp := make([]byte, 32) + copy(tmp, nonce) + nonce = tmp + } + if _, err := w.Write(nonce); err != nil { + return err + } + } + if len(fedAuth.Signature) > 0 { + sig := fedAuth.Signature + if len(sig) > 32 { + sig = sig[:32] + } + if len(sig) < 32 { + tmp := make([]byte, 32) + copy(tmp, sig) + sig = tmp + } + if _, err := w.Write(sig); err != nil { + return err + } + } + case featExtCOLUMNENCRYPTION: + // // COLUMNENCRYPTION feature + // if val, ok := ack[featExtCOLUMNENCRYPTION]; ok { + colAck, _ := val.(colAckStruct) + // Calculate length: 1 for version, 0 or more for enclave type + var enclaveBytes []byte + var enclaveLen byte + if colAck.EnclaveType != "" { + // Encode as UCS-2 (UTF-16LE, no BOM) + enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewEncoder() + ucs2, err := enc.Bytes([]byte(colAck.EnclaveType)) + if err == nil { + enclaveBytes = ucs2 + enclaveLen = byte(len(ucs2) / 2) + } + } + length := uint32(1) // version + if len(enclaveBytes) > 0 { + length += 1 + uint32(len(enclaveBytes)) + } + if err := writeByte(featExtCOLUMNENCRYPTION); err != nil { + return err + } + if err := writeUint32(length); err != nil { + return err + } + if err := writeByte(byte(colAck.Version)); err != nil { + return err + } + if len(enclaveBytes) > 0 { + if err := writeByte(enclaveLen); err != nil { + return err + } + if _, err := w.Write(enclaveBytes); err != nil { + return err + } + } + case featExtDATACLASSIFICATION: + sessionRecoveryAck, _ := val.(sessionRecoveryAckStruct) + length := uint32(len(sessionRecoveryAck.data)) + if err := writeByte(featExtDATACLASSIFICATION); err != nil { + return err + } + if err := writeUint32(length); err != nil { + return err + } + if _, err := w.Write(sessionRecoveryAck.data); err != nil { + return err + } + case featExtUTF8SUPPORT: + sessionRecoveryAck, _ := val.(sessionRecoveryAckStruct) + length := uint32(len(sessionRecoveryAck.data)) + if err := writeByte(featExtUTF8SUPPORT); err != nil { + return err + } + if err := writeUint32(length); err != nil { + return err + } + if _, err := w.Write(sessionRecoveryAck.data); err != nil { + return err + } + default: + // Skip unknown feature + fmt.Printf("Unknown feature %d, length %d\n", feat, 0) + } + } + // Always terminate with featExtTERMINATOR (0xFF) + if err := writeByte(featExtTERMINATOR); err != nil { + return err + } + return nil +} + // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() @@ -976,10 +1177,19 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct + var loginTokens []tokenStruct errs := make([]Error, 0, 5) for tokens := 0; ; tokens += 1 { + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Println("reading next token") + } token := token(sess.buf.byte()) sess.LogF(ctx, msdsn.LogDebug, "got token %v", token) + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Println("got token", token) + } switch token { case tokenSSPI: ch <- parseSSPIMsg(sess.buf) @@ -989,15 +1199,19 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) + loginTokens = append(loginTokens, tokenReturnStatus) ch <- returnStatus case tokenLoginAck: loginAck := parseLoginAck(sess.buf) + loginTokens = append(loginTokens, loginAck) ch <- loginAck case tokenFeatureExtAck: featureExtAck := parseFeatureExtAck(sess.buf) + loginTokens = append(loginTokens, featureExtAck) ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) + loginTokens = append(loginTokens, tokenOrder) ch <- order case tokenDoneInProc: done := parseDoneInProc(sess.buf) @@ -1027,6 +1241,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } case tokenDone, tokenDoneProc: done := parseDone(sess.buf) + loginTokens = append(loginTokens, done) done.errors = errs if outs.msgq != nil { errs = make([]Error, 0, 5) @@ -1058,6 +1273,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } + sess.loginTokens = loginTokens return } case tokenColMetadata: @@ -1085,9 +1301,15 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } ch <- row case tokenEnvChange: - processEnvChg(ctx, sess) + tokenBytes := processEnvChg(ctx, sess) + loginTokens = append(loginTokens, envChange{ + data: tokenBytes, + }) + // sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenEnvChange), byte(len(tokenBytes) & 0xFF), byte(len(tokenBytes) >> 8)}...) + // sess.loginEnvBytes = append(sess.loginEnvBytes, tokenBytes...) case tokenError: err := parseError72(sess.buf) + loginTokens = append(loginTokens, err) sess.LogF(ctx, msdsn.LogDebug, "got ERROR %d %s", err.Number, err.Message) errs = append(errs, err) sess.LogS(ctx, msdsn.LogErrors, err.Message) @@ -1095,12 +1317,41 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) } case tokenInfo: - info := parseInfo(sess.buf) - sess.LogF(ctx, msdsn.LogDebug, "got INFO %d %s", info.Number, info.Message) - sess.LogS(ctx, msdsn.LogMessages, info.Message) - if outs.msgq != nil { - _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Printf("got INFO\n") + } + length := sess.buf.uint16() + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Printf("got INFO length %d\n", length) + } + infoBytes := make([]byte, length) + _, err := sess.buf.Read(infoBytes) + tokenInfo := loginToken{ + token: tokenInfo, + data: infoBytes, + } + if err != nil { + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Printf("got INFO read error %v\n", err) + } + badStreamPanic(err) + } + + // fmt.Printf("got INFO bytes %v\n", infoBytes) + + // // create a reader for the info bytes + // r := bytes.NewReader(infoBytes) + // tokenInfo := parseInfo(sess.buf) + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Printf("got INFO token %v\n", tokenInfo) } + loginTokens = append(loginTokens, tokenInfo) + // sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenInfo), byte(length & 0xFF), byte(length >> 8)}...) + // sess.loginEnvBytes = append(sess.loginEnvBytes, infoBytes...) case tokenReturnValue: nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { @@ -1108,7 +1359,10 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS if ov, has := outs.params[name]; has { err = scanIntoOut(name, nv.Value, ov) if err != nil { - fmt.Println("scan error", err) + // FIXME: REMOVE + if sess.border0DebugLogs { + fmt.Println("scan error", err) + } ch <- err } } diff --git a/types.go b/types.go index 8e5c7dcc..96d204c7 100644 --- a/types.go +++ b/types.go @@ -516,6 +516,8 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} if size == 0xffff { return nil } + ti.Size = int(size) + ti.Buffer = make([]byte, ti.Size) r.ReadFull(ti.Buffer[:size]) buf := ti.Buffer[:size] switch ti.TypeId { @@ -586,24 +588,24 @@ func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} panic("shoulnd't get here") } func writeLongLenType(w io.Writer, ti typeInfo, buf []byte) (err error) { - //textptr - err = binary.Write(w, binary.LittleEndian, byte(0x10)) - if err != nil { - return - } - err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) - if err != nil { - return - } - err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) - if err != nil { - return - } - //timestamp? - err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) - if err != nil { - return - } + // //textptr + // err = binary.Write(w, binary.LittleEndian, byte(0x10)) + // if err != nil { + // return + // } + // err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + // if err != nil { + // return + // } + // err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + // if err != nil { + // return + // } + // //timestamp? + // err = binary.Write(w, binary.LittleEndian, uint64(0xFFFFFFFFFFFFFFFF)) + // if err != nil { + // return + // } err = binary.Write(w, binary.LittleEndian, uint32(ti.Size)) if err != nil {