diff --git a/internal/xds/clients/xdsclient/ads_stream.go b/internal/xds/clients/xdsclient/ads_stream.go index 83f8a5df03f1..911e02a67f3c 100644 --- a/internal/xds/clients/xdsclient/ads_stream.go +++ b/internal/xds/clients/xdsclient/ads_stream.go @@ -22,7 +22,6 @@ import ( "context" "fmt" "sync" - "sync/atomic" "time" "google.golang.org/grpc/grpclog" @@ -103,11 +102,11 @@ type adsStreamImpl struct { requestCh *buffer.Unbounded // Subscriptions and unsubscriptions are pushed here. runnerDoneCh chan struct{} // Notify completion of runner goroutine. cancel context.CancelFunc // To cancel the context passed to the runner goroutine. + fc *adsFlowControl // Flow control for ADS stream. // Guards access to the below fields (and to the contents of the map). mu sync.Mutex resourceTypeState map[ResourceType]*resourceTypeState // Map of resource types to their state. - fc *adsFlowControl // Flow control for ADS stream. firstRequest bool // False after the first request is sent out. } @@ -135,6 +134,7 @@ func newADSStreamImpl(opts adsStreamOpts) *adsStreamImpl { streamCh: make(chan clients.Stream, 1), requestCh: buffer.NewUnbounded(), runnerDoneCh: make(chan struct{}), + fc: newADSFlowControl(), resourceTypeState: make(map[ResourceType]*resourceTypeState), } @@ -150,6 +150,7 @@ func newADSStreamImpl(opts adsStreamOpts) *adsStreamImpl { // Stop blocks until the stream is closed and all spawned goroutines exit. func (s *adsStreamImpl) Stop() { s.cancel() + s.fc.stop() s.requestCh.Close() <-s.runnerDoneCh s.logger.Infof("Shutdown ADS stream") @@ -240,9 +241,6 @@ func (s *adsStreamImpl) runner(ctx context.Context) { } s.mu.Lock() - // Flow control is a property of the underlying streaming RPC call and - // needs to be initialized everytime a new one is created. - s.fc = newADSFlowControl(s.logger) s.firstRequest = true s.mu.Unlock() @@ -256,7 +254,7 @@ func (s *adsStreamImpl) runner(ctx context.Context) { // Backoff state is reset upon successful receipt of at least one // message from the server. - if s.recv(ctx, stream) { + if s.recv(stream) { return backoff.ErrResetBackoff } return nil @@ -318,11 +316,13 @@ func (s *adsStreamImpl) sendNew(stream clients.Stream, typ ResourceType) error { // This allows us to batch writes for requests which are generated as part // of local processing of a received response. state := s.resourceTypeState[typ] - if s.fc.pending.Load() { + bufferRequest := func() { select { case state.bufferedRequests <- struct{}{}: default: } + } + if s.fc.runIfPending(bufferRequest) { return nil } @@ -477,18 +477,19 @@ func (s *adsStreamImpl) sendMessageLocked(stream clients.Stream, names []string, // // It returns a boolean indicating whether at least one message was received // from the server. -func (s *adsStreamImpl) recv(ctx context.Context, stream clients.Stream) bool { +func (s *adsStreamImpl) recv(stream clients.Stream) bool { msgReceived := false for { - // Wait for ADS stream level flow control to be available, and send out - // a request if anything was buffered while we were waiting for local - // processing of the previous response to complete. - if !s.fc.wait(ctx) { + // Wait for ADS stream level flow control to be available. + if s.fc.wait() { if s.logger.V(2) { - s.logger.Infof("ADS stream context canceled") + s.logger.Infof("ADS stream stopped while waiting for flow control") } return msgReceived } + + // Send out a request if anything was buffered while we were waiting for + // local processing of the previous response to complete. s.sendBuffered(stream) resources, url, version, nonce, err := s.recvMessage(stream) @@ -508,8 +509,8 @@ func (s *adsStreamImpl) recv(ctx context.Context, stream clients.Stream) bool { } var resourceNames []string var nackErr error - s.fc.setPending() - resourceNames, nackErr = s.eventHandler.onResponse(resp, s.fc.onDone) + s.fc.setPending(true) + resourceNames, nackErr = s.eventHandler.onResponse(resp, sync.OnceFunc(func() { s.fc.setPending(false) })) if xdsresource.ErrType(nackErr) == xdsresource.ErrorTypeResourceTypeUnsupported { // A general guiding principle is that if the server sends // something the client didn't actually subscribe to, then the @@ -707,69 +708,84 @@ func resourceNames(m map[string]*xdsresource.ResourceWatchState) []string { return ret } -// adsFlowControl implements ADS stream level flow control that enables the -// transport to block the reading of the next message off of the stream until -// the previous update is consumed by all watchers. +// adsFlowControl implements ADS stream level flow control that enables the ADS +// stream to block the reading of the next message until the previous update is +// consumed by all watchers. // -// The lifetime of the flow control is tied to the lifetime of the stream. +// The lifetime of the flow control is tied to the lifetime of the stream. When +// the stream is closed, it is the responsibility of the caller to stop the flow +// control. This ensures that any goroutine blocked on the flow control's wait +// method is unblocked. type adsFlowControl struct { - logger *igrpclog.PrefixLogger - - // Whether the most recent update is pending consumption by all watchers. - pending atomic.Bool - // Channel used to notify when all the watchers have consumed the most - // recent update. Wait() blocks on reading a value from this channel. - readyCh chan struct{} + mu sync.Mutex + // cond is used to signal when the most recent update has been consumed, or + // the flow control has been stopped (in which case, waiters should be + // unblocked as well). + cond *sync.Cond + pending bool // indicates if the most recent update is pending consumption + stopped bool // indicates if the ADS stream has been stopped } // newADSFlowControl returns a new adsFlowControl. -func newADSFlowControl(logger *igrpclog.PrefixLogger) *adsFlowControl { - return &adsFlowControl{ - logger: logger, - readyCh: make(chan struct{}, 1), - } +func newADSFlowControl() *adsFlowControl { + fc := &adsFlowControl{} + fc.cond = sync.NewCond(&fc.mu) + return fc } -// setPending changes the internal state to indicate that there is an update -// pending consumption by all watchers. -func (fc *adsFlowControl) setPending() { - fc.pending.Store(true) +// stop marks the flow control as stopped and signals the condition variable to +// unblock any goroutine waiting on it. +func (fc *adsFlowControl) stop() { + fc.mu.Lock() + defer fc.mu.Unlock() + + fc.stopped = true + fc.cond.Broadcast() } -// wait blocks until all the watchers have consumed the most recent update and -// returns true. If the context expires before that, it returns false. -func (fc *adsFlowControl) wait(ctx context.Context) bool { - // If there is no pending update, there is no need to block. - if !fc.pending.Load() { - // If all watchers finished processing the most recent update before the - // `recv` goroutine made the next call to `Wait()`, there would be an - // entry in the readyCh channel that needs to be drained to ensure that - // the next call to `Wait()` doesn't unblock before it actually should. - select { - case <-fc.readyCh: - default: - } - return true +// setPending changes the internal state to indicate whether there is an update +// pending consumption by all watchers. If there is no longer a pending update, +// the condition variable is signaled to allow the recv method to proceed. +func (fc *adsFlowControl) setPending(pending bool) { + fc.mu.Lock() + defer fc.mu.Unlock() + + if fc.stopped { + return } - select { - case <-ctx.Done(): + fc.pending = pending + if !pending { + fc.cond.Broadcast() + } +} + +func (fc *adsFlowControl) runIfPending(f func()) bool { + fc.mu.Lock() + defer fc.mu.Unlock() + + if fc.stopped { return false - case <-fc.readyCh: - return true } + + // If there's a pending update, run the function while still holding the + // lock. This ensures that the pending state does not change between the + // check and the function call. + if fc.pending { + f() + } + return fc.pending } -// onDone indicates that all watchers have consumed the most recent update. -func (fc *adsFlowControl) onDone() { - select { - // Writes to the readyCh channel should not block ideally. The default - // branch here is to appease the paranoid mind. - case fc.readyCh <- struct{}{}: - default: - if fc.logger.V(2) { - fc.logger.Infof("ADS stream flow control readyCh is full") - } +// wait blocks until all the watchers have consumed the most recent update. +// Returns true if the flow control was stopped while waiting, false otherwise. +func (fc *adsFlowControl) wait() bool { + fc.mu.Lock() + defer fc.mu.Unlock() + + for fc.pending && !fc.stopped { + fc.cond.Wait() } - fc.pending.Store(false) + + return fc.stopped } diff --git a/internal/xds/clients/xdsclient/test/ads_stream_flow_control_test.go b/internal/xds/clients/xdsclient/test/ads_stream_flow_control_test.go index 9c417a14b28d..83155a992438 100644 --- a/internal/xds/clients/xdsclient/test/ads_stream_flow_control_test.go +++ b/internal/xds/clients/xdsclient/test/ads_stream_flow_control_test.go @@ -20,7 +20,6 @@ package xdsclient_test import ( "context" - "errors" "fmt" "slices" "sort" @@ -125,7 +124,6 @@ func (t *transport) NewStream(ctx context.Context, method string) (clients.Strea stream := &stream{ stream: s, recvCh: make(chan struct{}, 1), - doneCh: make(chan struct{}), } t.adsStreamCh <- stream @@ -138,9 +136,7 @@ func (t *transport) Close() { type stream struct { stream grpc.ClientStream - recvCh chan struct{} - doneCh <-chan struct{} } func (s *stream) Send(msg []byte) error { @@ -150,8 +146,8 @@ func (s *stream) Send(msg []byte) error { func (s *stream) Recv() ([]byte, error) { select { case s.recvCh <- struct{}{}: - case <-s.doneCh: - return nil, errors.New("Recv() called after the test has finished") + case <-s.stream.Context().Done(): + // Unblock the recv() once the stream is done. } var typedRes []byte diff --git a/internal/xds/clients/xdsclient/xdsclient.go b/internal/xds/clients/xdsclient/xdsclient.go index cc7d5c4e264d..b1c6955484dc 100644 --- a/internal/xds/clients/xdsclient/xdsclient.go +++ b/internal/xds/clients/xdsclient/xdsclient.go @@ -319,7 +319,7 @@ func (c *XDSClient) releaseChannel(serverConfig *ServerConfig, state *channelSta c.channelsMu.Lock() if c.logger.V(2) { - c.logger.Infof("Received request to release a reference to an xdsChannel for server config %q", serverConfig) + c.logger.Infof("Received request to release a reference to an xdsChannel for server config %+v", serverConfig) } deInitLocked(state) diff --git a/internal/xds/xdsclient/tests/dump_test.go b/internal/xds/xdsclient/tests/dump_test.go index a476c29bca7c..11a43c1275a9 100644 --- a/internal/xds/xdsclient/tests/dump_test.go +++ b/internal/xds/xdsclient/tests/dump_test.go @@ -63,7 +63,7 @@ func makeGenericXdsConfig(typeURL, name, version string, status v3adminpb.Client } func checkResourceDump(ctx context.Context, want *v3statuspb.ClientStatusResponse, pool *xdsclient.Pool) error { - var cmpOpts = cmp.Options{ + cmpOpts := cmp.Options{ protocmp.Transform(), protocmp.IgnoreFields((*v3statuspb.ClientConfig_GenericXdsConfig)(nil), "last_updated"), protocmp.IgnoreFields((*v3adminpb.UpdateFailureState)(nil), "last_update_attempt", "details"), @@ -89,7 +89,7 @@ func checkResourceDump(ctx context.Context, want *v3statuspb.ClientStatusRespons if diff == "" { return nil } - lastErr = fmt.Errorf("received unexpected resource dump, diff (-got, +want):\n%s, got: %s\n want:%s", diff, pretty.ToJSON(got), pretty.ToJSON(want)) + lastErr = fmt.Errorf("received unexpected resource dump, diff (-want, +got):\n%s, got: %s\n want:%s", diff, pretty.ToJSON(got), pretty.ToJSON(want)) } return fmt.Errorf("timeout when waiting for resource dump to reach expected state: %v", lastErr) }