diff --git a/internal/handlers/git_server_test.go b/internal/handlers/git_server_test.go index 2fda941..9d0c6cf 100644 --- a/internal/handlers/git_server_test.go +++ b/internal/handlers/git_server_test.go @@ -179,14 +179,15 @@ func TestGitServerHandler_AuthenticatedAccessToGitHubRepos(t *testing.T) { req := httptest.NewRequest("GET", fmt.Sprintf("https://github.com/%s", tt.repoNWO), nil) req, _ = handler.HandleRequest(req, nil) - if tt.expectedCredential != nil { + switch { + case tt.expectedCredential != nil: assertHasBasicAuth(t, req, tt.expectedCredential.GetString("username"), tt.expectedCredential.GetString("password"), "valid github request") - } else if tt.isAuthenticated { + case tt.isAuthenticated: assertAuthenticated(t, req, "valid github request") - } else { + default: assertUnauthenticated(t, req, "valid github request") } }) diff --git a/internal/handlers/github_api_test.go b/internal/handlers/github_api_test.go index cda8ded..39f64ec 100644 --- a/internal/handlers/github_api_test.go +++ b/internal/handlers/github_api_test.go @@ -159,11 +159,12 @@ func TestGitHubAPIHandler_AuthenticatedAccessToGitHubRepos(t *testing.T) { req := httptest.NewRequest("GET", fmt.Sprintf("https://api.github.com/%s", tt.repoNWO), nil) req, _ = handler.HandleRequest(req, nil) - if tt.expectedCredential != nil { + switch { + case tt.expectedCredential != nil: assertHasTokenAuth(t, req, "token", tt.expectedCredential.GetString("password"), "valid api request") - } else if tt.isAuthenticated { + case tt.isAuthenticated: assertAuthenticated(t, req, "valid github request") - } else { + default: assertUnauthenticated(t, req, "valid github request") } }) diff --git a/internal/handlers/nuget_feed.go b/internal/handlers/nuget_feed.go index eeb6f39..7e0a6b5 100644 --- a/internal/handlers/nuget_feed.go +++ b/internal/handlers/nuget_feed.go @@ -189,13 +189,14 @@ func extraUrlsFromSourceResponse(body []byte, url string) []string { var urls []string bodyString := strings.TrimSpace(string(body)) bodyReader := bytes.NewReader(body) - if strings.HasPrefix(bodyString, "<") { + switch { + case strings.HasPrefix(bodyString, "<"): // XML v2 API urls = handleV2Response(bodyReader, url) - } else if strings.HasPrefix(bodyString, "{") { + case strings.HasPrefix(bodyString, "{"): // JSON v3 API urls = handleV3Response(bodyReader, url) - } else { + default: logging.RequestLogf(nil, "unknown API response: %s...", bodyString[:10]) } diff --git a/internal/handlers/python_index.go b/internal/handlers/python_index.go index 876a935..29e6319 100644 --- a/internal/handlers/python_index.go +++ b/internal/handlers/python_index.go @@ -14,6 +14,8 @@ import ( "github.com/dependabot/proxy/internal/oidc" ) +var simpleSuffixRe = regexp.MustCompile(`/\+?simple/?\z`) + // PythonIndexHandler handles requests to Python indexes, adding auth. type PythonIndexHandler struct { credentials []pythonIndexCredentials @@ -89,8 +91,7 @@ func (h *PythonIndexHandler) HandleRequest(req *http.Request, ctx *goproxy.Proxy // Fall back to static credentials for _, cred := range h.credentials { - re, _ := regexp.Compile(`/\+?simple/?\z`) - indexURL := re.ReplaceAllString(cred.indexURL, "/") + indexURL := simpleSuffixRe.ReplaceAllString(cred.indexURL, "/") if !helpers.UrlMatchesRequest(req, indexURL, true) && !helpers.CheckHost(req, cred.host) { continue } diff --git a/internal/logging/logging.go b/internal/logging/logging.go index d4245d4..8ebd328 100644 --- a/internal/logging/logging.go +++ b/internal/logging/logging.go @@ -47,7 +47,7 @@ func requestLog(ctx *goproxy.ProxyCtx, message string) { argv = append([]any{reqId}, argv...) if cache.WasResponseCached(ctx) { - format = format + " (cached)" + format += " (cached)" } } formatted := fmt.Sprintf(format, argv...) diff --git a/internal/oidc/oidc_credential.go b/internal/oidc/oidc_credential.go index c25465f..3914b18 100644 --- a/internal/oidc/oidc_credential.go +++ b/internal/oidc/oidc_credential.go @@ -86,12 +86,13 @@ func CreateOIDCCredential(cred config.Credential) (*OIDCCredential, error) { domain := cred.GetString("domain") domainOwner := cred.GetString("domain-owner") - if tenantID != "" && clientID != "" { + switch { + case tenantID != "" && clientID != "": parameters = &AzureOIDCParameters{ TenantID: tenantID, ClientID: clientID, } - } else if jfrogOidcProviderName != "" && feedUrl != "" { + case jfrogOidcProviderName != "" && feedUrl != "": // jfrog domain is extracted from feed url jfrogUrlParsed, err := url.Parse(feedUrl) if err != nil { @@ -105,7 +106,7 @@ func CreateOIDCCredential(cred config.Credential) (*OIDCCredential, error) { Audience: cred.GetString("audience"), IdentityMappingName: cred.GetString("identity-mapping-name"), } - } else if awsRegion != "" && accountID != "" && roleName != "" && domain != "" && domainOwner != "" { + case awsRegion != "" && accountID != "" && roleName != "" && domain != "" && domainOwner != "": audience := cred.GetString("audience") if audience == "" { audience = "sts.amazonaws.com" // defaults to this diff --git a/main.go b/main.go index 8c6628a..534bdb0 100644 --- a/main.go +++ b/main.go @@ -40,12 +40,14 @@ func main() { cfg, err := config.Parse(*configPath) if err != nil { - log.Fatal(err) + log.Println(err) + return } sentry, err := setupSentry() if err != nil { - log.Fatal(err) + log.Println(err) + return } envSettings := config.ProxyEnvSettings{ @@ -89,11 +91,13 @@ func main() { log.Printf("Listening (%s)", *addr) if err := server.ListenAndServe(); err != http.ErrServerClosed { - log.Fatal(err) + log.Println(err) + return } if err := proxy.Close(); err != nil { - log.Fatal(err) + log.Println(err) + return } }