Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pkg/tlsconfig/configmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func (cm *TLSConfigManager) TLSCertLoader() *TLSCertLoader {
return cm.certLoader
}

// Return a net.Listener for network and address based on current configuration.
func (cm *TLSConfigManager) Listen(network, address string) (net.Listener, error) {
if cm.useTLS {
return tls.Listen(network, address, cm.tlsConfig)
Expand All @@ -75,6 +76,7 @@ func (cm *TLSConfigManager) Listen(network, address string) (net.Listener, error
}
}

// Dial a remote for network and addressing using the current configuration.
func (cm *TLSConfigManager) Dial(network, address string) (net.Conn, error) {
if cm.useTLS {
return tls.Dial(network, address, cm.tlsConfig)
Expand All @@ -83,6 +85,15 @@ func (cm *TLSConfigManager) Dial(network, address string) (net.Conn, error) {
}
}

// Dial a remote for network and addressing using the given dialer and current configuration.
func (cm *TLSConfigManager) DialWithDialer(dialer *net.Dialer, network, address string) (net.Conn, error) {
if cm.useTLS {
return tls.DialWithDialer(dialer, network, address, cm.tlsConfig)
} else {
return dialer.Dial(network, address)
}
}

// PrepareCertificateLoad is a wrapper for the TLSCertLoader's PrepareLoad method. If TLS is not
// enabled, then a NOP callback is returned.
func (cm *TLSConfigManager) PrepareCertificateLoad(certPath, keyPath string) (func() error, error) {
Expand Down
93 changes: 93 additions & 0 deletions pkg/tlsconfig/configmanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -657,3 +657,96 @@ func TestTLSConfigManager_Dial(t *testing.T) {
require.ErrorContains(t, err, "address invalid:address:format: too many colons in address")
})
}

func TestTLSConfigManager_DialWithDialer(t *testing.T) {
testDialWithDialerConnection := func(t *testing.T, listener net.Listener, dial func(dialer *net.Dialer, addr string) (net.Conn, error)) {
t.Helper()

testData := []byte("hello from client")

// Server: accept connection and read data
serverResult := make(chan error, 1)
serverData := make(chan []byte, 1)
go func() {
conn, err := listener.Accept()
if err != nil {
serverResult <- err
return
}

buf := make([]byte, len(testData))
var n int
n, err = conn.Read(buf)
err = errors.Join(err, conn.Close())
serverData <- buf[:n]
serverResult <- err
}()

// Bind to a specific local address to verify the dialer is used
localAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0")
require.NoError(t, err)
dialer := &net.Dialer{LocalAddr: localAddr}

// Client: connect and send data
conn, err := dial(dialer, listener.Addr().String())
require.NoError(t, err)

// Verify the connection's local address is from 127.0.0.1 (dialer's LocalAddr)
localTCPAddr, ok := conn.LocalAddr().(*net.TCPAddr)
require.True(t, ok)
require.Equal(t, "127.0.0.1", localTCPAddr.IP.String())

_, err = conn.Write(testData)
require.NoError(t, err)
defer func() {
require.NoError(t, conn.Close())
}()

require.NoError(t, <-serverResult)
require.Equal(t, testData, <-serverData)
}

t.Run("dialer LocalAddr is used for plain TCP", func(t *testing.T) {
manager, err := NewTLSConfigManager(false, nil, "/any/cert.pem", "/any/key.pem", false)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close())
}()

// Create plain TCP server
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
require.NoError(t, listener.Close())
}()

testDialWithDialerConnection(t, listener, func(dialer *net.Dialer, addr string) (net.Conn, error) {
return manager.DialWithDialer(dialer, "tcp", addr)
})
})

t.Run("dialer LocalAddr is used for TLS", func(t *testing.T) {
ss := selfsigned.NewSelfSignedCert(t)

manager, err := NewTLSConfigManager(true, nil, "", "", true)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close())
}()

// Create TLS server
cert, err := tls.LoadX509KeyPair(ss.CertPath, ss.KeyPath)
require.NoError(t, err)
listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{
Certificates: []tls.Certificate{cert},
})
require.NoError(t, err)
defer func() {
require.NoError(t, listener.Close())
}()

testDialWithDialerConnection(t, listener, func(dialer *net.Dialer, addr string) (net.Conn, error) {
return manager.DialWithDialer(dialer, "tcp", addr)
})
})
}