Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions pkg/testing/selfsigned/selfsigned.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ type CertOptions struct {

// CombinedFile indicates if the certificate and key should be combined into a single file
CombinedFile bool

// CAOrganization sets the CA certificate's Subject.Organization field
CAOrganization string

// CACommonName sets the CA certificate's Subject.CommonName field
CACommonName string
}

type CertOpt func(*CertOptions)
Expand Down Expand Up @@ -82,6 +88,13 @@ func WithCombinedFile() CertOpt {
}
}

func WithCASubject(organization, commonName string) CertOpt {
return func(o *CertOptions) {
o.CAOrganization = organization
o.CACommonName = commonName
}
}

func NewSelfSignedCert(t *testing.T, opts ...CertOpt) *Cert {
t.Helper()
tmpdir := t.TempDir()
Expand All @@ -108,6 +121,14 @@ func NewSelfSignedCert(t *testing.T, opts ...CertOpt) *Cert {
options.NotAfter = time.Now().Add(7 * 24 * time.Hour)
}

if options.CAOrganization == "" {
options.CAOrganization = "my_test_ca"
}

if options.CACommonName == "" {
options.CACommonName = "My Test CA"
}

// Sanity check options.
require.NotEmpty(t, options.DNSNames)

Expand All @@ -129,8 +150,8 @@ func NewSelfSignedCert(t *testing.T, opts ...CertOpt) *Cert {
BasicConstraintsValid: true,

Subject: pkix.Name{
Organization: []string{"my_test_ca"},
CommonName: "My Test CA",
Organization: []string{options.CAOrganization},
CommonName: options.CACommonName,
},

IsCA: true,
Expand Down
38 changes: 34 additions & 4 deletions pkg/tlsconfig/certconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@ const (
)

var (
ErrCertificateNil = errors.New("TLS certificate is nil")
ErrCertificateEmpty = errors.New("TLS certificate is empty")
ErrLoadedCertificateInvalid = errors.New("LoadedCertificate is invalid")
ErrPathEmpty = errors.New("empty path")
ErrCertificateNil = errors.New("TLS certificate is nil")
ErrCertificateEmpty = errors.New("TLS certificate is empty")
ErrCertificateRequestInfoNil = errors.New("CertificateRequestInfo is nil")
ErrLoadedCertificateInvalid = errors.New("LoadedCertificate is invalid")
ErrPathEmpty = errors.New("empty path")
)

// LoadedCertificate encapsulates information about a loaded certificate.
Expand Down Expand Up @@ -275,6 +276,18 @@ func (cl *TLSCertLoader) Certificate() *tls.Certificate {
return cl.cert
}

// SetupTLSConfig modifies tlsConfig to use cl for server and client certificates.
// tlsConfig may be nil. If other fields like tlsConfig.Certificates or
// tlsConfig.NameToCertificate have been set, then cl's certificate may not be used
// as expected.
func (cl *TLSCertLoader) SetupTLSConfig(tlsConfig *tls.Config) {
if tlsConfig == nil {
return
}
tlsConfig.GetCertificate = cl.GetCertificate
tlsConfig.GetClientCertificate = cl.GetClientCertificate
}

// GetCertificate is for use with a tls.Config's GetCertificate member. This allows a
// tls.Config to dynamically update its certificate when Load changes the active
// certificate.
Expand All @@ -290,6 +303,23 @@ func (cl *TLSCertLoader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate,
}
}

// GetClientCertificate is for use with a tls.Config's GetClientCertificate member. This allows a
// tls.Config to dynamically update its client certificates when Load changes the active
// certificate.
func (cl *TLSCertLoader) GetClientCertificate(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
if cri == nil {
return new(tls.Certificate), ErrCertificateRequestInfoNil
}
cert := cl.Certificate()
if cert == nil {
return new(tls.Certificate), ErrCertificateNil
}
if err := cri.SupportsCertificate(cert); err != nil {
return new(tls.Certificate), err
}
return cert, nil
}

// Leaf returns the parsed x509 certificate of the currently loaded certificate.
// If no certificate is loaded then nil is returned.
func (cl *TLSCertLoader) Leaf() *x509.Certificate {
Expand Down
138 changes: 138 additions & 0 deletions pkg/tlsconfig/certconfig_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package tlsconfig

import (
"crypto/tls"
"crypto/x509"
"encoding/pem"
"fmt"
"math/big"
"os"
Expand Down Expand Up @@ -488,3 +490,139 @@ func TestTLSCertLoader_VerifyLoad(t *testing.T) {
require.Equal(t, sn2, cl.Leaf().SerialNumber.String())
}
}

func TestTLSCertLoader_GetClientCertificate(t *testing.T) {
ss := selfsigned.NewSelfSignedCert(t, selfsigned.WithDNSName("client.influxdata.edge"))

cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath)
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()

// Test happy path: certificate supports the request.
// The selfsigned package creates RSA certificates, so we use RSA signature schemes.
t.Run("supported certificate", func(t *testing.T) {
cri := &tls.CertificateRequestInfo{
SignatureSchemes: []tls.SignatureScheme{
tls.PKCS1WithSHA256,
tls.PKCS1WithSHA384,
tls.PKCS1WithSHA512,
},
}

cert, err := cl.GetClientCertificate(cri)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, cl.Certificate(), cert)
})

t.Run("nil CertificateRequestInfo", func(t *testing.T) {
cert, err := cl.GetClientCertificate(nil)
require.ErrorIs(t, err, ErrCertificateRequestInfoNil)
require.NotNil(t, cert)
require.Empty(t, cert.Certificate)
})

// Test unsupported certificate: CertificateRequestInfo only accepts Ed25519,
// but our certificate uses RSA.
t.Run("unsupported certificate", func(t *testing.T) {
cri := &tls.CertificateRequestInfo{
SignatureSchemes: []tls.SignatureScheme{
tls.Ed25519, // Our RSA cert doesn't support this
},
}

cert, err := cl.GetClientCertificate(cri)
require.ErrorContains(t, err, "doesn't support any of the certificate's signature algorithms")
// GetClientCertificate must return a non-nil certificate even on error
// (per the tls.Config.GetClientCertificate contract).
require.NotNil(t, cert)
// The returned certificate should be an empty certificate, not the loaded one.
require.NotEqual(t, cl.Certificate(), cert)
require.Empty(t, cert.Certificate)
})

// Test with AcceptableCAs that include our CA.
t.Run("acceptable CA", func(t *testing.T) {
// Verify that if we change cri to ss's CA subject then we do get cert.
caCert, err := os.ReadFile(ss.CACertPath)
require.NoError(t, err)

// Parse the CA cert to get its RawSubject for AcceptableCAs.
block, _ := pem.Decode(caCert)
require.NotNil(t, block)
parsedCA, err := x509.ParseCertificate(block.Bytes)
require.NoError(t, err)

cri := &tls.CertificateRequestInfo{
SignatureSchemes: []tls.SignatureScheme{
tls.PKCS1WithSHA256,
},
AcceptableCAs: [][]byte{parsedCA.RawSubject},
}

cert, err := cl.GetClientCertificate(cri)
require.NoError(t, err)
require.NotNil(t, cert)
require.Equal(t, cl.Certificate(), cert)
})

// Test with AcceptableCAs that don't include our CA.
t.Run("unacceptable CA", func(t *testing.T) {
// Create a certificate with a different CA subject.
ss2 := selfsigned.NewSelfSignedCert(t,
selfsigned.WithCASubject("different_org", "Different CA"),
)
caCert2, err := os.ReadFile(ss2.CACertPath)
require.NoError(t, err)

// Parse the CA cert to get its RawSubject for AcceptableCAs.
block2, _ := pem.Decode(caCert2)
require.NotNil(t, block2)
parsedCA2, err := x509.ParseCertificate(block2.Bytes)
require.NoError(t, err)

cri := &tls.CertificateRequestInfo{
SignatureSchemes: []tls.SignatureScheme{
tls.PKCS1WithSHA256,
},
AcceptableCAs: [][]byte{parsedCA2.RawSubject},
}

cert, err := cl.GetClientCertificate(cri)
require.ErrorContains(t, err, "not signed by an acceptable CA")
require.NotNil(t, cert)
require.Empty(t, cert.Certificate)
})
}

