Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions cert/cert.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ const (
duration365d = time.Hour * 24 * 365
)

var (
ErrStaticCert = errors.New("cannot renew static certificate")
)

// Config contains the basic fields required for creating a certificate
type Config struct {
CommonName string
Expand Down Expand Up @@ -119,7 +123,13 @@ func NewSignedCert(cfg Config, key crypto.Signer, caCert *x509.Certificate, caKe
if err != nil {
return nil, err
}
return x509.ParseCertificate(certDERBytes)

parsedCert, err := x509.ParseCertificate(certDERBytes)
if err == nil {
logrus.Infof("certificate %s signed by %s: notBefore=%s notAfter=%s",
parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter)
}
return parsedCert, err
}

// MakeEllipticPrivateKeyPEM creates an ECDSA private key
Expand Down Expand Up @@ -271,11 +281,11 @@ func ipsToStrings(ips []net.IP) []string {
}

// IsCertExpired checks if the certificate about to expire
func IsCertExpired(cert *x509.Certificate) bool {
func IsCertExpired(cert *x509.Certificate, days int) bool {
expirationDate := cert.NotAfter
diffDays := expirationDate.Sub(time.Now()).Hours() / 24.0
if diffDays <= 90 {
logrus.Infof("certificate will expire in %f days", diffDays)
diffDays := time.Until(expirationDate).Hours() / 24.0
if diffDays <= float64(days) {
logrus.Infof("certificate %s will expire in %f days at %s", cert.Subject, diffDays, cert.NotAfter)
return true
}
return false
Expand Down
6 changes: 3 additions & 3 deletions cert/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ func CanReadCertAndKey(certPath, keyPath string) (bool, error) {
certReadable := canReadFile(certPath)
keyReadable := canReadFile(keyPath)

if certReadable == false && keyReadable == false {
if !certReadable && !keyReadable {
return false, nil
}

if certReadable == false {
if !certReadable {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", certPath)
}

if keyReadable == false {
if !keyReadable {
return false, fmt.Errorf("error reading %s, certificate and key must be supplied as a pair", keyPath)
}

Expand Down
9 changes: 8 additions & 1 deletion factory/cert_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"math/big"
"net"
"time"

"github.com/sirupsen/logrus"
)

const (
Expand Down Expand Up @@ -92,7 +94,12 @@ func NewSignedCert(signer crypto.Signer, caCert *x509.Certificate, caKey crypto.
return nil, err
}

return x509.ParseCertificate(cert)
parsedCert, err := x509.ParseCertificate(cert)
if err == nil {
logrus.Infof("certificate %s signed by %s: notBefore=%s notAfter=%s",
parsedCert.Subject, caCert.Subject, parsedCert.NotBefore, parsedCert.NotAfter)
}
return parsedCert, err
}

func ParseCertPEM(pemCerts []byte) (*x509.Certificate, error) {
Expand Down
77 changes: 52 additions & 25 deletions factory/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/sha1"
"crypto/x509"
"encoding/hex"
"encoding/pem"
"fmt"
"net"
"regexp"
"sort"
Expand All @@ -20,9 +20,9 @@ import (
)

const (
cnPrefix = "listener.cattle.io/cn-"
Static = "listener.cattle.io/static"
hashKey = "listener.cattle.io/hash"
cnPrefix = "listener.cattle.io/cn-"
Static = "listener.cattle.io/static"
fingerprint = "listener.cattle.io/fingerprint"
)

var (
Expand All @@ -49,16 +49,14 @@ func cns(secret *v1.Secret) (cns []string) {
return
}

func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string, err error) {
func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, err error) {
var (
cns = cns(secret)
digest = sha256.New()
cns = cns(secret)
)

sort.Strings(cns)

for _, v := range cns {
digest.Write([]byte(v))
ip := net.ParseIP(v)
if ip == nil {
domains = append(domains, v)
Expand All @@ -67,40 +65,61 @@ func collectCNs(secret *v1.Secret) (domains []string, ips []net.IP, hash string,
}
}

hash = hex.EncodeToString(digest.Sum(nil))
return
}

// Merge combines the SAN lists from the target and additional Secrets, and returns a potentially modified Secret,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I love all of the doc comments you added. Thank you very much.

// along with a bool indicating if the returned Secret has been updated or not. If the two SAN lists alread matched
// and no merging was necessary, but the Secrets' certificate fingerprints differed, the second secret is returned
// and the updated bool is set to true despite neither certificate having actually been modified. This is required
// to support handling certificate renewal within the kubernetes storage provider.
func (t *TLS) Merge(target, additional *v1.Secret) (*v1.Secret, bool, error) {
return t.AddCN(target, cns(additional)...)
secret, updated, err := t.AddCN(target, cns(additional)...)
if !updated {
if target.Annotations[fingerprint] != additional.Annotations[fingerprint] {
secret = additional
updated = true
}
}
return secret, updated, err
}

func (t *TLS) Refresh(secret *v1.Secret) (*v1.Secret, error) {
// Renew returns a copy of the given certificate that has been re-signed
// to extend the NotAfter date. It is an error to attempt to renew
// a static (user-provided) certificate.
func (t *TLS) Renew(secret *v1.Secret) (*v1.Secret, error) {
if IsStatic(secret) {
return secret, cert.ErrStaticCert
}
cns := cns(secret)
secret = secret.DeepCopy()
secret.Annotations = map[string]string{}
secret, _, err := t.AddCN(secret, cns...)
secret, _, err := t.generateCert(secret, cns...)
return secret, err
}

// Filter ensures that the CNs are all valid accorting to both internal logic, and any filter callbacks.
// The returned list will contain only approved CN entries.
func (t *TLS) Filter(cn ...string) []string {
if t.FilterCN == nil {
if len(cn) == 0 || t.FilterCN == nil {
return cn
}
return t.FilterCN(cn...)
}

// AddCN attempts to add a list of CN strings to a given Secret, returning the potentially-modified
// Secret along with a bool indicating whether or not it has been updated. The Secret will not be changed
// if it has an attribute indicating that it is static (aka user-provided), or if no new CNs were added.
func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
var (
err error
)

cn = t.Filter(cn...)

if !NeedsUpdate(0, secret, cn...) {
if IsStatic(secret) || !NeedsUpdate(0, secret, cn...) {
return secret, false, nil
}
return t.generateCert(secret, cn...)
}

func (t *TLS) generateCert(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
secret = secret.DeepCopy()
if secret == nil {
secret = &v1.Secret{}
Expand All @@ -113,7 +132,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
return nil, false, err
}

domains, ips, hash, err := collectCNs(secret)
domains, ips, err := collectCNs(secret)
if err != nil {
return nil, false, err
}
Expand All @@ -133,7 +152,7 @@ func (t *TLS) AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error) {
}
secret.Data[v1.TLSCertKey] = certBytes
secret.Data[v1.TLSPrivateKeyKey] = keyBytes
secret.Annotations[hashKey] = hash
secret.Annotations[fingerprint] = fmt.Sprintf("SHA1=%X", sha1.Sum(newCert.Raw))

return secret, true, nil
}
Expand All @@ -157,15 +176,21 @@ func populateCN(secret *v1.Secret, cn ...string) *v1.Secret {
return secret
}

// IsStatic returns true if the Secret has an attribute indicating that it contains
// a static (aka user-provided) certificate, which should not be modified.
func IsStatic(secret *v1.Secret) bool {
return secret.Annotations[Static] == "true"
}

// NeedsUpdate returns true if any of the CNs are not currently present on the
// secret's Certificate, as recorded in the cnPrefix annotations. It will return
// false if all requested CNs are already present, or if maxSANs is non-zero and has
// been exceeded.
func NeedsUpdate(maxSANs int, secret *v1.Secret, cn ...string) bool {
if secret == nil {
return true
}

if secret.Annotations[Static] == "true" {
return false
}

for _, cn := range cn {
if secret.Annotations[cnPrefix+cn] == "" {
if maxSANs > 0 && len(cns(secret)) >= maxSANs {
Expand All @@ -192,6 +217,7 @@ func getPrivateKey(secret *v1.Secret) (crypto.Signer, error) {
return NewPrivateKey()
}

// Marshal returns the given cert and key as byte slices.
func Marshal(x509Cert *x509.Certificate, privateKey crypto.Signer) ([]byte, []byte, error) {
certBlock := pem.Block{
Type: CertificateBlockType,
Expand All @@ -206,6 +232,7 @@ func Marshal(x509Cert *x509.Certificate, privateKey crypto.Signer) ([]byte, []by
return pem.EncodeToMemory(&certBlock), keyBytes, nil
}

// NewPrivateKey returnes a new ECDSA key
func NewPrivateKey() (crypto.Signer, error) {
return ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
}
46 changes: 29 additions & 17 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"sync"
"time"

"github.com/rancher/dynamiclistener/cert"
"github.com/rancher/dynamiclistener/factory"
"github.com/sirupsen/logrus"
v1 "k8s.io/api/core/v1"
Expand All @@ -22,7 +23,7 @@ type TLSStorage interface {
}

type TLSFactory interface {
Refresh(secret *v1.Secret) (*v1.Secret, error)
Renew(secret *v1.Secret) (*v1.Secret, error)
Copy link

@dweomer dweomer Aug 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would like to understand the reasoning behind changing this exported interface method name. Is it because dynamiclistener is the only client of such therefore no big deal?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The industry standard term for what we are doing here is renewal, not refreshing. See:
https://cabforum.org/wp-content/uploads/CA-Browser-Forum-BR-1.7.0-1.pdf section 4.6. Using standard terminology makes it clear what this method is responsible for doing.

Copy link

@dweomer dweomer Aug 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The industry standard term for what we are doing here is renewal, not refreshing. See:
https://cabforum.org/wp-content/uploads/CA-Browser-Forum-BR-1.7.0-1.pdf section 4.6. Using standard terminology makes it clear what this method is responsible for doing.

Sorry, I assumed this. I agree it is a better name but are there consequences that we might wish to avoid with such a breaking change in the public API? Yes we own it. I guess what I am asking is, is it only ever "internally" consumed? If so, fine. If not, are we comfortable breaking downstream on upgrade?

Copy link
Member Author

@brandond brandond Aug 10, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any direct use of the Refresh method in Rancher or K3s; it appears to only be called internal to this module. Given that we don't even have a README in this repo, I somewhat doubt anyone else is using this module at the moment, but I very well could be wrong.

AddCN(secret *v1.Secret, cn ...string) (*v1.Secret, bool, error)
Merge(target *v1.Secret, additional *v1.Secret) (*v1.Secret, bool, error)
Filter(cn ...string) []string
Expand Down Expand Up @@ -152,13 +153,18 @@ type listener struct {
func (l *listener) WrapExpiration(days int) net.Listener {
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(5 * time.Minute)
time.Sleep(30 * time.Second)

for {
wait := 6 * time.Hour
if err := l.checkExpiration(days); err != nil {
logrus.Errorf("failed to check and refresh dynamic cert: %v", err)
wait = 5 + time.Minute
logrus.Errorf("failed to check and renew dynamic cert: %v", err)
// Don't go into short retry loop if we're using a static (user-provided) cert.
// We will still check and print an error every six hours until the user updates the secret with
// a cert that is not about to expire. Hopefully this will prompt them to take action.
if err != cert.ErrStaticCert {
wait = 5 * time.Minute
}
}
select {
case <-ctx.Done():
Expand Down Expand Up @@ -191,22 +197,26 @@ func (l *listener) checkExpiration(days int) error {
return err
}

cert, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
keyPair, err := tls.X509KeyPair(secret.Data[v1.TLSCertKey], secret.Data[v1.TLSPrivateKeyKey])
if err != nil {
return err
}

certParsed, err := x509.ParseCertificate(cert.Certificate[0])
certParsed, err := x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return err
}

if time.Now().UTC().Add(time.Hour * 24 * time.Duration(days)).After(certParsed.NotAfter) {
secret, err := l.factory.Refresh(secret)
if cert.IsCertExpired(certParsed, days) {
secret, err := l.factory.Renew(secret)
if err != nil {
return err
}
return l.storage.Update(secret)
if err := l.storage.Update(secret); err != nil {
return err
}
// clear version to force cert reload
l.version = ""
}

return nil
Expand Down Expand Up @@ -304,7 +314,7 @@ func (l *listener) updateCert(cn ...string) error {
return err
}

if !factory.NeedsUpdate(l.maxSANs, secret, cn...) {
if !factory.IsStatic(secret) && !factory.NeedsUpdate(l.maxSANs, secret, cn...) {
return nil
}

Expand All @@ -324,13 +334,6 @@ func (l *listener) updateCert(cn ...string) error {
}
// clear version to force cert reload
l.version = ""
if l.conns != nil {
l.connLock.Lock()
for _, conn := range l.conns {
_ = conn.close()
}
l.connLock.Unlock()
}
}

return nil
Expand Down Expand Up @@ -366,6 +369,15 @@ func (l *listener) loadCert() (*tls.Certificate, error) {
return nil, err
}

// cert has changed, close closeWrapper wrapped connections
if l.conns != nil {
l.connLock.Lock()
for _, conn := range l.conns {
_ = conn.close()
}
l.connLock.Unlock()
}

l.cert = &cert
l.version = secret.ResourceVersion
return l.cert, nil
Expand Down
5 changes: 2 additions & 3 deletions storage/kubernetes/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ func (s *storage) saveInK8s(secret *v1.Secret) (*v1.Secret, error) {
if targetSecret.UID == "" {
logrus.Infof("Creating new TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Create(targetSecret)
} else {
logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Update(targetSecret)
}
logrus.Infof("Updating TLS secret for %v (count: %d): %v", targetSecret.Name, len(targetSecret.Annotations)-1, targetSecret.Annotations)
return s.secrets.Update(targetSecret)
}

func (s *storage) Update(secret *v1.Secret) (err error) {
Expand Down
Loading