diff --git a/.changelog/22732.txt b/.changelog/22732.txt new file mode 100644 index 000000000000..c3f8ac66fc60 --- /dev/null +++ b/.changelog/22732.txt @@ -0,0 +1,3 @@ +```release-note:feature +oidc: add client authentication using JWT assertion and PKCE. default PKCE is enabled. +``` diff --git a/api/acl.go b/api/acl.go index 1a13fd38ac5a..8961b12ec169 100644 --- a/api/acl.go +++ b/api/acl.go @@ -483,12 +483,14 @@ type OIDCAuthMethodConfig struct { OIDCDiscoveryURL string `json:",omitempty"` OIDCDiscoveryCACert string `json:",omitempty"` // just for type=oidc - OIDCClientID string `json:",omitempty"` - OIDCClientSecret string `json:",omitempty"` - OIDCScopes []string `json:",omitempty"` - OIDCACRValues []string `json:",omitempty"` - AllowedRedirectURIs []string `json:",omitempty"` - VerboseOIDCLogging bool `json:",omitempty"` + OIDCClientID string `json:",omitempty"` + OIDCClientSecret string `json:",omitempty"` + OIDCClientAssertion *OIDCClientAssertion `json:",omitempty"` + OIDCClientUsePKCE *bool `json:",omitempty"` + OIDCScopes []string `json:",omitempty"` + OIDCACRValues []string `json:",omitempty"` + AllowedRedirectURIs []string `json:",omitempty"` + VerboseOIDCLogging bool `json:",omitempty"` // just for type=jwt JWKSURL string `json:",omitempty"` JWKSCACert string `json:",omitempty"` @@ -513,6 +515,8 @@ func (c *OIDCAuthMethodConfig) RenderToConfig() map[string]interface{} { // just for type=oidc "OIDCClientID": c.OIDCClientID, "OIDCClientSecret": c.OIDCClientSecret, + "OIDCClientAssertion": c.OIDCClientAssertion, + "OIDCClientUsePKCE": c.OIDCClientUsePKCE, "OIDCScopes": c.OIDCScopes, "OIDCACRValues": c.OIDCACRValues, "AllowedRedirectURIs": c.AllowedRedirectURIs, @@ -528,6 +532,16 @@ func (c *OIDCAuthMethodConfig) RenderToConfig() map[string]interface{} { } } +type OIDCClientAssertion struct { + Audience []string + PrivateKey *OIDCClientAssertionKey + KeyAlgorithm string +} + +type OIDCClientAssertionKey struct { + PemKey string +} + type ACLLoginParams struct { AuthMethod string BearerToken string diff --git a/go.mod b/go.mod index 8c8e365d0f1b..dde322e31f35 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/armon/go-metrics v0.4.1 github.com/armon/go-radix v1.0.0 github.com/aws/aws-sdk-go v1.55.7 - github.com/coreos/go-oidc/v3 v3.9.0 + github.com/coreos/go-oidc/v3 v3.11.0 github.com/deckarep/golang-set/v2 v2.3.1 github.com/docker/go-connections v0.4.0 github.com/envoyproxy/go-control-plane v0.13.4 @@ -37,12 +37,14 @@ require ( github.com/go-openapi/runtime v0.26.2 github.com/go-openapi/strfmt v0.23.0 github.com/go-viper/mapstructure/v2 v2.4.0 + github.com/golang-jwt/jwt/v5 v5.2.2 github.com/google/go-cmp v0.7.0 github.com/google/gofuzz v1.2.0 github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 github.com/google/tcpproxy v0.0.0-20180808230851-dfa16c61dad2 github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 github.com/hashi-derek/grpc-proxy v0.0.0-20231207191910-191266484d75 + github.com/hashicorp/cap v0.10.0 github.com/hashicorp/consul-awsauth v0.0.0-20250825122907-9e35fe9ded3a github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 github.com/hashicorp/consul/api v1.31.2 @@ -191,6 +193,7 @@ require ( github.com/emicklei/go-restful/v3 v3.10.1 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/go-jose/go-jose/v4 v4.1.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-ole/go-ole v1.2.6 // indirect diff --git a/go.sum b/go.sum index 35704ed182da..86863f485b52 100644 --- a/go.sum +++ b/go.sum @@ -197,8 +197,8 @@ github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc github.com/coreos/etcd v3.3.27+incompatible h1:QIudLb9KeBsE5zyYxd1mjzRSkzLg9Wf9QlRwFgd6oTA= github.com/coreos/etcd v3.3.27+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-oidc/v3 v3.9.0 h1:0J/ogVOd4y8P0f0xUh8l9t07xRP/d8tccvjHl2dcsSo= -github.com/coreos/go-oidc/v3 v3.9.0/go.mod h1:rTKz2PYwftcrtoCzV5g5kvfJoWcm0Mk8AF8y1iAQro4= +github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI= +github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0= github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/coreos/go-systemd v0.0.0-20191104093116-d3cd4ed1dbcf h1:iW4rZ826su+pqaw19uhpSCzhj44qo35pNgKFGqzDKkU= @@ -279,6 +279,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-jose/go-jose/v4 v4.1.1 h1:JYhSgy4mXXzAdF3nUx3ygx347LRXJRrpgyU3adRmkAI= +github.com/go-jose/go-jose/v4 v4.1.1/go.mod h1:BdsZGqgdO3b6tTc6LSE56wcDbMMLuPsw5d4ZD5f94kA= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= @@ -331,6 +333,8 @@ github.com/golang-jwt/jwt/v4 v4.0.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzw github.com/golang-jwt/jwt/v4 v4.2.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v1.2.5 h1:DrW6hGnjIhtvhOIiAKT6Psh/Kd/ldepEa81DKeiRJ5I= github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= @@ -445,6 +449,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0 h1:YBftPWNWd4WwGqtY2yeZL2ef8rH github.com/grpc-ecosystem/grpc-gateway/v2 v2.16.0/go.mod h1:YN5jB8ie0yfIUg6VvR9Kz84aCaG7AsGZnLjhHbUqwPg= github.com/hashi-derek/grpc-proxy v0.0.0-20231207191910-191266484d75 h1:V5Uqf7VoWMd6UhNf/5EMA8LMPUm95GYvk2YF5SzT24o= github.com/hashi-derek/grpc-proxy v0.0.0-20231207191910-191266484d75/go.mod h1:5eEnHfK72jOkp4gC1dI/Q/E9MFNOM/ewE/vql5ijV3g= +github.com/hashicorp/cap v0.10.0 h1:OJM3JQTwVO1DigRIPNTxM387oqXlokKhttZHotU0b1s= +github.com/hashicorp/cap v0.10.0/go.mod h1:HKbv27kfps+wONFNyNTHpAQmU/DCjjDuB5HF6mFsqPQ= github.com/hashicorp/consul-awsauth v0.0.0-20250825122907-9e35fe9ded3a h1:Qd0N8lIr1QP/d7FYxseYjRLUtJp2+2R8k+mjiC2rmiY= github.com/hashicorp/consul-awsauth v0.0.0-20250825122907-9e35fe9ded3a/go.mod h1:++exZ1sI8JLIv4QvzGvTjZdf1eZARoZcaNEjNT9SZYA= github.com/hashicorp/consul-net-rpc v0.0.0-20221205195236-156cfab66a69 h1:wzWurXrxfSyG1PHskIZlfuXlTSCj1Tsyatp9DtaasuY= diff --git a/internal/go-sso/oidcauth/auth.go b/internal/go-sso/oidcauth/auth.go index e4a1fe766eff..f9f5189d9f7e 100644 --- a/internal/go-sso/oidcauth/auth.go +++ b/internal/go-sso/oidcauth/auth.go @@ -17,6 +17,7 @@ import ( "sync" "github.com/coreos/go-oidc/v3/oidc" + capOidc "github.com/hashicorp/cap/oidc" "github.com/hashicorp/go-hclog" "github.com/patrickmn/go-cache" ) @@ -41,8 +42,14 @@ type Authenticator struct { // parsedJWTPubKeys is the parsed form of config.JWTValidationPubKeys parsedJWTPubKeys []interface{} - provider *oidc.Provider - keySet oidc.KeySet + + // provider is the coreos/go-oidc provider used for JWT validation + provider *oidc.Provider + + // capProvider is the HashiCorp CAP library provider used for OIDC flows + // with support for private key JWT client authentication + capProvider *capOidc.Provider + keySet oidc.KeySet // httpClient should be configured with all relevant root CA certs and be // reused for all OIDC or JWKS operations. This will be nil for the static @@ -86,13 +93,43 @@ func New(c *Config, logger hclog.Logger) (*Authenticator, error) { } a.backgroundCtx, a.backgroundCtxCancel = context.WithCancel(context.Background()) + var err error if c.Type == TypeOIDC { a.oidcStates = cache.New(oidcStateTimeout, oidcStateCleanupInterval) } - var err error switch c.authType() { - case authOIDCDiscovery, authOIDCFlow: + case authOIDCFlow: + var supported []capOidc.Alg + if len(a.config.JWTSupportedAlgs) == 0 { + // Default to RS256 if nothing is specified. + supported = []capOidc.Alg{capOidc.RS256} + } else { + for _, alg := range a.config.JWTSupportedAlgs { + supported = append(supported, capOidc.Alg(alg)) + } + } + // Use CAP's OIDC provider to leverage its built-in support for + // both standard client secret and JWT assertion authentication methods + providerConfig, err := capOidc.NewConfig( + a.config.OIDCDiscoveryURL, + a.config.OIDCClientID, + capOidc.ClientSecret(a.config.OIDCClientSecret), + supported, + a.config.AllowedRedirectURIs, + capOidc.WithAudiences(a.config.BoundAudiences...), + capOidc.WithProviderCA(a.config.OIDCDiscoveryCACert), + ) + if err != nil { + return nil, fmt.Errorf("error creating provider config: %v", err) + } + + provider, err := capOidc.NewProvider(providerConfig) + if err != nil { + return nil, fmt.Errorf("error creating provider: %v", err) + } + a.capProvider = provider + case authOIDCDiscovery: a.httpClient, err = createHTTPClient(a.config.OIDCDiscoveryCACert) if err != nil { return nil, fmt.Errorf("error parsing OIDCDiscoveryCACert: %v", err) diff --git a/internal/go-sso/oidcauth/auth_test.go b/internal/go-sso/oidcauth/auth_test.go new file mode 100644 index 000000000000..5b92683dba2c --- /dev/null +++ b/internal/go-sso/oidcauth/auth_test.go @@ -0,0 +1,166 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: BUSL-1.1 + +package oidcauth + +import ( + "testing" + "time" + + "github.com/hashicorp/consul/internal/go-sso/oidcauth/oidcauthtest" + "github.com/hashicorp/go-hclog" + "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" +) + +func mockConfig(typ string, t *testing.T) *Config { + t.Helper() + + srv := oidcauthtest.Start(t) + srv.SetClientCreds("abc", "def") + cfg := &Config{ + Type: typ, + } + if typ == TypeJWT { + cfg.JWKSURL = srv.Addr() + "/certs" + cfg.JWKSCACert = srv.CACert() + } + if typ == TypeOIDC { + cfg.OIDCDiscoveryURL = srv.Addr() + cfg.OIDCClientID = "abc" + cfg.OIDCClientSecret = "def" + cfg.AllowedRedirectURIs = []string{"https://redirect"} + cfg.OIDCDiscoveryCACert = srv.CACert() + } + return cfg +} + +const testPublicKeyPEM = `-----BEGIN PUBLIC KEY----- +MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEEVs/o5+uQbTjL3chynL4wXgUg2R9 +q9UU8I5mEovUf86QZ7kOBIjJwqnzD1omageEHWwHdBO6B+dFabmdT9POxg== +-----END PUBLIC KEY-----` + +func TestAuthenticator_JWTGroup(t *testing.T) { + t.Run("JWTType static keys", func(t *testing.T) { + cfg := mockConfig(TypeJWT, t) + cfg.JWTValidationPubKeys = []string{testPublicKeyPEM} + cfg.JWKSURL = "" + cfg.JWKSCACert = "" + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth) + assert.Equal(t, cfg, auth.config) + assert.NotEmpty(t, auth.parsedJWTPubKeys) + }) + + t.Run("JWTType JWKS", func(t *testing.T) { + cfg := mockConfig(TypeJWT, t) + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth) + assert.Equal(t, cfg, auth.config) + }) + + t.Run("JWTType failure", func(t *testing.T) { + cfg := mockConfig(TypeJWT, t) + cfg.OIDCClientID = "abc" + logger := hclog.NewNullLogger() + _, err := New(cfg, logger) + assert.Error(t, err) + requireErrorContains(t, err, "'OIDCClientID' must not be set for type") + }) + + t.Run("Stop", func(t *testing.T) { + cfg := mockConfig(TypeJWT, t) + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth.backgroundCtxCancel) + auth.Stop() + assert.Nil(t, auth.backgroundCtxCancel) + }) + + t.Run("BackgroundContextCancel", func(t *testing.T) { + cfg := mockConfig(TypeJWT, t) + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + done := make(chan struct{}) + go func() { + <-auth.backgroundCtx.Done() + close(done) + }() + auth.Stop() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("backgroundCtx was not cancelled") + } + }) +} + +func TestAuthenticator_OIDCGroup(t *testing.T) { + t.Run("OIDCType", func(t *testing.T) { + cfg := mockConfig(TypeOIDC, t) + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth.capProvider) + assert.NotNil(t, auth.oidcStates) + }) + + t.Run("OIDCDiscovery", func(t *testing.T) { + srv := oidcauthtest.Start(t) + srv.SetClientCreds("abc", "def") + cfg := mockConfig(TypeJWT, t) + cfg.JWKSURL = "" + cfg.JWKSCACert = "" + cfg.OIDCDiscoveryURL = srv.Addr() + cfg.OIDCDiscoveryCACert = srv.CACert() + + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth) + assert.NotNil(t, auth.provider) + assert.NotNil(t, auth.httpClient) + }) + + t.Run("OIDCStatesCache", func(t *testing.T) { + cfg := mockConfig(TypeOIDC, t) + logger := hclog.NewNullLogger() + auth, err := New(cfg, logger) + assert.NoError(t, err) + assert.NotNil(t, auth.oidcStates) + auth.oidcStates.Set("state", "value", cache.DefaultExpiration) + val, found := auth.oidcStates.Get("state") + assert.True(t, found) + assert.Equal(t, "value", val) + }) +} + +func TestAuthenticator_OIDCFlow_Failure(t *testing.T) { + t.Run("InvalidCACert", func(t *testing.T) { + cfg := mockConfig(TypeOIDC, t) + cfg.OIDCDiscoveryCACert = "invalid cert data" + + logger := hclog.NewNullLogger() + _, err := New(cfg, logger) + + assert.Error(t, err) + requireErrorContains(t, err, "could not parse CA PEM value successfully") + }) + + t.Run("ProviderConfig_error", func(t *testing.T) { + cfg := mockConfig(TypeOIDC, t) + cfg.OIDCDiscoveryURL = "::invalid-url::" + + logger := hclog.NewNullLogger() + _, err := New(cfg, logger) + + assert.Error(t, err) + requireErrorContains(t, err, "error checking OIDCDiscoveryURL") + }) +} diff --git a/internal/go-sso/oidcauth/config.go b/internal/go-sso/oidcauth/config.go index 9f95363b352e..1e53d30a5cd4 100644 --- a/internal/go-sso/oidcauth/config.go +++ b/internal/go-sso/oidcauth/config.go @@ -87,6 +87,13 @@ type Config struct { // Valid only if Type=oidc OIDCClientSecret string + // Optionally send a signed JWT ("private key jwt") as a client assertion + // for client authentication. This enables enhanced + // security by using asymmetric cryptography instead of shared secrets. + OIDCClientAssertion *OIDCClientAssertion + // Disable S256 PKCE challenge verification + OIDCClientUsePKCE *bool + // Comma-separated list of OIDC scopes // // Valid only if Type=oidc @@ -166,6 +173,30 @@ type Config struct { ClockSkewLeeway time.Duration } +// OIDCClientAssertion configures private key JWT client authentication +// for enhanced security in OIDC flows. This allowing clients to authenticate +// using signed JWTs instead of shared secrets. +// See also: structs.OIDCClientAssertion +type OIDCClientAssertion struct { + // Audience is/are who will be processing the assertion. + // Typically set to the OIDC provider's token endpoint URL. + // Defaults to the parent ACLAuthMethodConfig's OIDCDiscoveryURL + Audience []string + + // PrivateKey contains external key material provided by users. + // KeySource must be "private_key" to enable this. + PrivateKey *OIDCClientAssertionKey + + KeyAlgorithm string +} + +// OIDCClientAssertionKey holds the private key used for signing client assertions +type OIDCClientAssertionKey struct { + // PemKey is the private key, in pem format. It is used to sign the JWT. + // Mutually exclusive with PemKeyFile. + PemKey string +} + // Validate returns an error if the config is not valid. func (c *Config) Validate() error { validateCtx, validateCtxCancel := context.WithCancel(context.Background()) @@ -179,8 +210,10 @@ func (c *Config) Validate() error { return fmt.Errorf("'OIDCDiscoveryURL' must be set for type %q", c.Type) case c.OIDCClientID == "": return fmt.Errorf("'OIDCClientID' must be set for type %q", c.Type) - case c.OIDCClientSecret == "": - return fmt.Errorf("'OIDCClientSecret' must be set for type %q", c.Type) + case c.OIDCClientSecret == "" && c.OIDCClientAssertion == nil: + return fmt.Errorf("'OIDCClientSecret' or 'OIDCClientAssertion' must be set for type %q", c.Type) + case c.OIDCClientAssertion != nil && c.OIDCClientAssertion.PrivateKey == nil: + return fmt.Errorf("'OIDCClientAssertion.PrivateKey' must be set when 'OIDCClientAssertion' is set for type %q", c.Type) case len(c.AllowedRedirectURIs) == 0: return fmt.Errorf("'AllowedRedirectURIs' must be set for type %q", c.Type) } @@ -189,6 +222,8 @@ func (c *Config) Validate() error { switch { case c.JWKSURL != "": return fmt.Errorf("'JWKSURL' must not be set for type %q", c.Type) + case c.OIDCClientSecret != "" && c.OIDCClientAssertion != nil: + return fmt.Errorf("only one of 'OIDCClientSecret' or 'OIDCClientAssertion' can be set for type %q", c.Type) case c.JWKSCACert != "": return fmt.Errorf("'JWKSCACert' must not be set for type %q", c.Type) case len(c.JWTValidationPubKeys) != 0: @@ -213,6 +248,14 @@ func (c *Config) Validate() error { return fmt.Errorf("Invalid AllowedRedirectURIs provided: %v", bad) } + if c.OIDCClientAssertion != nil { + // Validate KeyAlgorithm if set + if c.OIDCClientAssertion.KeyAlgorithm != "" && + c.OIDCClientAssertion.KeyAlgorithm != "RS256" { + return fmt.Errorf("'OIDCClientAssertion.KeyAlgorithm' must be 'RS256' currently") + } + } + case TypeJWT: // not allowed switch { @@ -220,6 +263,8 @@ func (c *Config) Validate() error { return fmt.Errorf("'OIDCClientID' must not be set for type %q", c.Type) case c.OIDCClientSecret != "": return fmt.Errorf("'OIDCClientSecret' must not be set for type %q", c.Type) + case c.OIDCClientAssertion != nil: + return fmt.Errorf("'OIDCClientAssertion' must not be set for type %q", c.Type) case len(c.OIDCScopes) != 0: return fmt.Errorf("'OIDCScopes' must not be set for type %q", c.Type) case len(c.OIDCACRValues) != 0: @@ -347,6 +392,8 @@ func (c *Config) authType() int { case c.OIDCDiscoveryURL != "": if c.OIDCClientID != "" && c.OIDCClientSecret != "" { return authOIDCFlow + } else if c.OIDCClientID != "" && c.OIDCClientAssertion != nil { + return authOIDCFlow } return authOIDCDiscovery default: @@ -354,4 +401,4 @@ func (c *Config) authType() int { } } -const testJWT = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk" +const testJWT = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.Hf3E3iCHzqC5QIQ0nCqS1kw78IiQTRVzsLTuKoDIpdk" diff --git a/internal/go-sso/oidcauth/config_test.go b/internal/go-sso/oidcauth/config_test.go index 2766c6c7bfdc..a65095e60f87 100644 --- a/internal/go-sso/oidcauth/config_test.go +++ b/internal/go-sso/oidcauth/config_test.go @@ -56,7 +56,20 @@ func TestConfigValidate(t *testing.T) { }, expectErr: "must be set for type", }, - "missing required OIDCClientSecret": { + "missing required OIDCClientAssertion.PrivateKey": { + config: Config{ + Type: TypeOIDC, + OIDCDiscoveryURL: srv.Addr(), + OIDCDiscoveryCACert: srv.CACert(), + OIDCClientID: "abc", + OIDCClientAssertion: &OIDCClientAssertion{ + Audience: []string{srv.Addr()}, + }, + AllowedRedirectURIs: []string{"http://foo.test"}, + }, + expectErr: "OIDCClientAssertion.PrivateKey' must be set when 'OIDCClientAssertion' is set for type", + }, + "missing required OIDCClientSecret or OIDCClientAssertion": { config: Config{ Type: TypeOIDC, OIDCDiscoveryURL: srv.Addr(), @@ -65,7 +78,7 @@ func TestConfigValidate(t *testing.T) { // OIDCClientSecret: "def", AllowedRedirectURIs: []string{"http://foo.test"}, }, - expectErr: "must be set for type", + expectErr: "OIDCClientSecret' or 'OIDCClientAssertion' must be set for type", }, "missing required AllowedRedirectURIs": { config: Config{ @@ -78,6 +91,36 @@ func TestConfigValidate(t *testing.T) { }, expectErr: "must be set for type", }, + "incompatible with OIDCClientSecret and OIDCClientAssertion": { + config: Config{ + Type: TypeOIDC, + OIDCDiscoveryURL: srv.Addr(), + OIDCDiscoveryCACert: srv.CACert(), + OIDCClientID: "abc", + OIDCClientSecret: "def", + OIDCClientAssertion: &OIDCClientAssertion{ + PrivateKey: &OIDCClientAssertionKey{PemKey: testRSAPrivateKey}, + Audience: []string{srv.Addr()}, + }, + AllowedRedirectURIs: []string{"http://foo.test"}, + }, + expectErr: "only one of 'OIDCClientSecret' or 'OIDCClientAssertion", + }, + "incompatible key algorithm": { + config: Config{ + Type: TypeOIDC, + OIDCDiscoveryURL: srv.Addr(), + OIDCDiscoveryCACert: srv.CACert(), + OIDCClientID: "abc", + OIDCClientAssertion: &OIDCClientAssertion{ + PrivateKey: &OIDCClientAssertionKey{PemKey: testRSAPrivateKey}, + Audience: []string{srv.Addr()}, + KeyAlgorithm: "foo", + }, + AllowedRedirectURIs: []string{"http://foo.test"}, + }, + expectErr: "OIDCClientAssertion.KeyAlgorithm' must be 'RS256' currently", + }, "incompatible with JWKSURL": { config: Config{ Type: TypeOIDC, @@ -366,6 +409,16 @@ func TestConfigValidate(t *testing.T) { }, expectErr: "must not be set for type", }, + "incompatible with OIDCClientAssertion": { + config: Config{ + Type: TypeJWT, + JWTValidationPubKeys: []string{testJWTPubKey}, + OIDCClientAssertion: &OIDCClientAssertion{ + Audience: []string{srv.Addr()}, + }, + }, + expectErr: "must not be set for type", + }, "incompatible with OIDCScopes": { config: Config{ Type: TypeJWT, diff --git a/internal/go-sso/oidcauth/jwt.go b/internal/go-sso/oidcauth/jwt.go index b71fdb2b31bc..84165ba7086a 100644 --- a/internal/go-sso/oidcauth/jwt.go +++ b/internal/go-sso/oidcauth/jwt.go @@ -15,6 +15,7 @@ import ( "fmt" "time" + "github.com/coreos/go-oidc/v3/oidc" "github.com/go-jose/go-jose/v3/jwt" ) @@ -208,3 +209,37 @@ func parsePublicKeyPEM(data []byte) (interface{}, error) { return nil, errors.New("data does not contain any valid RSA, ECDSA, or ED25519 public keys") } + +func (a *Authenticator) verifyOIDCToken(ctx context.Context, rawToken string) (map[string]any, error) { + allClaims := make(map[string]any) + + oidcConfig := &oidc.Config{ + SupportedSigningAlgs: a.config.JWTSupportedAlgs, + } + switch a.config.authType() { + case authOIDCFlow: + oidcConfig.ClientID = a.config.OIDCClientID + case authOIDCDiscovery: + oidcConfig.SkipClientIDCheck = true + default: + return nil, fmt.Errorf("unsupported auth type for this verifyOIDCToken: %d", a.config.authType()) + } + + verifier := a.provider.Verifier(oidcConfig) + + idToken, err := verifier.Verify(ctx, rawToken) + if err != nil { + return nil, fmt.Errorf("error validating signature: %v", err) + } + + if err := idToken.Claims(&allClaims); err != nil { + return nil, fmt.Errorf("unable to successfully parse all claims from token: %v", err) + } + // Follows behavior of hashicorp/vault-plugin-auth-jwt (non-strict validation). + // See https://developer.hashicorp.com/consul/docs/security/acl/auth-methods/oidc#oidc-configuration-troubleshooting. + if err := validateAudience(a.config.BoundAudiences, idToken.Audience, false); err != nil { + return nil, fmt.Errorf("error validating claims: %v", err) + } + + return allClaims, nil +} diff --git a/internal/go-sso/oidcauth/oidc.go b/internal/go-sso/oidcauth/oidc.go index 80b5131bd726..96a79ae0098c 100644 --- a/internal/go-sso/oidcauth/oidc.go +++ b/internal/go-sso/oidcauth/oidc.go @@ -11,9 +11,11 @@ import ( "strings" "time" - "github.com/coreos/go-oidc/v3/oidc" + "github.com/hashicorp/cap/oidc" + cass "github.com/hashicorp/cap/oidc/clientassertion" + + "github.com/golang-jwt/jwt/v5" "github.com/hashicorp/go-uuid" - "golang.org/x/oauth2" ) var ( @@ -39,31 +41,24 @@ func (a *Authenticator) GetAuthCodeURL(ctx context.Context, redirectURI string, return "", fmt.Errorf("unauthorized redirect_uri: %s", redirectURI) } - // "openid" is a required scope for OpenID Connect flows - scopes := append([]string{oidc.ScopeOpenID}, a.config.OIDCScopes...) + // Use HashiCorp CAP provider which supports advanced OIDC features + // including private key JWT client authentication configured during initialization + provider := a.capProvider + payload := statePayload - // Configure an OpenID Connect aware OAuth2 client - oauth2Config := oauth2.Config{ - ClientID: a.config.OIDCClientID, - ClientSecret: a.config.OIDCClientSecret, - RedirectURL: redirectURI, - Endpoint: a.provider.Endpoint(), - Scopes: scopes, + // Generate a secure state and nonce for the OIDC request + // The request object is stored for later use during token exchange + request, error := a.createOIDCState(redirectURI, payload) + if error != nil { + return "", fmt.Errorf("error generating OAuth state: %v", error) } - stateID, nonce, err := a.createOIDCState(redirectURI, statePayload) + authURL, err := provider.AuthURL(ctx, request) if err != nil { - return "", fmt.Errorf("error generating OAuth state: %v", err) + return "", fmt.Errorf("error while generating AuthURL %q", err) } - authCodeOpts := []oauth2.AuthCodeOption{ - oidc.Nonce(nonce), - } - if len(a.config.OIDCACRValues) > 0 { - authCodeOpts = append(authCodeOpts, oauth2.SetAuthURLParam("acr_values", strings.Join(a.config.OIDCACRValues, " "))) - } - - return oauth2Config.AuthCodeURL(stateID, authCodeOpts...), nil + return authURL, nil } // ClaimsFromAuthCode is the second part of the OIDC authorization code @@ -94,38 +89,39 @@ func (a *Authenticator) ClaimsFromAuthCode(ctx context.Context, stateParam, code } } - oidcCtx := contextWithHttpClient(ctx, a.httpClient) - - var oauth2Config = oauth2.Config{ - ClientID: a.config.OIDCClientID, - ClientSecret: a.config.OIDCClientSecret, - RedirectURL: state.redirectURI, - Endpoint: a.provider.Endpoint(), - Scopes: []string{oidc.ScopeOpenID}, + // Use the stored request object from the initial authorization request + if state.request == nil { + a.logger.Error("Request object not found in state", "stateParam", stateParam) + return nil, nil, &ProviderLoginFailedError{ + Err: fmt.Errorf("missing request object in OAuth state"), + } } - oauth2Token, err := oauth2Config.Exchange(oidcCtx, code) + // Use HashiCorp CAP provider for token exchange + // This provider supports private key JWT client authentication if configured + provider := a.capProvider + + tokens, err := provider.Exchange(ctx, state.request, stateParam, code) if err != nil { return nil, nil, &ProviderLoginFailedError{ Err: fmt.Errorf("Error exchanging oidc code: %w", err), } } - // Extract the ID Token from OAuth2 token. - rawToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { + if !tokens.Valid() { return nil, nil, &TokenVerificationFailedError{ - Err: errors.New("No id_token found in response."), + Err: err, } } + idToken := tokens.IDToken() + if a.config.VerboseOIDCLogging && a.logger != nil { - a.logger.Debug("OIDC provider response", "ID token", rawToken) + a.logger.Debug("OIDC provider response", "ID token", idToken) } - // Parse and verify ID Token payload. - allClaims, err := a.verifyOIDCToken(ctx, rawToken) // TODO(sso): should this use oidcCtx? - if err != nil { + var allClaims map[string]any + if err := idToken.Claims(&allClaims); err != nil { return nil, nil, &TokenVerificationFailedError{ Err: err, } @@ -141,15 +137,15 @@ func (a *Authenticator) ClaimsFromAuthCode(ctx context.Context, stateParam, code // Attempt to fetch information from the /userinfo endpoint and merge it with // the existing claims data. A failure to fetch additional information from this // endpoint will not invalidate the authorization flow. - if userinfo, err := a.provider.UserInfo(oidcCtx, oauth2.StaticTokenSource(oauth2Token)); err == nil { - _ = userinfo.Claims(&allClaims) - } else { - if a.logger != nil { - logFunc := a.logger.Warn - if strings.Contains(err.Error(), "user info endpoint is not supported") { - logFunc = a.logger.Info + if tokenSource := tokens.StaticTokenSource(); tokenSource != nil { + if err := provider.UserInfo(ctx, tokenSource, allClaims["sub"].(string), &allClaims); err != nil { + if a.logger != nil { + logFunc := a.logger.Warn + if strings.Contains(err.Error(), "user info endpoint is not supported") { + logFunc = a.logger.Info + } + logFunc("error reading /userinfo endpoint", "error", err) } - logFunc("error reading /userinfo endpoint", "error", err) } } @@ -210,40 +206,6 @@ func (e *TokenVerificationFailedError) Error() string { func (e *TokenVerificationFailedError) Unwrap() error { return e.Err } -func (a *Authenticator) verifyOIDCToken(ctx context.Context, rawToken string) (map[string]interface{}, error) { - allClaims := make(map[string]interface{}) - - oidcConfig := &oidc.Config{ - SupportedSigningAlgs: a.config.JWTSupportedAlgs, - } - switch a.config.authType() { - case authOIDCFlow: - oidcConfig.ClientID = a.config.OIDCClientID - case authOIDCDiscovery: - oidcConfig.SkipClientIDCheck = true - default: - return nil, fmt.Errorf("unsupported auth type for this verifyOIDCToken: %d", a.config.authType()) - } - - verifier := a.provider.Verifier(oidcConfig) - - idToken, err := verifier.Verify(ctx, rawToken) - if err != nil { - return nil, fmt.Errorf("error validating signature: %v", err) - } - - if err := idToken.Claims(&allClaims); err != nil { - return nil, fmt.Errorf("unable to successfully parse all claims from token: %v", err) - } - // Follows behavior of hashicorp/vault-plugin-auth-jwt (non-strict validation). - // See https://developer.hashicorp.com/consul/docs/security/acl/auth-methods/oidc#oidc-configuration-troubleshooting. - if err := validateAudience(a.config.BoundAudiences, idToken.Audience, false); err != nil { - return nil, fmt.Errorf("error validating claims: %v", err) - } - - return allClaims, nil -} - // verifyOIDCState tests whether the provided state ID is valid and returns the // associated state object if so. A nil state is returned if the ID is not found // or expired. The state should only ever be retrieved once and is deleted as @@ -261,23 +223,31 @@ func (a *Authenticator) verifyOIDCState(stateID string) *oidcState { // createOIDCState make an expiring state object, associated with a random state ID // that is passed throughout the OAuth process. A nonce is also included in the // auth process, and for simplicity will be identical in length/format as the state ID. -func (a *Authenticator) createOIDCState(redirectURI string, payload interface{}) (string, string, error) { +func (a *Authenticator) createOIDCState(redirectURI string, payload interface{}) (*oidc.Req, error) { // Get enough bytes for 2 160-bit IDs (per rfc6749#section-10.10) bytes, err := uuid.GenerateRandomBytes(2 * 20) if err != nil { - return "", "", err + return nil, err } stateID := fmt.Sprintf("%x", bytes[:20]) nonce := fmt.Sprintf("%x", bytes[20:]) + // Create OIDC request object using CAP library + // This request will be reused during token exchange + request, error := a.oidcRequest(nonce, redirectURI, stateID) + if error != nil { + return nil, fmt.Errorf("error while creating oidc req %w", error) + } + a.oidcStates.SetDefault(stateID, &oidcState{ nonce: nonce, redirectURI: redirectURI, payload: payload, + request: request, }) - return stateID, nonce, nil + return request, nil } // oidcState is created when an authURL is requested. The state @@ -286,4 +256,74 @@ type oidcState struct { nonce string redirectURI string payload interface{} + request *oidc.Req // Store the request object for later use in exchange +} + +// oidcRequest builds the request to send to the HashiCorp CAP library. +// This method configures all necessary OIDC parameters including scopes, +// audiences, and security parameters like state and nonce. +func (a *Authenticator) oidcRequest(nonce, redirect string, stateID string) (*oidc.Req, error) { + opts := []oidc.Option{ + oidc.WithNonce(nonce), + oidc.WithState(stateID), + } + + if len(a.config.OIDCScopes) > 0 { + scopes := append([]string{"openid"}, a.config.OIDCScopes...) + opts = append(opts, oidc.WithScopes(scopes...)) + } + if len(a.config.BoundAudiences) > 0 { + opts = append(opts, oidc.WithAudiences(a.config.BoundAudiences...)) + } + if len(a.config.OIDCACRValues) > 0 { + acrValues := strings.Join(a.config.OIDCACRValues, " ") + opts = append(opts, oidc.WithACRValues(acrValues)) + } + + if a.config.OIDCClientUsePKCE == nil || *a.config.OIDCClientUsePKCE { + verifier, err := oidc.NewCodeVerifier() + if err != nil { + return nil, fmt.Errorf("failed to make pkce verifier: %w", err) + } + opts = append(opts, oidc.WithPKCE(verifier)) + } + + if a.config.OIDCClientAssertion != nil { + rsaKey, parseErr := jwt.ParseRSAPrivateKeyFromPEM([]byte(a.config.OIDCClientAssertion.PrivateKey.PemKey)) + if parseErr != nil { + return nil, fmt.Errorf("failed to parse RSA private key: %w", parseErr) + } + + // Create a JWT with the token endpoint as the audience + var audience []string + if len(a.config.OIDCClientAssertion.Audience) > 0 { + audience = a.config.OIDCClientAssertion.Audience + } else { + audience = []string{a.config.OIDCDiscoveryURL} + } + + var alg cass.RSAlgorithm + switch a.config.OIDCClientAssertion.KeyAlgorithm { + case "RS256", "": // Default to RS256 if empty + alg = cass.RS256 + default: + return nil, fmt.Errorf("unsupported key algorithm: %s", a.config.OIDCClientAssertion.KeyAlgorithm) + } + j, err := cass.NewJWTWithRSAKey(a.config.OIDCClientID, audience, alg, rsaKey) + if err != nil { + return nil, err + } + opts = append(opts, oidc.WithClientAssertionJWT(j)) + } + + req, err := oidc.NewRequest( + oidcStateTimeout, + redirect, + opts..., + ) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC request: %w", err) + } + + return req, nil } diff --git a/internal/go-sso/oidcauth/oidc_test.go b/internal/go-sso/oidcauth/oidc_test.go index ac61755e6546..408557e06d9c 100644 --- a/internal/go-sso/oidcauth/oidc_test.go +++ b/internal/go-sso/oidcauth/oidc_test.go @@ -18,6 +18,9 @@ import ( "github.com/stretchr/testify/require" ) +const testRSAPrivateKey = "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDVMMi3HiDYhYmD\nbRi1MmacojGKP5HZMp4whUwp0oI+M0hpu2zQGv+p/vxUrsQQ6Esgp+sYA8aPRDky\ndMNLR+f0gjkAT7KglCIB6M4JfoLHaKrwCUPngQYWqdeVhl5abmRms/+gkZhqkkXr\nc3ax9yOCoyWMhaJjZeFyaeKt+DFDBB/VE8xNB5pVfCPgDJx5lmFwRtvzue65HLE6\nJHq8mA+y7k8qlH4H/yj5c1sZhJVUVxA/ixlDcYVI2vuPyoUAQsgUr4ZwlRVbbXOI\nUTMvnfy7OnScW0FxVNc66+7tp/qeTLBdkdMLa68hym/WbUnNqFbq7woraBoi+b+M\nNp1Uz5+7AgMBAAECggEATttsKvvcd2qxqmkEziVV+lMuUu5fswD7rYPo38Frhrlu\nbBm1Tqbl8coNKP+6K2zZOTuThL8Ex8KbC5RQFr0CyhkPH5PbRXV1vNIRwEZI9py7\nOe2bbfr2NxTc1wSsSvPxdGHZSNoCEE2JymVbvsllG7HgNkHKBs1NHoaXH/WhtyEX\nFoi2zEAl1xP3nrO6iJ/1Zjz0AHj+Ut0IL2abbT4ktQ4gkoSRjh7QMnBkQ4X5pyaM\nnQ1xhCMw8ryaV7zzCk5TuHiY2on1mp3F9TTq/lnyy712tY9g55IhX1vFu2iQ8Cv7\nfvOwZNvaxFJQVrs+kP3GZISEb34OvrKPycAN6lBEtQKBgQDrAXbFVW8TmPc38MjU\n/qEBfvzjzyUz9dGwuPK2y4ht1RdqYjT6n09FHTMUEcz+QxCTleAoZbI4TibAIWqG\nWu7HhjwEF0tIyXiEoUEWVhmbsPwbBc0yYFTJ3EhsyzLHwJ+tC0CSK5WTGygXpj6M\n1ZcpPjiiHDVpx/UQGtwKqvAk/wKBgQDoPGlgyKUTjpa6yj3Y2TMZnmWI4nJOBh7o\nEDX5vOhj7tfqrINllD6t4NFJFcVA30UK7RhmE0PkFAxnx/9+/+E0fUbRxFCNDGv5\nfVa6XaTqAwBsniGObkDjbzeNvRMloD3UzxeFdVkRVObXxJj7tLQ3ZymkYABBU3g7\nbEPt/cdZRQKBgEbKtBqRt9oxdBdX40e2RI4M0OVXGx/h5v7TV9oUyc48KMeVOdxd\nbSWmvCJJknTtgurSdSn2KI+piybJait67P8RwraAxd7xQerCILc3zJMH54nEX6HT\nPvdn8jFDrNJbhj48a4Ecu/wKbDNjkugd12FHKww6bySkZYAqdyqHf7vFAoGBAIXb\n5GWL4VKPeqP51II8V27p1N58n6QHdSMPzPzA/TY0wjGa9DXFqAczMY6txL+qsbIl\njU2wxw4c3DWpmsQKGzXVC8/3FvLl+QqaSzYqqdbUmhcBYpglRrORNHU3SWUDowAZ\nyhX72LXbuR8fS4qx0rqodOExEJSW1xNxSQpRn+j9AoGAXk6U9md5J3iRHtAPRIJI\nuWNm0rkBJnxBmxWVBlqghP5kXS6RGj72BlJjjT2nrbJeXjvHvBf31LHG+RrLuJUr\nl+P1QU5tkz6x487/yss7ZDkWgLoBuJmuVaTr8yQ548NJ48fuQEGRmQTtk/hXRqRp\n4keWXLkzwcqw6VF5MjfHaWA=\n-----END PRIVATE KEY-----\n" +const badRSAPrivateKey = "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgk\n-----END PRIVATE KEY-----\n" + func setupForOIDC(t *testing.T) (*Authenticator, *oidcauthtest.Server) { t.Helper() @@ -118,6 +121,117 @@ func TestOIDC_AuthURL(t *testing.T) { ) requireErrorContains(t, err, "missing redirect_uri") }) + + t.Run("custom scopes and PKCE enabled/disabled", func(t *testing.T) { + oa, _ := setupForOIDC(t) + oa.config.OIDCScopes = []string{"profile", "email"} + + authURL, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + map[string]string{"foo": "bar"}, + ) + require.NoError(t, err) + + parsedURL, err := url.Parse(authURL) + require.NoError(t, err) + + // Extract query parameters + params := parsedURL.Query() + require.Contains(t, authURL, "scope=openid+profile+email") + require.Contains(t, params, "code_challenge") + require.NotEmpty(t, params.Get("code_challenge")) + require.Equal(t, "S256", params.Get("code_challenge_method")) + + oa.config.OIDCClientUsePKCE = new(bool) + *oa.config.OIDCClientUsePKCE = false // PKCE disabled + authURL2, _ := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + map[string]string{"foo": "bar"}, + ) + parsedURL, err = url.Parse(authURL2) + require.NoError(t, err) + + // Extract query parameters + params = parsedURL.Query() + require.NoError(t, err) + require.Contains(t, authURL2, "scope=openid+profile+email") + require.NotContains(t, params, "code_challenge") + require.Empty(t, params.Get("code_challenge")) + }) + + t.Run("oidc client assertion (private key JWT)", func(t *testing.T) { + oa, _ := setupForOIDC(t) + oa.config.OIDCClientAssertion = &OIDCClientAssertion{ + PrivateKey: &OIDCClientAssertionKey{PemKey: testRSAPrivateKey}, + Audience: []string{oa.config.OIDCDiscoveryURL}, + KeyAlgorithm: "RS256", + } + authURL, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + map[string]string{"foo": "bar"}, + ) + require.NoError(t, err) + require.True(t, strings.HasPrefix(authURL, oa.config.OIDCDiscoveryURL+"/auth?")) + + expected := map[string]string{ + "client_id": "abc", + "redirect_uri": "https://example.com", + "response_type": "code", + "scope": "openid", + // optional values + "acr_values": "acr1 acr2", + } + + au, err := url.Parse(authURL) + require.NoError(t, err) + params := au.Query() + + for k, v := range expected { + assert.Equal(t, v, au.Query().Get(k), "key %q is incorrect", k) + } + + assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("nonce")) + assert.Regexp(t, `^[a-z0-9]{40}$`, au.Query().Get("state")) + require.Contains(t, params, "code_challenge") + require.NotEmpty(t, params.Get("code_challenge")) + require.Equal(t, "S256", params.Get("code_challenge_method")) + }) + + t.Run("oidc client assertion invalid pemkey", func(t *testing.T) { + oa, _ := setupForOIDC(t) + oa.config.OIDCClientAssertion = &OIDCClientAssertion{ + PrivateKey: &OIDCClientAssertionKey{PemKey: badRSAPrivateKey}, + Audience: []string{oa.config.OIDCDiscoveryURL}, + KeyAlgorithm: "RS256", + } + _, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + map[string]string{"foo": "bar"}, + ) + requireErrorContains(t, err, "failed to parse RSA private key") + }) + + t.Run("unsupported key algorithm", func(t *testing.T) { + oa, _ := setupForOIDC(t) + + oa.config.OIDCClientAssertion = &OIDCClientAssertion{ + PrivateKey: &OIDCClientAssertionKey{PemKey: testRSAPrivateKey}, + Audience: []string{oa.config.OIDCDiscoveryURL}, + KeyAlgorithm: "ES256", + } + origPayload := map[string]string{"foo": "bar"} + _, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + origPayload, + ) + requireErrorContains(t, err, "unsupported key algorithm") + + }) } func TestOIDC_JWT_Functions_Fail(t *testing.T) { @@ -205,6 +319,73 @@ func TestOIDC_ClaimsFromAuthCode(t *testing.T) { require.Equal(t, expectedClaims, claims) }) + t.Run("multiple and nested claim mappings", func(t *testing.T) { + oa, srv := setupForOIDC(t) + + origPayload := map[string]string{"foo": "bar"} + authURL, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + origPayload, + ) + require.NoError(t, err) + + state := getQueryParam(t, authURL, "state") + nonce := getQueryParam(t, authURL, "nonce") + + // set provider claims that will be returned by the mock server + srv.SetCustomClaims(sampleClaims(nonce)) + + // set mock provider's expected code + srv.SetExpectedAuthCode("abc") + + oa.config.ClaimMappings = map[string]string{ + "email": "user_email", + "/nested/Size": "user_size", + } + oa.config.ListClaimMappings = map[string]string{ + "/nested/Groups": "groups", + } + + srv.SetExpectedAuthCode("abc") + + // Now use mockState in your test + resultClaims, _, err := oa.ClaimsFromAuthCode( + context.Background(), + state, "abc", + ) + require.NoError(t, err) + require.Equal(t, "bob@example.com", resultClaims.Values["user_email"]) + require.Equal(t, "medium", resultClaims.Values["user_size"]) + require.ElementsMatch(t, []string{"a", "b"}, resultClaims.Lists["groups"]) + }) + + t.Run("State not found", func(t *testing.T) { + oa, srv := setupForOIDC(t) + nonce := "test-nonce" + srv.SetCustomClaims(sampleClaims(nonce)) + srv.SetExpectedAuthCode("abc") + + _, _, err := oa.ClaimsFromAuthCode( + context.Background(), + "state", "abc", + ) + requireErrorContains(t, err, "Expired or missing OAuth state") + }) + + t.Run("multiple audiences", func(t *testing.T) { + oa, _ := setupForOIDC(t) + oa.config.BoundAudiences = []string{"abc", "def"} + authURL, err := oa.GetAuthCodeURL( + context.Background(), + "https://example.com", + map[string]string{"foo": "bar"}, + ) + require.NoError(t, err) + require.Contains(t, authURL, "client_id=abc") + // Optionally: test with a token that matches one of the audiences + }) + t.Run("failed login unusable claims", func(t *testing.T) { oa, srv := setupForOIDC(t) @@ -307,8 +488,8 @@ func TestOIDC_ClaimsFromAuthCode(t *testing.T) { context.Background(), state, "abc", ) - requireErrorContains(t, err, "Invalid ID token nonce") - requireTokenVerificationError(t, err) + requireErrorContains(t, err, "invalid id_token nonce: invalid nonce") + requireProviderError(t, err) }) t.Run("missing state", func(t *testing.T) { @@ -420,8 +601,8 @@ func TestOIDC_ClaimsFromAuthCode(t *testing.T) { context.Background(), state, "abc", ) - requireErrorContains(t, err, "No id_token found in response") - requireTokenVerificationError(t, err) + requireErrorContains(t, err, "id_token is missing from auth code exchange") + requireProviderError(t, err) }) t.Run("no response from provider", func(t *testing.T) { @@ -475,8 +656,8 @@ func TestOIDC_ClaimsFromAuthCode(t *testing.T) { context.Background(), state, "abc", ) - requireErrorContains(t, err, `error validating signature: oidc: expected audience "abc" got ["not_gonna_match"]`) - requireTokenVerificationError(t, err) + requireErrorContains(t, err, `invalid id_token audiences`) + requireProviderError(t, err) }) } @@ -502,13 +683,15 @@ func sampleClaims(nonce string) map[string]interface{} { func getQueryParam(t *testing.T, inputURL, param string) string { t.Helper() - m, err := url.ParseQuery(inputURL) + // Replace this function with one that properly parses full URLs + u, err := url.Parse(inputURL) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to parse URL %q: %v", inputURL, err) } - v, ok := m[param] - if !ok { - t.Fatalf("query param %q not found", param) + + v := u.Query().Get(param) + if v == "" { + t.Fatalf("Query param %q not found in URL", param) } - return v[0] + return v } diff --git a/internal/go-sso/oidcauth/oidcauthtest/testing.go b/internal/go-sso/oidcauth/oidcauthtest/testing.go index 3ba4befe0989..0691097dca2a 100644 --- a/internal/go-sso/oidcauth/oidcauthtest/testing.go +++ b/internal/go-sso/oidcauth/oidcauthtest/testing.go @@ -73,7 +73,8 @@ func Start(t TestingT) *Server { "https://example.com", }, replySubject: "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", - replyUserinfo: map[string]interface{}{ + replyUserinfo: map[string]any{ + "sub": "r3qXcK2bix9eFECzsU3Sbmh0K16fatW6@clients", "color": "red", "temperature": "76", "flavor": "umami", @@ -115,6 +116,12 @@ func (s *Server) SetExpectedAuthCode(code string) { s.expectedAuthCode = code } +func (s *Server) SetUserInfo(info map[string]any) { + s.mu.Lock() + defer s.mu.Unlock() + s.replyUserinfo = info +} + // SetExpectedAuthNonce configures the nonce value required for /auth. func (s *Server) SetExpectedAuthNonce(nonce string) { s.mu.Lock() diff --git a/ui/packages/consul-ui/app/components/consul/auth-method/view/index.hbs b/ui/packages/consul-ui/app/components/consul/auth-method/view/index.hbs index eab0fa48399c..101c73b201db 100644 --- a/ui/packages/consul-ui/app/components/consul/auth-method/view/index.hbs +++ b/ui/packages/consul-ui/app/components/consul/auth-method/view/index.hbs @@ -190,6 +190,10 @@ as |item|}}
{{t 'models.auth-method.Config.OIDCClientSecret'}}
{{@item.Config.OIDCClientSecret}}
{{/if}} + {{#if @item.Config.OIDCClientAssertion.PrivateKey.PemKey}} +
{{t 'models.auth-method.Config.OIDCClientAssertionPrivateKey'}}
+
{{@item.Config.OIDCClientAssertion.PrivateKey.PemKey}}
+ {{/if}} {{#if @item.Config.AllowedRedirectURIs}}
{{t 'models.auth-method.Config.AllowedRedirectURIs'}}
diff --git a/ui/packages/consul-ui/translations/models/en-us.yaml b/ui/packages/consul-ui/translations/models/en-us.yaml index bab5f63ef819..59bb5cd575dd 100644 --- a/ui/packages/consul-ui/translations/models/en-us.yaml +++ b/ui/packages/consul-ui/translations/models/en-us.yaml @@ -26,6 +26,7 @@ auth-method: OIDCDiscoveryCACert: OIDC discovery CA cert OIDCClientID: Client ID OIDCClientSecret: Client secret + OIDCClientAssertionPrivateKey: Client assertion private key AllowedRedirectURIs: Allowed redirect URIs OIDCScopes: OIDC scopes VerboseOIDCLogging: Verbose OIDC logging