Skip to content

Commit 5276ad4

Browse files
Merge pull request #17 from ibuildthecloud/dropconn
Add option to close connections on cert change
2 parents 3f92468 + 8545ce9 commit 5276ad4

File tree

1 file changed

+58
-5
lines changed

1 file changed

+58
-5
lines changed

listener.go

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@ func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, c
5959
}
6060
dynamicListener.tlsConfig.GetCertificate = dynamicListener.getCertificate
6161

62+
if config.CloseConnOnCertChange {
63+
dynamicListener.conns = map[int]*closeWrapper{}
64+
}
65+
6266
if setter, ok := storage.(SetFactory); ok {
6367
setter.SetFactory(dynamicListener.factory)
6468
}
@@ -82,17 +86,22 @@ func (c *cancelClose) Close() error {
8286
}
8387

8488
type Config struct {
85-
CN string
86-
Organization []string
87-
TLSConfig *tls.Config
88-
SANs []string
89-
ExpirationDaysCheck int
89+
CN string
90+
Organization []string
91+
TLSConfig *tls.Config
92+
SANs []string
93+
ExpirationDaysCheck int
94+
CloseConnOnCertChange bool
9095
}
9196

9297
type listener struct {
9398
sync.RWMutex
9499
net.Listener
95100

101+
conns map[int]*closeWrapper
102+
connID int
103+
connLock sync.Mutex
104+
96105
factory TLSFactory
97106
storage TLSStorage
98107
version string
@@ -194,9 +203,45 @@ func (l *listener) Accept() (net.Conn, error) {
194203
}
195204
}
196205

206+
if l.conns != nil {
207+
conn = l.wrap(conn)
208+
}
209+
197210
return conn, nil
198211
}
199212

213+
func (l *listener) wrap(conn net.Conn) net.Conn {
214+
l.connLock.Lock()
215+
defer l.connLock.Unlock()
216+
l.connID++
217+
218+
wrapper := &closeWrapper{
219+
Conn: conn,
220+
id: l.connID,
221+
l: l,
222+
}
223+
l.conns[l.connID] = wrapper
224+
225+
return wrapper
226+
}
227+
228+
type closeWrapper struct {
229+
net.Conn
230+
id int
231+
l *listener
232+
}
233+
234+
func (c *closeWrapper) close() error {
235+
delete(c.l.conns, c.id)
236+
return c.Conn.Close()
237+
}
238+
239+
func (c *closeWrapper) Close() error {
240+
c.l.Lock()
241+
defer c.l.Unlock()
242+
return c.close()
243+
}
244+
200245
func (l *listener) getCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
201246
if hello.ServerName != "" {
202247
if err := l.updateCert(hello.ServerName); err != nil {
@@ -238,6 +283,14 @@ func (l *listener) updateCert(cn ...string) error {
238283
l.version = ""
239284
}
240285

286+
if l.conns != nil {
287+
l.connLock.Lock()
288+
for _, conn := range l.conns {
289+
_ = conn.close()
290+
}
291+
l.connLock.Unlock()
292+
}
293+
241294
return nil
242295
}
243296

0 commit comments

Comments
 (0)