@@ -59,6 +59,10 @@ func NewListener(l net.Listener, storage TLSStorage, caCert *x509.Certificate, c
5959 }
6060 dynamicListener .tlsConfig .GetCertificate = dynamicListener .getCertificate
6161
62+ if config .CloseConnOnCertChange {
63+ dynamicListener .conns = map [int ]* closeWrapper {}
64+ }
65+
6266 if setter , ok := storage .(SetFactory ); ok {
6367 setter .SetFactory (dynamicListener .factory )
6468 }
@@ -82,17 +86,22 @@ func (c *cancelClose) Close() error {
8286}
8387
8488type Config struct {
85- CN string
86- Organization []string
87- TLSConfig * tls.Config
88- SANs []string
89- ExpirationDaysCheck int
89+ CN string
90+ Organization []string
91+ TLSConfig * tls.Config
92+ SANs []string
93+ ExpirationDaysCheck int
94+ CloseConnOnCertChange bool
9095}
9196
9297type listener struct {
9398 sync.RWMutex
9499 net.Listener
95100
101+ conns map [int ]* closeWrapper
102+ connID int
103+ connLock sync.Mutex
104+
96105 factory TLSFactory
97106 storage TLSStorage
98107 version string
@@ -194,9 +203,45 @@ func (l *listener) Accept() (net.Conn, error) {
194203 }
195204 }
196205
206+ if l .conns != nil {
207+ conn = l .wrap (conn )
208+ }
209+
197210 return conn , nil
198211}
199212
213+ func (l * listener ) wrap (conn net.Conn ) net.Conn {
214+ l .connLock .Lock ()
215+ defer l .connLock .Unlock ()
216+ l .connID ++
217+
218+ wrapper := & closeWrapper {
219+ Conn : conn ,
220+ id : l .connID ,
221+ l : l ,
222+ }
223+ l .conns [l .connID ] = wrapper
224+
225+ return wrapper
226+ }
227+
228+ type closeWrapper struct {
229+ net.Conn
230+ id int
231+ l * listener
232+ }
233+
234+ func (c * closeWrapper ) close () error {
235+ delete (c .l .conns , c .id )
236+ return c .Conn .Close ()
237+ }
238+
239+ func (c * closeWrapper ) Close () error {
240+ c .l .Lock ()
241+ defer c .l .Unlock ()
242+ return c .close ()
243+ }
244+
200245func (l * listener ) getCertificate (hello * tls.ClientHelloInfo ) (* tls.Certificate , error ) {
201246 if hello .ServerName != "" {
202247 if err := l .updateCert (hello .ServerName ); err != nil {
@@ -238,6 +283,14 @@ func (l *listener) updateCert(cn ...string) error {
238283 l .version = ""
239284 }
240285
286+ if l .conns != nil {
287+ l .connLock .Lock ()
288+ for _ , conn := range l .conns {
289+ _ = conn .close ()
290+ }
291+ l .connLock .Unlock ()
292+ }
293+
241294 return nil
242295}
243296
0 commit comments