Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ const (
ArgoCDUserAgentName = "argocd-client"
// AuthCookieName is the HTTP cookie name where we store our auth token
AuthCookieName = "argocd.token"
// StateCookieName is the HTTP cookie name that holds temporary nonce tokens for CSRF protection
StateCookieName = "argocd.oauthstate"
// StateCookieMaxAge is the maximum age of the oauth state cookie
StateCookieMaxAge = time.Minute * 5

// ChangePasswordSSOTokenMaxAge is the max token age for password change operation
ChangePasswordSSOTokenMaxAge = time.Minute * 5
Expand Down
17 changes: 0 additions & 17 deletions server/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
cacheutil "github.com/argoproj/argo-cd/v2/util/cache"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/env"
"github.com/argoproj/argo-cd/v2/util/oidc"
)

var ErrCacheMiss = appstatecache.ErrCacheMiss
Expand All @@ -25,8 +24,6 @@ type Cache struct {
loginAttemptsExpiration time.Duration
}

var _ oidc.OIDCStateStorage = &Cache{}

func NewCache(
cache *appstatecache.Cache,
connectionStatusCacheExpiration time.Duration,
Expand Down Expand Up @@ -91,20 +88,6 @@ func (c *Cache) SetClusterInfo(server string, res *appv1.ClusterInfo) error {
return c.cache.SetClusterInfo(server, res)
}

func oidcStateKey(key string) string {
return fmt.Sprintf("oidc|%s", key)
}

func (c *Cache) GetOIDCState(key string) (*oidc.OIDCState, error) {
res := oidc.OIDCState{}
err := c.cache.GetItem(oidcStateKey(key), &res)
return &res, err
}

func (c *Cache) SetOIDCState(key string, state *oidc.OIDCState) error {
return c.cache.SetItem(oidcStateKey(key), state, c.oidcCacheExpiration, state == nil)
}

func (c *Cache) GetCache() *cacheutil.Cache {
return c.cache.Cache
}
18 changes: 0 additions & 18 deletions server/cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
. "github.com/argoproj/argo-cd/v2/pkg/apis/application/v1alpha1"
cacheutil "github.com/argoproj/argo-cd/v2/util/cache"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/oidc"
)

type fixtures struct {
Expand Down Expand Up @@ -46,23 +45,6 @@ func TestCache_GetRepoConnectionState(t *testing.T) {
assert.Equal(t, ConnectionState{Status: "my-state"}, value)
}

func TestCache_GetOIDCState(t *testing.T) {
cache := newFixtures().Cache
// cache miss
_, err := cache.GetOIDCState("my-key")
assert.Equal(t, ErrCacheMiss, err)
// populate cache
err = cache.SetOIDCState("my-key", &oidc.OIDCState{ReturnURL: "my-return-url"})
assert.NoError(t, err)
//cache miss
_, err = cache.GetOIDCState("other-key")
assert.Equal(t, ErrCacheMiss, err)
// cache hit
value, err := cache.GetOIDCState("my-key")
assert.NoError(t, err)
assert.Equal(t, &oidc.OIDCState{ReturnURL: "my-return-url"}, value)
}

func TestAddCacheFlagsToCmd(t *testing.T) {
cache, err := AddCacheFlagsToCmd(&cobra.Command{})()
assert.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ func (a *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
tlsConfig := a.settings.TLSConfig()
tlsConfig.InsecureSkipVerify = true
}
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.Cache, a.DexServerAddr, a.BaseHRef)
a.ssoClientApp, err = oidc.NewClientApp(a.settings, a.DexServerAddr, a.BaseHRef)
errors.CheckError(err)
mux.HandleFunc(common.LoginEndpoint, a.ssoClientApp.HandleLogin)
mux.HandleFunc(common.CallbackEndpoint, a.ssoClientApp.HandleCallback)
Expand Down
49 changes: 49 additions & 0 deletions util/crypto/crypto.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package crypto

import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"io"
)

// Encrypt encrypts the given data with the given passphrase.
func Encrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, data, nil)
return ciphertext, nil
}

// Decrypt decrypts the given data using the given passphrase.
func Decrypt(data []byte, key []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(data) < nonceSize {
return nil, errors.New("data length is less than nonce size")
}
nonce, ciphertext := data[:nonceSize], data[nonceSize:]
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, err
}
return plaintext, nil
}
43 changes: 43 additions & 0 deletions util/crypto/crypto_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package crypto

