diff --git a/protocol/client.go b/protocol/client.go index e5c21672..8c814b30 100644 --- a/protocol/client.go +++ b/protocol/client.go @@ -108,25 +108,11 @@ func (c *CollectedClientData) Verify(storedChallenge string, ceremony CeremonyTy // Registration Step 5 & Assertion Step 9. Verify that the value of C.origin matches // the Relying Party's origin. - var fqOrigin string - if fqOrigin, err = FullyQualifiedOrigin(c.Origin); err != nil { - return ErrParsingData.WithDetails("Error decoding clientData origin as URL").WithError(err) - } - - found := false - - for _, origin := range rpOrigins { - if strings.EqualFold(fqOrigin, origin) { - found = true - break - } - } - - if !found { + if !IsOriginInHaystack(c.Origin, rpOrigins) { return ErrVerification. WithDetails("Error validating origin"). - WithInfo(fmt.Sprintf("Expected Values: %s, Received: %s", rpOrigins, fqOrigin)) + WithInfo(fmt.Sprintf("Expected Values: %s, Received: %s", rpOrigins, c.Origin)) } if rpTopOriginsVerify != TopOriginIgnoreVerificationMode { @@ -145,10 +131,6 @@ func (c *CollectedClientData) Verify(storedChallenge string, ceremony CeremonyTy possibleTopOrigins []string ) - if fqTopOrigin, err = FullyQualifiedOrigin(c.TopOrigin); err != nil { - return ErrParsingData.WithDetails("Error decoding clientData topOrigin as URL").WithError(err) - } - switch rpTopOriginsVerify { case TopOriginExplicitVerificationMode: possibleTopOrigins = rpTopOrigins @@ -160,16 +142,7 @@ func (c *CollectedClientData) Verify(storedChallenge string, ceremony CeremonyTy return ErrNotImplemented.WithDetails("Error handling unknown Top Origin verification mode") } - found = false - - for _, origin := range possibleTopOrigins { - if strings.EqualFold(fqTopOrigin, origin) { - found = true - break - } - } - - if !found { + if !IsOriginInHaystack(c.TopOrigin, possibleTopOrigins) { return ErrVerification. WithDetails("Error validating top origin"). WithInfo(fmt.Sprintf("Expected Values: %s, Received: %s", possibleTopOrigins, fqTopOrigin)) @@ -221,3 +194,92 @@ const ( // Top Origin is verified against the allowed Top Origins values. TopOriginExplicitVerificationMode ) + +// IsOriginInHaystack checks if the needle is in the haystack using the mechanism to determine origin equality defined +// in HTML5 Section 5.3 and RFC3986 Section 6.2.1. +// +// Specifically if the needle value has the 'http://' or 'https://' prefix (case-insensitive) and can be parsed as a +// URL; we check each item in the haystack to see if it matches the same rules, and then if the scheme and host (with +// a normalized port) components match case-insensitively then they're considered a match. +// +// If the needle value does not have the 'http://' or 'https://' prefix (case-insensitive) or can't be parsed as a URL +// equality is determined using simple string comparison. +// +// It is important to note that this function completely ignores Apple Associated Domains entirely as Apple is using +// an unassigned Well-Known URI in breech of Well-Known Uniform Resource Identifiers (RFC8615). +// +// See (Origin Definition): https://www.w3.org/TR/2011/WD-html5-20110525/origin-0.html +// +// See (Simple String Comparison Definition): https://datatracker.ietf.org/doc/html/rfc3986#section-6.2.1 +// +// See (Apple Associated Domains): https://developer.apple.com/documentation/xcode/supporting-associated-domains +// +// See (IANA Well Known URI Assignments): https://www.iana.org/assignments/well-known-uris/well-known-uris.xhtml +// +// See (Well-Known Uniform Resource Identifiers): https://datatracker.ietf.org/doc/html/rfc8615 +func IsOriginInHaystack(needle string, haystack []string) bool { + needleURI := parseOriginURI(needle) + + if needleURI != nil { + for _, hay := range haystack { + if hayURI := parseOriginURI(hay); hayURI != nil { + if isOriginEqual(needleURI, hayURI) { + return true + } + } + } + } else { + for _, hay := range haystack { + if needle == hay { + return true + } + } + } + + return false +} + +func isOriginEqual(a *url.URL, b *url.URL) bool { + if !strings.EqualFold(a.Scheme, b.Scheme) { + return false + } + + if !strings.EqualFold(a.Host, b.Host) { + return false + } + + return true +} + +func parseOriginURI(raw string) *url.URL { + if !isPossibleFQDN(raw) { + return nil + } + + // We can ignore the error here because it's effectively not a FQDN if this fails. + uri, _ := url.Parse(raw) + + if uri == nil { + return nil + } + + // Normalize the port if necessary. + switch uri.Scheme { + case "http": + if uri.Port() == "80" { + uri.Host = uri.Hostname() + } + case "https": + if uri.Port() == "443" { + uri.Host = uri.Hostname() + } + } + + return uri +} + +func isPossibleFQDN(raw string) bool { + normalized := strings.ToLower(raw) + + return strings.HasPrefix(normalized, "http://") || strings.HasPrefix(normalized, "https://") +} diff --git a/protocol/client_test.go b/protocol/client_test.go index 5ab5049f..9f44ed15 100644 --- a/protocol/client_test.go +++ b/protocol/client_test.go @@ -1,17 +1,20 @@ package protocol import ( + "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func setupCollectedClientData(challenge URLEncodedBase64, origin string) *CollectedClientData { +func setupCollectedClientData(challenge URLEncodedBase64, origin, topOrigin string, crossOrigin bool) *CollectedClientData { ccd := &CollectedClientData{ - Type: CreateCeremony, - Origin: origin, - Challenge: challenge.String(), + Type: CreateCeremony, + Origin: origin, + TopOrigin: topOrigin, + CrossOrigin: crossOrigin, + Challenge: challenge.String(), } return ccd @@ -21,11 +24,77 @@ func TestVerifyCollectedClientData(t *testing.T) { newChallenge, err := CreateChallenge() require.NoError(t, err) - ccd := setupCollectedClientData(newChallenge, "http://example.com") + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) var storedChallenge = newChallenge - require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, nil, TopOriginIgnoreVerificationMode)) + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{ccd.TopOrigin}, TopOriginExplicitVerificationMode)) +} + +func TestVerifyCollectedClientDataNoTopOrigin(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "", true) + + var storedChallenge = newChallenge + + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{ccd.TopOrigin}, TopOriginExplicitVerificationMode)) +} + +func TestVerifyCollectedClientDataTopOrigin(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example2.com", true) + + var storedChallenge = newChallenge + + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{ccd.TopOrigin}, TopOriginExplicitVerificationMode)) +} + +func TestVerifyCollectedClientDataTopOriginIgnore(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example2.com", true) + + var storedChallenge = newChallenge + + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{"https://example3.com"}, TopOriginIgnoreVerificationMode)) +} + +func TestVerifyCollectedClientDataTopOriginImplicit(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) + + var storedChallenge = newChallenge + + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, nil, TopOriginImplicitVerificationMode)) +} + +func TestVerifyCollectedClientDataTopOriginAuto(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) + + var storedChallenge = newChallenge + + require.NoError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{"https://example.com"}, TopOriginAutoVerificationMode)) +} + +func TestVerifyCollectedClientDataTopOriginInvalidValue(t *testing.T) { + newChallenge, err := CreateChallenge() + require.NoError(t, err) + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) + + var storedChallenge = newChallenge + + AssertIsProtocolError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{"https://example.com"}, -1), "not_implemented", "Error handling unknown Top Origin verification mode", "") } func TestVerifyCollectedClientDataIncorrectChallenge(t *testing.T) { @@ -34,12 +103,12 @@ func TestVerifyCollectedClientDataIncorrectChallenge(t *testing.T) { t.Fatalf("error creating challenge: %s", err) } - ccd := setupCollectedClientData(newChallenge, "http://example.com") + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) bogusChallenge, err := CreateChallenge() require.NoError(t, err) - assert.EqualError(t, ccd.Verify(bogusChallenge.String(), ccd.Type, []string{ccd.Origin}, nil, TopOriginIgnoreVerificationMode), "Error validating challenge") + AssertIsProtocolError(t, ccd.Verify(bogusChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{ccd.TopOrigin}, TopOriginExplicitVerificationMode), "verification_error", "Error validating challenge", fmt.Sprintf("Expected b Value: \"%s\"\nReceived b: \"%s\"\n", bogusChallenge.String(), newChallenge.String())) } func TestVerifyCollectedClientDataUnexpectedOrigin(t *testing.T) { @@ -48,11 +117,36 @@ func TestVerifyCollectedClientDataUnexpectedOrigin(t *testing.T) { t.Fatalf("error creating challenge: %s", err) } - ccd := setupCollectedClientData(newChallenge, "http://example.com") + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) + storedChallenge := newChallenge + expectedOrigins := []string{"http://different.com"} + + AssertIsProtocolError(t, ccd.Verify(storedChallenge.String(), ccd.Type, expectedOrigins, nil, TopOriginExplicitVerificationMode), "verification_error", "Error validating origin", "Expected Values: [http://different.com], Received: http://example.com") +} + +func TestVerifyCollectedClientDataUnexpectedTopOriginCrossOrigin(t *testing.T) { + newChallenge, err := CreateChallenge() + if err != nil { + t.Fatalf("error creating challenge: %s", err) + } + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example2.com", false) + storedChallenge := newChallenge + + AssertIsProtocolError(t, ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.Origin}, []string{ccd.TopOrigin}, TopOriginExplicitVerificationMode), "verification_error", "Error validating topOrigin", "The topOrigin can't have values unless crossOrigin is true.") +} + +func TestVerifyCollectedClientDataUnexpectedTopOrigin(t *testing.T) { + newChallenge, err := CreateChallenge() + if err != nil { + t.Fatalf("error creating challenge: %s", err) + } + + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) storedChallenge := newChallenge expectedOrigins := []string{"http://different.com"} - if err = ccd.Verify(storedChallenge.String(), ccd.Type, expectedOrigins, nil, TopOriginIgnoreVerificationMode); err == nil { + if err = ccd.Verify(storedChallenge.String(), ccd.Type, []string{ccd.TopOrigin}, expectedOrigins, TopOriginExplicitVerificationMode); err == nil { t.Fatalf("error expected but not received. expected %#v got %#v", expectedOrigins, ccd.Origin) } } @@ -63,7 +157,7 @@ func TestVerifyCollectedClientDataWithMultipleExpectedOrigins(t *testing.T) { t.Fatalf("error creating challenge: %s", err) } - ccd := setupCollectedClientData(newChallenge, "http://example.com") + ccd := setupCollectedClientData(newChallenge, "http://example.com", "http://example.com", true) var storedChallenge = newChallenge @@ -106,3 +200,157 @@ func TestFullyQualifiedOrigin(t *testing.T) { }) } } + +func TestIsOriginInHaystack(t *testing.T) { + testCases := []struct { + name string + origin string + haystack []string + expected bool + }{ + { + "ShouldHandleFullyQualifiedOrigin", + "https://app.example.com", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginCaseInsensitiveScheme", + "https://app.example.com", + []string{"HTTPS://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginCaseInsensitiveHost", + "https://app.EXAMPLE.com", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginWithPort", + "https://app.example.com:443", + []string{"https://app.example.com:443"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentScheme", + "http://app.example.com", + []string{"https://app.example.com"}, + false, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentPort", + "https://app.example.com:443", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentPortNotMatchingScheme", + "https://app.example.com:80", + []string{"https://app.example.com"}, + false, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentPath", + "https://app.example.com/abc", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentQuery", + "https://app.example.com/?abc=123", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentQueryCount", + "https://app.example.com/?abc=123", + []string{"https://app.example.com/?zyz=123&abc=123"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentQueryOrder", + "https://app.example.com/?abc=123&xyz=123", + []string{"https://app.example.com/?xyz=123&abc=123"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDifferentQueryValue", + "https://app.example.com/?abc=123&xyz=123", + []string{"https://app.example.com/?xyz=1234&abc=123"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginFragment", + "https://app.example.com/#abc", + []string{"https://app.example.com/#abc"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginFragmentDifferent", + "https://app.example.com/#abc", + []string{"https://app.example.com/#abc2"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginWithoutAllowed", + "https://app.example.com", + nil, + false, + }, + { + "ShouldHandleFullyQualifiedOriginWithTrailingSlashes", + "https://app.example.com/", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleNativeAppAndroid", + "android:apk-key-hash:7d1043473d55bfa90e8530d35801d4e381bc69f0", + []string{"android:apk-key-hash:7d1043473d55bfa90e8530d35801d4e381bc69f0"}, + true, + }, + { + "ShouldHandleNativeAppAndroidCaseSensitive", + "android:apk-key-hash:7d1043473d55bfa90e8530d35801d4e381bc69F0", + []string{"android:apk-key-hash:7d1043473d55bfa90e8530d35801d4e381bc69f0"}, + false, + }, + { + "ShouldHandleNonFQDNOrigin", + "https://user:password@app.example.com/", + []string{"https://app.example.com/"}, + true, + }, + { + "ShouldHandleNonFQDNOriginExactStringMatch", + "https://user:password@app.example.com/", + []string{"https://user:password@app.example.com/"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDefaultPortEquivalentHTTPS", + "https://app.example.com:443", + []string{"https://app.example.com"}, + true, + }, + { + "ShouldHandleFullyQualifiedOriginDefaultPortEquivalentHTTP", + "http://app.example.com:80", + []string{"http://app.example.com"}, + true, + }, + { + "ShouldHandleInvalidURLAsSimpleStringMatch", + "http://app.example.%%%&123?1", + []string{"http://app.example.%%%&123?1"}, + true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expected, IsOriginInHaystack(tc.origin, tc.haystack)) + }) + } +} diff --git a/protocol/func_test.go b/protocol/func_test.go index 933b4ad5..a887265e 100644 --- a/protocol/func_test.go +++ b/protocol/func_test.go @@ -2,18 +2,42 @@ package protocol import ( "errors" + "regexp" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func AssertIsProtocolError(t *testing.T, err error, errType, errDetails, errInfo string) { +func AssertIsProtocolError(t *testing.T, err error, errType, errDetails, errInfo any) { var e *Error require.True(t, errors.As(err, &e)) - assert.Equal(t, errType, e.Type) - assert.Equal(t, errDetails, e.Details) - assert.Equal(t, errInfo, e.DevInfo) + switch et := errType.(type) { + case string: + assert.Equal(t, et, e.Type) + case *regexp.Regexp: + assert.Regexp(t, et, e.Type) + default: + t.Fatalf("%T is not a known type", errType) + } + + switch ed := errDetails.(type) { + case string: + assert.Equal(t, ed, e.Details) + case *regexp.Regexp: + assert.Regexp(t, ed, e.Details) + default: + t.Fatalf("%T is not a known type", errDetails) + } + + switch ed := errInfo.(type) { + case string: + assert.Equal(t, ed, e.DevInfo) + case *regexp.Regexp: + assert.Regexp(t, ed, e.DevInfo) + default: + t.Fatalf("%T is not a known type", errInfo) + } } diff --git a/webauthn/types.go b/webauthn/types.go index 5af251da..91777db6 100644 --- a/webauthn/types.go +++ b/webauthn/types.go @@ -33,12 +33,18 @@ type Config struct { // RPDisplayName configures the display name for the Relying Party Server. This can be any string. RPDisplayName string - // RPOrigins configures the list of Relying Party Server Origins that are permitted. These should be fully - // qualified origins. + // RPOrigins configures the list of Relying Party Server Origins that are permitted. The provided origins can either + // be fully qualified origins or strings for simple string comparison. The strings are matched using canonical + // origin matching semantics specifically if they start with 'http://' or 'https://' if the provided origin has a + // case-insensitive equal scheme and host component they are equal, otherwise simple string comparison is utilized + // to determine equality. RPOrigins []string - // RPTopOrigins configures the list of Relying Party Server Top Origins that are permitted. These should be fully - // qualified origins. + // RPTopOrigins configures the list of Relying Party Server Top Origins that are permitted. The provided origins can + // either be fully qualified origins or strings for simple string comparison. The strings are matched using + // canonical origin matching semantics specifically if they start with 'http://' or 'https://' if the provided + // origin has a case-insensitive equal scheme and host component they are equal, otherwise simple string comparison + // is utilized to determine equality. RPTopOrigins []string // RPTopOriginVerificationMode determines the verification mode for the Top Origin value. By default the @@ -90,13 +96,11 @@ type TimeoutConfig struct { } // Validate that the config flags in Config are properly set -func (config *Config) validate() error { +func (config *Config) validate() (err error) { if config.validated { return nil } - var err error - if len(config.RPID) != 0 { if _, err = url.Parse(config.RPID); err != nil { return fmt.Errorf(errFmtFieldNotValidURI, "RPID", err) @@ -129,9 +133,9 @@ func (config *Config) validate() error { switch config.RPTopOriginVerificationMode { case protocol.TopOriginDefaultVerificationMode: config.RPTopOriginVerificationMode = protocol.TopOriginIgnoreVerificationMode - case protocol.TopOriginImplicitVerificationMode: + case protocol.TopOriginExplicitVerificationMode: if len(config.RPTopOrigins) == 0 { - return fmt.Errorf("must provide at least one value to the 'RPTopOrigins' field when 'RPTopOriginVerificationMode' field is set to protocol.TopOriginImplicitVerificationMode") + return fmt.Errorf("must provide at least one value to the 'RPTopOrigins' field when 'RPTopOriginVerificationMode' field is set to protocol.TopOriginExplicitVerificationMode") } } diff --git a/webauthn/types_test.go b/webauthn/types_test.go index b154bd93..095015c3 100644 --- a/webauthn/types_test.go +++ b/webauthn/types_test.go @@ -1,5 +1,13 @@ package webauthn +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/go-webauthn/webauthn/protocol" +) + type defaultUser struct { id []byte credentials []Credential @@ -22,3 +30,60 @@ func (user *defaultUser) WebAuthnDisplayName() string { func (user *defaultUser) WebAuthnCredentials() []Credential { return user.credentials } + +func TestNew(t *testing.T) { + testCases := []struct { + description string + config *Config + err string + }{ + { + "ShouldPassMinimalConfig", + &Config{ + RPID: "https://example.com/", + RPOrigins: []string{"https://example.com"}, + }, + "", + }, + { + "ShouldFailBadRPID", + &Config{ + RPID: "%%&&", + RPOrigins: []string{"https://example.com"}, + }, + `error occurred validating the configuration: field 'RPID' is not a valid URI: parse "%%&&": invalid URL escape "%%&"`, + }, + { + "ShouldFailNoRPOrigins", + &Config{ + RPID: "https://example.com/", + }, + "error occurred validating the configuration: must provide at least one value to the 'RPOrigins' field", + }, + { + "ShouldFailBadTopOrigins", + &Config{ + RPID: "https://example.com/", + RPOrigins: []string{"https://example.com"}, + RPTopOriginVerificationMode: protocol.TopOriginExplicitVerificationMode, + }, + "error occurred validating the configuration: must provide at least one value to the 'RPTopOrigins' field when 'RPTopOriginVerificationMode' field is set to protocol.TopOriginExplicitVerificationMode", + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + w, err := New(tc.config) + + if tc.err == "" { + assert.NotNil(t, w) + assert.NoError(t, err) + assert.NoError(t, tc.config.validate()) + } else { + assert.Nil(t, w) + assert.EqualError(t, err, tc.err) + assert.Error(t, tc.config.validate()) + } + }) + } +}