diff --git a/internal/buffer/unbounded.go b/internal/buffer/unbounded.go index 4399c3df4959..11f91668ac9b 100644 --- a/internal/buffer/unbounded.go +++ b/internal/buffer/unbounded.go @@ -18,7 +18,10 @@ // Package buffer provides an implementation of an unbounded buffer. package buffer -import "sync" +import ( + "errors" + "sync" +) // Unbounded is an implementation of an unbounded buffer which does not use // extra goroutines. This is typically used for passing updates from one entity @@ -36,6 +39,7 @@ import "sync" type Unbounded struct { c chan any closed bool + closing bool mu sync.Mutex backlog []any } @@ -45,32 +49,32 @@ func NewUnbounded() *Unbounded { return &Unbounded{c: make(chan any, 1)} } +var errBufferClosed = errors.New("Put called on closed buffer.Unbounded") + // Put adds t to the unbounded buffer. -func (b *Unbounded) Put(t any) { +func (b *Unbounded) Put(t any) error { b.mu.Lock() defer b.mu.Unlock() - if b.closed { - return + if b.closing { + return errBufferClosed } if len(b.backlog) == 0 { select { case b.c <- t: - return + return nil default: } } b.backlog = append(b.backlog, t) + return nil } -// Load sends the earliest buffered data, if any, onto the read channel -// returned by Get(). Users are expected to call this every time they read a +// Load sends the earliest buffered data, if any, onto the read channel returned +// by Get(). Users are expected to call this every time they successfully read a // value from the read channel. func (b *Unbounded) Load() { b.mu.Lock() defer b.mu.Unlock() - if b.closed { - return - } if len(b.backlog) > 0 { select { case b.c <- b.backlog[0]: @@ -78,6 +82,8 @@ func (b *Unbounded) Load() { b.backlog = b.backlog[1:] default: } + } else if b.closing && !b.closed { + close(b.c) } } @@ -88,18 +94,23 @@ func (b *Unbounded) Load() { // send the next buffered value onto the channel if there is any. // // If the unbounded buffer is closed, the read channel returned by this method -// is closed. +// is closed after all data is drained. func (b *Unbounded) Get() <-chan any { return b.c } -// Close closes the unbounded buffer. +// Close closes the unbounded buffer. No subsequent data may be Put(), and the +// channel returned from Get() will be closed after all the data is read and +// Load() is called for the final time. func (b *Unbounded) Close() { b.mu.Lock() defer b.mu.Unlock() - if b.closed { + if b.closing { return } - b.closed = true - close(b.c) + b.closing = true + if len(b.backlog) == 0 { + b.closed = true + close(b.c) + } } diff --git a/internal/buffer/unbounded_test.go b/internal/buffer/unbounded_test.go index 1708391e7f27..ef24d0fb7a8c 100644 --- a/internal/buffer/unbounded_test.go +++ b/internal/buffer/unbounded_test.go @@ -52,7 +52,7 @@ func init() { } // TestSingleWriter starts one reader and one writer goroutine and makes sure -// that the reader gets all the value added to the buffer by the writer. +// that the reader gets all the values added to the buffer by the writer. func (s) TestSingleWriter(t *testing.T) { ub := NewUnbounded() reads := []int{} @@ -124,14 +124,25 @@ func (s) TestMultipleWriters(t *testing.T) { // buffer is closed. func (s) TestClose(t *testing.T) { ub := NewUnbounded() + if err := ub.Put(1); err != nil { + t.Fatalf("Unbounded.Put() = %v; want nil", err) + } ub.Close() - if v, ok := <-ub.Get(); ok { - t.Errorf("Unbounded.Get() = %v, want closed channel", v) + if err := ub.Put(1); err == nil { + t.Fatalf("Unbounded.Put() = ; want non-nil error") + } + if v, ok := <-ub.Get(); !ok { + t.Errorf("Unbounded.Get() = %v, %v, want %v, %v", v, ok, 1, true) + } + if err := ub.Put(1); err == nil { + t.Fatalf("Unbounded.Put() = ; want non-nil error") } - ub.Put(1) ub.Load() if v, ok := <-ub.Get(); ok { t.Errorf("Unbounded.Get() = %v, want closed channel", v) } - ub.Close() + if err := ub.Put(1); err == nil { + t.Fatalf("Unbounded.Put() = ; want non-nil error") + } + ub.Close() // ignored } diff --git a/internal/grpcsync/callback_serializer.go b/internal/grpcsync/callback_serializer.go index 900917dbe6c1..f7f40a16acee 100644 --- a/internal/grpcsync/callback_serializer.go +++ b/internal/grpcsync/callback_serializer.go @@ -20,7 +20,6 @@ package grpcsync import ( "context" - "sync" "google.golang.org/grpc/internal/buffer" ) @@ -38,8 +37,6 @@ type CallbackSerializer struct { done chan struct{} callbacks *buffer.Unbounded - closedMu sync.Mutex - closed bool } // NewCallbackSerializer returns a new CallbackSerializer instance. The provided @@ -65,56 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer { // callbacks to be executed by the serializer. It is not possible to add // callbacks once the context passed to NewCallbackSerializer is cancelled. func (cs *CallbackSerializer) Schedule(f func(ctx context.Context)) bool { - cs.closedMu.Lock() - defer cs.closedMu.Unlock() - - if cs.closed { - return false - } - cs.callbacks.Put(f) - return true + return cs.callbacks.Put(f) == nil } func (cs *CallbackSerializer) run(ctx context.Context) { - var backlog []func(context.Context) - defer close(cs.done) + + // TODO: when Go 1.21 is the oldest supported version, this loop and Close + // can be replaced with: + // + // context.AfterFunc(ctx, cs.callbacks.Close) for ctx.Err() == nil { select { case <-ctx.Done(): // Do nothing here. Next iteration of the for loop will not happen, // since ctx.Err() would be non-nil. - case callback, ok := <-cs.callbacks.Get(): - if !ok { - return - } + case cb := <-cs.callbacks.Get(): cs.callbacks.Load() - callback.(func(ctx context.Context))(ctx) + cb.(func(context.Context))(ctx) } } - // Fetch pending callbacks if any, and execute them before returning from - // this method and closing cs.done. - cs.closedMu.Lock() - cs.closed = true - backlog = cs.fetchPendingCallbacks() + // Close the buffer to prevent new callbacks from being added. cs.callbacks.Close() - cs.closedMu.Unlock() - for _, b := range backlog { - b(ctx) - } -} -func (cs *CallbackSerializer) fetchPendingCallbacks() []func(context.Context) { - var backlog []func(context.Context) - for { - select { - case b := <-cs.callbacks.Get(): - backlog = append(backlog, b.(func(context.Context))) - cs.callbacks.Load() - default: - return backlog - } + // Run all pending callbacks. + for cb := range cs.callbacks.Get() { + cs.callbacks.Load() + cb.(func(context.Context))(ctx) } }