diff --git a/agent/agent.go b/agent/agent.go index 5c59c1416..b05170810 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -255,7 +255,7 @@ func (r *Agent) setupSink(ctx context.Context, sr recipe.PluginRecipe, stream *s ) } stream.subscribe(func(records []models.Record) error { - err := r.retrier.retry(func() error { + err := r.retrier.retry(ctx, func() error { err := sink.Sink(ctx, records) return err }, retryNotification) diff --git a/agent/agent_test.go b/agent/agent_test.go index 2d6aa6641..fdbfeb83b 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -704,6 +704,66 @@ func TestAgentRun(t *testing.T) { assert.NoError(t, run.Error) assert.Equal(t, validRecipe, run.Recipe) }) + + t.Run("should respect context cancellation and stop retries", func(t *testing.T) { + err := errors.New("some-error") + ctx, cancel := context.WithCancel(ctx) + data := []models.Record{ + models.NewRecord(&v1beta2.Asset{}), + } + + extr := mocks.NewExtractor() + extr.SetEmit(data) + extr.On("Init", utils.OfTypeContext(), buildPluginConfig(validRecipe.Source)).Return(nil).Once() + extr.On("Extract", utils.OfTypeContext(), mock.AnythingOfType("plugins.Emit")).Return(nil) + ef := registry.NewExtractorFactory() + if err := ef.Register("test-extractor", newExtractor(extr)); err != nil { + t.Fatal(err) + } + + proc := mocks.NewProcessor() + proc.On("Init", utils.OfTypeContext(), buildPluginConfig(validRecipe.Processors[0])).Return(nil).Once() + proc.On("Process", utils.OfTypeContext(), data[0]).Return(data[0], nil) + defer proc.AssertExpectations(t) + pf := registry.NewProcessorFactory() + if err := pf.Register("test-processor", newProcessor(proc)); err != nil { + t.Fatal(err) + } + + sink := mocks.NewSink() + sink.On("Init", utils.OfTypeContext(), buildPluginConfig(validRecipe.Sinks[0])).Return(nil).Once() + // Sink should not be called more than once in total since we cancel the context after the first call. + sink.On("Sink", utils.OfTypeContext(), data).Return(plugins.NewRetryError(err)).Once().Run(func(args mock.Arguments) { + go func() { + cancel() + }() + }) + sink.On("Close").Return(nil) + defer sink.AssertExpectations(t) + sf := registry.NewSinkFactory() + if err := sf.Register("test-sink", newSink(sink)); err != nil { + t.Fatal(err) + } + + monitor := newMockMonitor() + monitor.On("RecordRun", mock.AnythingOfType("agent.Run")).Once() + monitor.On("RecordPlugin", mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("string"), mock.AnythingOfType("bool")) + defer monitor.AssertExpectations(t) + + r := agent.NewAgent(agent.Config{ + ExtractorFactory: ef, + ProcessorFactory: pf, + SinkFactory: sf, + Logger: utils.Logger, + Monitor: monitor, + MaxRetries: 5, + RetryInitialInterval: 10 * time.Second, + }) + run := r.Run(ctx, validRecipe) + assert.NoError(t, run.Error) + assert.Equal(t, validRecipe, run.Recipe) + }) + } func TestAgentRunMultiple(t *testing.T) { diff --git a/agent/retrier.go b/agent/retrier.go index 5404a5c8e..84a246b9a 100644 --- a/agent/retrier.go +++ b/agent/retrier.go @@ -1,6 +1,7 @@ package agent import ( + "context" "errors" "time" @@ -30,8 +31,9 @@ func newRetrier(maxRetries int, initialInterval time.Duration) *retrier { return &r } -func (r *retrier) retry(operation func() error, notify func(e error, d time.Duration)) error { +func (r *retrier) retry(ctx context.Context, operation func() error, notify func(e error, d time.Duration)) error { bo := backoff.WithMaxRetries(r.createExponentialBackoff(r.initialInterval), uint64(r.maxRetries)) + bo = backoff.WithContext(bo, ctx) return backoff.RetryNotify(func() error { err := operation() if err == nil { diff --git a/test/utils/arg_matcher.go b/test/utils/arg_matcher.go new file mode 100644 index 000000000..752dc490d --- /dev/null +++ b/test/utils/arg_matcher.go @@ -0,0 +1,13 @@ +package utils + +import ( + "context" + + "github.com/stretchr/testify/mock" +) + +type ArgMatcher interface{ Matches(interface{}) bool } + +func OfTypeContext() ArgMatcher { + return mock.MatchedBy(func(ctx context.Context) bool { return ctx != nil }) +}