Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions example_conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package proxyproto_test

import (
"net"
"time"

"github.com/pires/go-proxyproto"
)

func ExampleNewConn_default() {
serverConn, clientConn := net.Pipe()
defer func() { _ = serverConn.Close() }()
defer func() { _ = clientConn.Close() }()

go func() {
_, _ = clientConn.Write([]byte("x"))
_ = clientConn.Close()
}()

conn := proxyproto.NewConn(serverConn)
buf := make([]byte, 1)
_, _ = conn.Read(buf)
// Output:
}

func ExampleNewConn_withBufferSize() {
serverConn, clientConn := net.Pipe()
defer func() { _ = serverConn.Close() }()
defer func() { _ = clientConn.Close() }()

go func() {
_, _ = clientConn.Write([]byte("y"))
_ = clientConn.Close()
}()

conn := proxyproto.NewConn(serverConn, proxyproto.WithBufferSize(4096))
buf := make([]byte, 1)
_, _ = conn.Read(buf)
// Output:
}

func ExampleNewConn_withReadHeaderTimeout() {
serverConn, clientConn := net.Pipe()
defer func() { _ = serverConn.Close() }()
defer func() { _ = clientConn.Close() }()

go func() {
_, _ = clientConn.Write([]byte("z"))
_ = clientConn.Close()
}()

conn := proxyproto.NewConn(serverConn, proxyproto.SetReadHeaderTimeout(time.Second))
buf := make([]byte, 1)
_, _ = conn.Read(buf)
// Output:
}

func ExampleNewConn_withPolicy() {
serverConn, clientConn := net.Pipe()
defer func() { _ = serverConn.Close() }()
defer func() { _ = clientConn.Close() }()

go func() {
_, _ = clientConn.Write([]byte(proxyV1Line))
_, _ = clientConn.Write([]byte("p"))
_ = clientConn.Close()
}()

conn := proxyproto.NewConn(serverConn, proxyproto.WithPolicy(proxyproto.REQUIRE))
buf := make([]byte, 1)
_, _ = conn.Read(buf)
// Output:
}

func ExampleNewConn_combined() {
serverConn, clientConn := net.Pipe()
defer func() { _ = serverConn.Close() }()
defer func() { _ = clientConn.Close() }()

go func() {
_, _ = clientConn.Write([]byte("c"))
_ = clientConn.Close()
}()

conn := proxyproto.NewConn(serverConn,
proxyproto.WithBufferSize(2048),
proxyproto.SetReadHeaderTimeout(2*time.Second),
)
buf := make([]byte, 1)
_, _ = conn.Read(buf)
// Output:
}
136 changes: 136 additions & 0 deletions example_listener_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package proxyproto_test

import (
"net"
"time"

"github.com/pires/go-proxyproto"
)

// proxyV1Line is a minimal PROXY protocol v1 header for examples.
const proxyV1Line = "PROXY TCP4 192.168.1.1 192.168.1.2 12345 443\r\n"

func ExampleListener_default() {
l, _ := net.Listen("tcp", "127.0.0.1:0")
pl := &proxyproto.Listener{Listener: l}
defer func() { _ = pl.Close() }()

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_, _ = c.Write([]byte("x"))
_ = c.Close()
}
}()

conn, _ := pl.Accept()
if conn != nil {
buf := make([]byte, 1)
_, _ = conn.Read(buf)
_ = conn.Close()
}
// Output:
}

func ExampleListener_readHeaderTimeout() {
l, _ := net.Listen("tcp", "127.0.0.1:0")
pl := &proxyproto.Listener{
Listener: l,
ReadHeaderTimeout: 2 * time.Second,
}
defer func() { _ = pl.Close() }()

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_, _ = c.Write([]byte("a"))
_ = c.Close()
}
}()

conn, _ := pl.Accept()
if conn != nil {
_ = conn.SetReadDeadline(time.Now().Add(time.Second))
buf := make([]byte, 1)
_, _ = conn.Read(buf)
_ = conn.Close()
}
// Output:
}

func ExampleListener_readBufferSize() {
l, _ := net.Listen("tcp", "127.0.0.1:0")
pl := &proxyproto.Listener{
Listener: l,
ReadBufferSize: 4096,
}
defer func() { _ = pl.Close() }()

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_, _ = c.Write([]byte("b"))
_ = c.Close()
}
}()

