@@ -23,6 +23,7 @@ package transport
2323
2424import (
2525 "bufio"
26+ "bytes"
2627 "context"
2728 "encoding/base64"
2829 "fmt"
@@ -58,7 +59,7 @@ type proxyServer struct {
5859 requestCheck func (* http.Request ) error
5960}
6061
61- func (p * proxyServer ) run () {
62+ func (p * proxyServer ) run (waitForServerHello bool ) {
6263 in , err := p .lis .Accept ()
6364 if err != nil {
6465 return
@@ -83,8 +84,26 @@ func (p *proxyServer) run() {
8384 p .t .Errorf ("failed to dial to server: %v" , err )
8485 return
8586 }
87+ out .SetDeadline (time .Now ().Add (defaultTestTimeout ))
8688 resp := http.Response {StatusCode : http .StatusOK , Proto : "HTTP/1.0" }
87- resp .Write (p .in )
89+ var buf bytes.Buffer
90+ resp .Write (& buf )
91+ if waitForServerHello {
92+ // Batch the first message from the server with the http connect
93+ // response. This is done to test the cases in which the grpc client has
94+ // the response to the connect request and proxied packets from the
95+ // destination server when it reads the transport.
96+ b := make ([]byte , 50 )
97+ bytesRead , err := out .Read (b )
98+ if err != nil {
99+ p .t .Errorf ("Got error while reading server hello: %v" , err )
100+ in .Close ()
101+ out .Close ()
102+ return
103+ }
104+ buf .Write (b [0 :bytesRead ])
105+ }
106+ p .in .Write (buf .Bytes ())
88107 p .out = out
89108 go io .Copy (p .in , p .out )
90109 go io .Copy (p .out , p .in )
@@ -100,17 +119,23 @@ func (p *proxyServer) stop() {
100119 }
101120}
102121
103- func testHTTPConnect (t * testing.T , proxyURLModify func (* url.URL ) * url.URL , proxyReqCheck func (* http.Request ) error ) {
122+ type testArgs struct {
123+ proxyURLModify func (* url.URL ) * url.URL
124+ proxyReqCheck func (* http.Request ) error
125+ serverMessage []byte
126+ }
127+
128+ func testHTTPConnect (t * testing.T , args testArgs ) {
104129 plis , err := net .Listen ("tcp" , "localhost:0" )
105130 if err != nil {
106131 t .Fatalf ("failed to listen: %v" , err )
107132 }
108133 p := & proxyServer {
109134 t : t ,
110135 lis : plis ,
111- requestCheck : proxyReqCheck ,
136+ requestCheck : args . proxyReqCheck ,
112137 }
113- go p .run ()
138+ go p .run (len ( args . serverMessage ) > 0 )
114139 defer p .stop ()
115140
116141 blis , err := net .Listen ("tcp" , "localhost:0" )
@@ -128,13 +153,14 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
128153 return
129154 }
130155 defer in .Close ()
156+ in .Write (args .serverMessage )
131157 in .Read (recvBuf )
132158 done <- nil
133159 }()
134160
135161 // Overwrite the function in the test and restore them in defer.
136162 hpfe := func (req * http.Request ) (* url.URL , error ) {
137- return proxyURLModify (& url.URL {Host : plis .Addr ().String ()}), nil
163+ return args . proxyURLModify (& url.URL {Host : plis .Addr ().String ()}), nil
138164 }
139165 defer overwrite (hpfe )()
140166
@@ -143,47 +169,76 @@ func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxy
143169 defer cancel ()
144170 c , err := proxyDial (ctx , blis .Addr ().String (), "test" )
145171 if err != nil {
146- t .Fatalf ("http connect Dial failed: %v" , err )
172+ t .Fatalf ("HTTP connect Dial failed: %v" , err )
147173 }
148174 defer c .Close ()
175+ c .SetDeadline (time .Now ().Add (defaultTestTimeout ))
149176
150177 // Send msg on the connection.
151178 c .Write (msg )
152179 if err := <- done ; err != nil {
153- t .Fatalf ("failed to accept: %v" , err )
180+ t .Fatalf ("Failed to accept: %v" , err )
154181 }
155182
156183 // Check received msg.
157184 if string (recvBuf ) != string (msg ) {
158- t .Fatalf ("received msg: %v, want %v" , recvBuf , msg )
185+ t .Fatalf ("Received msg: %v, want %v" , recvBuf , msg )
186+ }
187+
188+ if len (args .serverMessage ) > 0 {
189+ gotServerMessage := make ([]byte , len (args .serverMessage ))
190+ if _ , err := c .Read (gotServerMessage ); err != nil {
191+ t .Errorf ("Got error while reading message from server: %v" , err )
192+ return
193+ }
194+ if string (gotServerMessage ) != string (args .serverMessage ) {
195+ t .Errorf ("Message from server: %v, want %v" , gotServerMessage , args .serverMessage )
196+ }
159197 }
160198}
161199
162200func (s ) TestHTTPConnect (t * testing.T ) {
163- testHTTPConnect ( t ,
164- func (in * url.URL ) * url.URL {
201+ args := testArgs {
202+ proxyURLModify : func (in * url.URL ) * url.URL {
165203 return in
166204 },
167- func (req * http.Request ) error {
205+ proxyReqCheck : func (req * http.Request ) error {
168206 if req .Method != http .MethodConnect {
169207 return fmt .Errorf ("unexpected Method %q, want %q" , req .Method , http .MethodConnect )
170208 }
171209 return nil
172210 },
173- )
211+ }
212+ testHTTPConnect (t , args )
213+ }
214+
215+ func (s ) TestHTTPConnectWithServerHello (t * testing.T ) {
216+ args := testArgs {
217+ proxyURLModify : func (in * url.URL ) * url.URL {
218+ return in
219+ },
220+ proxyReqCheck : func (req * http.Request ) error {
221+ if req .Method != http .MethodConnect {
222+ return fmt .Errorf ("unexpected Method %q, want %q" , req .Method , http .MethodConnect )
223+ }
224+ return nil
225+ },
226+ serverMessage : []byte ("server-hello" ),
227+ }
228+ testHTTPConnect (t , args )
174229}
175230
176231func (s ) TestHTTPConnectBasicAuth (t * testing.T ) {
177232 const (
178233 user = "notAUser"
179234 password = "notAPassword"
180235 )
181- testHTTPConnect ( t ,
182- func (in * url.URL ) * url.URL {
236+ args := testArgs {
237+ proxyURLModify : func (in * url.URL ) * url.URL {
183238 in .User = url .UserPassword (user , password )
184239 return in
185240 },
186- func (req * http.Request ) error {
241+ proxyReqCheck : func (req * http.Request ) error {
187242 if req .Method != http .MethodConnect {
188243 return fmt .Errorf ("unexpected Method %q, want %q" , req .Method , http .MethodConnect )
189244 }
@@ -195,7 +250,8 @@ func (s) TestHTTPConnectBasicAuth(t *testing.T) {
195250 }
196251 return nil
197252 },
198- )
253+ }
254+ testHTTPConnect (t , args )
199255}
200256
201257func (s ) TestMapAddressEnv (t * testing.T ) {
0 commit comments