Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 72 additions & 21 deletions pkg/tlsconfig/certconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,35 @@ func (lc *LoadedCertificate) IsValid() bool {
return lc.valid
}

// loadCertificateConfig is an internal config for LoadCertificate.
type loadCertificateConfig struct {
// ignoreFilePermissions indicates if file permissions should be ignored during load.
ignoreFilePermissions bool
}

// LoadCertificateOpt are functions to change the behavior of LoadCertificate.
type LoadCertificateOpt func(*loadCertificateConfig)

// WithLoadCertificateIgnoreFilePermissions instructs LoadCertificate to ignore file permissions
// if ignore is true.
func WithLoadCertificateIgnoreFilePermissions(ignore bool) LoadCertificateOpt {
return func(c *loadCertificateConfig) {
c.ignoreFilePermissions = ignore
}
}

// LoadCertificate loads a key pair from certPath and keyPath, performing several checks
// along the way. If any checks fail or an error occurs loading the files, then an error is returned.
// If keyPath is empty, then certPath is assumed to contain both the certificate and the private key.
// Only trusted input (standard configuration files) should be used for certPath and keyPath.
func LoadCertificate(certPath, keyPath string) (LoadedCertificate, error) {
func LoadCertificate(certPath, keyPath string, opts ...LoadCertificateOpt) (LoadedCertificate, error) {
fail := func(err error) (LoadedCertificate, error) { return LoadedCertificate{valid: false}, err }

config := loadCertificateConfig{}
for _, o := range opts {
o(&config)
}

if certPath == "" {
return fail(fmt.Errorf("LoadCertificate: certificate: %w", ErrPathEmpty))
}
Expand Down Expand Up @@ -95,9 +117,11 @@ func LoadCertificate(certPath, keyPath string) (LoadedCertificate, error) {
}
}()

if err := file.VerifyFilePermissivenessF(f, maxPerms); err != nil {
// VerifyFilePermissivenessF includes a lot context in its errors. No need to add duplicate here.
return nil, fmt.Errorf("LoadCertificate: %w", err)
if !config.ignoreFilePermissions {
if err := file.VerifyFilePermissivenessF(f, maxPerms); err != nil {
// VerifyFilePermissivenessF includes a lot context in its errors. No need to add duplicate here.
return nil, fmt.Errorf("LoadCertificate: %w", err)
}
}
data, err := io.ReadAll(f)
if err != nil {
Expand Down Expand Up @@ -157,12 +181,18 @@ type TLSCertLoader struct {
// certificateCheckInterval determines the duration between each certificate check.
certificateCheckInterval time.Duration

// ignoreFilePermissions is true if file permission checks should be bypassed.
ignoreFilePermissions bool

// closeOnce is used to close closeCh exactly one time.
closeOnce sync.Once

// closeCh is used to trigger closing the monitor.
closeCh chan struct{}

// monitorStartWg can be used to detect if the monitor goroutine has started.
monitorStartWg sync.WaitGroup

// mu protects all members below.
mu sync.Mutex

Expand All @@ -181,28 +211,35 @@ type TLSCertLoader struct {

type TLSCertLoaderOpt func(*TLSCertLoader)

// WithExpirationAdvanced sets the how far ahead a CertLoader will
// WithCertLoaderExpirationAdvanced sets the how far ahead a CertLoader will
// warn about a certificate that is about to expire.
func WithExpirationAdvanced(d time.Duration) TLSCertLoaderOpt {
func WithCertLoaderExpirationAdvanced(d time.Duration) TLSCertLoaderOpt {
return func(cl *TLSCertLoader) {
cl.expirationAdvanced = d
}
}

// WithCertificateCheckInterval sets how often to check for certificate expiration.
func WithCertificateCheckInterval(d time.Duration) TLSCertLoaderOpt {
// WithCertLoaderCertificateCheckInterval sets how often to check for certificate expiration.
func WithCertLoaderCertificateCheckInterval(d time.Duration) TLSCertLoaderOpt {
return func(cl *TLSCertLoader) {
cl.certificateCheckInterval = d
}
}

// WithLogger assigns a logger for to use.
func WithLogger(logger *zap.Logger) TLSCertLoaderOpt {
// WithCertLoaderLogger assigns a logger for to use.
func WithCertLoaderLogger(logger *zap.Logger) TLSCertLoaderOpt {
return func(cl *TLSCertLoader) {
cl.logger = logger
}
}

// WithCertLoaderIgnoreFilePermissions skips file permission checking when loading certificates.
func WithCertLoaderIgnoreFilePermissions(ignore bool) TLSCertLoaderOpt {
return func(cl *TLSCertLoader) {
cl.ignoreFilePermissions = ignore
}
}

// NewTLSCertLoader creates a TLSCertLoader loaded with the certifcate found in certPath and keyPath.
// Only trusted input (standard configuration files) should be used for certPath and keyPath.
// If the certificate can not be loaded, an error is returned. On success, a monitor is setup to
Expand Down Expand Up @@ -230,10 +267,8 @@ func NewTLSCertLoader(certPath, keyPath string, opts ...TLSCertLoaderOpt) (rCert
}

// Start monitoring certificate.
var monitorWg sync.WaitGroup
monitorWg.Add(1)
go cl.monitorCert(&monitorWg)
monitorWg.Wait()
cl.monitorStartWg.Add(1)
go cl.monitorCert(&cl.monitorStartWg)

return cl, nil
}
Expand Down Expand Up @@ -308,16 +343,22 @@ func (cl *TLSCertLoader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate,
// certificate.
func (cl *TLSCertLoader) GetClientCertificate(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
if cri == nil {
return new(tls.Certificate), ErrCertificateRequestInfoNil
return new(tls.Certificate), fmt.Errorf("tls client: %w", ErrCertificateRequestInfoNil)
}
cert := cl.Certificate()
if cert == nil {
return new(tls.Certificate), ErrCertificateNil
return new(tls.Certificate), fmt.Errorf("tls client: %w", ErrCertificateNil)
}
if err := cri.SupportsCertificate(cert); err != nil {
return new(tls.Certificate), err

// Will our certificate be accepted by server?
if err := cri.SupportsCertificate(cert); err == nil {
return cert, nil
}
return cert, nil

// We don't have a certificate that would be accepted by the server. Don't return an error.
// This replicates Go's behavior when tls.Config.Certificates is used instead of GetClientCertificate
// and gives a better error on both the client and server side.
return new(tls.Certificate), nil
}

// Leaf returns the parsed x509 certificate of the currently loaded certificate.
Expand All @@ -328,13 +369,17 @@ func (cl *TLSCertLoader) Leaf() *x509.Certificate {
return cl.leaf
}

func (cl *TLSCertLoader) loadCertificate(certPath, keyPath string) (LoadedCertificate, error) {
return LoadCertificate(certPath, keyPath, WithLoadCertificateIgnoreFilePermissions(cl.ignoreFilePermissions))
}

// Load loads the certificate at the given certificate path and private keyfile path.
// Only trusted input (standard configuration files) should be used for certPath and keyPath.
func (cl *TLSCertLoader) Load(certPath, keyPath string) (rErr error) {
log, logEnd := logger.NewOperation(cl.logger, "Loading TLS certificate", "tls_load_cert", zap.String("cert", certPath), zap.String("key", keyPath))
defer logEnd()

loadedCert, err := LoadCertificate(certPath, keyPath)
loadedCert, err := cl.loadCertificate(certPath, keyPath)

cl.mu.Lock()
defer cl.mu.Unlock()
Expand Down Expand Up @@ -365,7 +410,7 @@ func (cl *TLSCertLoader) Load(certPath, keyPath string) (rErr error) {
// If the certificate can be loaded, a function that will apply the certificate reload is
// returned. Otherwise, an error is returned.
func (cl *TLSCertLoader) PrepareLoad(certPath, keyPath string) (func() error, error) {
loadedCert, err := LoadCertificate(certPath, keyPath)
loadedCert, err := cl.loadCertificate(certPath, keyPath)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -457,6 +502,12 @@ func (cl *TLSCertLoader) checkCurrentCert() {
}
}

// WaitForMonitorStart will wait for the certificate monitor goroutine to start. This is mainly useful
// for tests to avoid race conditions.
func (cl *TLSCertLoader) WaitForMonitorStart() {
cl.monitorStartWg.Wait()
}

// monitorCert periodically logs errors with the currently loaded certificate.
func (cl *TLSCertLoader) monitorCert(wg *sync.WaitGroup) {
cl.logger.Info("Starting TLS certificate monitor")
Expand Down
27 changes: 20 additions & 7 deletions pkg/tlsconfig/certconfig_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ func TestTLSCertLoader_HappyPath(t *testing.T) {
logger := zap.New(core)

// Start cert loader
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithLogger(logger))
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithCertLoaderLogger(logger))
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()
cl.WaitForMonitorStart()
cl.WaitForMonitorStart() // should be able to safely call multiple times

{
// Check for expected log output
Expand Down Expand Up @@ -108,12 +110,13 @@ func TestTLSCertLoader_GoodCertPersists(t *testing.T) {
logger := zap.New(core)

// Start cert loader
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithLogger(logger))
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithCertLoaderLogger(logger))
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()
cl.WaitForMonitorStart()

var goodSerial big.Int
{
Expand Down Expand Up @@ -281,12 +284,13 @@ func TestTLSCertLoader_PrematureCertificateLogging(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)

cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithLogger(logger), WithCertificateCheckInterval(testCheckTime))
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithCertLoaderLogger(logger), WithCertLoaderCertificateCheckInterval(testCheckTime))
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()
cl.WaitForMonitorStart()

checkWarning := func(t *testing.T) {
warning := logs.FilterMessage("Certificate is not valid yet").TakeAll()
Expand All @@ -313,12 +317,13 @@ func TestTLSCertLoader_ExpiredCertificateLogging(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)

cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithLogger(logger), WithCertificateCheckInterval(testCheckTime))
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithCertLoaderLogger(logger), WithCertLoaderCertificateCheckInterval(testCheckTime))
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()
cl.WaitForMonitorStart()

checkWarning := func(t *testing.T) {
warning := logs.FilterMessage("Certificate is expired").TakeAll()
Expand Down Expand Up @@ -346,12 +351,16 @@ func TestTLSCertLoader_CertificateExpiresSoonLogging(t *testing.T) {
core, logs := observer.New(zapcore.InfoLevel)
logger := zap.New(core)

cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath, WithLogger(logger), WithCertificateCheckInterval(testCheckTime), WithExpirationAdvanced(2*24*time.Hour))
cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath,
WithCertLoaderLogger(logger),
WithCertLoaderCertificateCheckInterval(testCheckTime),
WithCertLoaderExpirationAdvanced(2*24*time.Hour))
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()
cl.WaitForMonitorStart()

checkWarning := func(t *testing.T) {
warning := logs.FilterMessage("Certificate will expire soon").TakeAll()
Expand Down Expand Up @@ -534,8 +543,10 @@ func TestTLSCertLoader_GetClientCertificate(t *testing.T) {
},
}

// We should get an empty certificate with no error. This replicates Go's behavior when
// tls.Config.Certificates is used and none of the certificates are accepted by the server.
cert, err := cl.GetClientCertificate(cri)
require.ErrorContains(t, err, "doesn't support any of the certificate's signature algorithms")
require.NoError(t, err)
// GetClientCertificate must return a non-nil certificate even on error
// (per the tls.Config.GetClientCertificate contract).
require.NotNil(t, cert)
Expand Down Expand Up @@ -591,8 +602,10 @@ func TestTLSCertLoader_GetClientCertificate(t *testing.T) {
AcceptableCAs: [][]byte{parsedCA2.RawSubject},
}

// This should return an empty certificate with no error to replicate
// Go's behavior when tls.Config.Certificates is used.
cert, err := cl.GetClientCertificate(cri)
require.ErrorContains(t, err, "not signed by an acceptable CA")
require.NoError(t, err)
require.NotNil(t, cert)
require.Empty(t, cert.Certificate)
})
Expand Down
Loading