Skip to content

Commit 0b33bfe

Browse files
authored
transport: Discard the buffer when empty after http connect handshake (#7424)
* Discard the buffer when empty after http connect handshake * configure the proxy to wait for server hello * Extract test args to a struct * Change deadline sets
1 parent 566aad1 commit 0b33bfe

File tree

2 files changed

+81
-19
lines changed

2 files changed

+81
-19
lines changed

internal/transport/proxy.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,14 @@ func doHTTPConnectHandshake(ctx context.Context, conn net.Conn, backendAddr stri
107107
}
108108
return nil, fmt.Errorf("failed to do connect handshake, response: %q", dump)
109109
}
110-
111-
return &bufConn{Conn: conn, r: r}, nil
110+
// The buffer could contain extra bytes from the target server, so we can't
111+
// discard it. However, in many cases where the server waits for the client
112+
// to send the first message (e.g. when TLS is being used), the buffer will
113+
// be empty, so we can avoid the overhead of reading through this buffer.
114+
if r.Buffered() != 0 {
115+
return &bufConn{Conn: conn, r: r}, nil
116+
}
117+
return conn, nil
112118
}
113119

114120
// proxyDial dials, connecting to a proxy first if necessary. Checks if a proxy

internal/transport/proxy_test.go

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package transport
2323

2424
import (
2525
"bufio"
26+
"bytes"
2627
"context"
2728
"encoding/base64"
2829
"fmt"
@@ -58,7 +59,7 @@ type proxyServer struct {
5859
requestCheck func(*http.Request) error
5960
}
6061

61-
func (p *proxyServer) run() {
62+
func (p *proxyServer) run(waitForServerHello bool) {
6263
in, err := p.lis.Accept()
6364
if err != nil {
6465
return
@@ -83,8 +84,26 @@ func (p *proxyServer) run() {
8384
p.t.Errorf("failed to dial to server: %v", err)
8485
return
8586
}
87+
out.SetDeadline(time.Now().Add(defaultTestTimeout))
8688
resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
87-
resp.Write(p.in)
89+
var buf bytes.Buffer
90+
resp.Write(&buf)
91+
if waitForServerHello {
92+
// Batch the first message from the server with the http connect
93+
// response. This is done to test the cases in which the grpc client has
94+
// the response to the connect request and proxied packets from the
95+
// destination server when it reads the transport.
96+
b := make([]byte, 50)
97+
bytesRead, err := out.Read(b)
98+
if err != nil {
99+
p.t.Errorf("Got error while reading server hello: %v", err)
100+
in.Close()
101+
out.Close()
102+
return
103+
}
104+
buf.Write(b[0:bytesRead])
105+
}
106+
p.in.Write(buf.Bytes())
88107
p.out = out
89108
go io.Copy(p.in, p.out)
90109
go io.Copy(p.out, p.in)
@@ -100,17 +119,23 @@ func (p *proxyServer) stop() {
100119
}
101120
}
102121

103-
func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
122+
type testArgs struct {
123+
proxyURLModify func(*url.URL) *url.URL
124+
proxyReqCheck func(*http.Request) error
125+
serverMessage []byte
126+
}
127+
128+
func testHTTPConnect(t *testing.T, args testArgs) {
104129
plis, err := net.Listen("tcp", "localhost:0")
105130
if err != nil {
106131
t.Fatalf("failed to listen: %v", err)
107132
}
108133
p := &proxyServer{
109134
t: t,
110135
lis: plis,
111-
requestCheck: proxyReqCheck,
136+
requestCheck: args.proxyReqCheck,
112137
}
113-
go p.run()
138+
go p.run(len(args.serverMessage) > 0)
114139
defer p.stop()
115140

116141
blis, err := net.Listen("tcp", "localhost:0")
@@ -128,13 +153,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
128153
return
129154
}
130155
defer in.Close()
156+
in.Write(args.serverMessage)
131157
in.Read(recvBuf)
132158
done <- nil
133159
}()
134160

135161
// Overwrite the function in the test and restore them in defer.
136162
hpfe := func(req *http.Request) (*url.URL, error) {
137-
return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
163+
return args.proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
138164
}
139165
defer overwrite(hpfe)()
140166

@@ -143,47 +169,76 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
143169
defer cancel()
144170
c, err := proxyDial(ctx, blis.Addr().String(), "test")
145171
if err != nil {
146-
t.Fatalf("http connect Dial failed: %v", err)
172+
t.Fatalf("HTTP connect Dial failed: %v", err)
147173
}
148174
defer c.Close()
175+
c.SetDeadline(time.Now().Add(defaultTestTimeout))
149176

150177
// Send msg on the connection.
151178
c.Write(msg)
152179
if err := <-done; err != nil {
153-
t.Fatalf("failed to accept: %v", err)
180+
t.Fatalf("Failed to accept: %v", err)
154181
}
155182

156183
// Check received msg.
157184
if string(recvBuf) != string(msg) {
158-
t.Fatalf("received msg: %v, want %v", recvBuf, msg)
185+
t.Fatalf("Received msg: %v, want %v", recvBuf, msg)
186+
}
187+
188+
if len(args.serverMessage) > 0 {
189+
gotServerMessage := make([]byte, len(args.serverMessage))
190+
if _, err := c.Read(gotServerMessage); err != nil {
191+
t.Errorf("Got error while reading message from server: %v", err)
192+
return
193+
}
194+
if string(gotServerMessage) != string(args.serverMessage) {
195+
t.Errorf("Message from server: %v, want %v", gotServerMessage, args.serverMessage)
196+
}
159197
}
160198
}
161199

162200
func (s) TestHTTPConnect(t *testing.T) {
163-
testHTTPConnect(t,
164-
func(in *url.URL) *url.URL {
201+
args := testArgs{
202+
proxyURLModify: func(in *url.URL) *url.URL {
165203
return in
166204
},
167-
func(req *http.Request) error {
205+
proxyReqCheck: func(req *http.Request) error {
168206
if req.Method != http.MethodConnect {
169207
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
170208
}
171209
return nil
172210
},
173-
)
211+
}
212+
testHTTPConnect(t, args)
213+
}
214+
215+
func (s) TestHTTPConnectWithServerHello(t *testing.T) {
216+
args := testArgs{
217+
proxyURLModify: func(in *url.URL) *url.URL {
218+
return in
219+
},
220+
proxyReqCheck: func(req *http.Request) error {
221+
if req.Method != http.MethodConnect {
222+
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
223+
}
224+
return nil
225+
},
226+
serverMessage: []byte("server-hello"),
227+
}
228+
testHTTPConnect(t, args)
174229
}
175230

176231
func (s) TestHTTPConnectBasicAuth(t *testing.T) {
177232
const (
178233
user = "notAUser"
179234
password = "notAPassword"
180235
)
181-
testHTTPConnect(t,
182-
func(in *url.URL) *url.URL {
236+
args := testArgs{
237+
proxyURLModify: func(in *url.URL) *url.URL {
183238
in.User = url.UserPassword(user, password)
184239
return in
185240
},
186-
func(req *http.Request) error {
241+
proxyReqCheck: func(req *http.Request) error {
187242
if req.Method != http.MethodConnect {
188243
return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
189244
}
@@ -195,7 +250,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
195250
}
196251
return nil
197252
},
198-
)
253+
}
254+
testHTTPConnect(t, args)
199255
}
200256

201257
func (s) TestMapAddressEnv(t *testing.T) {

0 commit comments

Comments
 (0)