import (
"crypto/rand"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newKey() ([]byte, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
b = nil
}
return b, err
}

func TestEncryptDecrypt_Successful(t *testing.T) {
key, err := newKey()
require.NoError(t, err)
encrypted, err := Encrypt([]byte("test"), key)
require.NoError(t, err)

decrypted, err := Decrypt(encrypted, key)
require.NoError(t, err)

assert.Equal(t, "test", string(decrypted))
}

func TestEncryptDecrypt_Failed(t *testing.T) {
key, err := newKey()
require.NoError(t, err)
encrypted, err := Encrypt([]byte("test"), key)
require.NoError(t, err)

wrongKey, err := newKey()
require.NoError(t, err)

_, err = Decrypt(encrypted, wrongKey)
assert.Error(t, err)
}
105 changes: 67 additions & 38 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package oidc

import (
"encoding/hex"
"encoding/json"
"fmt"
"html"
Expand All @@ -20,7 +21,7 @@ import (

"github.com/argoproj/argo-cd/v2/common"
"github.com/argoproj/argo-cd/v2/server/settings/oidc"
appstatecache "github.com/argoproj/argo-cd/v2/util/cache/appstate"
"github.com/argoproj/argo-cd/v2/util/crypto"
"github.com/argoproj/argo-cd/v2/util/dex"
httputil "github.com/argoproj/argo-cd/v2/util/http"
"github.com/argoproj/argo-cd/v2/util/rand"
Expand All @@ -45,16 +46,6 @@ type ClaimsRequest struct {
IDToken map[string]*oidc.Claim `json:"id_token"`
}

type OIDCState struct {
// ReturnURL is the URL in which to redirect a user back to after completing an OAuth2 login
ReturnURL string `json:"returnURL"`
}

type OIDCStateStorage interface {
GetOIDCState(key string) (*OIDCState, error)
SetOIDCState(key string, state *OIDCState) error
}

type ClientApp struct {
// OAuth2 client ID of this application (e.g. argo-cd)
clientID string
Expand All @@ -75,9 +66,6 @@ type ClientApp struct {
settings *settings.ArgoCDSettings
// provider is the OIDC provider
provider Provider
// cache holds temporary nonce tokens to which hold application state values
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
cache OIDCStateStorage
}

func GetScopesOrDefault(scopes []string) []string {
Expand All @@ -89,7 +77,7 @@ func GetScopesOrDefault(scopes []string) []string {

// NewClientApp will register the Argo CD client app (either via Dex or external OIDC) and return an
// object which has HTTP handlers for handling the HTTP responses for login and callback
func NewClientApp(settings *settings.ArgoCDSettings, cache OIDCStateStorage, dexServerAddr, baseHRef string) (*ClientApp, error) {
func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr, baseHRef string) (*ClientApp, error) {
redirectURL, err := settings.RedirectURL()
if err != nil {
return nil, err
Expand All @@ -100,7 +88,6 @@ func NewClientApp(settings *settings.ArgoCDSettings, cache OIDCStateStorage, dex
redirectURI: redirectURL,
issuerURL: settings.IssuerURL(),
baseHRef: baseHRef,
cache: cache,
}
log.Infof("Creating client app (%s)", a.clientID)
u, err := url.Parse(settings.URL)
Expand Down Expand Up @@ -149,31 +136,68 @@ func (a *ClientApp) oauth2Config(scopes []string) (*oauth2.Config, error) {
}

// generateAppState creates an app state nonce
func (a *ClientApp) generateAppState(returnURL string) string {
func (a *ClientApp) generateAppState(returnURL string, w http.ResponseWriter) (string, error) {
randStr := rand.RandString(10)
if returnURL == "" {
returnURL = a.baseHRef
}
err := a.cache.SetOIDCState(randStr, &OIDCState{ReturnURL: returnURL})
cookieValue := fmt.Sprintf("%s:%s", randStr, returnURL)
key, err := a.settings.GetServerSignatureKey()
if err != nil {
// This should never happen with the in-memory cache
log.Errorf("Failed to set app state: %v", err)
return "", err
}
return randStr
if encrypted, err := crypto.Encrypt([]byte(cookieValue), key); err != nil {
return "", err
} else {
cookieValue = hex.EncodeToString(encrypted)
}

http.SetCookie(w, &http.Cookie{
Name: common.StateCookieName,
Value: cookieValue,
Expires: time.Now().Add(common.StateCookieMaxAge),
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: a.secureCookie,
})
return randStr, nil
}

func (a *ClientApp) verifyAppState(state string) (*OIDCState, error) {
res, err := a.cache.GetOIDCState(state)
func (a *ClientApp) verifyAppState(r *http.Request, w http.ResponseWriter, state string) (string, error) {
c, err := r.Cookie(common.StateCookieName)
if err != nil {
if err == appstatecache.ErrCacheMiss {
return nil, fmt.Errorf("unknown app state %s", state)
} else {
return nil, fmt.Errorf("failed to verify app state %s: %v", state, err)
}
return "", err
}

_ = a.cache.SetOIDCState(state, nil)
return res, nil
val, err := hex.DecodeString(c.Value)
if err != nil {
return "", err
}
key, err := a.settings.GetServerSignatureKey()
if err != nil {
return "", err
}
val, err = crypto.Decrypt(val, key)
if err != nil {
return "", err
}
cookieVal := string(val)
returnURL := a.baseHRef
parts := strings.SplitN(cookieVal, ":", 2)
if len(parts) == 2 && parts[1] != "" {
returnURL = parts[1]
}
if parts[0] != state {
return "", fmt.Errorf("invalid state in '%s' cookie", common.AuthCookieName)
}
// set empty cookie to clear it
http.SetCookie(w, &http.Cookie{
Name: common.StateCookieName,
Value: "",
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
Secure: a.secureCookie,
})
return returnURL, nil
}

// isValidRedirectURL checks whether the given redirectURL matches on of the
Expand Down Expand Up @@ -248,7 +272,12 @@ func (a *ClientApp) HandleLogin(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Invalid redirect URL: the protocol and host (including port) must match and the path must be within allowed URLs if provided", http.StatusBadRequest)
return
}
stateNonce := a.generateAppState(returnURL)
stateNonce, err := a.generateAppState(returnURL, w)
if err != nil {
log.Errorf("Failed to initiate login flow: %v", err)
http.Error(w, "Failed to initiate login flow", http.StatusInternalServerError)
return
}
grantType := InferGrantType(oidcConf)
var url string
switch grantType {
Expand Down Expand Up @@ -281,10 +310,10 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
state := r.FormValue("state")
if code == "" {
// If code was not given, it implies implicit flow
a.handleImplicitFlow(w, state)
a.handleImplicitFlow(r, w, state)
return
}
appState, err := a.verifyAppState(state)
returnURL, err := a.verifyAppState(r, w, state)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
Expand Down Expand Up @@ -339,7 +368,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
claimsJSON, _ := json.MarshalIndent(claims, "", " ")
renderToken(w, a.redirectURI, idTokenRAW, token.RefreshToken, claimsJSON)
} else {
http.Redirect(w, r, appState.ReturnURL, http.StatusSeeOther)
http.Redirect(w, r, returnURL, http.StatusSeeOther)
}
}

Expand All @@ -366,7 +395,7 @@ if (state != "" && returnURL == "") {
// state nonce for verification, as well as looking up the return URL. Once verified, the client
// stores the id_token from the fragment as a cookie. Finally it performs the final redirect back to
// the return URL.
func (a *ClientApp) handleImplicitFlow(w http.ResponseWriter, state string) {
func (a *ClientApp) handleImplicitFlow(r *http.Request, w http.ResponseWriter, state string) {
type implicitFlowValues struct {
CookieName string
ReturnURL string
Expand All @@ -375,12 +404,12 @@ func (a *ClientApp) handleImplicitFlow(w http.ResponseWriter, state string) {
CookieName: common.AuthCookieName,
}
if state != "" {
appState, err := a.verifyAppState(state)
returnURL, err := a.verifyAppState(r, w, state)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
vals.ReturnURL = appState.ReturnURL
vals.ReturnURL = returnURL
}
renderTemplate(w, implicitFlowTmpl, vals)
}
Expand Down
Loading