Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,13 @@ func (p *Listener) Addr() net.Addr {
return p.Listener.Addr()
}

// NewConn is used to wrap a net.Conn that may be speaking
// the proxy protocol into a proxyproto.Conn.
// NewConn is used to wrap a net.Conn that may be speaking the PROXY protocol
// into a proxyproto.Conn.
//
// NOTE: NewConn may interfere with previously set ReadDeadline on the provided net.Conn,
// because it sets a temporary deadline when detecting and reading the PROXY protocol header.
// If you need to enforce a specific ReadDeadline on the connection, be sure to call Conn.SetReadDeadline
// again after NewConn returns, to restore your desired deadline.
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// For v1 the header length is at most 108 bytes.
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
Expand All @@ -176,18 +181,20 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
// the initial scan. If there is an error parsing the header,
// it is returned and the socket is closed.
func (p *Conn) Read(b []byte) (int, error) {
p.once.Do(func() {
p.readErr = p.readHeader()
})
if p.readErr != nil {
return 0, p.readErr
// Ensure header processing runs at most once and surface any errors.
if err := p.ensureHeaderProcessed(); err != nil {
return 0, err
}

return p.reader.Read(b)
}

// Write wraps original conn.Write.
func (p *Conn) Write(b []byte) (int, error) {
// Ensure header processing has completed before writing.
if err := p.ensureHeaderProcessed(); err != nil {
return 0, err
}
return p.conn.Write(b)
}

Expand All @@ -199,7 +206,8 @@ func (p *Conn) Close() error {
// ProxyHeader returns the proxy protocol header, if any. If an error occurs
// while reading the proxy header, nil is returned.
func (p *Conn) ProxyHeader() *Header {
p.once.Do(func() { p.readErr = p.readHeader() })
// Ensure header processing runs at most once.
_ = p.ensureHeaderProcessed()
return p.header
}

Expand All @@ -210,7 +218,8 @@ func (p *Conn) ProxyHeader() *Header {
// from the proxy header even if the proxy header itself is
// syntactically correct.
func (p *Conn) LocalAddr() net.Addr {
p.once.Do(func() { p.readErr = p.readHeader() })
// Ensure header processing runs at most once.
_ = p.ensureHeaderProcessed()
if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
return p.conn.LocalAddr()
}
Expand All @@ -225,7 +234,8 @@ func (p *Conn) LocalAddr() net.Addr {
// from the proxy header even if the proxy header itself is
// syntactically correct.
func (p *Conn) RemoteAddr() net.Addr {
p.once.Do(func() { p.readErr = p.readHeader() })
// Ensure header processing runs at most once.
_ = p.ensureHeaderProcessed()
if p.header == nil || p.header.Command.IsLocal() || p.readErr != nil {
return p.conn.RemoteAddr()
}
Expand Down Expand Up @@ -291,11 +301,25 @@ func (p *Conn) SetWriteDeadline(t time.Time) error {
// readHeader reads the proxy protocol header from the connection.
func (p *Conn) readHeader() error {
// If the connection's readHeaderTimeout is more than 0,
// push our deadline back to now plus the timeout. This should only
// run on the connection, as we don't want to override the previous
// read deadline the user may have used.
// apply a temporary deadline without extending a user-configured
// deadline. If the user has no deadline, we use now + timeout.
if p.readHeaderTimeout > 0 {
if err := p.conn.SetReadDeadline(time.Now().Add(p.readHeaderTimeout)); err != nil {
var (
storedDeadline time.Time
hasDeadline bool
)
if t := p.readDeadline.Load(); t != nil {
storedDeadline = t.(time.Time)
hasDeadline = !storedDeadline.IsZero()
}

headerDeadline := time.Now().Add(p.readHeaderTimeout)
if hasDeadline && storedDeadline.Before(headerDeadline) {
// Clamp to the user's earlier deadline to avoid extending it.
headerDeadline = storedDeadline
}

if err := p.conn.SetReadDeadline(headerDeadline); err != nil {
return err
}
}
Expand All @@ -304,7 +328,7 @@ func (p *Conn) readHeader() error {

// If the connection's readHeaderTimeout is more than 0, undo the change to the
// deadline that we made above. Because we retain the readDeadline as part of our
// SetReadDeadline override, we know the user's desired deadline so we use that.
// SetReadDeadline override, we can restore the user's deadline (if any).
// Therefore, we check whether the error is a net.Timeout and if it is, we decide
// the proxy proto does not exist and set the error accordingly.
if p.readHeaderTimeout > 0 {
Expand Down Expand Up @@ -352,8 +376,23 @@ func (p *Conn) readHeader() error {
return err
}

// ensureHeaderProcessed runs header processing once.
func (p *Conn) ensureHeaderProcessed() error {
p.once.Do(func() {
p.readErr = p.readHeader()
})
if p.readErr != nil {
return p.readErr
}
return nil
}

// ReadFrom implements the io.ReaderFrom ReadFrom method.
func (p *Conn) ReadFrom(r io.Reader) (int64, error) {
// Ensure header processing has completed before reading/writing.
if err := p.ensureHeaderProcessed(); err != nil {
return 0, err
}
if rf, ok := p.conn.(io.ReaderFrom); ok {
return rf.ReadFrom(r)
}
Expand All @@ -362,9 +401,9 @@ func (p *Conn) ReadFrom(r io.Reader) (int64, error) {

// WriteTo implements io.WriterTo.
func (p *Conn) WriteTo(w io.Writer) (int64, error) {
p.once.Do(func() { p.readErr = p.readHeader() })
if p.readErr != nil {
return 0, p.readErr
// Ensure header processing has completed before reading/writing.
if err := p.ensureHeaderProcessed(); err != nil {
return 0, err
}

b := make([]byte, p.bufReader.Buffered())
Expand Down
Loading
Loading