diff --git a/agent/agent.go b/agent/agent.go index 080a7e0e6..efcc0fe8c 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -174,12 +174,13 @@ func (r *Agent) Run(ctx context.Context, recipe recipe.Recipe) (run Run) { r.logger.Info(string(debug.Stack())) run.Error = fmt.Errorf("agent run: close stream: panic: %s", rcvr) } - stream.Close() + stream.Shutdown() }() if err := runExtractor(); err != nil { run.Error = errors.Wrap(err, "failed to run extractor") } }() + defer stream.Close() // start listening. // this process is blocking @@ -270,8 +271,6 @@ func (r *Agent) setupSink(ctx context.Context, sr recipe.PluginRecipe, stream *s return nil }, defaultBatchSize) - // TODO: the sink closes even though some records remain unpublished - // TODO: once fixed, file sink's Close needs to close *File stream.onClose(func() { if err := sink.Close(); err != nil { r.logger.Warn("error closing sink", "sink", sr.Name, "error", err) diff --git a/agent/agent_test.go b/agent/agent_test.go index 9fe62ee52..4df7efd6d 100644 --- a/agent/agent_test.go +++ b/agent/agent_test.go @@ -22,6 +22,7 @@ import ( v1beta2 "github.com/odpf/meteor/models/odpf/assets/v1beta2" _ "github.com/odpf/meteor/plugins/extractors" // populate extractors registry _ "github.com/odpf/meteor/plugins/processors" // populate processors registry + _ "github.com/odpf/meteor/plugins/sinks" // populate sinks registry ) var ( @@ -862,6 +863,36 @@ func TestAgentRun(t *testing.T) { utils.AssertEqualProto(t, expected, records[0].Data()) }) + t.Run("should close stream after sink finishes writing records", func(t *testing.T) { + r := agent.NewAgent(agent.Config{ + ExtractorFactory: registry.Extractors, + ProcessorFactory: registry.Processors, + SinkFactory: registry.Sinks, + Logger: utils.Logger, + StopOnSinkError: true, + }) + + run := r.Run(ctx, recipe.Recipe{ + Name: "sink_close-test", + Version: "v1beta1", + Source: recipe.PluginRecipe{ + Name: "application_yaml", + Scope: "application-test", + Config: map[string]interface{}{ + "file": "../plugins/extractors/application_yaml/testdata/application.detailed.yaml", + }, + }, + Sinks: []recipe.PluginRecipe{{ + Name: "file", + Config: map[string]interface{}{ + "path": "./application_yaml-sink[yaml].out", + "format": "yaml", + "overwrite": true, + }, + }}, + }) + assert.NoError(t, run.Error) + }) } func TestAgentRunMultiple(t *testing.T) { diff --git a/agent/stream.go b/agent/stream.go index b58793a83..07dd4b6df 100644 --- a/agent/stream.go +++ b/agent/stream.go @@ -20,6 +20,7 @@ type stream struct { subscribers []*subscriber onCloses []func() mu sync.Mutex + shutdown bool closed bool err error } @@ -116,23 +117,35 @@ func (s *stream) closeWithError(err error) { s.Close() } -// Close the emitter and signalling all subscriber of the event. -func (s *stream) Close() { +func (s *stream) Shutdown() { s.mu.Lock() defer s.mu.Unlock() - if s.closed { + if s.shutdown { return } for _, l := range s.subscribers { close(l.channel) } - s.closed = true + s.shutdown = true +} + +// Close the emitter and signalling all subscriber of the event. +func (s *stream) Close() { + s.Shutdown() + + s.mu.Lock() + defer s.mu.Unlock() + + if s.closed { + return + } for _, onClose := range s.onCloses { onClose() } + s.closed = true } func (s *stream) runMiddlewares(d models.Record) (models.Record, error) { diff --git a/plugins/sinks/file/file.go b/plugins/sinks/file/file.go index fd8019bf9..00b83c3cf 100644 --- a/plugins/sinks/file/file.go +++ b/plugins/sinks/file/file.go @@ -1,6 +1,7 @@ package file import ( + "bytes" "context" _ "embed" "fmt" @@ -52,8 +53,8 @@ func New(logger log.Logger) plugins.Syncer { return s } -func (s *Sink) Init(ctx context.Context, config plugins.Config) (err error) { - if err = s.BasePlugin.Init(ctx, config); err != nil { +func (s *Sink) Init(ctx context.Context, config plugins.Config) (error) { + if err := s.BasePlugin.Init(ctx, config); err != nil { return err } @@ -62,53 +63,53 @@ func (s *Sink) Init(ctx context.Context, config plugins.Config) (err error) { } s.format = s.config.Format + var ( + f *os.File + err error + ) if s.config.Overwrite { - s.File, err = os.Create(s.config.Path) - return err + f, err = os.Create(s.config.Path) + } else { + f, err = os.OpenFile(s.config.Path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0777) } - s.File, err = os.OpenFile(s.config.Path, os.O_RDWR|os.O_CREATE|os.O_APPEND, 0777) if err != nil { return err } - return + + s.File = f + return nil } func (s *Sink) Sink(ctx context.Context, batch []models.Record) (err error) { - var data []*assetsv1beta2.Asset + data := make([]*assetsv1beta2.Asset, 0, len(batch)) for _, record := range batch { data = append(data, record.Data()) } + if s.format == "ndjson" { - err := s.ndjsonOut(data) - if err != nil { - return err - } - return nil - } - err = s.yamlOut(data) - if err != nil { - return err + return s.ndjsonOut(data) } - return nil + + return s.yamlOut(data) } func (s *Sink) Close() (err error) { - // return s.File.Close() - return nil + return s.File.Close() } func (s *Sink) ndjsonOut(data []*assetsv1beta2.Asset) error { - result := "" + var result bytes.Buffer for _, asset := range data { jsonBytes, err := models.ToJSON(asset) if err != nil { return fmt.Errorf("error marshaling asset (%s): %w", asset.Urn, err) } - result += string(jsonBytes) + "\n" + result.Write(jsonBytes) + result.WriteRune('\n') } - if err := s.writeBytes([]byte(result)); err != nil { + if err := s.writeBytes(result.Bytes()); err != nil { return fmt.Errorf("error writing to file: %w", err) } @@ -120,16 +121,13 @@ func (s *Sink) yamlOut(data []*assetsv1beta2.Asset) error { if err != nil { return err } - err = s.writeBytes(ymlByte) - return err + + return s.writeBytes(ymlByte) } func (s *Sink) writeBytes(b []byte) error { _, err := s.File.Write(b) - if err != nil { - return err - } - return nil + return err } func (s *Sink) validateFilePath(path string) error { diff --git a/plugins/sinks/file/file_test.go b/plugins/sinks/file/file_test.go index 2ecd916b1..5138cc8ef 100644 --- a/plugins/sinks/file/file_test.go +++ b/plugins/sinks/file/file_test.go @@ -54,7 +54,7 @@ func TestInit(t *testing.T) { }) } -func TestMain(t *testing.T) { +func TestSink(t *testing.T) { t.Run("should return no error with for valid ndjson config", func(t *testing.T) { assert.NoError(t, sinkValidSetup(t, validConfig)) }) @@ -83,11 +83,15 @@ func TestMain(t *testing.T) { } func sinkInvalidPath(t *testing.T, config map[string]interface{}) error { + t.Helper() + fileSink := f.New(testUtils.Logger) return fileSink.Init(context.TODO(), plugins.Config{RawConfig: config}) } func sinkValidSetup(t *testing.T, config map[string]interface{}) error { + t.Helper() + fileSink := f.New(testUtils.Logger) err := fileSink.Init(context.TODO(), plugins.Config{RawConfig: config}) assert.NoError(t, err) @@ -97,6 +101,8 @@ func sinkValidSetup(t *testing.T, config map[string]interface{}) error { } func getExpectedVal(t *testing.T) []models.Record { + t.Helper() + table1, err := anypb.New(&v1beta2.Table{ Columns: []*v1beta2.Column{ {