Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

Commit 523e98d

Browse files
committed
add test for RegisterProtocol
Signed-off-by: 楚贤 <[email protected]>
1 parent 92de56e commit 523e98d

File tree

3 files changed

+155
-2
lines changed

3 files changed

+155
-2
lines changed

pkg/httputils/http_util_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
package httputils
1818

1919
import (
20+
"crypto/tls"
2021
"encoding/json"
2122
"fmt"
2223
"math/rand"
2324
"net"
25+
"net/http"
2426
"sync"
2527
"testing"
2628
"time"
@@ -257,3 +259,36 @@ type testJSONReq struct {
257259
type testJSONRes struct {
258260
Sum int
259261
}
262+
263+
type testTransport struct {
264+
}
265+
266+
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
267+
return &http.Response{
268+
Proto: "HTTP/1.1",
269+
ProtoMajor: 1,
270+
ProtoMinor: 1,
271+
Body: http.NoBody,
272+
Status: http.StatusText(http.StatusOK),
273+
StatusCode: http.StatusOK,
274+
ContentLength: -1,
275+
}, nil
276+
}
277+
278+
func (s *HTTPUtilTestSuite) TestRegisterProtocol(c *check.C) {
279+
protocol := "test"
280+
RegisterProtocol(protocol, &testTransport{})
281+
resp, err := HTTPWithHeaders(http.MethodGet,
282+
protocol+"://test/test",
283+
map[string]string{
284+
"test": "test",
285+
},
286+
time.Second,
287+
&tls.Config{},
288+
)
289+
c.Assert(err, check.IsNil)
290+
defer resp.Body.Close()
291+
292+
c.Assert(resp, check.NotNil)
293+
c.Assert(resp.ContentLength, check.Equals, int64(-1))
294+
}

supernode/httpclient/origin_http_client.go

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,30 @@ type OriginHTTPClient interface {
4949

5050
// OriginClient is an implementation of the interface of OriginHTTPClient.
5151
type OriginClient struct {
52-
clientMap *sync.Map
52+
clientMap *sync.Map
53+
defaultHTTPClient *http.Client
5354
}
5455

5556
// NewOriginClient returns a new OriginClient.
5657
func NewOriginClient() OriginHTTPClient {
58+
defaultTransport := &http.Transport{
59+
Proxy: http.ProxyFromEnvironment,
60+
DialContext: (&net.Dialer{
61+
Timeout: 3 * time.Second,
62+
KeepAlive: 30 * time.Second,
63+
DualStack: true,
64+
}).DialContext,
65+
MaxIdleConns: 100,
66+
IdleConnTimeout: 90 * time.Second,
67+
TLSHandshakeTimeout: 10 * time.Second,
68+
ExpectContinueTimeout: 1 * time.Second,
69+
}
70+
httputils.RegisterProtocolOnTransport(defaultTransport)
5771
return &OriginClient{
5872
clientMap: &sync.Map{},
73+
defaultHTTPClient: &http.Client{
74+
Transport: defaultTransport,
75+
},
5976
}
6077
}
6178

@@ -195,7 +212,8 @@ func (client *OriginClient) HTTPWithHeaders(method, url string, headers map[stri
195212

196213
httpClientObject, existed := client.clientMap.Load(req.Host)
197214
if !existed {
198-
httpClientObject = http.DefaultClient
215+
// use client.defaultHTTPClient to support custom protocols
216+
httpClientObject = client.defaultHTTPClient
199217
}
200218

201219
httpClient, ok := httpClientObject.(*http.Client)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright The Dragonfly Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package httpclient
18+
19+
import (
20+
"io/ioutil"
21+
"net/http"
22+
"net/http/httptest"
23+
"testing"
24+
"time"
25+
26+
"github.com/go-check/check"
27+
28+
"github.com/dragonflyoss/Dragonfly/pkg/httputils"
29+
)
30+
31+
func init() {
32+
check.Suite(&OriginHTTPClientTestSuite{})
33+
}
34+
35+
func Test(t *testing.T) {
36+
check.TestingT(t)
37+
}
38+
39+
type OriginHTTPClientTestSuite struct {
40+
client *OriginClient
41+
}
42+
43+
func (s *OriginHTTPClientTestSuite) SetUpSuite(c *check.C) {
44+
s.client = NewOriginClient().(*OriginClient)
45+
}
46+
47+
func (s *OriginHTTPClientTestSuite) TearDownSuite(c *check.C) {
48+
}
49+
50+
func (s *OriginHTTPClientTestSuite) TestHTTPWithHeaders(c *check.C) {
51+
testString := "test bytes"
52+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53+
w.WriteHeader(http.StatusOK)
54+
w.Write([]byte(testString))
55+
if r.Method != "GET" {
56+
c.Errorf("Expected 'GET' request, got '%s'", r.Method)
57+
}
58+
}))
59+
defer ts.Close()
60+
61+
httptest.NewRecorder()
62+
resp, err := s.client.HTTPWithHeaders(http.MethodGet, ts.URL, map[string]string{}, time.Second)
63+
c.Check(err, check.IsNil)
64+
defer resp.Body.Close()
65+
66+
testBytes, err := ioutil.ReadAll(resp.Body)
67+
c.Check(err, check.IsNil)
68+
c.Check(string(testBytes), check.Equals, testString)
69+
}
70+
71+
type testTransport struct {
72+
}
73+
74+
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
75+
return &http.Response{
76+
Proto: "HTTP/1.1",
77+
ProtoMajor: 1,
78+
ProtoMinor: 1,
79+
Body: http.NoBody,
80+
Status: http.StatusText(http.StatusOK),
81+
StatusCode: http.StatusOK,
82+
ContentLength: -1,
83+
}, nil
84+
}
85+
86+
func (s *OriginHTTPClientTestSuite) TestRegisterTLSConfig(c *check.C) {
87+
protocol := "test"
88+
httputils.RegisterProtocol(protocol, &testTransport{})
89+
s.client.RegisterTLSConfig(protocol+"://test/test", true, nil)
90+
httpClientInterface, ok := s.client.clientMap.Load("test")
91+
c.Check(ok, check.Equals, true)
92+
httpClient, ok := httpClientInterface.(*http.Client)
93+
c.Check(ok, check.Equals, true)
94+
95+
resp, err := httpClient.Get(protocol + "://test/test")
96+
c.Assert(err, check.IsNil)
97+
defer resp.Body.Close()
98+
c.Assert(resp, check.NotNil)
99+
c.Assert(resp.ContentLength, check.Equals, int64(-1))
100+
}

0 commit comments

Comments
 (0)