Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/trivy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func run() error {
return nil
}

// Set up signal handling for graceful shutdown
ctx, stop := commands.NotifyContext(context.Background())
defer stop()

app := commands.NewApp()
return app.Execute()
return app.ExecuteContext(ctx)
}
37 changes: 37 additions & 0 deletions pkg/commands/signal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package commands

import (
"context"
"os"
"os/signal"
"syscall"

"github.com/aquasecurity/trivy/pkg/log"
)

// NotifyContext returns a context that is canceled when SIGINT or SIGTERM is received.
// It also ensures cleanup of temporary files when the signal is received.
//
// When a signal is received, Trivy will attempt to gracefully shut down by canceling
// the context and waiting for all operations to complete. If users want to force an
// immediate exit, they can send a second SIGINT or SIGTERM signal.
func NotifyContext(parent context.Context) (context.Context, context.CancelFunc) {
ctx, stop := signal.NotifyContext(parent, os.Interrupt, syscall.SIGTERM)

// Start a goroutine to handle cleanup when context is done
go func() {
<-ctx.Done()

// Log that we're shutting down gracefully
log.Info("Received signal, attempting graceful shutdown...")
log.Info("Press Ctrl+C again to force exit")

// TODO: Add any necessary cleanup logic here

// Clean up signal handling
// After calling stop(), a second signal will cause immediate termination
stop()
}()

return ctx, stop
}
3 changes: 2 additions & 1 deletion pkg/oci/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/aquasecurity/trivy/pkg/log"
"github.com/aquasecurity/trivy/pkg/remote"
"github.com/aquasecurity/trivy/pkg/version/doc"
xio "github.com/aquasecurity/trivy/pkg/x/io"
)

const (
Expand Down Expand Up @@ -188,7 +189,7 @@ func (a *Artifact) download(ctx context.Context, layer v1.Layer, fileName, dir s
}()

// Download the layer content into a temporal file
if _, err = io.Copy(f, pr); err != nil {
if _, err = xio.Copy(ctx, f, pr); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to overwrite xio.Copy for other cases?
e.g. we can copy big files to cache:

if _, err = io.Copy(f, o.reader); err != nil {
o.err = xerrors.Errorf("failed to copy: %w", err)
return
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, possibly, as I mentioned in the PR description. Since xio.Copy has some trade-off, I would like to carefully replace other xio.Copy one by one.

return xerrors.Errorf("copy error: %w", err)
}

Expand Down
37 changes: 32 additions & 5 deletions pkg/rpc/server/listen.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"errors"
"net/http"
"os"
"strings"
Expand Down Expand Up @@ -62,20 +63,46 @@ func (s Server) ListenAndServe(ctx context.Context, serverCache cache.Cache, ski
requestWg := &sync.WaitGroup{}
dbUpdateWg := &sync.WaitGroup{}

server := &http.Server{
Addr: s.addr,
Handler: s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg),
ReadHeaderTimeout: 10 * time.Second,
}

// Start DB update worker
go func() {
worker := newDBWorker(db.NewClient(s.dbDir, true, db.WithDBRepository(s.dbRepositories)))
ticker := time.NewTicker(updateInterval)
defer ticker.Stop()

for {
time.Sleep(updateInterval)
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
log.Errorf("%+v\n", err)
select {
case <-ctx.Done():
log.Debug("Server shutting down gracefully...")

// Give active requests time to complete
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if err := server.Shutdown(shutdownCtx); err != nil {
log.Errorf("Server shutdown error: %v", err)
}
cancel()
return
case <-ticker.C:
if err := worker.update(ctx, s.appVersion, s.dbDir, skipDBUpdate, dbUpdateWg, requestWg, s.RegistryOptions); err != nil {
log.Errorf("%+v\n", err)
}
}
}
}()

mux := s.NewServeMux(ctx, serverCache, dbUpdateWg, requestWg)
log.Infof("Listening %s...", s.addr)

return http.ListenAndServe(s.addr, mux)
// This will block until Shutdown is called
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return xerrors.Errorf("server error: %w", err)
}

return nil
}

func (s Server) NewServeMux(ctx context.Context, serverCache cache.Cache, dbUpdateWg, requestWg *sync.WaitGroup) *http.ServeMux {
Expand Down
25 changes: 25 additions & 0 deletions pkg/x/io/io.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package io

import (
"bytes"
"context"
"io"

"golang.org/x/xerrors"
Expand Down Expand Up @@ -71,3 +72,27 @@ type nopCloser struct {
}

func (nopCloser) Close() error { return nil }

// readerFunc is a function that implements io.Reader
type readerFunc func([]byte) (int, error)

func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}

// Copy copies from src to dst until either EOF is reached on src or the context is canceled.
// It returns the number of bytes copied and the first error encountered while copying, if any.
//
// Note: This implementation wraps the reader with a context check, which means it won't
// benefit from WriterTo optimization in io.Copy if the source implements it. This is a trade-off
// for being able to cancel the operation on context cancellation.
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {
select {
case <-ctx.Done():
return 0, ctx.Err()
default:
return src.Read(p)
}
}))
}
66 changes: 66 additions & 0 deletions pkg/x/io/io_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package io

import (
"bytes"
"context"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestCopy(t *testing.T) {
t.Run("successful copy", func(t *testing.T) {
ctx := t.Context()
src := strings.NewReader("hello world")
dst := &bytes.Buffer{}

n, err := Copy(ctx, dst, src)
require.NoError(t, err)
assert.Equal(t, int64(11), n)
assert.Equal(t, "hello world", dst.String())
})

t.Run("context canceled before read", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel() // Cancel immediately

src := strings.NewReader("hello world")
dst := &bytes.Buffer{}

n, err := Copy(ctx, dst, src)
require.ErrorIs(t, err, context.Canceled)
assert.Equal(t, int64(0), n)
assert.Empty(t, dst.String())
})

t.Run("context canceled during read", func(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())

// Create a reader that will be canceled after first read
reader := &dummyReader{
cancel: cancel, // Cancel after first read
}
dst := &bytes.Buffer{}

n, err := Copy(ctx, dst, reader)
require.ErrorIs(t, err, context.Canceled)
// Should have written first chunk before cancellation
assert.Equal(t, int64(5), n)
assert.Equal(t, "dummy", dst.String())
})
}

// dummyReader returns the same data on every Read call
type dummyReader struct {
cancel context.CancelFunc
}

func (r *dummyReader) Read(p []byte) (int, error) {
n := copy(p, "dummy")
if r.cancel != nil {
r.cancel() // Simulate cancellation after first read
}
return n, nil
}
Loading