11package socket
22
33import (
4+ "errors"
5+ "io"
46 "io/fs"
57 "net"
68 "os"
79 "runtime"
810 "strings"
11+ "sync/atomic"
912 "testing"
1013 "time"
1114
1215 "gotest.tools/v3/assert"
1316 "gotest.tools/v3/poll"
1417)
1518
16- func TestSetupConn (t * testing.T ) {
17- t .Run ("updates conn when connected " , func (t * testing.T ) {
18- var conn * net. UnixConn
19- listener , err := SetupConn ( & conn )
19+ func TestPluginServer (t * testing.T ) {
20+ t .Run ("connection closes with EOF when server closes " , func (t * testing.T ) {
21+ called := make ( chan struct {})
22+ srv , err := NewPluginServer ( func ( _ net. Conn ) { close ( called ) } )
2023 assert .NilError (t , err )
21- assert .Check (t , listener != nil , "returned nil listener but no error" )
22- addr , err := net .ResolveUnixAddr ("unix" , listener .Addr ().String ())
24+ assert .Assert (t , srv != nil , "returned nil listener but no error" )
25+
26+ addr , err := net .ResolveUnixAddr ("unix" , srv .Addr ().String ())
2327 assert .NilError (t , err , "failed to resolve listener address" )
2428
25- _ , err = net .DialUnix ("unix" , nil , addr )
29+ conn , err : = net .DialUnix ("unix" , nil , addr )
2630 assert .NilError (t , err , "failed to dial returned listener" )
31+ defer conn .Close ()
32+
33+ done := make (chan error , 1 )
34+ go func () {
35+ _ , err := conn .Read (make ([]byte , 1 ))
36+ done <- err
37+ }()
38+
39+ select {
40+ case <- called :
41+ case <- time .After (10 * time .Millisecond ):
42+ t .Fatal ("handler not called" )
43+ }
44+
45+ srv .Close ()
2746
28- pollConnNotNil (t , & conn )
47+ select {
48+ case err := <- done :
49+ if ! errors .Is (err , io .EOF ) {
50+ t .Fatalf ("exepcted EOF error, got: %v" , err )
51+ }
52+ case <- time .After (10 * time .Millisecond ):
53+ }
2954 })
3055
3156 t .Run ("allows reconnects" , func (t * testing.T ) {
32- var conn * net.UnixConn
33- listener , err := SetupConn (& conn )
57+ var calls int32
58+ h := func (_ net.Conn ) {
59+ atomic .AddInt32 (& calls , 1 )
60+ }
61+
62+ srv , err := NewPluginServer (h )
3463 assert .NilError (t , err )
35- assert .Check (t , listener != nil , "returned nil listener but no error" )
36- addr , err := net .ResolveUnixAddr ("unix" , listener .Addr ().String ())
64+ defer srv .Close ()
65+
66+ addr := srv .Addr ().(* net.UnixAddr )
67+
68+ assert .Check (t , addr != nil , "returned nil listener but no error" )
69+
70+ _ , err = net .ResolveUnixAddr ("unix" , addr .String ())
3771 assert .NilError (t , err , "failed to resolve listener address" )
3872
73+ waitForCalls := func (n int ) {
74+ poll .WaitOn (t , func (t poll.LogT ) poll.Result {
75+ if atomic .LoadInt32 (& calls ) == int32 (n ) {
76+ return poll .Success ()
77+ }
78+ return poll .Continue ("waiting for handler to be called" )
79+ })
80+ }
81+
3982 otherConn , err := net .DialUnix ("unix" , nil , addr )
4083 assert .NilError (t , err , "failed to dial returned listener" )
41-
4284 otherConn .Close ()
4385
44- _ , err = net .DialUnix ("unix" , nil , addr )
86+ waitForCalls (1 )
87+
88+ conn , err := net .DialUnix ("unix" , nil , addr )
4589 assert .NilError (t , err , "failed to redial listener" )
90+ defer conn .Close ()
91+ waitForCalls (2 )
92+
93+ // and again but don't close the existing connection
94+ conn2 , err := net .DialUnix ("unix" , nil , addr )
95+ assert .NilError (t , err , "failed to redial listener" )
96+ defer conn2 .Close ()
97+ waitForCalls (3 )
98+
99+ srv .Close ()
100+
101+ // now make sure we get EOF on the existing connections
102+ buf := make ([]byte , 1 )
103+ _ , err = conn .Read (buf )
104+ assert .ErrorIs (t , err , io .EOF , "expected EOF error, got: %v" , err )
105+
106+ _ , err = conn2 .Read (buf )
107+ assert .ErrorIs (t , err , io .EOF , "expected EOF error, got: %v" , err )
46108 })
47109
48110 t .Run ("does not leak sockets to local directory" , func (t * testing.T ) {
49- var conn * net.UnixConn
50- listener , err := SetupConn (& conn )
111+ srv , err := NewPluginServer (nil )
51112 assert .NilError (t , err )
52- assert .Check (t , listener != nil , "returned nil listener but no error" )
53- checkDirNoPluginSocket (t )
113+ assert .Check (t , srv != nil , "returned nil server but no error" )
114+ checkDirNoNewPluginServer (t )
54115
55- addr , err : = net .ResolveUnixAddr ("unix" , listener .Addr ().String ())
116+ _ , err = net .ResolveUnixAddr ("unix" , srv .Addr ().String ())
56117 assert .NilError (t , err , "failed to resolve listener address" )
57- _ , err = net .DialUnix ("unix" , nil , addr )
118+
119+ _ , err = net .DialUnix ("unix" , nil , srv .Addr ().(* net.UnixAddr ))
58120 assert .NilError (t , err , "failed to dial returned listener" )
59- checkDirNoPluginSocket (t )
121+ checkDirNoNewPluginServer (t )
60122 })
61123}
62124
63- func checkDirNoPluginSocket (t * testing.T ) {
125+ func checkDirNoNewPluginServer (t * testing.T ) {
64126 t .Helper ()
65127
66128 files , err := os .ReadDir ("." )
@@ -78,18 +140,24 @@ func checkDirNoPluginSocket(t *testing.T) {
78140
79141func TestConnectAndWait (t * testing.T ) {
80142 t .Run ("calls cancel func on EOF" , func (t * testing.T ) {
81- var conn * net.UnixConn
82- listener , err := SetupConn (& conn )
143+ srv , err := NewPluginServer (nil )
83144 assert .NilError (t , err , "failed to setup listener" )
145+ defer srv .Close ()
84146
85147 done := make (chan struct {})
86- t .Setenv (EnvKey , listener .Addr ().String ())
148+ t .Setenv (EnvKey , srv .Addr ().String ())
87149 cancelFunc := func () {
88150 done <- struct {}{}
89151 }
90152 ConnectAndWait (cancelFunc )
91- pollConnNotNil (t , & conn )
92- conn .Close ()
153+
154+ select {
155+ case <- done :
156+ t .Fatal ("unexpectedly done" )
157+ default :
158+ }
159+
160+ srv .Close ()
93161
94162 select {
95163 case <- done :
@@ -101,17 +169,19 @@ func TestConnectAndWait(t *testing.T) {
101169 // TODO: this test cannot be executed with `t.Parallel()`, due to
102170 // relying on goroutine numbers to ensure correct behaviour
103171 t .Run ("connect goroutine exits after EOF" , func (t * testing.T ) {
104- var conn * net.UnixConn
105- listener , err := SetupConn (& conn )
172+ srv , err := NewPluginServer (nil )
106173 assert .NilError (t , err , "failed to setup listener" )
107- t .Setenv (EnvKey , listener .Addr ().String ())
174+
175+ defer srv .Close ()
176+
177+ t .Setenv (EnvKey , srv .Addr ().String ())
108178 numGoroutines := runtime .NumGoroutine ()
109179
110180 ConnectAndWait (func () {})
111181 assert .Equal (t , runtime .NumGoroutine (), numGoroutines + 1 )
112182
113- pollConnNotNil ( t , & conn )
114- conn . Close ()
183+ srv . Close ( )
184+
115185 poll .WaitOn (t , func (t poll.LogT ) poll.Result {
116186 if runtime .NumGoroutine () > numGoroutines + 1 {
117187 return poll .Continue ("waiting for connect goroutine to exit" )
@@ -120,14 +190,3 @@ func TestConnectAndWait(t *testing.T) {
120190 }, poll .WithDelay (1 * time .Millisecond ), poll .WithTimeout (10 * time .Millisecond ))
121191 })
122192}
123-
124- func pollConnNotNil (t * testing.T , conn * * net.UnixConn ) {
125- t .Helper ()
126-
127- poll .WaitOn (t , func (t poll.LogT ) poll.Result {
128- if * conn == nil {
129- return poll .Continue ("waiting for conn to not be nil" )
130- }
131- return poll .Success ()
132- }, poll .WithDelay (1 * time .Millisecond ), poll .WithTimeout (10 * time .Millisecond ))
133- }
0 commit comments