@@ -20,7 +20,6 @@ package grpcsync
2020
2121import (
2222 "context"
23- "sync"
2423
2524 "google.golang.org/grpc/internal/buffer"
2625)
@@ -38,8 +37,6 @@ type CallbackSerializer struct {
3837 done chan struct {}
3938
4039 callbacks * buffer.Unbounded
41- closedMu sync.Mutex
42- closed bool
4340}
4441
4542// NewCallbackSerializer returns a new CallbackSerializer instance. The provided
@@ -65,53 +62,34 @@ func NewCallbackSerializer(ctx context.Context) *CallbackSerializer {
6562// callbacks to be executed by the serializer. It is not possible to add
6663// callbacks once the context passed to NewCallbackSerializer is cancelled.
6764func (cs * CallbackSerializer ) Schedule (f func (ctx context.Context )) bool {
68- cs .closedMu .Lock ()
69- defer cs .closedMu .Unlock ()
70-
71- if cs .closed {
72- return false
73- }
74- cs .callbacks .Put (f )
75- return true
65+ return cs .callbacks .Put (f ) == nil
7666}
7767
7868func (cs * CallbackSerializer ) run (ctx context.Context ) {
79- var backlog []func (context.Context )
80-
8169 defer close (cs .done )
70+
71+ // TODO: when Go 1.21 is the oldest supported version, this loop and Close
72+ // can be replaced with:
73+ //
74+ // context.AfterFunc(ctx, cs.callbacks.Close)
8275 for ctx .Err () == nil {
8376 select {
8477 case <- ctx .Done ():
8578 // Do nothing here. Next iteration of the for loop will not happen,
8679 // since ctx.Err() would be non-nil.
87- case callback := <- cs .callbacks .Get ():
80+ case cb := <- cs .callbacks .Get ():
8881 cs .callbacks .Load ()
89- callback .(func (ctx context.Context ))(ctx )
82+ cb .(func (context.Context ))(ctx )
9083 }
9184 }
9285
93- // Fetch pending callbacks if any, and execute them before returning from
94- // this method and closing cs.done.
95- cs .closedMu .Lock ()
96- cs .closed = true
97- backlog = cs .fetchPendingCallbacks ()
86+ // Close the buffer to prevent new callbacks from being added.
9887 cs .callbacks .Close ()
99- cs .closedMu .Unlock ()
100- for _ , b := range backlog {
101- b (ctx )
102- }
103- }
10488
105- func (cs * CallbackSerializer ) fetchPendingCallbacks () []func (context.Context ) {
106- var backlog []func (context.Context )
107- for {
108- select {
109- case b := <- cs .callbacks .Get ():
110- backlog = append (backlog , b .(func (context.Context )))
111- cs .callbacks .Load ()
112- default :
113- return backlog
114- }
89+ // Run all pending callbacks.
90+ for cb := range cs .callbacks .Get () {
91+ cs .callbacks .Load ()
92+ cb .(func (context.Context ))(ctx )
11593 }
11694}
11795
0 commit comments