func TestTLSCertLoader_SetupTLSConfig(t *testing.T) {
ss := selfsigned.NewSelfSignedCert(t)

cl, err := NewTLSCertLoader(ss.CertPath, ss.KeyPath)
require.NoError(t, err)
require.NotNil(t, cl)
defer func() {
require.NoError(t, cl.Close())
}()

t.Run("nil config", func(t *testing.T) {
require.NotPanics(t, func() {
cl.SetupTLSConfig(nil)
})
})

t.Run("sets callbacks", func(t *testing.T) {
tlsConfig := &tls.Config{}

require.Nil(t, tlsConfig.GetCertificate)
require.Nil(t, tlsConfig.GetClientCertificate)

cl.SetupTLSConfig(tlsConfig)

require.NotNil(t, tlsConfig.GetCertificate)
require.NotNil(t, tlsConfig.GetClientCertificate)
})
}
2 changes: 1 addition & 1 deletion services/httpd/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (s *Service) Open() error {
}

tlsConfig := s.tlsConfig.Clone()
tlsConfig.GetCertificate = s.certLoader.GetCertificate
s.certLoader.SetupTLSConfig(tlsConfig)

listener, err := tls.Listen("tcp", s.addr, tlsConfig)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion services/opentsdb/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ func (s *Service) Open() error {
s.certLoader = certLoader

tlsConfig := s.tlsConfig.Clone()
tlsConfig.GetCertificate = s.certLoader.GetCertificate
s.certLoader.SetupTLSConfig(tlsConfig)

listener, err := tls.Listen("tcp", s.BindAddress, tlsConfig)
if err != nil {
Expand Down