conn, _ := pl.Accept()
if conn != nil {
buf := make([]byte, 1)
_, _ = conn.Read(buf)
_ = conn.Close()
}
// Output:
}

func ExampleListener_policyRequire() {
l, _ := net.Listen("tcp", "127.0.0.1:0")
pl := &proxyproto.Listener{
Listener: l,
Policy: func(net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil },
}
defer func() { _ = pl.Close() }()

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_, _ = c.Write([]byte(proxyV1Line))
_, _ = c.Write([]byte("p"))
_ = c.Close()
}
}()

conn, _ := pl.Accept()
if conn != nil {
buf := make([]byte, 1)
_, _ = conn.Read(buf)
_ = conn.Close()
}
// Output:
}

func ExampleListener_validateHeader() {
l, _ := net.Listen("tcp", "127.0.0.1:0")
pl := &proxyproto.Listener{
Listener: l,
ValidateHeader: func(*proxyproto.Header) error { return nil },
}
defer func() { _ = pl.Close() }()

go func() {
c, _ := net.Dial("tcp", pl.Addr().String())
if c != nil {
_, _ = c.Write([]byte(proxyV1Line))
_, _ = c.Write([]byte("v"))
_ = c.Close()
}
}()

conn, _ := pl.Accept()
if conn != nil {
buf := make([]byte, 1)
_, _ = conn.Read(buf)
_ = conn.Close()
}
// Output:
}
50 changes: 39 additions & 11 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,33 @@ var (
// Only one of Policy or ConnPolicy should be provided. If both are provided then
// a panic would occur during accept.
type Listener struct {
// Listener is the underlying listener.
Listener net.Listener
// Deprecated: use ConnPolicyFunc instead. This will be removed in future release.
Policy PolicyFunc
ConnPolicy ConnPolicyFunc
ValidateHeader Validator
Policy PolicyFunc
// ConnPolicy is the policy function for accepted connections.
ConnPolicy ConnPolicyFunc
// ValidateHeader is the validator function for the proxy header.
ValidateHeader Validator
// ReadHeaderTimeout is the timeout for reading the proxy header.
ReadHeaderTimeout time.Duration
// ReadBufferSize is the read buffer size for accepted connections. When > 0,
// each accepted connection uses this size for proxy header detection; 0 means default.
ReadBufferSize int
}

// Conn is used to wrap and underlying connection which
// may be speaking the Proxy Protocol. If it is, the RemoteAddr() will
// return the address of the client instead of the proxy address. Each connection
// will have its own readHeaderTimeout and readDeadline set by the Accept() call.
type Conn struct {
readDeadline atomic.Value // time.Time
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
readDeadline atomic.Value // time.Time
once sync.Once
readErr error
conn net.Conn
bufReader *bufio.Reader
// bufferSize is set when the client overrides via WithBufferSize; nil means use default.
bufferSize *int
header *Header
ProxyHeaderPolicy Policy
Validate Validator
Expand Down Expand Up @@ -89,6 +98,22 @@ func SetReadHeaderTimeout(t time.Duration) func(*Conn) {
}
}

// WithBufferSize sets the size of the read buffer used for proxy header detection.
// Values <= 0 are ignored and the default (256 bytes) is used. Values < 16 are
// effectively 16 due to bufio's minimum. The default is tuned for typical proxy
// protocol header lengths.
func WithBufferSize(length int) func(*Conn) {
return func(c *Conn) {
if length <= 0 {
return
}
p := new(int)
*p = length
c.bufferSize = p
c.bufReader = bufio.NewReaderSize(c.conn, length)
}
}

// Accept waits for and returns the next valid connection to the listener.
func (p *Listener) Accept() (net.Conn, error) {
for {
Expand Down Expand Up @@ -130,11 +155,14 @@ func (p *Listener) Accept() (net.Conn, error) {
}
}

newConn := NewConn(
conn,
opts := []func(*Conn){
WithPolicy(proxyHeaderPolicy),
ValidateHeader(p.ValidateHeader),
)
}
if p.ReadBufferSize > 0 {
opts = append(opts, WithBufferSize(p.ReadBufferSize))
}
newConn := NewConn(conn, opts...)

// If the ReadHeaderTimeout for the listener is unset, use the default timeout.
if p.ReadHeaderTimeout == 0 {
Expand Down
Loading
Loading