diff --git a/pkg/cover/store.go b/pkg/cover/store.go index 0929039e..467021e8 100644 --- a/pkg/cover/store.go +++ b/pkg/cover/store.go @@ -105,12 +105,15 @@ func (l *fileStore) GetAll() map[string][]string { // Remove the service from the memory store and the file store func (l *fileStore) Remove(addr string) error { + l.mu.Lock() + defer l.mu.Unlock() + err := l.memoryStore.Remove(addr) if err != nil { return err } - return l.Set(l.memoryStore.GetAll()) + return syncToFile(l.persistentFile, l.memoryStore.GetAll()) } // Init cleanup all the registered service information @@ -177,24 +180,7 @@ func (l *fileStore) Set(services map[string][]string) error { return err } - f, err := os.OpenFile(l.persistentFile, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0600) - if err != nil { - return err - } - - s := "" - for name, addrs := range services { - for _, addr := range addrs { - s += fmt.Sprintf("%s&%s\n", name, addr) - } - } - - _, err = f.WriteString(s) - if err != nil { - return err - } - - return f.Sync() + return syncToFile(l.persistentFile, services) } func (l *fileStore) appendToFile(s ServiceUnderTest) error { @@ -221,6 +207,27 @@ func split(r rune) bool { return r == '&' } +func syncToFile(persistentFile string, services map[string][]string) error { + f, err := os.OpenFile(persistentFile, os.O_TRUNC|os.O_WRONLY|os.O_CREATE, 0600) + if err != nil { + return err + } + + s := "" + for name, addrs := range services { + for _, addr := range addrs { + s += fmt.Sprintf("%s&%s\n", name, addr) + } + } + + _, err = f.WriteString(s) + if err != nil { + return err + } + + return f.Sync() +} + // memoryStore holds the registered services only into memory type memoryStore struct { mu sync.RWMutex @@ -287,7 +294,11 @@ func (l *memoryStore) Set(services map[string][]string) error { l.mu.Lock() defer l.mu.Unlock() - l.servicesMap = services + newMap := make(map[string][]string) + for k, v := range services { + newMap[k] = append(make([]string, 0), v...) + } + l.servicesMap = newMap return nil } @@ -317,7 +328,7 @@ func (l *memoryStore) Remove(removeAddr string) error { } if !flag { - return fmt.Errorf("no service found") + return fmt.Errorf("no service found: %s", removeAddr) } return nil diff --git a/pkg/cover/store_test.go b/pkg/cover/store_test.go index a4898d85..e1a32acb 100644 --- a/pkg/cover/store_test.go +++ b/pkg/cover/store_test.go @@ -18,10 +18,10 @@ package cover import ( "fmt" + "github.com/stretchr/testify/assert" "os" + "sync" "testing" - - "github.com/stretchr/testify/assert" ) func TestLocalStore(t *testing.T) { @@ -130,3 +130,32 @@ func TestFileStoreRemove(t *testing.T) { err = store.Remove("http") assert.Error(t, err, fmt.Errorf("no service found")) } + +// verify issue fix https://github.com/golang/go/issues/56552 +func TestConcurrentRemoval(t *testing.T) { + store, _ := NewFileStore("_svrs_address.txt") + _ = store.Init() + + for i := 0; i < 100; i++ { + _ = store.Add(ServiceUnderTest{ + Name: fmt.Sprintf("test%d", i), + Address: fmt.Sprintf("http://127.0.0.1:890%d", i), + }) + } + + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + j := i // for loop trap in golang, avoid goroutine uses the same value for i pointer + wg.Add(1) + go func() { + defer wg.Done() + err := store.Remove(fmt.Sprintf("http://127.0.0.1:890%d", j)) + if err != nil { + t.Errorf("fileStore.Remove Error: %v", err) + } + }() + } + wg.Wait() + + assert.Equal(t, 0, len(store.GetAll())) +}