diff --git a/pkg/flag/remote_flags.go b/pkg/flag/remote_flags.go index c821041ec2b5..b214285b2de5 100644 --- a/pkg/flag/remote_flags.go +++ b/pkg/flag/remote_flags.go @@ -2,8 +2,11 @@ package flag import ( "net/http" + "net/url" "strings" + "golang.org/x/xerrors" + "github.com/aquasecurity/trivy/pkg/log" ) @@ -112,6 +115,12 @@ func (f *RemoteFlagGroup) Flags() []Flagger { func (f *RemoteFlagGroup) ToOptions(opts *Options) error { serverAddr := f.ServerAddr.Value() + + // Validate server schema + if err := validateServerSchema(serverAddr); err != nil { + return err + } + customHeaders := splitCustomHeaders(f.CustomHeaders.Value()) listen := f.Listen.Value() token := f.Token.Value() @@ -159,3 +168,22 @@ func splitCustomHeaders(headers []string) http.Header { } return result } + +func validateServerSchema(serverAddr string) error { + if serverAddr == "" { + return nil + } + + parsedURL, err := url.Parse(serverAddr) + if err != nil { + return xerrors.Errorf("invalid server address format: %w", err) + } + + if parsedURL.Scheme == "" { + return xerrors.Errorf("server address must include HTTP or HTTPS schema (e.g., http://localhost:4954 or https://localhost:4954)") + } else if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" { + return xerrors.Errorf("server address must use HTTP or HTTPS schema, got '%s' (e.g., use http://localhost:4954 instead of %s)", parsedURL.Scheme, serverAddr) + } + + return nil +} diff --git a/pkg/flag/remote_flags_test.go b/pkg/flag/remote_flags_test.go index c11b9e7e055d..7137431f52d5 100644 --- a/pkg/flag/remote_flags_test.go +++ b/pkg/flag/remote_flags_test.go @@ -24,6 +24,7 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) { fields fields want flag.RemoteOptions wantLogs []string + wantErr string }{ { name: "happy", @@ -93,6 +94,39 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) { `"--token-header" should be used with "--token"`, }, }, + { + name: "server address without schema", + fields: fields{ + Server: "localhost:8080", + }, + wantErr: "server address must use HTTP or HTTPS schema, got 'localhost'", + }, + { + name: "server address with invalid schema", + fields: fields{ + Server: "ftp://localhost:8080", + }, + wantErr: "server address must use HTTP or HTTPS schema, got 'ftp'", + }, + { + name: "server address with malformed URL", + fields: fields{ + Server: "http://[::1:8080", + }, + wantErr: "invalid server address format", + }, + { + name: "server address with https schema", + fields: fields{ + Server: "https://localhost:4954", + TokenHeader: "Trivy-Token", + }, + want: flag.RemoteOptions{ + CustomHeaders: http.Header{}, + ServerAddr: "https://localhost:4954", + TokenHeader: "Trivy-Token", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -112,6 +146,12 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) { } flags := flag.Flags{f} got, err := flags.ToOptions(nil) + + if tt.wantErr != "" { + assert.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) assert.Equal(t, tt.want, got.RemoteOptions)