From 49df8424acb4ea3382cf3bac095d6eeb42d42984 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Fri, 15 Dec 2023 11:45:48 +0100 Subject: [PATCH 01/11] add proxy --- buf.go | 23 ++ proxy.go | 690 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 713 insertions(+) create mode 100644 proxy.go diff --git a/buf.go b/buf.go index 68663421..47cd625c 100644 --- a/buf.go +++ b/buf.go @@ -57,6 +57,8 @@ type tdsBuffer struct { // before the first use. It is executed after the first packet is // written and then removed. afterFirst func() + + serverConn *tdsSession } func newTdsBuffer(bufsize uint16, transport io.ReadWriteCloser) *tdsBuffer { @@ -185,10 +187,31 @@ func (r *tdsBuffer) readNextPacket() error { r.rsize = int(h.Size) r.final = h.Status != 0 r.rPacketType = h.PacketType + + if r.serverConn != nil { + _, err := r.serverConn.buf.Write(r.rbuf[r.rpos:r.rsize]) + if err != nil { + return err + } + + if r.final { + if err := r.serverConn.buf.FinishPacket(); err != nil { + return err + } + } else { + if err := r.serverConn.buf.flush(); err != nil { + return err + } + } + } return nil } func (r *tdsBuffer) BeginRead() (packetType, error) { + if r.serverConn != nil { + r.serverConn.buf.BeginPacket(r.rPacketType, false) + } + err := r.readNextPacket() if err != nil { return 0, err diff --git a/proxy.go b/proxy.go new file mode 100644 index 00000000..09ef021f --- /dev/null +++ b/proxy.go @@ -0,0 +1,690 @@ +package mssql + +import ( + "bytes" + "context" + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" +) + +const ( + defaultServerProgName = "GO MSSQL Server" + defaultServerVerion = "v16.0.0" +) + +type Client struct { + Conn *Conn +} + +type Server struct { + ConnTimeout time.Duration + PacketSize uint16 + Logger ContextLogger + Version uint32 + ProgName string + Encryption byte +} + +type ServerConfig struct { + ConnTimeout *time.Duration + PacketSize *uint16 + Logger ContextLogger + Version *string + Encryption *string + ProgName *string +} + +func NewServer(config ServerConfig) (*Server, error) { + server := &Server{} + + if config.PacketSize == nil { + server.PacketSize = defaultPacketSize + } else { + server.PacketSize = *config.PacketSize + } + // Ensure packet size falls within the TDS protocol range of 512 to 32767 bytes + // NOTE: Encrypted connections have a maximum size of 16383 bytes. If you request + // a higher packet size, the server will respond with an ENVCHANGE request to + // alter the packet size to 16383 bytes. + if server.PacketSize < 512 { + server.PacketSize = 512 + } else if server.PacketSize > 32767 { + server.PacketSize = 32767 + } + + if config.ConnTimeout != nil { + server.ConnTimeout = *config.ConnTimeout + } + + if config.Logger != nil { + server.Logger = config.Logger + } + + if config.Version != nil { + server.Version = getDriverVersion(*config.Version) + } else { + server.Version = getDriverVersion(defaultServerVerion) + } + + if config.ProgName != nil { + server.ProgName = *config.ProgName + } else { + server.ProgName = defaultServerProgName + } + + if config.Encryption != nil { + switch *config.Encryption { + case "strict": + server.Encryption = encryptStrict + case "required": + server.Encryption = encryptReq + case "on": + server.Encryption = encryptOn + case "off": + server.Encryption = encryptOff + default: + return nil, errors.New("invalid encryption option") + } + } else { + server.Encryption = encryptNotSup + } + + return server, nil +} + +func (s *Server) NewTdsServerSession(conn net.Conn) (*tdsSession, *login, error) { + toconn := newTimeoutConn(conn, s.ConnTimeout) + inbuf := newTdsBuffer(s.PacketSize, toconn) + + login, err := s.handshake(inbuf) + if err != nil { + return nil, nil, err + } + + sess := tdsSession{ + buf: inbuf, + logger: s.Logger, + } + + return &sess, &login, nil +} + +func (s *tdsSession) ReadCommand() (packetType, error) { + for { + _, err := s.buf.BeginRead() + if err != nil { + return 0, err + } + + if s.buf.final { + return s.buf.rPacketType, nil + } + } +} + +func (s *Server) handshake(r *tdsBuffer) (login, error) { + var login login + + err := s.readPrelogin(r) + if err != nil { + return login, err + } + + err = s.writePrelogin(r) + if err != nil { + return login, err + } + + login, err = s.readLogin(r) + if err != nil { + return login, err + } + + err = s.writeLogin(r) + if err != nil { + return login, err + } + + return login, nil +} + +func (s *Server) readPrelogin(r *tdsBuffer) error { + packet_type, err := r.BeginRead() + if err != nil { + return err + } + struct_buf, err := io.ReadAll(r) + if err != nil { + return err + } + if packet_type != packPrelogin { + return errors.New("invalid request, expected pre-login packet") + } + if len(struct_buf) == 0 { + return errors.New("invalid empty PRELOGIN request, it must contain at least one byte") + } + + offset := 0 + results := map[uint8][]byte{} + for { + // read prelogin option + plOption, err := readPreloginOption(struct_buf, offset) + if err != nil { + return err + } + + if plOption.token == preloginTERMINATOR { + break + } + + // read prelogin option data + value, err := readPreloginOptionData(plOption, struct_buf) + if err != nil { + return err + } + results[plOption.token] = value + + offset += preloginOptionSize + } + + return nil +} + +func (s *Server) writePrelogin(r *tdsBuffer) error { + if err := writePrelogin(packReply, r, s.preparePreloginResponseFields()); err != nil { + return err + } + + return nil +} + +func (s *Server) preparePreloginResponseFields() map[uint8][]byte { + fields := map[uint8][]byte{ + // 4 bytes for version and 2 bytes for minor version + preloginVERSION: {byte(s.Version), byte(s.Version >> 8), byte(s.Version >> 16), byte(s.Version >> 24), 0, 0}, + preloginENCRYPTION: {s.Encryption}, + preloginINSTOPT: {0}, + preloginTHREADID: {0, 0, 0, 0}, + preloginMARS: {0}, // MARS disabled + } + + return fields +} + +func (s *Server) readLogin(r *tdsBuffer) (login, error) { + var login login + + packet_type, err := r.BeginRead() + if err != nil { + return login, err + } + + if packet_type != packLogin7 { + return login, errors.New("invalid request, expected login packet") + } + + struct_buf, err := io.ReadAll(r) + if err != nil { + return login, err + } + + if len(struct_buf) == 0 { + return login, errors.New("invalid empty login request, it must contain at least one byte") + } + + var loginHeader loginHeader + if err := binary.Read(bytes.NewReader(struct_buf), binary.LittleEndian, &loginHeader); err != nil { + return login, fmt.Errorf("failed to read login packet: %w", err) + } + + login.TDSVersion = loginHeader.TDSVersion + login.ClientPID = loginHeader.ClientPID + login.ConnectionID = loginHeader.ConnectionID + login.OptionFlags1 = loginHeader.OptionFlags1 + login.OptionFlags2 = loginHeader.OptionFlags2 + login.TypeFlags = loginHeader.TypeFlags + login.OptionFlags3 = loginHeader.OptionFlags3 + login.ClientTimeZone = loginHeader.ClientTimeZone + login.ClientLCID = loginHeader.ClientLCID + login.ClientID = loginHeader.ClientID + + login.HostName, err = readLoginFieldString(struct_buf, loginHeader.HostNameOffset, loginHeader.HostNameLength) + if err != nil { + return login, fmt.Errorf("failed to read hostname: %w", err) + } + login.UserName, err = readLoginFieldString(struct_buf, loginHeader.UserNameOffset, loginHeader.UserNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.AppName, err = readLoginFieldString(struct_buf, loginHeader.AppNameOffset, loginHeader.AppNameLength) + if err != nil { + return login, fmt.Errorf("failed to read username: %w", err) + } + login.ServerName, err = readLoginFieldString(struct_buf, loginHeader.ServerNameOffset, loginHeader.ServerNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.CtlIntName, err = readLoginFieldString(struct_buf, loginHeader.CtlIntNameOffset, loginHeader.CtlIntNameLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Language, err = readLoginFieldString(struct_buf, loginHeader.LanguageOffset, loginHeader.LanguageLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.Database, err = readLoginFieldString(struct_buf, loginHeader.DatabaseOffset, loginHeader.DatabaseLength) + if err != nil { + return login, fmt.Errorf("failed to read servername: %w", err) + } + login.SSPI, err = readLoginFieldBytes(struct_buf, loginHeader.SSPIOffset, loginHeader.SSPILength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.AtchDBFile, err = readLoginFieldString(struct_buf, loginHeader.AtchDBFileOffset, loginHeader.AtchDBFileLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + login.ChangePassword, err = readLoginFieldString(struct_buf, loginHeader.ChangePasswordOffset, loginHeader.ChangePasswordLength) + if err != nil { + return login, fmt.Errorf("failed to read sspi: %w", err) + } + + return login, nil +} + +func readLoginFieldString(b []byte, offset uint16, length uint16) (string, error) { + if len(b) < int(offset)+int(length)*2 { + return "", fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length*2, len(b)) + } + + return ucs22str(b[offset : offset+length*2]) +} + +func readLoginFieldBytes(b []byte, offset uint16, length uint16) ([]byte, error) { + if len(b) < int(offset)+int(length)*2 { + return nil, fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length*2, len(b)) + } + + return b[offset : offset+length*2], nil +} + +func (s *Server) writeLogin(r *tdsBuffer) error { + loginAckStruct := loginAckStruct{ + Interface: 1, + TDSVersion: verTDS74, + ProgName: s.ProgName, + ProgVer: s.Version, + } + + doneStruct := doneStruct{ + Status: 0, + CurCmd: 0, + RowCount: 0, + errors: []Error{}, + } + + r.BeginPacket(packReply, false) + r.Write(writeLoginAck(loginAckStruct)) + r.Write(writeDone(doneStruct)) + + return r.FinishPacket() +} + +func UCS2String(s []byte) (string, error) { + return ucs22str(s) +} + +func (c *Conn) Transport() io.ReadWriteCloser { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf.transport +} + +func (c *Conn) Buffer() *tdsBuffer { + if c.sess == nil || c.sess.buf == nil { + return nil + } + + return c.sess.buf +} + +func (c *Conn) Session() *tdsSession { + return c.sess +} + +func (s *tdsSession) ParseHeader() (header, error) { + var h header + err := binary.Read(s.buf, binary.LittleEndian, &h) + if err != nil { + return header{}, err + } + + return h, nil +} + +func (s *tdsSession) ParseSQLBatch() ([]headerStruct, string, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, "", err + } + + query, err := readUcs2(s.buf, (s.buf.rsize-s.buf.rpos)/2) + if err != nil { + return nil, "", err + } + + return headers, query, nil +} + +func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []interface{}, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + + var nameLength uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &nameLength); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + var proc procId + var idswitch uint16 = 0xffff + if nameLength == idswitch { + if err := binary.Read(s.buf, binary.LittleEndian, &proc.id); err != nil { + return nil, procId{}, 0, nil, nil, err + } + } else { + proc.name, err = readUcs2(s.buf, int(nameLength)) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + } + + var flags uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, procId{}, 0, nil, nil, err + } + + params, values, err := parseParams(s.buf) + if err != nil { + return nil, procId{}, 0, nil, nil, err + } + + return headers, proc, flags, params, values, nil +} + +func parseParams(b *tdsBuffer) ([]param, []interface{}, error) { + var params []param + var values []interface{} + for { + if b.rpos >= b.rsize { + break + } + + var p param + + name, err := readBVarChar(b) + if err != nil { + return nil, nil, err + } + p.Name = name + + var flags uint8 + if err := binary.Read(b, binary.LittleEndian, &flags); err != nil { + return nil, nil, err + } + + p.Flags = flags + + p.ti = readTypeInfo(b, b.byte(), nil) + val := p.ti.Reader(&p.ti, b, nil) + p.buffer = p.ti.Buffer + + params = append(params, p) + values = append(values, val) + } + return params, values, nil +} + +func readAllHeaders(r io.Reader) ([]headerStruct, error) { + var totalLength uint32 + err := binary.Read(r, binary.LittleEndian, &totalLength) + if err != nil { + return nil, err + } + + if totalLength < 4 { + return nil, errors.New("invalid total length") + } + + var headers []headerStruct + remainingLength := totalLength - 4 // Subtracting the length of the totalLength field + + for remainingLength > 0 { + var headerLength uint32 + err = binary.Read(r, binary.LittleEndian, &headerLength) + if err != nil { + return nil, err + } + + if headerLength < 6 || headerLength-6 > remainingLength { + return nil, errors.New("invalid header length") + } + + var hdrtype uint16 + err = binary.Read(r, binary.LittleEndian, &hdrtype) + if err != nil { + return nil, err + } + + dataLength := headerLength - 6 // Subtracting the length of the headerLength and hdrtype fields + data := make([]byte, dataLength) + _, err = io.ReadFull(r, data) + if err != nil { + return nil, err + } + + headers = append(headers, headerStruct{ + hdrtype: hdrtype, + data: data, + }) + + remainingLength -= headerLength + } + + if remainingLength != 0 { + return nil, errors.New("inconsistent header length") + } + + return headers, nil +} + +func (p *procId) Id() uint16 { + return p.id +} + +func (p *procId) Name() string { + return p.name +} + +func writeDone(d doneStruct) []byte { + data := make([]byte, 0, 12) + + // Append tokenDone and the calculated size + data = append(data, byte(tokenDone)) + + // Append Status + statusBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(statusBytes, d.Status) + data = append(data, statusBytes...) + + // Append CurCmd + curCmdBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(curCmdBytes, d.CurCmd) + data = append(data, curCmdBytes...) + + // Append RowCount + rowCountBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(rowCountBytes, d.RowCount) + data = append(data, rowCountBytes...) + + return data +} + +func writeLoginAck(l loginAckStruct) []byte { + progNameUCS2 := str2ucs2(l.ProgName) + + // Prepare the slice with preallocated size for efficiency + data := make([]byte, 0, 10+len(progNameUCS2)) + + // Append tokenLoginAck + data = append(data, byte(tokenLoginAck)) + + // Append calculated size + size := uint16(10 + len(progNameUCS2)) + sizeBytes := make([]byte, 2) + binary.LittleEndian.PutUint16(sizeBytes, size) + data = append(data, sizeBytes...) + + // Append Interface + data = append(data, l.Interface) + + // Append TDSVersion + tdsVersionBytes := make([]byte, 4) + binary.BigEndian.PutUint32(tdsVersionBytes, l.TDSVersion) + data = append(data, tdsVersionBytes...) + + // Append ProgName Length and ProgName + data = append(data, byte(len(progNameUCS2)/2)) + data = append(data, progNameUCS2...) + + // Append ProgVer + progVerBytes := make([]byte, 4) + binary.BigEndian.PutUint32(progVerBytes, l.ProgVer) + data = append(data, progVerBytes...) + + return data +} + +func NewClient(ctx context.Context, dsn string) (*Client, error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + + conn, err := c.Connect(ctx) + if err != nil { + return nil, err + } + + mssqlConn, ok := conn.(*Conn) + if !ok { + return nil, fmt.Errorf("invalid conn") + } + + return &Client{ + Conn: mssqlConn, + }, nil +} + +func (c *Client) Close() error { + return c.Conn.Close() +} + +func (c *Client) SendSqlBatch(ctx context.Context, serverConn *tdsSession, query string, headers []headerStruct, resetSession bool) ([]doneStruct, error) { + if err := sendSqlBatch72(c.Conn.sess.buf, query, headers, resetSession); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) SendRpc(ctx context.Context, serverConn *tdsSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { + if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} + +func (c *Client) processResponse(ctx context.Context, sess *tdsSession) ([]doneStruct, error) { + c.Conn.sess.buf.serverConn = sess + + packet_type, err := c.Conn.sess.buf.BeginRead() + if err != nil { + switch e := err.(type) { + case *net.OpError: + return nil, e + default: + return nil, &net.OpError{Op: "Read", Err: err} + } + } + + if packet_type != packReply { + return nil, StreamError{ + InnerError: fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply), + } + } + + var dones []doneStruct + var columns []columnStruct + var errs []Error + for { + token := token(c.Conn.sess.buf.byte()) + switch token { + case tokenReturnStatus: + parseReturnStatus(c.Conn.sess.buf) + case tokenOrder: + parseOrder(c.Conn.sess.buf) + case tokenDone, tokenDoneProc, tokenDoneInProc: + res := parseDone(c.Conn.sess.buf) + res.errors = errs + dones = append(dones, res) + if res.Status&doneSrvError != 0 { + return dones, ServerError{res.getError()} + } + + if res.Status&doneMore == 0 { + return dones, nil + } + case tokenColMetadata: + columns = parseColMetadata72(c.Conn.sess.buf, sess) + case tokenRow: + row := make([]interface{}, len(columns)) + err = parseRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenNbcRow: + row := make([]interface{}, len(columns)) + err = parseNbcRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) + if err != nil { + return nil, StreamError{ + InnerError: fmt.Errorf("failed to parse row: %w", err), + } + } + case tokenEnvChange: + processEnvChg(ctx, c.Conn.sess) + case tokenError: + err := parseError72(c.Conn.sess.buf) + errs = append(errs, err) + case tokenInfo: + info := parseInfo(c.Conn.sess.buf) + fmt.Printf("got INFO %d %s\n", info.Number, info.Message) + case tokenReturnValue: + parseReturnValue(c.Conn.sess.buf, c.Conn.sess) + default: + return nil, StreamError{ + InnerError: fmt.Errorf("unknown token type returned: %v", token), + } + } + } +} From e5028045008986e2ef09cd65f32d71ac4055abff Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Fri, 15 Dec 2023 13:27:39 +0100 Subject: [PATCH 02/11] use msdsn config --- proxy.go | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/proxy.go b/proxy.go index 09ef021f..8a36e74b 100644 --- a/proxy.go +++ b/proxy.go @@ -9,6 +9,8 @@ import ( "io" "net" "time" + + "github.com/microsoft/go-mssqldb/msdsn" ) const ( @@ -572,11 +574,8 @@ func writeLoginAck(l loginAckStruct) []byte { return data } -func NewClient(ctx context.Context, dsn string) (*Client, error) { - c, err := NewConnector(dsn) - if err != nil { - return nil, err - } +func NewClient(ctx context.Context, params msdsn.Config) (*Client, error) { + c := newConnector(params, driverInstanceNoProcess) conn, err := c.Connect(ctx) if err != nil { From 69d8747064a82f15b2c8b7395794e721df4d10cf Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Fri, 15 Dec 2023 13:50:02 +0100 Subject: [PATCH 03/11] add server session --- proxy.go | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/proxy.go b/proxy.go index 8a36e74b..770e38f2 100644 --- a/proxy.go +++ b/proxy.go @@ -40,6 +40,10 @@ type ServerConfig struct { ProgName *string } +type ServerSession struct { + *tdsSession +} + func NewServer(config ServerConfig) (*Server, error) { server := &Server{} @@ -98,7 +102,7 @@ func NewServer(config ServerConfig) (*Server, error) { return server, nil } -func (s *Server) NewTdsServerSession(conn net.Conn) (*tdsSession, *login, error) { +func (s *Server) NewTdsServerSession(conn net.Conn) (*ServerSession, *login, error) { toconn := newTimeoutConn(conn, s.ConnTimeout) inbuf := newTdsBuffer(s.PacketSize, toconn) @@ -107,10 +111,10 @@ func (s *Server) NewTdsServerSession(conn net.Conn) (*tdsSession, *login, error) return nil, nil, err } - sess := tdsSession{ + sess := ServerSession{&tdsSession{ buf: inbuf, logger: s.Logger, - } + }} return &sess, &login, nil } @@ -596,7 +600,7 @@ func (c *Client) Close() error { return c.Conn.Close() } -func (c *Client) SendSqlBatch(ctx context.Context, serverConn *tdsSession, query string, headers []headerStruct, resetSession bool) ([]doneStruct, error) { +func (c *Client) SendSqlBatch(ctx context.Context, serverConn *ServerSession, query string, headers []headerStruct, resetSession bool) ([]doneStruct, error) { if err := sendSqlBatch72(c.Conn.sess.buf, query, headers, resetSession); err != nil { return nil, err } @@ -604,7 +608,7 @@ func (c *Client) SendSqlBatch(ctx context.Context, serverConn *tdsSession, query return c.processResponse(ctx, serverConn) } -func (c *Client) SendRpc(ctx context.Context, serverConn *tdsSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { +func (c *Client) SendRpc(ctx context.Context, serverConn *ServerSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession); err != nil { return nil, err } @@ -612,8 +616,8 @@ func (c *Client) SendRpc(ctx context.Context, serverConn *tdsSession, headers [] return c.processResponse(ctx, serverConn) } -func (c *Client) processResponse(ctx context.Context, sess *tdsSession) ([]doneStruct, error) { - c.Conn.sess.buf.serverConn = sess +func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]doneStruct, error) { + c.Conn.sess.buf.serverConn = sess.tdsSession packet_type, err := c.Conn.sess.buf.BeginRead() if err != nil { @@ -653,7 +657,7 @@ func (c *Client) processResponse(ctx context.Context, sess *tdsSession) ([]doneS return dones, nil } case tokenColMetadata: - columns = parseColMetadata72(c.Conn.sess.buf, sess) + columns = parseColMetadata72(c.Conn.sess.buf, c.Conn.sess) case tokenRow: row := make([]interface{}, len(columns)) err = parseRow(ctx, c.Conn.sess.buf, c.Conn.sess, columns, row) From 41cb1b5aa55981c1353dc9eb88f32aa9e9f67d23 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Fri, 15 Dec 2023 13:59:38 +0100 Subject: [PATCH 04/11] test --- proxy.go | 1 + 1 file changed, 1 insertion(+) diff --git a/proxy.go b/proxy.go index 770e38f2..761c6eac 100644 --- a/proxy.go +++ b/proxy.go @@ -691,3 +691,4 @@ func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]do } } } + From b0d9e55d200b89935edc20a68b058bd540144065 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Fri, 15 Dec 2023 15:40:33 +0100 Subject: [PATCH 05/11] publish errors --- proxy.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/proxy.go b/proxy.go index 761c6eac..d3913b80 100644 --- a/proxy.go +++ b/proxy.go @@ -692,3 +692,17 @@ func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]do } } +func (d doneStruct) GetError() error { + n := len(d.errors) + if n == 0 { + return nil + } + + var err error + + for _, e := range d.errors { + err = errors.Join(err, e) + } + + return err +} From a517545d1c9f0a3dd1e944dc378e516b0c5366a3 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Tue, 19 Dec 2023 10:59:49 +0100 Subject: [PATCH 06/11] proxy changes --- proxy.go | 29 +++++++++++++++++------------ tds.go | 1 + token.go | 32 +++++++++++++++++++++++++++----- types.go | 2 ++ 4 files changed, 47 insertions(+), 17 deletions(-) diff --git a/proxy.go b/proxy.go index d3913b80..6b8a734d 100644 --- a/proxy.go +++ b/proxy.go @@ -102,7 +102,7 @@ func NewServer(config ServerConfig) (*Server, error) { return server, nil } -func (s *Server) NewTdsServerSession(conn net.Conn) (*ServerSession, *login, error) { +func (s *Server) ReadLogin(conn net.Conn) (*ServerSession, *login, error) { toconn := newTimeoutConn(conn, s.ConnTimeout) inbuf := newTdsBuffer(s.PacketSize, toconn) @@ -150,11 +150,6 @@ func (s *Server) handshake(r *tdsBuffer) (login, error) { return login, err } - err = s.writeLogin(r) - if err != nil { - return login, err - } - return login, nil } @@ -211,7 +206,7 @@ func (s *Server) writePrelogin(r *tdsBuffer) error { func (s *Server) preparePreloginResponseFields() map[uint8][]byte { fields := map[uint8][]byte{ // 4 bytes for version and 2 bytes for minor version - preloginVERSION: {byte(s.Version), byte(s.Version >> 8), byte(s.Version >> 16), byte(s.Version >> 24), 0, 0}, + preloginVERSION: {byte(s.Version >> 24), byte(s.Version >> 16), byte(s.Version >> 8), byte(s.Version), 0, 0}, preloginENCRYPTION: {s.Encryption}, preloginINSTOPT: {0}, preloginTHREADID: {0, 0, 0, 0}, @@ -318,7 +313,7 @@ func readLoginFieldBytes(b []byte, offset uint16, length uint16) ([]byte, error) return b[offset : offset+length*2], nil } -func (s *Server) writeLogin(r *tdsBuffer) error { +func (s *Server) WriteLogin(session *ServerSession, loginEnvBytes []byte) error { loginAckStruct := loginAckStruct{ Interface: 1, TDSVersion: verTDS74, @@ -333,11 +328,12 @@ func (s *Server) writeLogin(r *tdsBuffer) error { errors: []Error{}, } - r.BeginPacket(packReply, false) - r.Write(writeLoginAck(loginAckStruct)) - r.Write(writeDone(doneStruct)) + session.buf.BeginPacket(packReply, false) + session.buf.Write(loginEnvBytes) + session.buf.Write(writeLoginAck(loginAckStruct)) + session.buf.Write(writeDone(doneStruct)) - return r.FinishPacket() + return session.buf.FinishPacket() } func UCS2String(s []byte) (string, error) { @@ -682,6 +678,7 @@ func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]do case tokenInfo: info := parseInfo(c.Conn.sess.buf) fmt.Printf("got INFO %d %s\n", info.Number, info.Message) + case tokenReturnValue: parseReturnValue(c.Conn.sess.buf, c.Conn.sess) default: @@ -706,3 +703,11 @@ func (d doneStruct) GetError() error { return err } + +func (c *Client) LoginEnvBytes() []byte { + return c.Conn.sess.loginEnvBytes +} + +func (c *Client) Database() string { + return c.Conn.sess.database +} diff --git a/tds.go b/tds.go index 9ddc2ce7..e6b88881 100644 --- a/tds.go +++ b/tds.go @@ -176,6 +176,7 @@ type tdsSession struct { connid UniqueIdentifier activityid UniqueIdentifier encoding msdsn.EncodeParameters + loginEnvBytes []byte } type alwaysEncryptedSettings struct { diff --git a/token.go b/token.go index 8926ca58..710e7886 100644 --- a/token.go +++ b/token.go @@ -144,15 +144,24 @@ type doneInProcStruct doneStruct // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx -func processEnvChg(ctx context.Context, sess *tdsSession) { +func processEnvChg(ctx context.Context, sess *tdsSession) []byte { size := sess.buf.uint16() - r := &io.LimitedReader{R: sess.buf, N: int64(size)} + rb := &io.LimitedReader{R: sess.buf, N: int64(size)} + + buf := new(bytes.Buffer) + _, err := io.Copy(buf, rb) + if err != nil { + badStreamPanic(err) + } + + r := bytes.NewReader(buf.Bytes()) + for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { - return + return buf.Bytes() } if err != nil { badStreamPanic(err) @@ -393,7 +402,7 @@ func processEnvChg(ctx context.Context, sess *tdsSession) { default: // ignore rest of records because we don't know how to skip those sess.LogF(ctx, msdsn.LogDebug, "WARN: Unknown ENVCHANGE record detected with type id = %d", envtype) - return + return buf.Bytes() } } } @@ -1085,7 +1094,9 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS } ch <- row case tokenEnvChange: - processEnvChg(ctx, sess) + tokenBytes := processEnvChg(ctx, sess) + sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenEnvChange), byte(len(tokenBytes) & 0xFF), byte(len(tokenBytes) >> 8)}...) + sess.loginEnvBytes = append(sess.loginEnvBytes, tokenBytes...) case tokenError: err := parseError72(sess.buf) sess.LogF(ctx, msdsn.LogDebug, "got ERROR %d %s", err.Number, err.Message) @@ -1095,12 +1106,23 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) } case tokenInfo: +<<<<<<< HEAD info := parseInfo(sess.buf) sess.LogF(ctx, msdsn.LogDebug, "got INFO %d %s", info.Number, info.Message) sess.LogS(ctx, msdsn.LogMessages, info.Message) if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) +======= + length := sess.buf.uint16() + infoBytes := make([]byte, length) + _, err := sess.buf.Read(infoBytes) + if err != nil { + badStreamPanic(err) +>>>>>>> 1158ce2 (proxy changes) } + + sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenInfo), byte(length & 0xFF), byte(length >> 8)}...) + sess.loginEnvBytes = append(sess.loginEnvBytes, infoBytes...) case tokenReturnValue: nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { diff --git a/types.go b/types.go index 8e5c7dcc..30637167 100644 --- a/types.go +++ b/types.go @@ -516,6 +516,8 @@ func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} if size == 0xffff { return nil } + ti.Size = int(size) + ti.Buffer = make([]byte, ti.Size) r.ReadFull(ti.Buffer[:size]) buf := ti.Buffer[:size] switch ti.TypeId { From efe52c0bb290475a3f9bd1320748d193840e61e2 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Tue, 19 Dec 2023 11:03:03 +0100 Subject: [PATCH 07/11] remove debug --- proxy.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/proxy.go b/proxy.go index 6b8a734d..0c2ccbd5 100644 --- a/proxy.go +++ b/proxy.go @@ -676,9 +676,7 @@ func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]do err := parseError72(c.Conn.sess.buf) errs = append(errs, err) case tokenInfo: - info := parseInfo(c.Conn.sess.buf) - fmt.Printf("got INFO %d %s\n", info.Number, info.Message) - + parseInfo(c.Conn.sess.buf) case tokenReturnValue: parseReturnValue(c.Conn.sess.buf, c.Conn.sess) default: From 1f4f890f4b38ffff1c9445e353d5ee5726d4065b Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Thu, 28 Dec 2023 11:14:53 +0100 Subject: [PATCH 08/11] fix types --- buf.go | 9 +++- examples/simple/simple.go | 18 ++++--- proxy.go | 101 +++++++++++++++++++++++++++++++++++--- tds.go | 2 +- 4 files changed, 114 insertions(+), 16 deletions(-) diff --git a/buf.go b/buf.go index 47cd625c..ceecd2a7 100644 --- a/buf.go +++ b/buf.go @@ -11,7 +11,7 @@ type packetType uint8 type header struct { PacketType packetType - Status uint8 + Status byte Size uint16 Spid uint16 PacketNo uint8 @@ -45,6 +45,7 @@ type tdsBuffer struct { wpos int wPacketSeq byte wPacketType packetType + wSpid uint16 // Read fields. rbuf []byte @@ -88,6 +89,7 @@ func (w *tdsBuffer) flush() (err error) { // Write packet size. w.wbuf[0] = byte(w.wPacketType) binary.BigEndian.PutUint16(w.wbuf[2:], uint16(w.wpos)) + binary.BigEndian.PutUint16(w.wbuf[4:], w.wSpid) w.wbuf[6] = w.wPacketSeq // Write packet into underlying transport. @@ -171,12 +173,14 @@ func (r *tdsBuffer) readNextPacket() error { PacketNo: buf[6], Pad: buf[7], } + if int(h.Size) > r.packetSize { return errors.New("invalid packet size, it is longer than buffer size") } if headerSize > int(h.Size) { return errors.New("invalid packet size, it is shorter than header size") } + _, err = io.ReadFull(r.transport, r.rbuf[headerSize:h.Size]) //s := base64.StdEncoding.EncodeToString(r.rbuf[headerSize:h.Size]) //fmt.Print(s) @@ -185,7 +189,7 @@ func (r *tdsBuffer) readNextPacket() error { } r.rpos = headerSize r.rsize = int(h.Size) - r.final = h.Status != 0 + r.final = h.Status&0x1 != 0 r.rPacketType = h.PacketType if r.serverConn != nil { @@ -195,6 +199,7 @@ func (r *tdsBuffer) readNextPacket() error { } if r.final { + r.serverConn.buf.wSpid = h.Spid if err := r.serverConn.buf.FinishPacket(); err != nil { return err } diff --git a/examples/simple/simple.go b/examples/simple/simple.go index 7cb5cfcc..faccef2d 100644 --- a/examples/simple/simple.go +++ b/examples/simple/simple.go @@ -5,6 +5,7 @@ import ( "flag" "fmt" "log" + "strings" _ "github.com/microsoft/go-mssqldb" ) @@ -37,21 +38,24 @@ func main() { } defer conn.Close() - stmt, err := conn.Prepare("select 1, 'abc'") + stmt, err := conn.Prepare("EXEC UpdateUserComments @UserID=?, @Comment=?") + if err != nil { + log.Fatal("Error preparing statement: ", err.Error()) + } + defer stmt.Close() + + //comment := "Your long comment text here..." + comment := strings.Repeat("A long comment text ", 500) // Simulate a long text + + _, err = stmt.Exec(10, comment) if err != nil { log.Fatal("Prepare failed:", err.Error()) } defer stmt.Close() - row := stmt.QueryRow() - var somenumber int64 - var somechars string - err = row.Scan(&somenumber, &somechars) if err != nil { log.Fatal("Scan failed:", err.Error()) } - fmt.Printf("somenumber:%d\n", somenumber) - fmt.Printf("somechars:%s\n", somechars) fmt.Printf("bye\n") } diff --git a/proxy.go b/proxy.go index 0c2ccbd5..cb5de66f 100644 --- a/proxy.go +++ b/proxy.go @@ -120,13 +120,21 @@ func (s *Server) ReadLogin(conn net.Conn) (*ServerSession, *login, error) { } func (s *tdsSession) ReadCommand() (packetType, error) { + var buf []byte for { _, err := s.buf.BeginRead() if err != nil { return 0, err } + bytes := make([]byte, s.buf.rsize-s.buf.rpos) + s.buf.ReadFull(bytes) + buf = append(buf, bytes...) + if s.buf.final { + copy(s.buf.rbuf, buf) + s.buf.rsize = len(buf) + s.buf.rpos = 0 return s.buf.rPacketType, nil } } @@ -218,7 +226,6 @@ func (s *Server) preparePreloginResponseFields() map[uint8][]byte { func (s *Server) readLogin(r *tdsBuffer) (login, error) { var login login - packet_type, err := r.BeginRead() if err != nil { return login, err @@ -243,6 +250,7 @@ func (s *Server) readLogin(r *tdsBuffer) (login, error) { } login.TDSVersion = loginHeader.TDSVersion + login.ClientProgVer = loginHeader.ClientProgVer login.ClientPID = loginHeader.ClientPID login.ConnectionID = loginHeader.ConnectionID login.OptionFlags1 = loginHeader.OptionFlags1 @@ -306,11 +314,11 @@ func readLoginFieldString(b []byte, offset uint16, length uint16) (string, error } func readLoginFieldBytes(b []byte, offset uint16, length uint16) ([]byte, error) { - if len(b) < int(offset)+int(length)*2 { - return nil, fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length*2, len(b)) + if len(b) < int(offset)+int(length) { + return nil, fmt.Errorf("invalid login packet, expected %d bytes, got %d", offset+length, len(b)) } - return b[offset : offset+length*2], nil + return b[offset : offset+length], nil } func (s *Server) WriteLogin(session *ServerSession, loginEnvBytes []byte) error { @@ -384,6 +392,60 @@ func (s *tdsSession) ParseSQLBatch() ([]headerStruct, string, error) { return headers, query, nil } +func (s *tdsSession) ParseTransMgrReq() ([]headerStruct, uint16, isoLevel, string, string, uint8, error) { + headers, err := readAllHeaders(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var rqtype uint16 + if err := binary.Read(s.buf, binary.LittleEndian, &rqtype); err != nil { + return nil, 0, 0, "", "", 0, err + } + + switch rqtype { + case tmBeginXact: + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + return headers, rqtype, isolationLevel, name, "", 0, nil + case tmCommitXact, tmRollbackXact: + name, err := readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + + var flags uint8 + if err := binary.Read(s.buf, binary.LittleEndian, &flags); err != nil { + return nil, 0, 0, "", "", 0, err + } + + var newname string + if flags&fBeginXact != 0 { + var isolationLevel isoLevel + if err := binary.Read(s.buf, binary.LittleEndian, &isolationLevel); err != nil { + return nil, 0, 0, "", "", 0, err + } + + newname, err = readBVarChar(s.buf) + if err != nil { + return nil, 0, 0, "", "", 0, err + } + } + + return headers, rqtype, 0, name, newname, flags, nil + default: + return nil, 0, 0, "", "", 0, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } +} + func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []interface{}, error) { headers, err := readAllHeaders(s.buf) if err != nil { @@ -443,11 +505,9 @@ func parseParams(b *tdsBuffer) ([]param, []interface{}, error) { } p.Flags = flags - p.ti = readTypeInfo(b, b.byte(), nil) val := p.ti.Reader(&p.ti, b, nil) p.buffer = p.ti.Buffer - params = append(params, p) values = append(values, val) } @@ -612,6 +672,27 @@ func (c *Client) SendRpc(ctx context.Context, serverConn *ServerSession, headers return c.processResponse(ctx, serverConn) } +func (c *Client) TransMgrReq(ctx context.Context, serverConn *ServerSession, headers []headerStruct, rqtype uint16, isolationLevel isoLevel, name, newname string, flags uint8, resetSession bool) ([]doneStruct, error) { + switch rqtype { + case tmBeginXact: + if err := sendBeginXact(c.Conn.sess.buf, headers, isolationLevel, name, resetSession); err != nil { + return nil, err + } + case tmCommitXact: + if err := sendCommitXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + case tmRollbackXact: + if err := sendRollbackXact(c.Conn.sess.buf, headers, name, flags, uint8(isolationLevel), newname, resetSession); err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("invalid transaction manager request type: %d", rqtype) + } + + return c.processResponse(ctx, serverConn) +} + func (c *Client) processResponse(ctx context.Context, sess *ServerSession) ([]doneStruct, error) { c.Conn.sess.buf.serverConn = sess.tdsSession @@ -709,3 +790,11 @@ func (c *Client) LoginEnvBytes() []byte { func (c *Client) Database() string { return c.Conn.sess.database } + +func (c *Client) SendAttention(ctx context.Context, serverConn *ServerSession) ([]doneStruct, error) { + if err := sendAttention(c.Conn.sess.buf); err != nil { + return nil, err + } + + return c.processResponse(ctx, serverConn) +} diff --git a/tds.go b/tds.go index e6b88881..c7da8647 100644 --- a/tds.go +++ b/tds.go @@ -1045,6 +1045,7 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont } else { serverName = p.Host } + l = &login{ TDSVersion: TDSVersion, PacketSize: packetSize, @@ -1270,7 +1271,6 @@ initiate_connection: if err != nil { return nil, err } - err = sendLogin(outbuf, login) if err != nil { return nil, err From 020decf2ca1976ef260a393f415438adc9992fbf Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Wed, 3 Jan 2024 14:53:28 +0100 Subject: [PATCH 09/11] add support for gcp cloudsql proxy and aws rds proxy --- proxy.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/proxy.go b/proxy.go index cb5de66f..7136a589 100644 --- a/proxy.go +++ b/proxy.go @@ -634,9 +634,19 @@ func writeLoginAck(l loginAckStruct) []byte { return data } -func NewClient(ctx context.Context, params msdsn.Config) (*Client, error) { +func NewClient(ctx context.Context, params msdsn.Config, tokenProvider func(ctx context.Context) (string, error), dialer Dialer) (*Client, error) { c := newConnector(params, driverInstanceNoProcess) + if tokenProvider != nil { + c.fedAuthRequired = true + c.fedAuthLibrary = FedAuthLibrarySecurityToken + c.securityTokenProvider = tokenProvider + } + + if dialer != nil { + c.Dialer = dialer + } + conn, err := c.Connect(ctx) if err != nil { return nil, err From edc2f5ae60611d3cf7426accfd0856f8bddb9558 Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Mon, 29 Jan 2024 10:31:58 +0100 Subject: [PATCH 10/11] support for azuread --- proxy.go | 24 +++++++++++------------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/proxy.go b/proxy.go index 7136a589..3a96af3d 100644 --- a/proxy.go +++ b/proxy.go @@ -634,31 +634,29 @@ func writeLoginAck(l loginAckStruct) []byte { return data } -func NewClient(ctx context.Context, params msdsn.Config, tokenProvider func(ctx context.Context) (string, error), dialer Dialer) (*Client, error) { - c := newConnector(params, driverInstanceNoProcess) - - if tokenProvider != nil { - c.fedAuthRequired = true - c.fedAuthLibrary = FedAuthLibrarySecurityToken - c.securityTokenProvider = tokenProvider - } +func NewConnectorFromConfig(config msdsn.Config) *Connector { + return newConnector(config, driverInstanceNoProcess) +} +func NewClient(ctx context.Context, c *Connector, dialer Dialer, database string) (*Client, error) { if dialer != nil { c.Dialer = dialer } - conn, err := c.Connect(ctx) + params := c.params + params.Database = database + + conn, err := c.driver.connect(ctx, c, c.params) if err != nil { return nil, err } - mssqlConn, ok := conn.(*Conn) - if !ok { - return nil, fmt.Errorf("invalid conn") + if err := conn.ResetSession(ctx); err != nil { + return nil, err } return &Client{ - Conn: mssqlConn, + Conn: conn, }, nil } From 8e5aa44c57181f4bd27af52920bdcd730395272b Mon Sep 17 00:00:00 2001 From: Bas Toonk Date: Tue, 22 Apr 2025 15:21:41 +0200 Subject: [PATCH 11/11] update upstream --- proxy.go | 8 ++++---- token.go | 8 -------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/proxy.go b/proxy.go index 3a96af3d..5c8d3cfa 100644 --- a/proxy.go +++ b/proxy.go @@ -475,7 +475,7 @@ func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []inte return nil, procId{}, 0, nil, nil, err } - params, values, err := parseParams(s.buf) + params, values, err := parseParams(s.buf, s.encoding) if err != nil { return nil, procId{}, 0, nil, nil, err } @@ -483,7 +483,7 @@ func (s *tdsSession) ParseRPC() ([]headerStruct, procId, uint16, []param, []inte return headers, proc, flags, params, values, nil } -func parseParams(b *tdsBuffer) ([]param, []interface{}, error) { +func parseParams(b *tdsBuffer, encoding msdsn.EncodeParameters) ([]param, []interface{}, error) { var params []param var values []interface{} for { @@ -505,7 +505,7 @@ func parseParams(b *tdsBuffer) ([]param, []interface{}, error) { } p.Flags = flags - p.ti = readTypeInfo(b, b.byte(), nil) + p.ti = readTypeInfo(b, b.byte(), nil, encoding) val := p.ti.Reader(&p.ti, b, nil) p.buffer = p.ti.Buffer params = append(params, p) @@ -673,7 +673,7 @@ func (c *Client) SendSqlBatch(ctx context.Context, serverConn *ServerSession, qu } func (c *Client) SendRpc(ctx context.Context, serverConn *ServerSession, headers []headerStruct, proc procId, flags uint16, params []param, resetSession bool) ([]doneStruct, error) { - if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession); err != nil { + if err := sendRpc(c.Conn.sess.buf, headers, proc, flags, params, resetSession, c.Conn.sess.encoding); err != nil { return nil, err } diff --git a/token.go b/token.go index 710e7886..5feb2f7f 100644 --- a/token.go +++ b/token.go @@ -1106,19 +1106,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) } case tokenInfo: -<<<<<<< HEAD - info := parseInfo(sess.buf) - sess.LogF(ctx, msdsn.LogDebug, "got INFO %d %s", info.Number, info.Message) - sess.LogS(ctx, msdsn.LogMessages, info.Message) - if outs.msgq != nil { - _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) -======= length := sess.buf.uint16() infoBytes := make([]byte, length) _, err := sess.buf.Read(infoBytes) if err != nil { badStreamPanic(err) ->>>>>>> 1158ce2 (proxy changes) } sess.loginEnvBytes = append(sess.loginEnvBytes, []byte{byte(tokenInfo), byte(length & 0xFF), byte(length >> 8)}...)