@@ -59,9 +59,17 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
5959 return func (ctx context.Context , config shim.Config ) (_ shimapi.ShimService , _ io.Closer , err error ) {
6060 socket , err := newSocket (address )
6161 if err != nil {
62- return nil , nil , err
62+ if ! eaddrinuse (err ) {
63+ return nil , nil , err
64+ }
65+ if err := RemoveSocket (address ); err != nil {
66+ return nil , nil , errors .Wrap (err , "remove already used socket" )
67+ }
68+ if socket , err = newSocket (address ); err != nil {
69+ return nil , nil , err
70+ }
6371 }
64- defer socket . Close ()
72+
6573 f , err := socket .File ()
6674 if err != nil {
6775 return nil , nil , errors .Wrapf (err , "failed to get fd for socket %s" , address )
@@ -108,6 +116,8 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
108116 if stderrLog != nil {
109117 stderrLog .Close ()
110118 }
119+ socket .Close ()
120+ RemoveSocket (address )
111121 }()
112122 log .G (ctx ).WithFields (logrus.Fields {
113123 "pid" : cmd .Process .Pid ,
@@ -142,6 +152,26 @@ func WithStart(binary, address, daemonAddress, cgroup string, debug bool, exitHa
142152 }
143153}
144154
155+ func eaddrinuse (err error ) bool {
156+ cause := errors .Cause (err )
157+ netErr , ok := cause .(* net.OpError )
158+ if ! ok {
159+ return false
160+ }
161+ if netErr .Op != "listen" {
162+ return false
163+ }
164+ syscallErr , ok := netErr .Err .(* os.SyscallError )
165+ if ! ok {
166+ return false
167+ }
168+ errno , ok := syscallErr .Err .(syscall.Errno )
169+ if ! ok {
170+ return false
171+ }
172+ return errno == syscall .EADDRINUSE
173+ }
174+
145175// setupOOMScore gets containerd's oom score and adds +1 to it
146176// to ensure a shim has a lower* score than the daemons
147177func setupOOMScore (shimPid int ) error {
@@ -214,31 +244,73 @@ func writeFile(path, address string) error {
214244 return os .Rename (tempPath , path )
215245}
216246
247+ const (
248+ abstractSocketPrefix = "\x00 "
249+ socketPathLimit = 106
250+ )
251+
252+ type socket string
253+
254+ func (s socket ) isAbstract () bool {
255+ return ! strings .HasPrefix (string (s ), "unix://" )
256+ }
257+
258+ func (s socket ) path () string {
259+ path := strings .TrimPrefix (string (s ), "unix://" )
260+ // if there was no trim performed, we assume an abstract socket
261+ if len (path ) == len (s ) {
262+ path = abstractSocketPrefix + path
263+ }
264+ return path
265+ }
266+
217267func newSocket (address string ) (* net.UnixListener , error ) {
218- if len (address ) > 106 {
219- return nil , errors .Errorf ("%q: unix socket path too long (> 106)" , address )
268+ if len (address ) > socketPathLimit {
269+ return nil , errors .Errorf ("%q: unix socket path too long (> %d)" , address , socketPathLimit )
270+ }
271+ var (
272+ sock = socket (address )
273+ path = sock .path ()
274+ )
275+ if ! sock .isAbstract () {
276+ if err := os .MkdirAll (filepath .Dir (path ), 0600 ); err != nil {
277+ return nil , errors .Wrapf (err , "%s" , path )
278+ }
220279 }
221- l , err := net .Listen ("unix" , " \x00 " + address )
280+ l , err := net .Listen ("unix" , path )
222281 if err != nil {
223- return nil , errors .Wrapf (err , "failed to listen to abstract unix socket %q" , address )
282+ return nil , errors .Wrapf (err , "failed to listen to unix socket %q (abstract: %t)" , address , sock .isAbstract ())
283+ }
284+ if err := os .Chmod (path , 0600 ); err != nil {
285+ l .Close ()
286+ return nil , err
224287 }
225288
226289 return l .(* net.UnixListener ), nil
227290}
228291
292+ // RemoveSocket removes the socket at the specified address if
293+ // it exists on the filesystem
294+ func RemoveSocket (address string ) error {
295+ sock := socket (address )
296+ if ! sock .isAbstract () {
297+ return os .Remove (sock .path ())
298+ }
299+ return nil
300+ }
301+
229302func connect (address string , d func (string , time.Duration ) (net.Conn , error )) (net.Conn , error ) {
230303 return d (address , 100 * time .Second )
231304}
232305
233- func annonDialer (address string , timeout time.Duration ) (net.Conn , error ) {
234- address = strings .TrimPrefix (address , "unix://" )
235- return dialer .Dialer ("\x00 " + address , timeout )
306+ func anonDialer (address string , timeout time.Duration ) (net.Conn , error ) {
307+ return dialer .Dialer (socket (address ).path (), timeout )
236308}
237309
238310// WithConnect connects to an existing shim
239311func WithConnect (address string , onClose func ()) Opt {
240312 return func (ctx context.Context , config shim.Config ) (shimapi.ShimService , io.Closer , error ) {
241- conn , err := connect (address , annonDialer )
313+ conn , err := connect (address , anonDialer )
242314 if err != nil {
243315 return nil , nil , err
244316 }
0 commit comments