diff --git a/buf.go b/buf.go index 68663421..ceecd2a7 100644 --- a/buf.go +++ b/buf.go @@ -11,7 +11,7 @@ type packetType uint8 type header struct { PacketType packetType - Status uint8 + Status byte Size uint16 Spid uint16 PacketNo uint8 @@ -45,6 +45,7 @@ type tdsBuffer struct { wpos int wPacketSeq byte wPacketType packetType + wSpid uint16 // Read fields. rbuf []byte @@ -57,6 +58,8 @@ type tdsBuffer struct { // before the first use. It is executed after the first packet is // written and then removed. afterFirst func() + + serverConn *tdsSession } func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { @@ -86,6 +89,7 @@ func (w *tdsBuffer) flush() (err error) { // Write packet size. w.wbuf[0] = byte(w.wPacketType) binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) + binary.BigEndian.PutUint16(w.wbuf[4:], w.wSpid) w.wbuf[6] = w.wPacketSeq // Write packet into underlying transport. @@ -169,12 +173,14 @@ func (r *tdsBuffer) readNextPacket() error { PacketNo: buf[6], Pad: buf[7], } + if int(h.Size) > r.packetSize { return errors.New("invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { return errors.New("invalid packet size, it is shorter than header size") } + _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) //s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size]) //fmt.Print(s) @@ -183,12 +189,34 @@ func (r *tdsBuffer) readNextPacket() error { } r.rpos = headerSize r.rsize = int(h.Size) - r.final = h.Status != 0 + r.final = h.Status&0x1 != 0 r.rPacketType = h.PacketType + + if r.serverConn != nil { + _, err := r.serverConn.buf.Write(r.rbuf[r.rpos:r.rsize]) + if err != nil { + return err + } + + if r.final { + r.serverConn.buf.wSpid = h.Spid + if err := r.serverConn.buf.FinishPacket(); err != nil { + return err + } + } else { + if err := r.serverConn.buf.flush(); err != nil { + return err + } + } + } return nil } func (r *tdsBuffer) BeginRead() (packetType, error) { + if r.serverConn != nil { + r.serverConn.buf.BeginPacket(r.rPacketType, false) + } + err := r.readNextPacket() if err != nil { return 0, err diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 7cb5cfcc..faccef2d 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "strings" _ "github.com/microsoft/go-mssqldb" ) @@ -37,21 +38,24 @@ func main() { } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("EXEC UpdateUserComments @UserID=?, @Comment=?") + if err != nil { + log.Fatal("Error preparing statement: ", err.Error()) + } + defer stmt.Close() + + //comment := "Your long comment text here..." + comment := strings.Repeat("A long comment text ", 500) // Simulate a long text + + _, err = stmt.Exec(10, comment) if err != nil { log.Fatal("Prepare failed:", err.Error()) } defer stmt.Close() - row := stmt.QueryRow() - var somenumber int64 - var somechars string - err = row.Scan(&somenumber, &somechars) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) fmt.Printf("bye\n") } diff --git a/proxy.go b/proxy.go new file mode 100644 index 00000000..5c8d3cfa --- /dev/null +++ b/proxy.go @@ -0,0 +1,808 @@ +package mssql + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +const ( + defaultServerProgName = "GO MSSQL Server" + defaultServerVerion = "v16.0.0" +) + +type Client struct { + Conn *Conn +} + +type Server struct { + ConnTimeout time.Duration + PacketSize uint16 + Logger ContextLogger + Version uint32 + ProgName string + Encryption byte +} + +type ServerConfig struct { + ConnTimeout *time.Duration + PacketSize *uint16 + Logger ContextLogger + Version *string + Encryption *string + ProgName *string +} + +type ServerSession struct { + *tdsSession +} + +func NewServer(config ServerConfig) (*Server, error) { + server := &Server{} + + if config.PacketSize == nil { + server.PacketSize = defaultPacketSize + } else { + server.PacketSize = *config.PacketSize + } + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + if server.PacketSize < 512 { + server.PacketSize = 512 + } else if server.PacketSize > 32767 { + server.PacketSize = 32767 + } + + if config.ConnTimeout != nil { + server.ConnTimeout = *config.ConnTimeout + } + + if config.Logger != nil { + server.Logger = config.Logger + } + + if config.Version != nil { + server.Version = getDriverVersion(*config.Version) + } else { + server.Version = getDriverVersion(defaultServerVerion) + } + + if config.ProgName != nil { + server.ProgName = *config.ProgName + } else { + server.ProgName = defaultServerProgName + } + + if config.Encryption != nil { + switch *config.Encryption { + case "strict": + server.Encryption = encryptStrict + case "required": + server.Encryption = encryptReq + case "on": + server.Encryption = encryptOn + case "off": + server.Encryption = encryptOff + default: + return nil, errors.New("invalid encryption option") + } + } else { + server.Encryption = encryptNotSup + } + + return server, nil +} + +func (s *Server) ReadLogin(conn net.Conn) (*ServerSession, *login, error) { + toconn := newTimeoutConn(conn, s.ConnTimeout) + inbuf := newTdsBuffer(s.PacketSize, toconn) + + login, err := s.handshake(inbuf) + if err != nil { + return nil, nil, err + } + + sess := ServerSession{&tdsSession{ + buf: inbuf, + logger: s.Logger, + }} + + return &sess, &login, nil +} + +func (s *tdsSession) ReadCommand() (packetType, error) { + var buf []byte + for { + _, err := s.buf.BeginRead() + if err != nil { + return 0, err + } + + bytes := make([]byte, s.buf.rsize-s.buf.rpos) + s.buf.ReadFull(bytes) + buf = append(buf, bytes...) + + if s.buf.final { + copy(s.buf.rbuf, buf) + s.buf.rsize = len(buf) + s.buf.rpos = 0 + return s.buf.rPacketType, nil + } + } +} + +func (s *Server) handshake(r *tdsBuffer) (login, error) { + var login login + + err := s.readPrelogin(r) + if err != nil { + return login, err + } + + err = s.writePrelogin(r) + if err != nil { + return login, err + } + + login, err = s.readLogin(r) + if err != nil { + return login, err + } + + return login, nil +} + +func (s *Server) readPrelogin(r *tdsBuffer) error { + packet_type, err := r.BeginRead() + if err != nil { + return err + } + struct_buf, err := io.ReadAll(r) + if err != nil { + return err + } + if packet_type != packPrelogin { + return errors.New("invalid request, expected pre-login packet") + } + if len(struct_buf) == 0 { + return errors.New("invalid empty PRELOGIN request, it must contain at least one byte") + } + + offset := 0 + results := map[uint8][]byte{} + for { + // read prelogin option + plOption, err := readPreloginOption(struct_buf, offset) + if err != nil { + return err + } + + if plOption.token == preloginTERMINATOR { + break + } + + // read prelogin option data + value, err := readPreloginOptionData(plOption, struct_buf) + if err != nil { + return err + } + results[plOption.token] = value + + offset += preloginOptionSize + } + + return nil +} + +func (s *Server) writePrelogin(r *tdsBuffer) error { + if err := writePrelogin(packReply, r, s.preparePreloginResponseFields()); err != nil { + return err + } + + return nil +} + +func (s *Server) preparePreloginResponseFields() map[uint8][]byte { + fields := map[uint8][]byte{ + // 4 bytes for version and 2 bytes for minor version + preloginVERSION: {byte(s.Version >> 24), byte(s.Version >> 16), byte(s.Version >> 8), byte(s.Version), 0, 0}, + preloginENCRYPTION: {s.Encryption}, + preloginINSTOPT: {0}, + preloginTHREADID: {0, 0, 0, 0}, + preloginMARS: {0}, // MARS disabled + } + + return fields +} + +func (s *Server) readLogin(r *tdsBuffer) (login, error) { + var login login + packet_type, err := r.BeginRead() + if err != nil { + return login, err + } + + if packet_type != packLogin7 { + return login, errors.New("invalid request, expected login packet") + } + + struct_buf, err := io.ReadAll(r) + if err != nil { + return login, err + } + + if len(struct_buf) == 0 { + return login, errors.New("invalid empty login request, it must contain at least one byte") + } + + var loginHeader loginHeader + if err := binary.Read(bytes.NewReader(struct_buf), binary.LittleEndian, &loginHeader); err != nil { + return login, fmt.Errorf("failed to read login packet: %w", err) + } + + login.TDSVersion = loginHeader.TDSVersion + login.ClientProgVer = loginHeader.ClientProgVer + login.ClientPID = loginHeader.ClientPID + login.ConnectionID = loginHeader.ConnectionID + login.OptionFlags1 = loginHeader.OptionFlags1 + login.OptionFlags2 = loginHeader.OptionFlags2 + login.TypeFlags = loginHeader.TypeFlags + login.OptionFlags3 = loginHeader.OptionFlags3 + login.ClientTimeZone = loginHeader.ClientTimeZone + login.ClientLCID = loginHeader.ClientLCID + login.ClientID = loginHeader.ClientID + + login.HostName, err = readLoginFieldString(struct_buf, loginHeader.HostNameOffset, loginHeader.HostNameLength) + if err != nil { + return login, fmt.Errorf("failed to read hostname: %w", err) + } + login.UserName, err = readLoginFieldString(struct_buf, loginHeader.UserNameOffset, loginHeader.UserNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.AppName, err = readLoginFieldString(struct_buf, loginHeader.AppNameOffset, loginHeader.AppNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.ServerName, err = readLoginFieldString(struct_buf, loginHeader.ServerNameOffset, loginHeader.ServerNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.CtlIntName, err = readLoginFieldString(struct_buf, loginHeader.CtlIntNameOffset, loginHeader.CtlIntNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Language, err = readLoginFieldString(struct_buf, loginHeader.LanguageOffset, loginHeader.LanguageLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Database, err = readLoginFieldString(struct_buf, loginHeader.DatabaseOffset, loginHeader.DatabaseLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.SSPI, err = readLoginFieldBytes(struct_buf, loginHeader.SSPIOffset, loginHeader.SSPILength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.AtchDBFile, err = readLoginFieldString(struct_buf, loginHeader.AtchDBFileOffset, loginHeader.AtchDBFileLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.ChangePassword, err = readLoginFieldString(struct_buf, loginHeader.ChangePasswordOffset, loginHeader.ChangePasswordLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + + return login, nil +} + +func readLoginFieldString(b []byte, offset uint16, length uint16) (string, error) { + if len(b) < int(offset)+int(length)*2 { + return "", fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length*2, len(b)) + } + + return ucs22str(b[offset : offset+length*2]) +} + +func readLoginFieldBytes(b []byte, offset uint16, length uint16) ([]byte, error) { + if len(b) < int(offset)+int(length) { + return nil, fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length, len(b)) + } + + return b[offset : offset+length], nil +} + +func (s *Server) WriteLogin(session *ServerSession, loginEnvBytes []byte) error { + loginAckStruct := loginAckStruct{ + Interface: 1, + TDSVersion: verTDS74, + ProgName: s.ProgName, + ProgVer: s.Version, + } + + doneStruct := doneStruct{ + Status: 0, + CurCmd: 0, + RowCount: 0, + errors: []Error{}, + } + + session.buf.BeginPacket(packReply, false) + session.buf.Write(loginEnvBytes) + session.buf.Write(writeLoginAck(loginAckStruct)) + session.buf.Write(writeDone(doneStruct)) + + return session.buf.FinishPacket() +} + +func UCS2String(s []byte) (string, error) { + return ucs22str(s) +} + +func (c *Conn) Transport() io.ReadWriteCloser { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf.transport +} + +func (c *Conn) Buffer() *tdsBuffer { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf +} + +func (c *Conn) Session() *tdsSession { + return c.sess +} + +func (s *tdsSession) ParseHeader() (header, error) { + var h header + err := binary.Read(s.buf, binary.LittleEndian, &h) + if err != nil { + return header{}, err + } + + return h, nil +} + +func (s *tdsSession) ParseSQLBatch() ([]headerStruct, string, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, "", err + } + + query, err := readUcs2(s.buf, (s.buf.rsize-s.buf.rpos)/2) + if err != nil { + return nil, "", err + } + + return headers, query, nil +} + +func (s *tdsSession) ParseTransMgrReq() ([]headerStruct, uint16, isoLevel, string, string, uint8, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var rqtype uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &rqtype); err != nil { + return nil, 0, 0, "", "", 0, err + } + + switch rqtype { + case tmBeginXact: + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + return headers, rqtype, isolationLevel, name, "", 0, nil + case tmCommitXact, tmRollbackXact: + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var flags uint8 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, 0, 0, "", "", 0, err + } + + var newname string + if flags&fBeginXact != 0 { + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + newname, err = readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + } + + return headers, rqtype, 0, name, newname, flags, nil + default: + return nil, 0, 0, "", "", 0, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } +} + +func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []interface{}, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + + var nameLength uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &nameLength); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + var proc procId + var idswitch uint16 = 0xffff + if nameLength == idswitch { + if err := binary.Read(s.buf, binary.LittleEndian, &proc.id); err != nil { + return nil, procId{}, 0, nil, nil, err + } + } else { + proc.name, err = readUcs2(s.buf, int(nameLength)) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + } + + var flags uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + params, values, err := parseParams(s.buf, s.encoding) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + + return headers, proc, flags, params, values, nil +} + +func parseParams(b *tdsBuffer, encoding msdsn.EncodeParameters) ([]param, []interface{}, error) { + var params []param + var values []interface{} + for { + if b.rpos >= b.rsize { + break + } + + var p param + + name, err := readBVarChar(b) + if err != nil { + return nil, nil, err + } + p.Name = name + + var flags uint8 + if err := binary.Read(b, binary.LittleEndian, &flags); err != nil { + return nil, nil, err + } + + p.Flags = flags + p.ti = readTypeInfo(b, b.byte(), nil, encoding) + val := p.ti.Reader(&p.ti, b, nil) + p.buffer = p.ti.Buffer + params = append(params, p) + values = append(values, val) + } + return params, values, nil +} + +func readAllHeaders(r io.Reader) ([]headerStruct, error) { + var totalLength uint32 + err := binary.Read(r, binary.LittleEndian, &totalLength) + if err != nil { + return nil, err + } + + if totalLength < 4 { + return nil, errors.New("invalid total length") + } + + var headers []headerStruct + remainingLength := totalLength - 4 // Subtracting the length of the totalLength field + + for remainingLength > 0 { + var headerLength uint32 + err = binary.Read(r, binary.LittleEndian, &headerLength) + if err != nil { + return nil, err + } + + if headerLength < 6 || headerLength-6 > remainingLength { + return nil, errors.New("invalid header length") + } + + var hdrtype uint16 + err = binary.Read(r, binary.LittleEndian, &hdrtype) + if err != nil { + return nil, err + } + + dataLength := headerLength - 6 // Subtracting the length of the headerLength and hdrtype fields + data := make([]byte, dataLength) + _, err = io.ReadFull(r, data) + if err != nil { + return nil, err + } + + headers = append(headers, headerStruct{ + hdrtype: hdrtype, + data: data, + }) + + remainingLength -= headerLength + } + + if remainingLength != 0 { + return nil, errors.New("inconsistent header length") + } + + return headers, nil +} + +func (p *procId) Id() uint16 { + return p.id +} + +func (p *procId) Name() string { + return p.name +} + +func writeDone(d doneStruct) []byte { + data := make([]byte, 0, 12) + + // Append tokenDone and the calculated size + data = append(data, byte(tokenDone)) + + // Append Status + statusBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(statusBytes, d.Status) + data = append(data, statusBytes...) + + // Append CurCmd + curCmdBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(curCmdBytes, d.CurCmd) + data = append(data, curCmdBytes...) + + // Append RowCount + rowCountBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(rowCountBytes, d.RowCount) + data = append(data, rowCountBytes...) + + return data +} + +func writeLoginAck(l loginAckStruct) []byte { + progNameUCS2 := str2ucs2(l.ProgName) + + // Prepare the slice with preallocated size for efficiency + data := make([]byte, 0, 10+len(progNameUCS2)) + + // Append tokenLoginAck + data = append(data, byte(tokenLoginAck)) + + // Append calculated size + size := uint16(10 + len(progNameUCS2)) + sizeBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(sizeBytes, size) + data = append(data, sizeBytes...) + + // Append Interface + data = append(data, l.Interface) + + // Append TDSVersion + tdsVersionBytes := make([]byte, 4) + binary.BigEndian.PutUint32(tdsVersionBytes, l.TDSVersion) + data = append(data, tdsVersionBytes...) + + // Append ProgName Length and ProgName + data = append(data, byte(len(progNameUCS2)/2)) + data = append(data, progNameUCS2...) + + // Append ProgVer + progVerBytes := make([]byte, 4) + binary.BigEndian.PutUint32(progVerBytes, l.ProgVer) + data = append(data, progVerBytes...) + + return data +} + +func NewConnectorFromConfig(config msdsn.Config) *Connector { + return newConnector(config, driverInstanceNoProcess) +} + +func NewClient(ctx context.Context, c *Connector, dialer Dialer, database string) (*Client, error) { + if dialer != nil { + c.Dialer = dialer + } + + params := c.params + params.Database = database + + conn, err := c.driver.connect(ctx, c, c.params) + if err != nil { + return nil, err + } + + if err := conn.ResetSession(ctx); err != nil { + return nil, err + } + + return &Client{ + Conn: conn, + }, nil +} + +func (c *Client) Close() error { + return c.Conn.Close() +} + +func (c *Client) SendSqlBatch(ctx context.Context, serverConn *ServerSession, query string, headers []headerStruct, resetSession bool) ([]doneStruct, error) { + if err := sendSqlBatch72(c.Conn.sess.buf, query, headers, resetSession); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) SendRpc(ctx context.Context, serverConn *ServerSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { + if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession, c.Conn.sess.encoding); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) TransMgrReq(ctx context.Context, serverConn *ServerSession, headers []headerStruct, rqtype uint16, isolationLevel isoLevel, name, newname string, flags uint8, resetSession bool) ([]doneStruct, error) { + switch rqtype { + case tmBeginXact: + if err := sendBeginXact(c.Conn.sess.buf, headers, isolationLevel, name, resetSession); err != nil { + return nil, err + } + case tmCommitXact: + if err := sendCommitXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + case tmRollbackXact: + if err := sendRollbackXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]doneStruct, error) { + c.Conn.sess.buf.serverConn = sess.tdsSession + + packet_type, err := c.Conn.sess.buf.BeginRead() + if err != nil { + switch e := err.(type) { + case *net.OpError: + return nil, e + default: + return nil, &net.OpError{Op: "Read", Err: err} + } + } + + if packet_type != packReply { + return nil, StreamError{ + InnerError: fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply), + } + } + + var dones []doneStruct + var columns []columnStruct + var errs []Error + for { + token := token(c.Conn.sess.buf.byte()) + switch token { + case tokenReturnStatus: + parseReturnStatus(c.Conn.sess.buf) + case tokenOrder: + parseOrder(c.Conn.sess.buf) + case tokenDone, tokenDoneProc, tokenDoneInProc: + res := parseDone(c.Conn.sess.buf) + res.errors = errs + dones = append(dones, res) + if res.Status&doneSrvError != 0 { + return dones, ServerError{res.getError()} + } + + if res.Status&doneMore == 0 { + return dones, nil + } + case tokenColMetadata: + columns = parseColMetadata72(c.Conn.sess.buf, c.Conn.sess) + case tokenRow: + row := make([]interface{}, len(columns)) + err = parseRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenNbcRow: + row := make([]interface{}, len(columns)) + err = parseNbcRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenEnvChange: + processEnvChg(ctx, c.Conn.sess) + case tokenError: + err := parseError72(c.Conn.sess.buf) + errs = append(errs, err) + case tokenInfo: + parseInfo(c.Conn.sess.buf) + case tokenReturnValue: + parseReturnValue(c.Conn.sess.buf, c.Conn.sess) + default: + return nil, StreamError{ + InnerError: fmt.Errorf("unknown token type returned: %v", token), + } + } + } +} + +func (d doneStruct) GetError() error { + n := len(d.errors) + if n == 0 { + return nil + } + + var err error + + for _, e := range d.errors { + err = errors.Join(err, e) + } + + return err +} + +func (c *Client) LoginEnvBytes() []byte { + return c.Conn.sess.loginEnvBytes +} + +func (c *Client) Database() string { + return c.Conn.sess.database +} + +func (c *Client) SendAttention(ctx context.Context, serverConn *ServerSession) ([]doneStruct, error) { + if err := sendAttention(c.Conn.sess.buf); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} diff --git a/tds.go b/tds.go index 9ddc2ce7..c7da8647 100644 --- a/tds.go +++ b/tds.go @@ -176,6 +176,7 @@ type tdsSession struct { connid UniqueIdentifier activityid UniqueIdentifier encoding msdsn.EncodeParameters + loginEnvBytes []byte } type alwaysEncryptedSettings struct { @@ -1044,6 +1045,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont } else { serverName = p.Host } + l = &login{ TDSVersion: TDSVersion, PacketSize: packetSize, @@ -1269,7 +1271,6 @@ initiate_connection: if err != nil { return nil, err } - err = sendLogin(outbuf, login) if err != nil { return nil, err diff --git a/token.go b/token.go index 8926ca58..5feb2f7f 100644 --- a/token.go +++ b/token.go @@ -144,15 +144,24 @@ type doneInProcStruct doneStruct // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx -func processEnvChg(ctx context.Context, sess *tdsSession) { +func processEnvChg(ctx context.Context, sess *tdsSession) []byte { size := sess.buf.uint16() - r := &io.LimitedReader{R: sess.buf, N: int64(size)} + rb := &io.LimitedReader{R: sess.buf, N: int64(size)} + + buf := new(bytes.Buffer) + _, err := io.Copy(buf, rb) + if err != nil { + badStreamPanic(err) + } + + r := bytes.NewReader(buf.Bytes()) + for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { - return + return buf.Bytes() } if err != nil { badStreamPanic(err) @@ -393,7 +402,7 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { default: // ignore rest of records because we don't know how to skip those sess.LogF(ctx, msdsn.LogDebug, "WARN: Unknown ENVCHANGE record detected with type id = %d", envtype) - return + return buf.Bytes() } } } @@ -1085,7 +1094,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } ch <- row case tokenEnvChange: - processEnvChg(ctx, sess) + tokenBytes := processEnvChg(ctx, sess) + sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenEnvChange), byte(len(tokenBytes) & 0xFF), byte(len(tokenBytes) >> 8)}...) + sess.loginEnvBytes = append(sess.loginEnvBytes, tokenBytes...) case tokenError: err := parseError72(sess.buf) sess.LogF(ctx, msdsn.LogDebug, "got ERROR %d %s", err.Number, err.Message) @@ -1095,12 +1106,15 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) } case tokenInfo: - info := parseInfo(sess.buf) - sess.LogF(ctx, msdsn.LogDebug, "got INFO %d %s", info.Number, info.Message) - sess.LogS(ctx, msdsn.LogMessages, info.Message) - if outs.msgq != nil { - _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) + length := sess.buf.uint16() + infoBytes := make([]byte, length) + _, err := sess.buf.Read(infoBytes) + if err != nil { + badStreamPanic(err) } + + sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenInfo), byte(length & 0xFF), byte(length >> 8)}...) + sess.loginEnvBytes = append(sess.loginEnvBytes, infoBytes...) case tokenReturnValue: nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { diff --git a/types.go b/types.go index 8e5c7dcc..30637167 100644 --- a/types.go +++ b/types.go @@ -516,6 +516,8 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} if size == 0xffff { return nil } + ti.Size = int(size) + ti.Buffer = make([]byte, ti.Size) r.ReadFull(ti.Buffer[:size]) buf := ti.Buffer[:size] switch ti.TypeId {