diff --git a/allocate.go b/allocate.go index df7ce451..e42c9925 100644 --- a/allocate.go +++ b/allocate.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io" + "net/http" "os" "os/exec" "path/filepath" @@ -30,6 +31,8 @@ type Allocator interface { // Cancelling the allocator context will already perform this operation, // so normally there's no need to call Wait directly. Wait() + + getDialHeader() http.Header } // setupExecAllocator is similar to NewExecAllocator, but it allows NewContext @@ -327,6 +330,10 @@ func (a *ExecAllocator) Wait() { a.wg.Wait() } +func (a *ExecAllocator) getDialHeader() http.Header { + return nil +} + // ExecPath returns an ExecAllocatorOption which uses the given path to execute // browser processes. The given path can be an absolute path to a binary, or // just the name of the program to find via exec.LookPath. @@ -552,6 +559,8 @@ type RemoteAllocator struct { modifyURLFunc func(ctx context.Context, wsURL string) (string, error) wg sync.WaitGroup + + dialHeader http.Header } // Allocate satisfies the Allocator interface. @@ -583,6 +592,7 @@ func (a *RemoteAllocator) Allocate(ctx context.Context, opts ...BrowserOption) ( a.wg.Done() }() + opts = append(opts, WithDialHeaderBrowser(FromContext(ctx).Allocator.getDialHeader())) browser, err := NewBrowser(wctx, wsURL, opts...) if err != nil { return nil, err @@ -605,8 +615,23 @@ func (a *RemoteAllocator) Wait() { a.wg.Wait() } +func (a *RemoteAllocator) getDialHeader() http.Header { + return a.dialHeader +} + // NoModifyURL is a RemoteAllocatorOption that prevents the remote allocator // from modifying the websocket debugger URL passed to it. func NoModifyURL(a *RemoteAllocator) { a.modifyURLFunc = nil } + +func WithDialHeader(h http.Header) RemoteAllocatorOption { + return func(a *RemoteAllocator) { + if a.dialHeader == nil { + a.dialHeader = make(http.Header) + } + for k, v := range h { + a.dialHeader[k] = v + } + } +} diff --git a/browser.go b/browser.go index 9b857d94..d98581cb 100644 --- a/browser.go +++ b/browser.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log" + "net/http" "os" "sync" "sync/atomic" @@ -41,6 +42,7 @@ type Browser struct { closingGracefully chan struct{} dialTimeout time.Duration + dialHeader http.Header // pages keeps track of the attached targets, indexed by each's session // ID. The only reason this is a field is so that the tests can check the @@ -109,7 +111,7 @@ func NewBrowser(ctx context.Context, urlstr string, opts ...BrowserOption) (*Bro } var err error - b.conn, err = DialContext(dialCtx, urlstr, WithConnDebugf(b.dbgf)) + b.conn, err = DialContext(dialCtx, urlstr, b.dialHeader, WithConnDebugf(b.dbgf)) if err != nil { return nil, fmt.Errorf("could not dial %q: %w", urlstr, err) } @@ -358,3 +360,17 @@ func WithConsolef(f func(string, ...interface{})) BrowserOption { func WithDialTimeout(d time.Duration) BrowserOption { return func(b *Browser) { b.dialTimeout = d } } + +func WithDialHeaderBrowser(header http.Header) BrowserOption { + if header == nil { + return func(b *Browser) {} + } + return func(b *Browser) { + if b.dialHeader == nil { + b.dialHeader = make(http.Header) + } + for k, v := range header { + b.dialHeader[k] = v + } + } +} diff --git a/conn.go b/conn.go index edcb3af4..7ce60265 100644 --- a/conn.go +++ b/conn.go @@ -5,6 +5,7 @@ import ( "context" "io" "net" + "net/http" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" @@ -41,9 +42,13 @@ type Conn struct { } // DialContext dials the specified websocket URL using gobwas/ws. -func DialContext(ctx context.Context, urlstr string, opts ...DialOption) (*Conn, error) { +func DialContext(ctx context.Context, urlstr string, header http.Header, opts ...DialOption) (*Conn, error) { // connect - conn, br, _, err := ws.Dial(ctx, urlstr) + // h := FromContext(ctx).Allocator.getDialHeader() + dialer := ws.Dialer{ + Header: ws.HandshakeHeaderHTTP(header), + } + conn, br, _, err := dialer.Dial(ctx, urlstr) if err != nil { return nil, err } diff --git a/util.go b/util.go index adcd59d9..b07b53dc 100644 --- a/util.go +++ b/util.go @@ -95,6 +95,8 @@ func modifyURL(ctx context.Context, urlstr string) (string, error) { // to get "webSocketDebuggerUrl" in the response req, err := http.NewRequestWithContext(lctx, "GET", u.String(), nil) + req.Header = FromContext(ctx).Allocator.getDialHeader() + if err != nil { return "", err }