Skip to content

Commit 72a6233

Browse files
committed
Plugin conn refactor
This changes things to rely on a plugin server that manages all connections made to the server. An optional handler can be passed into the server when the caller wants to do extra things with the connection. It is the caller's responsability to close the server. When the server is closed, first all existing connections are closed (and new connections are prevented). Now the signal loop only needs to close the server and not deal with `net.Conn`'s directly (or double-indirects as the case was before this change). Signed-off-by: Brian Goff <[email protected]>
1 parent 19d02cd commit 72a6233

File tree

6 files changed

+199
-76
lines changed

6 files changed

+199
-76
lines changed

cli-plugins/socket/socket.go

Lines changed: 89 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,106 @@ import (
77
"io"
88
"net"
99
"os"
10+
"runtime"
11+
"sync"
1012
)
1113

1214
// EnvKey represents the well-known environment variable used to pass the plugin being
1315
// executed the socket name it should listen on to coordinate with the host CLI.
1416
const EnvKey = "DOCKER_CLI_PLUGIN_SOCKET"
1517

16-
// SetupConn sets up a Unix socket listener, establishes a goroutine to handle connections
17-
// and update the conn pointer, and returns the listener for the socket (which the caller
18-
// is responsible for closing when it's no longer needed).
19-
func SetupConn(conn **net.UnixConn) (*net.UnixListener, error) {
20-
listener, err := listen("docker_cli_" + randomID())
18+
// NewPluginServer creates a plugin server that listens on a new Unix domain socket.
19+
// `h` is called for each new connection to the socket in a goroutine.
20+
func NewPluginServer(h func(net.Conn)) (*PluginServer, error) {
21+
l, err := listen("docker_cli_" + randomID())
2122
if err != nil {
2223
return nil, err
2324
}
2425

25-
accept(listener, conn)
26+
if h == nil {
27+
h = func(net.Conn) {}
28+
}
29+
30+
pl := &PluginServer{
31+
l: l,
32+
h: h,
33+
}
34+
35+
unlinkOnce := sync.OnceFunc(func() {
36+
unlink(l)
37+
})
38+
39+
go func() {
40+
defer pl.Close()
41+
for {
42+
err := pl.accept()
43+
unlinkOnce()
44+
if err != nil {
45+
return
46+
}
47+
}
48+
}()
49+
50+
return pl, nil
51+
}
52+
53+
type PluginServer struct {
54+
mu sync.Mutex
55+
conns []net.Conn
56+
l *net.UnixListener
57+
h func(net.Conn)
58+
closed bool
59+
}
60+
61+
func (l *PluginServer) accept() error {
62+
conn, err := l.l.Accept()
63+
if err != nil {
64+
return err
65+
}
66+
67+
l.mu.Lock()
68+
defer l.mu.Unlock()
69+
70+
if l.closed {
71+
// handle potential race condition between Close and Accept
72+
conn.Close()
73+
return errors.New("plugin server is closed")
74+
}
75+
76+
l.conns = append(l.conns, conn)
77+
78+
go l.h(conn)
79+
return nil
80+
}
81+
82+
func (l *PluginServer) Addr() net.Addr {
83+
return l.l.Addr()
84+
}
85+
86+
// Close ensures that the server is no longer accepting new connections and closes all existing connections.
87+
// Existing connections will receive [io.EOF].
88+
func (l *PluginServer) Close() error {
89+
// close connections first to ensure the connections get io.EOF instead of a connection reset.
90+
l.closeAllConns()
91+
92+
// Try to ensure that any active connections have a chance to receive io.EOF
93+
runtime.Gosched()
94+
95+
return l.l.Close()
96+
}
2697

27-
return listener, nil
98+
func (l *PluginServer) closeAllConns() {
99+
l.mu.Lock()
100+
defer l.mu.Unlock()
101+
102+
// Prevent new connections from being accepted
103+
l.closed = true
104+
105+
for _, conn := range l.conns {
106+
conn.Close()
107+
}
108+
109+
l.conns = nil
28110
}
29111

30112
func randomID() string {
@@ -35,18 +117,6 @@ func randomID() string {
35117
return hex.EncodeToString(b)
36118
}
37119

38-
func accept(listener *net.UnixListener, conn **net.UnixConn) {
39-
go func() {
40-
for {
41-
// ignore error here, if we failed to accept a connection,
42-
// conn is nil and we fallback to previous behavior
43-
*conn, _ = listener.AcceptUnix()
44-
// perform any platform-specific actions on accept (e.g. unlink non-abstract sockets)
45-
onAccept(*conn, listener)
46-
}
47-
}()
48-
}
49-
50120
// ConnectAndWait connects to the socket passed via well-known env var,
51121
// if present, and attempts to read from it until it receives an EOF, at which
52122
// point cb is called.

cli-plugins/socket/socket_darwin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) {
1414
})
1515
}
1616

