diff --git a/cli/context/store/metadatastore.go b/cli/context/store/metadatastore.go index 2bb6a20e8c61..7b602f6d9205 100644 --- a/cli/context/store/metadatastore.go +++ b/cli/context/store/metadatastore.go @@ -20,6 +20,7 @@ import ( const ( metadataDir = "meta" metaFile = "meta.json" + lockFile = "lock" ) type metadataStore struct { diff --git a/cli/context/store/store.go b/cli/context/store/store.go index 45bb76ce6664..c009a5cef5ca 100644 --- a/cli/context/store/store.go +++ b/cli/context/store/store.go @@ -10,16 +10,19 @@ import ( "bytes" _ "crypto/sha256" // ensure ids can be computed "encoding/json" + "errors" + "fmt" "io" "net/http" + "os" "path" "path/filepath" "regexp" "strings" "github.com/docker/docker/errdefs" + "github.com/gofrs/flock" "github.com/opencontainers/go-digest" - "github.com/pkg/errors" ) const restrictedNamePattern = "^[a-zA-Z0-9][a-zA-Z0-9_.+-]+$" @@ -98,6 +101,7 @@ type ContextTLSData struct { // If the directory does not exist or is empty, initialize it func New(dir string, cfg Config) *ContextStore { metaRoot := filepath.Join(dir, metadataDir) + lockPath := filepath.Join(dir, lockFile) tlsRoot := filepath.Join(dir, tlsDir) return &ContextStore{ @@ -108,17 +112,44 @@ func New(dir string, cfg Config) *ContextStore { tls: &tlsStore{ root: tlsRoot, }, + lockFile: flock.New(lockPath), } } // ContextStore implements Store. type ContextStore struct { - meta *metadataStore - tls *tlsStore + meta *metadataStore + tls *tlsStore + lockFile *flock.Flock +} + +func (s *ContextStore) lock() error { + if err := os.MkdirAll(filepath.Dir(s.lockFile.Path()), 0o755); err != nil { + return fmt.Errorf("creating context store lock directory: %w", err) + } + if err := s.lockFile.Lock(); err != nil { + return fmt.Errorf("locking context store lock: %w", err) + } + return nil +} + +func (s *ContextStore) unlock() error { + if err := s.lockFile.Unlock(); err != nil { + return fmt.Errorf("unlocking context store lock: %w", err) + } + return nil } // List return all contexts. -func (s *ContextStore) List() ([]Metadata, error) { +func (s *ContextStore) List() (_ []Metadata, errs error) { + if err := s.lock(); err != nil { + return nil, err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() return s.meta.list() } @@ -136,30 +167,62 @@ func Names(s Lister) ([]string, error) { } // CreateOrUpdate creates or updates metadata for the context. -func (s *ContextStore) CreateOrUpdate(meta Metadata) error { +func (s *ContextStore) CreateOrUpdate(meta Metadata) (errs error) { + if err := s.lock(); err != nil { + return err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() return s.meta.createOrUpdate(meta) } // Remove deletes the context with the given name, if found. -func (s *ContextStore) Remove(name string) error { +func (s *ContextStore) Remove(name string) (errs error) { + if err := s.lock(); err != nil { + return err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() if err := s.meta.remove(name); err != nil { - return errors.Wrapf(err, "failed to remove context %s", name) + return fmt.Errorf("failed to remove context %s: %w", name, err) } if err := s.tls.remove(name); err != nil { - return errors.Wrapf(err, "failed to remove context %s", name) + return fmt.Errorf("failed to remove context %s: %w", name, err) } return nil } // GetMetadata returns the metadata for the context with the given name. // It returns an errdefs.ErrNotFound if the context was not found. -func (s *ContextStore) GetMetadata(name string) (Metadata, error) { +func (s *ContextStore) GetMetadata(name string) (_ Metadata, errs error) { + if err := s.lock(); err != nil { + return Metadata{}, err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() return s.meta.get(name) } // ResetTLSMaterial removes TLS data for all endpoints in the context and replaces // it with the new data. -func (s *ContextStore) ResetTLSMaterial(name string, data *ContextTLSData) error { +func (s *ContextStore) ResetTLSMaterial(name string, data *ContextTLSData) (errs error) { + if err := s.lock(); err != nil { + return err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() if err := s.tls.remove(name); err != nil { return err } @@ -178,7 +241,15 @@ func (s *ContextStore) ResetTLSMaterial(name string, data *ContextTLSData) error // ResetEndpointTLSMaterial removes TLS data for the given context and endpoint, // and replaces it with the new data. -func (s *ContextStore) ResetEndpointTLSMaterial(contextName string, endpointName string, data *EndpointTLSData) error { +func (s *ContextStore) ResetEndpointTLSMaterial(contextName string, endpointName string, data *EndpointTLSData) (errs error) { + if err := s.lock(); err != nil { + return err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() if err := s.tls.removeEndpoint(contextName, endpointName); err != nil { return err } @@ -195,13 +266,29 @@ func (s *ContextStore) ResetEndpointTLSMaterial(contextName string, endpointName // ListTLSFiles returns the list of TLS files present for each endpoint in the // context. -func (s *ContextStore) ListTLSFiles(name string) (map[string]EndpointFiles, error) { +func (s *ContextStore) ListTLSFiles(name string) (_ map[string]EndpointFiles, errs error) { + if err := s.lock(); err != nil { + return nil, err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() return s.tls.listContextData(name) } // GetTLSData reads, and returns the content of the given fileName for an endpoint. // It returns an errdefs.ErrNotFound if the file was not found. -func (s *ContextStore) GetTLSData(contextName, endpointName, fileName string) ([]byte, error) { +func (s *ContextStore) GetTLSData(contextName, endpointName, fileName string) (_ []byte, errs error) { + if err := s.lock(); err != nil { + return nil, err + } + defer func() { + if err := s.unlock(); err != nil { + errs = errors.Join(errs, err) + } + }() return s.tls.getData(contextName, endpointName, fileName) } @@ -223,7 +310,7 @@ func ValidateContextName(name string) error { return errors.New(`"default" is a reserved context name`) } if !restrictedNameRegEx.MatchString(name) { - return errors.Errorf("context name %q is invalid, names are validated against regexp %q", name, restrictedNamePattern) + return fmt.Errorf("context name %q is invalid, names are validated against regexp %q", name, restrictedNamePattern) } return nil } @@ -371,7 +458,7 @@ func importTar(name string, s Writer, reader io.Reader) error { continue } if err := isValidFilePath(hdr.Name); err != nil { - return errors.Wrap(err, hdr.Name) + return fmt.Errorf("%s: %w", hdr.Name, err) } if hdr.Name == metaFile { data, err := io.ReadAll(tr) @@ -423,7 +510,7 @@ func importZip(name string, s Writer, reader io.Reader) error { continue } if err := isValidFilePath(zf.Name); err != nil { - return errors.Wrap(err, zf.Name) + return fmt.Errorf("%s: %w", zf.Name, err) } if zf.Name == metaFile { f, err := zf.Open() diff --git a/cli/context/store/store_test.go b/cli/context/store/store_test.go index 8896d137d8b2..3290283c02cb 100644 --- a/cli/context/store/store_test.go +++ b/cli/context/store/store_test.go @@ -22,6 +22,15 @@ import ( is "gotest.tools/v3/assert/cmp" ) +func TestNew(t *testing.T) { + s := New(path.Join(t.TempDir(), "does", "not", "exist", "yet"), testCfg) + assert.Assert(t, s != nil) + // Check that the file lock works even when the directory does not exist yet. + all, err := s.List() + assert.NilError(t, err) + assert.Assert(t, len(all) == 0) +} + type endpoint struct { Foo string `json:"a_very_recognizable_field_name"` } diff --git a/scripts/vendor b/scripts/vendor index 0812e41c67e0..881e76db6af6 100755 --- a/scripts/vendor +++ b/scripts/vendor @@ -18,12 +18,12 @@ init() { cat > go.mod < The function requests an exclusive lock. Otherwise, it requests a shared +// > lock. +// +// https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx + +func lockFileEx(handle syscall.Handle, flags uint32, reserved uint32, numberOfBytesToLockLow uint32, numberOfBytesToLockHigh uint32, offset *syscall.Overlapped) (bool, syscall.Errno) { + r1, _, errNo := syscall.Syscall6( + uintptr(procLockFileEx), + 6, + uintptr(handle), + uintptr(flags), + uintptr(reserved), + uintptr(numberOfBytesToLockLow), + uintptr(numberOfBytesToLockHigh), + uintptr(unsafe.Pointer(offset))) + + if r1 != 1 { + if errNo == 0 { + return false, syscall.EINVAL + } + + return false, errNo + } + + return true, 0 +} + +func unlockFileEx(handle syscall.Handle, reserved uint32, numberOfBytesToLockLow uint32, numberOfBytesToLockHigh uint32, offset *syscall.Overlapped) (bool, syscall.Errno) { + r1, _, errNo := syscall.Syscall6( + uintptr(procUnlockFileEx), + 5, + uintptr(handle), + uintptr(reserved), + uintptr(numberOfBytesToLockLow), + uintptr(numberOfBytesToLockHigh), + uintptr(unsafe.Pointer(offset)), + 0) + + if r1 != 1 { + if errNo == 0 { + return false, syscall.EINVAL + } + + return false, errNo + } + + return true, 0 +} diff --git a/vendor/github.com/gofrs/flock/flock_windows.go b/vendor/github.com/gofrs/flock/flock_windows.go new file mode 100644 index 000000000000..ddb534ccef09 --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock_windows.go @@ -0,0 +1,142 @@ +// Copyright 2015 Tim Heckman. All rights reserved. +// Use of this source code is governed by the BSD 3-Clause +// license that can be found in the LICENSE file. + +package flock + +import ( + "syscall" +) + +// ErrorLockViolation is the error code returned from the Windows syscall when a +// lock would block and you ask to fail immediately. +const ErrorLockViolation syscall.Errno = 0x21 // 33 + +// Lock is a blocking call to try and take an exclusive file lock. It will wait +// until it is able to obtain the exclusive file lock. It's recommended that +// TryLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) Lock() error { + return f.lock(&f.l, winLockfileExclusiveLock) +} + +// RLock is a blocking call to try and take a shared file lock. It will wait +// until it is able to obtain the shared file lock. It's recommended that +// TryRLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) RLock() error { + return f.lock(&f.r, winLockfileSharedLock) +} + +func (f *Flock) lock(locked *bool, flag uint32) error { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return err + } + defer f.ensureFhState() + } + + if _, errNo := lockFileEx(syscall.Handle(f.fh.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}); errNo > 0 { + return errNo + } + + *locked = true + return nil +} + +// Unlock is a function to unlock the file. This file takes a RW-mutex lock, so +// while it is running the Locked() and RLocked() functions will be blocked. +// +// This function short-circuits if we are unlocked already. If not, it calls +// UnlockFileEx() on the file and closes the file descriptor. It does not remove +// the file from disk. It's up to your application to do. +func (f *Flock) Unlock() error { + f.m.Lock() + defer f.m.Unlock() + + // if we aren't locked or if the lockfile instance is nil + // just return a nil error because we are unlocked + if (!f.l && !f.r) || f.fh == nil { + return nil + } + + // mark the file as unlocked + if _, errNo := unlockFileEx(syscall.Handle(f.fh.Fd()), 0, 1, 0, &syscall.Overlapped{}); errNo > 0 { + return errNo + } + + f.fh.Close() + + f.l = false + f.r = false + f.fh = nil + + return nil +} + +// TryLock is the preferred function for taking an exclusive file lock. This +// function does take a RW-mutex lock before it tries to lock the file, so there +// is the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the exclusive +// file lock, the function will return false instead of waiting for the lock. If +// we get the lock, we also set the *Flock instance as being exclusive-locked. +func (f *Flock) TryLock() (bool, error) { + return f.try(&f.l, winLockfileExclusiveLock) +} + +// TryRLock is the preferred function for taking a shared file lock. This +// function does take a RW-mutex lock before it tries to lock the file, so there +// is the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the shared file +// lock, the function will return false instead of waiting for the lock. If we +// get the lock, we also set the *Flock instance as being shared-locked. +func (f *Flock) TryRLock() (bool, error) { + return f.try(&f.r, winLockfileSharedLock) +} + +func (f *Flock) try(locked *bool, flag uint32) (bool, error) { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return true, nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return false, err + } + defer f.ensureFhState() + } + + _, errNo := lockFileEx(syscall.Handle(f.fh.Fd()), flag|winLockfileFailImmediately, 0, 1, 0, &syscall.Overlapped{}) + + if errNo > 0 { + if errNo == ErrorLockViolation || errNo == syscall.ERROR_IO_PENDING { + return false, nil + } + + return false, errNo + } + + *locked = true + + return true, nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 89ad29b72696..43214ecb5ad7 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -127,6 +127,9 @@ github.com/go-logr/logr/funcr # github.com/go-logr/stdr v1.2.2 ## explicit; go 1.16 github.com/go-logr/stdr +# github.com/gofrs/flock v0.8.1 +## explicit +github.com/gofrs/flock # github.com/gogo/protobuf v1.3.2 ## explicit; go 1.15 github.com/gogo/protobuf/gogoproto