diff --git a/claims.go b/claims.go index f0228f02..ae3bf951 100644 --- a/claims.go +++ b/claims.go @@ -61,7 +61,7 @@ func (c StandardClaims) Valid() error { // Compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (c *StandardClaims) VerifyAudience(cmp string, req bool) bool { - return verifyAud(c.Audience, cmp, req) + return verifyAud([]string{c.Audience}, cmp, req) } // Compares the exp claim against cmp. @@ -90,15 +90,16 @@ func (c *StandardClaims) VerifyNotBefore(cmp int64, req bool) bool { // ----- helpers -func verifyAud(aud string, cmp string, required bool) bool { - if aud == "" { +func verifyAud(auds []string, cmp string, required bool) bool { + if auds == nil || len(auds) == 0 { return !required } - if subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { - return true - } else { - return false + for _, aud := range auds { + if len(aud) == len(cmp) && subtle.ConstantTimeCompare([]byte(aud), []byte(cmp)) != 0 { + return true + } } + return false } func verifyExp(exp int64, now int64, required bool) bool { diff --git a/map_claims.go b/map_claims.go index 291213c4..4395f757 100644 --- a/map_claims.go +++ b/map_claims.go @@ -13,7 +13,13 @@ type MapClaims map[string]interface{} // Compares the aud claim against cmp. // If required is false, this method will return true if the value matches or is unset func (m MapClaims) VerifyAudience(cmp string, req bool) bool { - aud, _ := m["aud"].(string) + var aud []string + switch exp := m["aud"].(type) { + case string: + aud = []string{exp} + case []string: + aud = exp + } return verifyAud(aud, cmp, req) } diff --git a/map_claims_test.go b/map_claims_test.go new file mode 100644 index 00000000..1ce4b757 --- /dev/null +++ b/map_claims_test.go @@ -0,0 +1,120 @@ +package jwt_test + +import ( + "encoding/json" + "fmt" + "github.com/dgrijalva/jwt-go" + "github.com/stretchr/testify/assert" + "testing" + "time" +) + +func TestMapClaims_Valid(t *testing.T) { + now := time.Now() + oneMinuteFromNow := json.Number(fmt.Sprint(now.Add(time.Minute).Unix())) + oneMinuteAgo := json.Number(fmt.Sprint(now.Add(-time.Minute).Unix())) + twoMinutesAgo := json.Number(fmt.Sprint(now.Add(-2 * time.Minute).Unix())) + thirtySecondFromNow := json.Number(fmt.Sprint(now.Add(30*time.Second).Unix())) + nowStr := json.Number(fmt.Sprint(now.Unix())) + validClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": nowStr, + "nbf": nowStr, + } + assert.NoError(t, validClaims.Valid()) + expiredClaims := jwt.MapClaims{ + "exp": oneMinuteAgo, + "iat": twoMinutesAgo, + "nbf": twoMinutesAgo, + } + assert.Error(t, expiredClaims.Valid()) + notYetValidClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": nowStr, + "nbf": thirtySecondFromNow, + } + assert.Error(t, notYetValidClaims.Valid()) + notYetIssuedClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": thirtySecondFromNow, + "nbf": thirtySecondFromNow, + } + assert.Error(t, notYetIssuedClaims.Valid()) +} + +func TestMapClaims_Valid_Float(t *testing.T) { + now := time.Now() + oneMinuteFromNow := float64(now.Add(time.Minute).Unix()) + oneMinuteAgo := float64(now.Add(-time.Minute).Unix()) + twoMinutesAgo := float64(now.Add(-2 * time.Minute).Unix()) + thirtySecondFromNow := float64(now.Add(30*time.Second).Unix()) + nowStr := float64(now.Unix()) + validClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": nowStr, + "nbf": nowStr, + } + assert.NoError(t, validClaims.Valid()) + expiredClaims := jwt.MapClaims{ + "exp": oneMinuteAgo, + "iat": twoMinutesAgo, + "nbf": twoMinutesAgo, + } + assert.Error(t, expiredClaims.Valid()) + notYetValidClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": nowStr, + "nbf": thirtySecondFromNow, + } + assert.Error(t, notYetValidClaims.Valid()) + notYetIssuedClaims := jwt.MapClaims{ + "exp": oneMinuteFromNow, + "iat": thirtySecondFromNow, + "nbf": thirtySecondFromNow, + } + assert.Error(t, notYetIssuedClaims.Valid()) +} + +func TestMapClaims_VerifyAudience(t *testing.T) { + joe := "joe" + jill := "jill" + jack := "jack" + + claims := jwt.MapClaims{} + assert.True(t, claims.VerifyAudience(joe, false)) + assert.False(t, claims.VerifyAudience(joe, true)) + + claims = jwt.MapClaims{"aud":[]string{}} + assert.True(t, claims.VerifyAudience(joe, false)) + assert.False(t, claims.VerifyAudience(joe, true)) + + claims = jwt.MapClaims{ + "aud": joe, + } + assert.True(t, claims.VerifyAudience(joe, false)) + assert.False(t, claims.VerifyAudience(jill, false)) + assert.True(t, claims.VerifyAudience(joe, true)) + assert.False(t, claims.VerifyAudience(jill, true)) + + claims = jwt.MapClaims{ + "aud": []string {joe, jill}, + } + assert.True(t, claims.VerifyAudience(joe, false)) + assert.True(t, claims.VerifyAudience(joe, true)) + assert.False(t, claims.VerifyAudience(jack, false)) + assert.False(t, claims.VerifyAudience(jack, true)) +} + +func TestMapClaims_VerifyIssuer(t *testing.T) { + claims := jwt.MapClaims{} + assert.True(t, claims.VerifyIssuer("service1", false)) + assert.False(t, claims.VerifyIssuer("service1", true)) + + claims = jwt.MapClaims{"iss": "service1"} + assert.True(t, claims.VerifyIssuer("service1", false)) + assert.True(t, claims.VerifyIssuer("service1", true)) + + claims = jwt.MapClaims{"iss": "service2"} + assert.False(t, claims.VerifyIssuer("service1", false)) + assert.False(t, claims.VerifyIssuer("service1", true)) +} \ No newline at end of file