@@ -59,6 +59,12 @@ const (
5959 maxClientSubscriptionBuffer = 20000
6060)
6161
62+ const (
63+ httpScheme = "http"
64+ wsScheme = "ws"
65+ ipcScheme = "ipc"
66+ )
67+
6268// BatchElem is an element in a batch request.
6369type BatchElem struct {
6470 Method string
@@ -75,7 +81,7 @@ type BatchElem struct {
7581// Client represents a connection to an RPC server.
7682type Client struct {
7783 idgen func () ID // for subscriptions
78- isHTTP bool
84+ scheme string // connection type: http, ws or ipc
7985 services * serviceRegistry
8086
8187 idCounter uint32
@@ -111,6 +117,10 @@ type clientConn struct {
111117
112118func (c * Client ) newClientConn (conn ServerCodec ) * clientConn {
113119 ctx := context .WithValue (context .Background (), clientContextKey {}, c )
120+ // Http connections have already set the scheme
121+ if ! c .isHTTP () && c .scheme != "" {
122+ ctx = context .WithValue (ctx , "scheme" , c .scheme )
123+ }
114124 handler := newHandler (ctx , conn , c .idgen , c .services )
115125 return & clientConn {conn , handler }
116126}
@@ -136,7 +146,7 @@ func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, erro
136146 select {
137147 case <- ctx .Done ():
138148 // Send the timeout to dispatch so it can remove the request IDs.
139- if ! c .isHTTP {
149+ if ! c .isHTTP () {
140150 select {
141151 case c .reqTimeout <- op :
142152 case <- c .closing :
@@ -203,10 +213,18 @@ func newClient(initctx context.Context, connect reconnectFunc) (*Client, error)
203213}
204214
205215func initClient (conn ServerCodec , idgen func () ID , services * serviceRegistry ) * Client {
206- _ , isHTTP := conn .(* httpConn )
216+ scheme := ""
217+ switch conn .(type ) {
218+ case * httpConn :
219+ scheme = httpScheme
220+ case * websocketCodec :
221+ scheme = wsScheme
222+ case * jsonCodec :
223+ scheme = ipcScheme
224+ }
207225 c := & Client {
208226 idgen : idgen ,
209- isHTTP : isHTTP ,
227+ scheme : scheme ,
210228 services : services ,
211229 writeConn : conn ,
212230 close : make (chan struct {}),
@@ -219,7 +237,7 @@ func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *C
219237 reqSent : make (chan error , 1 ),
220238 reqTimeout : make (chan * requestOp ),
221239 }
222- if ! isHTTP {
240+ if ! c . isHTTP () {
223241 go c .dispatch (conn )
224242 }
225243 return c
@@ -250,7 +268,7 @@ func (c *Client) SupportedModules() (map[string]string, error) {
250268
251269// Close closes the client, aborting any in-flight requests.
252270func (c * Client ) Close () {
253- if c .isHTTP {
271+ if c .isHTTP () {
254272 return
255273 }
256274 select {
@@ -264,7 +282,7 @@ func (c *Client) Close() {
264282// This method only works for clients using HTTP, it doesn't have
265283// any effect for clients using another transport.
266284func (c * Client ) SetHeader (key , value string ) {
267- if ! c .isHTTP {
285+ if ! c .isHTTP () {
268286 return
269287 }
270288 conn := c .writeConn .(* httpConn )
@@ -298,7 +316,7 @@ func (c *Client) CallContext(ctx context.Context, result interface{}, method str
298316 }
299317 op := & requestOp {ids : []json.RawMessage {msg .ID }, resp : make (chan * jsonrpcMessage , 1 )}
300318
301- if c .isHTTP {
319+ if c .isHTTP () {
302320 err = c .sendHTTP (ctx , op , msg )
303321 } else {
304322 err = c .send (ctx , op , msg )
@@ -357,7 +375,7 @@ func (c *Client) BatchCallContext(ctx context.Context, b []BatchElem) error {
357375 }
358376
359377 var err error
360- if c .isHTTP {
378+ if c .isHTTP () {
361379 err = c .sendBatchHTTP (ctx , op , msgs )
362380 } else {
363381 err = c .send (ctx , op , msgs )
@@ -402,7 +420,7 @@ func (c *Client) Notify(ctx context.Context, method string, args ...interface{})
402420 }
403421 msg .ID = nil
404422
405- if c .isHTTP {
423+ if c .isHTTP () {
406424 return c .sendHTTP (ctx , op , msg )
407425 }
408426 return c .send (ctx , op , msg )
@@ -440,7 +458,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
440458 if chanVal .IsNil () {
441459 panic ("channel given to Subscribe must not be nil" )
442460 }
443- if c .isHTTP {
461+ if c .isHTTP () {
444462 return nil , ErrNotificationsUnsupported
445463 }
446464
@@ -642,3 +660,7 @@ func (c *Client) read(codec ServerCodec) {
642660 c .readOp <- readOp {msgs , batch }
643661 }
644662}
663+
664+ func (c * Client ) isHTTP () bool {
665+ return c .scheme == httpScheme
666+ }
0 commit comments