From e47f4abfaace92a737ed2f2f9054533aa0cab33a Mon Sep 17 00:00:00 2001 From: Geoffrey Wossum Date: Thu, 8 Jan 2026 17:37:55 -0600 Subject: [PATCH] feat: add TLSConfigManager.DialWithDialer method --- pkg/tlsconfig/configmanager.go | 11 ++++ pkg/tlsconfig/configmanager_test.go | 93 +++++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) diff --git a/pkg/tlsconfig/configmanager.go b/pkg/tlsconfig/configmanager.go index 62047bee286..941eca9d103 100644 --- a/pkg/tlsconfig/configmanager.go +++ b/pkg/tlsconfig/configmanager.go @@ -67,6 +67,7 @@ func (cm *TLSConfigManager) TLSCertLoader() *TLSCertLoader { return cm.certLoader } +// Return a net.Listener for network and address based on current configuration. func (cm *TLSConfigManager) Listen(network, address string) (net.Listener, error) { if cm.useTLS { return tls.Listen(network, address, cm.tlsConfig) @@ -75,6 +76,7 @@ func (cm *TLSConfigManager) Listen(network, address string) (net.Listener, error } } +// Dial a remote for network and addressing using the current configuration. func (cm *TLSConfigManager) Dial(network, address string) (net.Conn, error) { if cm.useTLS { return tls.Dial(network, address, cm.tlsConfig) @@ -83,6 +85,15 @@ func (cm *TLSConfigManager) Dial(network, address string) (net.Conn, error) { } } +// Dial a remote for network and addressing using the given dialer and current configuration. +func (cm *TLSConfigManager) DialWithDialer(dialer *net.Dialer, network, address string) (net.Conn, error) { + if cm.useTLS { + return tls.DialWithDialer(dialer, network, address, cm.tlsConfig) + } else { + return dialer.Dial(network, address) + } +} + // PrepareCertificateLoad is a wrapper for the TLSCertLoader's PrepareLoad method. If TLS is not // enabled, then a NOP callback is returned. func (cm *TLSConfigManager) PrepareCertificateLoad(certPath, keyPath string) (func() error, error) { diff --git a/pkg/tlsconfig/configmanager_test.go b/pkg/tlsconfig/configmanager_test.go index 438cec10659..bf8e7b838b8 100644 --- a/pkg/tlsconfig/configmanager_test.go +++ b/pkg/tlsconfig/configmanager_test.go @@ -657,3 +657,96 @@ func TestTLSConfigManager_Dial(t *testing.T) { require.ErrorContains(t, err, "address invalid:address:format: too many colons in address") }) } + +func TestTLSConfigManager_DialWithDialer(t *testing.T) { + testDialWithDialerConnection := func(t *testing.T, listener net.Listener, dial func(dialer *net.Dialer, addr string) (net.Conn, error)) { + t.Helper() + + testData := []byte("hello from client") + + // Server: accept connection and read data + serverResult := make(chan error, 1) + serverData := make(chan []byte, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + serverResult <- err + return + } + + buf := make([]byte, len(testData)) + var n int + n, err = conn.Read(buf) + err = errors.Join(err, conn.Close()) + serverData <- buf[:n] + serverResult <- err + }() + + // Bind to a specific local address to verify the dialer is used + localAddr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") + require.NoError(t, err) + dialer := &net.Dialer{LocalAddr: localAddr} + + // Client: connect and send data + conn, err := dial(dialer, listener.Addr().String()) + require.NoError(t, err) + + // Verify the connection's local address is from 127.0.0.1 (dialer's LocalAddr) + localTCPAddr, ok := conn.LocalAddr().(*net.TCPAddr) + require.True(t, ok) + require.Equal(t, "127.0.0.1", localTCPAddr.IP.String()) + + _, err = conn.Write(testData) + require.NoError(t, err) + defer func() { + require.NoError(t, conn.Close()) + }() + + require.NoError(t, <-serverResult) + require.Equal(t, testData, <-serverData) + } + + t.Run("dialer LocalAddr is used for plain TCP", func(t *testing.T) { + manager, err := NewTLSConfigManager(false, nil, "/any/cert.pem", "/any/key.pem", false) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close()) + }() + + // Create plain TCP server + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { + require.NoError(t, listener.Close()) + }() + + testDialWithDialerConnection(t, listener, func(dialer *net.Dialer, addr string) (net.Conn, error) { + return manager.DialWithDialer(dialer, "tcp", addr) + }) + }) + + t.Run("dialer LocalAddr is used for TLS", func(t *testing.T) { + ss := selfsigned.NewSelfSignedCert(t) + + manager, err := NewTLSConfigManager(true, nil, "", "", true) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close()) + }() + + // Create TLS server + cert, err := tls.LoadX509KeyPair(ss.CertPath, ss.KeyPath) + require.NoError(t, err) + listener, err := tls.Listen("tcp", "127.0.0.1:0", &tls.Config{ + Certificates: []tls.Certificate{cert}, + }) + require.NoError(t, err) + defer func() { + require.NoError(t, listener.Close()) + }() + + testDialWithDialerConnection(t, listener, func(dialer *net.Dialer, addr string) (net.Conn, error) { + return manager.DialWithDialer(dialer, "tcp", addr) + }) + }) +}