Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions pkg/flag/remote_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ package flag

import (
"net/http"
"net/url"
"strings"

"golang.org/x/xerrors"

"github.com/aquasecurity/trivy/pkg/log"
)

Expand Down Expand Up @@ -112,6 +115,12 @@ func (f *RemoteFlagGroup) Flags() []Flagger {

func (f *RemoteFlagGroup) ToOptions(opts *Options) error {
serverAddr := f.ServerAddr.Value()

// Validate server schema
if err := validateServerSchema(serverAddr); err != nil {
return err
}

customHeaders := splitCustomHeaders(f.CustomHeaders.Value())
listen := f.Listen.Value()
token := f.Token.Value()
Expand Down Expand Up @@ -159,3 +168,22 @@ func splitCustomHeaders(headers []string) http.Header {
}
return result
}

func validateServerSchema(serverAddr string) error {
if serverAddr == "" {
return nil
}

parsedURL, err := url.Parse(serverAddr)
if err != nil {
return xerrors.Errorf("invalid server address format: %w", err)
}

if parsedURL.Scheme == "" {
return xerrors.Errorf("server address must include HTTP or HTTPS schema (e.g., http://localhost:4954 or https://localhost:4954)")
} else if parsedURL.Scheme != "http" && parsedURL.Scheme != "https" {
return xerrors.Errorf("server address must use HTTP or HTTPS schema, got '%s' (e.g., use http://localhost:4954 instead of %s)", parsedURL.Scheme, serverAddr)
}

return nil
}
40 changes: 40 additions & 0 deletions pkg/flag/remote_flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
fields fields
want flag.RemoteOptions
wantLogs []string
wantErr string
}{
{
name: "happy",
Expand Down Expand Up @@ -93,6 +94,39 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
`"--token-header" should be used with "--token"`,
},
},
{
name: "server address without schema",
fields: fields{
Server: "localhost:8080",
},
wantErr: "server address must use HTTP or HTTPS schema, got 'localhost'",
},
{
name: "server address with invalid schema",
fields: fields{
Server: "ftp://localhost:8080",
},
wantErr: "server address must use HTTP or HTTPS schema, got 'ftp'",
},
{
name: "server address with malformed URL",
fields: fields{
Server: "http://[::1:8080",
},
wantErr: "invalid server address format",
},
{
name: "server address with https schema",
fields: fields{
Server: "https://localhost:4954",
TokenHeader: "Trivy-Token",
},
want: flag.RemoteOptions{
CustomHeaders: http.Header{},
ServerAddr: "https://localhost:4954",
TokenHeader: "Trivy-Token",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -112,6 +146,12 @@ func TestRemoteFlagGroup_ToOptions(t *testing.T) {
}
flags := flag.Flags{f}
got, err := flags.ToOptions(nil)

if tt.wantErr != "" {
assert.ErrorContains(t, err, tt.wantErr)
return
}

require.NoError(t, err)
assert.Equal(t, tt.want, got.RemoteOptions)

Expand Down
Loading