17-
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
17+
func unlink(listener *net.UnixListener) {
1818
syscall.Unlink(listener.Addr().String())
1919
}

cli-plugins/socket/socket_nodarwin.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ func listen(socketname string) (*net.UnixListener, error) {
1313
})
1414
}
1515

16-
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
16+
func unlink(listener *net.UnixListener) {
1717
// do nothing
1818
// while on darwin and OpenBSD we would unlink here;
1919
// on non-darwin the socket is abstract and not present on the filesystem

cli-plugins/socket/socket_openbsd.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@ func listen(socketname string) (*net.UnixListener, error) {
1414
})
1515
}
1616

17-
func onAccept(conn *net.UnixConn, listener *net.UnixListener) {
17+
func unlink(listener *net.UnixListener) {
1818
syscall.Unlink(listener.Addr().String())
1919
}

cli-plugins/socket/socket_test.go

Lines changed: 102 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,128 @@
11
package socket
22

33
import (
4+
"errors"
5+
"io"
46
"io/fs"
57
"net"
68
"os"
79
"runtime"
810
"strings"
11+
"sync/atomic"
912
"testing"
1013
"time"
1114

1215
"gotest.tools/v3/assert"
1316
"gotest.tools/v3/poll"
1417
)
1518

16-
func TestSetupConn(t *testing.T) {
17-
t.Run("updates conn when connected", func(t *testing.T) {
18-
var conn *net.UnixConn
19-
listener, err := SetupConn(&conn)
19+
func TestPluginServer(t *testing.T) {
20+
t.Run("connection closes with EOF when server closes", func(t *testing.T) {
21+
called := make(chan struct{})
22+
srv, err := NewPluginServer(func(_ net.Conn) { close(called) })
2023
assert.NilError(t, err)
21-
assert.Check(t, listener != nil, "returned nil listener but no error")
22-
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
24+
assert.Assert(t, srv != nil, "returned nil listener but no error")
25+
26+
addr, err := net.ResolveUnixAddr("unix", srv.Addr().String())
2327
assert.NilError(t, err, "failed to resolve listener address")
2428

25-
_, err = net.DialUnix("unix", nil, addr)
29+
conn, err := net.DialUnix("unix", nil, addr)
2630
assert.NilError(t, err, "failed to dial returned listener")
31+
defer conn.Close()
32+
33+
done := make(chan error, 1)
34+
go func() {
35+
_, err := conn.Read(make([]byte, 1))
36+
done <- err
37+
}()
38+
39+
select {
40+
case <-called:
41+
case <-time.After(10 * time.Millisecond):
42+
t.Fatal("handler not called")
43+
}
44+
45+
srv.Close()
2746

28-
pollConnNotNil(t, &conn)
47+
select {
48+
case err := <-done:
49+
if !errors.Is(err, io.EOF) {
50+
t.Fatalf("exepcted EOF error, got: %v", err)
51+
}
52+
case <-time.After(10 * time.Millisecond):
53+
}
2954
})
3055

3156
t.Run("allows reconnects", func(t *testing.T) {
32-
var conn *net.UnixConn
33-
listener, err := SetupConn(&conn)
57+
var calls int32
58+
h := func(_ net.Conn) {
59+
atomic.AddInt32(&calls, 1)
60+
}
61+
62+
srv, err := NewPluginServer(h)
3463
assert.NilError(t, err)
35-
assert.Check(t, listener != nil, "returned nil listener but no error")
36-
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
64+
defer srv.Close()
65+
66+
addr := srv.Addr().(*net.UnixAddr)
67+
68+
assert.Check(t, addr != nil, "returned nil listener but no error")
69+
70+
_, err = net.ResolveUnixAddr("unix", addr.String())
3771
assert.NilError(t, err, "failed to resolve listener address")
3872

73+
waitForCalls := func(n int) {
74+
poll.WaitOn(t, func(t poll.LogT) poll.Result {
75+
if atomic.LoadInt32(&calls) == int32(n) {
76+
return poll.Success()
77+
}
78+
return poll.Continue("waiting for handler to be called")
79+
})
80+
}
81+
3982
otherConn, err := net.DialUnix("unix", nil, addr)
4083
assert.NilError(t, err, "failed to dial returned listener")
41-
4284
otherConn.Close()
4385

44-
_, err = net.DialUnix("unix", nil, addr)
86+
waitForCalls(1)
87+
88+
conn, err := net.DialUnix("unix", nil, addr)
4589
assert.NilError(t, err, "failed to redial listener")
90+
defer conn.Close()
91+
waitForCalls(2)
92+
93+
// and again but don't close the existing connection
94+
conn2, err := net.DialUnix("unix", nil, addr)
95+
assert.NilError(t, err, "failed to redial listener")
96+
defer conn2.Close()
97+
waitForCalls(3)
98+
99+
srv.Close()
100+
101+
// now make sure we get EOF on the existing connections
102+
buf := make([]byte, 1)
103+
_, err = conn.Read(buf)
104+
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
105+
106+
_, err = conn2.Read(buf)
107+
assert.ErrorIs(t, err, io.EOF, "expected EOF error, got: %v", err)
46108
})
47109

48110
t.Run("does not leak sockets to local directory", func(t *testing.T) {
49-
var conn *net.UnixConn
50-
listener, err := SetupConn(&conn)
111+
srv, err := NewPluginServer(nil)
51112
assert.NilError(t, err)
52-
assert.Check(t, listener != nil, "returned nil listener but no error")
53-
checkDirNoPluginSocket(t)
113+
assert.Check(t, srv != nil, "returned nil server but no error")
114+
checkDirNoNewPluginServer(t)
54115

55-
addr, err := net.ResolveUnixAddr("unix", listener.Addr().String())
116+
_, err = net.ResolveUnixAddr("unix", srv.Addr().String())
56117
assert.NilError(t, err, "failed to resolve listener address")
57-
_, err = net.DialUnix("unix", nil, addr)
118+
119+
_, err = net.DialUnix("unix", nil, srv.Addr().(*net.UnixAddr))
58120
assert.NilError(t, err, "failed to dial returned listener")
59-
checkDirNoPluginSocket(t)
121+
checkDirNoNewPluginServer(t)
60122
})
61123
}
62124

63-
func checkDirNoPluginSocket(t *testing.T) {
125+
func checkDirNoNewPluginServer(t *testing.T) {
64126
t.Helper()
65127

66128
files, err := os.ReadDir(".")
@@ -78,18 +140,24 @@ func checkDirNoPluginSocket(t *testing.T) {
78140

79141
func TestConnectAndWait(t *testing.T) {
80142
t.Run("calls cancel func on EOF", func(t *testing.T) {
81-
var conn *net.UnixConn
82-
listener, err := SetupConn(&conn)
143+
srv, err := NewPluginServer(nil)
83144
assert.NilError(t, err, "failed to setup listener")
145+
defer srv.Close()
84146

85147
done := make(chan struct{})
86-
t.Setenv(EnvKey, listener.Addr().String())
148+
t.Setenv(EnvKey, srv.Addr().String())
87149
cancelFunc := func() {
88150
done <- struct{}{}
89151
}
90152
ConnectAndWait(cancelFunc)
91-
pollConnNotNil(t, &conn)
92-
conn.Close()
153+
154+
select {
155+
case <-done:
156+
t.Fatal("unexpectedly done")
157+
default:
158+
}
159+
160+
srv.Close()
93161

94162
select {
95163
case <-done:
@@ -101,17 +169,19 @@ func TestConnectAndWait(t *testing.T) {
101169
// TODO: this test cannot be executed with `t.Parallel()`, due to
102170
// relying on goroutine numbers to ensure correct behaviour
103171
t.Run("connect goroutine exits after EOF", func(t *testing.T) {
104-
var conn *net.UnixConn
105-
listener, err := SetupConn(&conn)
172+
srv, err := NewPluginServer(nil)
106173
assert.NilError(t, err, "failed to setup listener")
107-
t.Setenv(EnvKey, listener.Addr().String())
174+
175+
defer srv.Close()
176+
177+
t.Setenv(EnvKey, srv.Addr().String())
108178
numGoroutines := runtime.NumGoroutine()
109179

110180
ConnectAndWait(func() {})
111181
assert.Equal(t, runtime.NumGoroutine(), numGoroutines+1)
112182

113-
pollConnNotNil(t, &conn)
114-
conn.Close()
183+
srv.Close()
184+
115185
poll.WaitOn(t, func(t poll.LogT) poll.Result {
116186
if runtime.NumGoroutine() > numGoroutines+1 {
117187
return poll.Continue("waiting for connect goroutine to exit")
@@ -120,14 +190,3 @@ func TestConnectAndWait(t *testing.T) {
120190
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
121191
})
122192
}
123-
124-
func pollConnNotNil(t *testing.T, conn **net.UnixConn) {
125-
t.Helper()
126-
127-
poll.WaitOn(t, func(t poll.LogT) poll.Result {
128-
if *conn == nil {
129-
return poll.Continue("waiting for conn to not be nil")
130-
}
131-
return poll.Success()
132-
}, poll.WithDelay(1*time.Millisecond), poll.WithTimeout(10*time.Millisecond))
133-
}

0 commit comments

Comments
 (0)