diff --git a/custom_components/auth_oidc/__init__.py b/custom_components/auth_oidc/__init__.py index af14a9e..de87b87 100644 --- a/custom_components/auth_oidc/__init__.py +++ b/custom_components/auth_oidc/__init__.py @@ -30,6 +30,8 @@ FEATURES_DISABLE_FRONTEND_INJECTION, FEATURES_FORCE_HTTPS, REQUIRED_SCOPES, + VERBOSE_DEBUG_MODE, + NETWORK_USERINFO_FALLBACK, ) from .config import convert_ui_config_entry_to_internal_format @@ -134,6 +136,8 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam claims=my_config.get(CLAIMS, {}), roles=my_config.get(ROLES, {}), network=my_config.get(NETWORK, {}), + enable_verbose_debug_mode=my_config.get(VERBOSE_DEBUG_MODE, False), + userinfo_fallback=my_config.get(NETWORK_USERINFO_FALLBACK, False), ) # Register the views diff --git a/custom_components/auth_oidc/config/const.py b/custom_components/auth_oidc/config/const.py index 8538262..334f882 100644 --- a/custom_components/auth_oidc/config/const.py +++ b/custom_components/auth_oidc/config/const.py @@ -40,12 +40,14 @@ NETWORK = "network" NETWORK_TLS_VERIFY = "tls_verify" NETWORK_TLS_CA_PATH = "tls_ca_path" +NETWORK_USERINFO_FALLBACK = "userinfo_fallback" +VERBOSE_DEBUG_MODE = "enable_verbose_debug_mode" ## === ## Default configurations for providers ## === -REQUIRED_SCOPES = "openid profile" +REQUIRED_SCOPES = "openid profile email" DEFAULT_ID_TOKEN_SIGNING_ALGORITHM = "RS256" DEFAULT_GROUPS_SCOPE = "groups" diff --git a/custom_components/auth_oidc/config/schema.py b/custom_components/auth_oidc/config/schema.py index 2167503..3a249c1 100644 --- a/custom_components/auth_oidc/config/schema.py +++ b/custom_components/auth_oidc/config/schema.py @@ -26,8 +26,10 @@ NETWORK, NETWORK_TLS_VERIFY, NETWORK_TLS_CA_PATH, + NETWORK_USERINFO_FALLBACK, DOMAIN, DEFAULT_GROUPS_SCOPE, + VERBOSE_DEBUG_MODE, ) CONFIG_SCHEMA = vol.Schema( @@ -53,6 +55,11 @@ # Additional scopes to request from the OIDC provider # Optional, this field is unnecessary if you only use the openid and profile scopes. vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]), + # Added for debugging purposes + # If enabled, logging will include more detailed information regarding + # the full OIDC auth chain (including tokens) and is captured within: + # /custom_components/auth_oidc/verbose_debug/ + vol.Optional(VERBOSE_DEBUG_MODE, default=False): vol.Coerce(bool), # Which features should be enabled/disabled? # Optional, defaults to sane/secure defaults vol.Optional(FEATURES): vol.Schema( @@ -115,6 +122,11 @@ ), # Load custom certificate chain for private CAs vol.Optional(NETWORK_TLS_CA_PATH): vol.Coerce(str), + # Constructed Userinfo endpoint fallback if not provided in discovery + # Some OPs omit this endpoint + vol.Optional( + NETWORK_USERINFO_FALLBACK, default=False + ): vol.Coerce(bool), } ), } diff --git a/custom_components/auth_oidc/endpoints/injected_auth_page.py b/custom_components/auth_oidc/endpoints/injected_auth_page.py index b38eab1..406aeed 100644 --- a/custom_components/auth_oidc/endpoints/injected_auth_page.py +++ b/custom_components/auth_oidc/endpoints/injected_auth_page.py @@ -79,7 +79,7 @@ async def frontend_injection(hass: HomeAssistant, sso_name: str) -> None: "/auth/oidc/static/style.css", hass.config.path("custom_components/auth_oidc/static/style.css"), cache_headers=False, - ) + ), ] ) diff --git a/custom_components/auth_oidc/stores/code_store.py b/custom_components/auth_oidc/stores/code_store.py index d3aedab..c532927 100644 --- a/custom_components/auth_oidc/stores/code_store.py +++ b/custom_components/auth_oidc/stores/code_store.py @@ -1,7 +1,6 @@ """Code Store, stores the codes and their associated authenticated user temporarily.""" -import random -import string +import secrets from datetime import datetime, timedelta, timezone from typing import cast, Optional @@ -37,8 +36,8 @@ async def _async_save(self) -> None: await self._store.async_save(self._data) def _generate_code(self) -> str: - """Generate a random six-digit code.""" - return "".join(random.choices(string.digits, k=6)) + """Generate a secure URL-safe code for temporary handoff.""" + return secrets.token_urlsafe(16) async def async_generate_code_for_userinfo(self, user_info: UserDetails) -> str: """Generates a one time code and adds it to the database for 5 minutes.""" diff --git a/custom_components/auth_oidc/tools/helpers.py b/custom_components/auth_oidc/tools/helpers.py index ec984dd..93cf21b 100644 --- a/custom_components/auth_oidc/tools/helpers.py +++ b/custom_components/auth_oidc/tools/helpers.py @@ -1,6 +1,12 @@ """Helper functions for the integration.""" +import logging +from pathlib import Path +from typing import Optional + +import aiofiles from homeassistant.components import http + from ..views.loader import AsyncTemplateRenderer @@ -22,3 +28,107 @@ async def get_view(template: str, parameters: dict | None = None) -> str: renderer = AsyncTemplateRenderer() return await renderer.render_template(f"{template}.html", **parameters) + + +def compute_allowed_signing_algs( + discovery: dict, + id_token_signing_alg: Optional[str], + verbose_debug_mode: bool, + logger: logging.Logger, +) -> list[str]: + """Compute allowed ID token signing algorithms from config and OP discovery document. + + - If `id_token_signing_alg` set: Use only it (warn if not in OP-supported). + - Else: Use OP's `id_token_signing_alg_values_supported` (fallback ['RS256']). + + Args: + discovery: Fetched OIDC discovery document. + id_token_signing_alg: Configured alg from + self.id_token_signing_alg (or None; falls + back to DEFAULT_ID_TOKEN_SIGNING_ALGORITHM="RS256"). + verbose_debug_mode: Enable debug logs. + logger: Logger instance (e.g., _LOGGER). + + Returns: + List of allowed algs (e.g., ['RS256', 'ES256']). + """ + supported_algs = discovery.get("id_token_signing_alg_values_supported", []) + + if id_token_signing_alg: + allowed_algs = [id_token_signing_alg] + if id_token_signing_alg not in supported_algs: + logger.warning( + ( + "Configured signing algorithm '%s' is not in OP" + " supported algorithms: %s. Proceeding anyway." + ), + id_token_signing_alg, + supported_algs, + ) + else: + allowed_algs = supported_algs or ["RS256"] + if not supported_algs: + logger.info( + ( + "No signing algorithms supported from OP" + " discovery document! Will default to RS256" + ) + ) + + if verbose_debug_mode: + logger.debug("Allowed ID token signing algorithms: %s", allowed_algs) + + return allowed_algs + + +async def capture_auth_flows( + log_info: tuple[logging.Logger, int], + verbose_debug_mode: bool, + capture_dir: Path | None, + debug_msg: str, + filename: str, + content: str, + mode: str = "a", + header: str = "", + is_request: bool = False, +) -> None: + """Helper to log verbose debug messages and optionally capture content to file. + + Reduces repetition in OIDCClient/OIDCDiscoveryClient verbose logging and file captures. + Only writes/captures if verbose_debug_mode is True and capture_dir exists. + + Args: + log_info: Tuple containing logger instance (e.g., (_LOGGER, 10) is Debug level). + verbose_debug_mode: Whether verbose mode is enabled. + capture_dir: Directory path for captures (if None, skips file write). + debug_msg: Message for _LOGGER.debug(). + filename: Base filename for capture file (e.g., 'get_discovery.txt'). + content: Content to write (e.g., JSON string or URL). + mode: File write mode ('w' to overwrite, 'a' to append). + header: Prepend header comment to content (e.g., discovery endpoint info). + is_request: If True, uses 'BEGIN REQUEST' header; else 'BEGIN RESPONSE'. + """ + + # Unpack logger and log level + logger, log_level = log_info + + if verbose_debug_mode: + logger.log(log_level, debug_msg) + + if verbose_debug_mode and capture_dir: + header_str = ( + f"/*\n----------BEGIN {'REQUEST' if is_request else 'RESPONSE'}----------\n" + f"{header}*/\n\n" + if header + else "" + ) + full_content = header_str + content + file_path = capture_dir / filename + async with aiofiles.open(file_path, mode=mode, encoding="utf-8") as f: + await f.write(full_content) + logger.log( + log_level, + "Check %s capture in: %s for more details...", + filename, + file_path, + ) diff --git a/custom_components/auth_oidc/tools/oidc_client.py b/custom_components/auth_oidc/tools/oidc_client.py index 1d3413a..e41b526 100644 --- a/custom_components/auth_oidc/tools/oidc_client.py +++ b/custom_components/auth_oidc/tools/oidc_client.py @@ -1,29 +1,35 @@ """OIDC Client class""" -import urllib.parse -import logging -import os import base64 import hashlib +import json +import logging +import os import ssl -from typing import Optional +import urllib.parse from functools import partial +from pathlib import Path +from typing import Optional + import aiohttp -from joserfc import jwt, jwk, jws, errors as joserfc_errors from homeassistant.core import HomeAssistant +from joserfc import errors as joserfc_errors +from joserfc import jwk, jws, jwt -from .types import UserDetails from ..config.const import ( - FEATURES_DISABLE_PKCE, CLAIMS_DISPLAY_NAME, - CLAIMS_USERNAME, CLAIMS_GROUPS, + CLAIMS_USERNAME, + DEFAULT_ID_TOKEN_SIGNING_ALGORITHM, + FEATURES_DISABLE_PKCE, + NETWORK_TLS_CA_PATH, + NETWORK_TLS_VERIFY, + NETWORK_USERINFO_FALLBACK, ROLE_ADMINS, ROLE_USERS, - NETWORK_TLS_VERIFY, - NETWORK_TLS_CA_PATH, - DEFAULT_ID_TOKEN_SIGNING_ALGORITHM, ) +from .helpers import capture_auth_flows, compute_allowed_signing_algs +from .types import UserDetails from .validation import validate_url _LOGGER = logging.getLogger(__name__) @@ -79,6 +85,10 @@ class OIDCIdTokenSigningAlgorithmInvalid(OIDCTokenResponseInvalid): "Raised when the id_token is signed with the wrong algorithm, adjust your config accordingly." +class OIDCIdTokenInvalid(OIDCClientException): + """Raised when the ID token is invalid, unverifiable, or claims validation fails.""" + + class HTTPClientError(aiohttp.ClientResponseError): "Raised when the HTTP client encounters not OK (200) status code." @@ -117,16 +127,53 @@ def __init__( discovery_url: str, http_session: aiohttp.ClientSession, verification_context: dict, + verbose_debug_mode: bool = False, + capture_dir: Optional[Path] = None, ): self.discovery_url = discovery_url self.http_session = http_session self.verification_context = verification_context + self.verbose_debug_mode = verbose_debug_mode + self.capture_dir = capture_dir async def _fetch_discovery_document(self): """Fetches discovery document from the given URL.""" + # Pass verbose context from OIDCClient (additive) + verbose_mode = getattr(self, "verbose_debug_mode", False) + capture_dir = getattr(self, "capture_dir", None) + try: + await capture_auth_flows( + (_LOGGER, 10), # logger.DEBUG is 10 + verbose_mode, + capture_dir, + f"Attempting to fetch discovery document from: {self.discovery_url}", + "get_discovery.txt", + f"Discovery Endpoint URL: {self.discovery_url}", + mode="w", + header="", + is_request=True, + ) + async with self.http_session.get(self.discovery_url) as response: await http_raise_for_status(response) + response_text = await response.text() + + await capture_auth_flows( + (_LOGGER, 10), + verbose_mode, + capture_dir, + f"Discovery response received: Status {response.status}", + "get_discovery.txt", + ( + f"Fetch Discovery Doc Response Status: {response.status}" + f"\n//Response Body:\n{response_text}" + ), + mode="a", + header="", + is_request=False, + ) + return await response.json() except HTTPClientError as e: if e.status == 404: @@ -139,9 +186,42 @@ async def _fetch_discovery_document(self): async def _fetch_jwks(self, jwks_uri): """Fetches JWKS from the given URL.""" + # Pass verbose context (additive) + verbose_mode = getattr(self, "verbose_debug_mode", False) + capture_dir = getattr(self, "capture_dir", None) + try: + await capture_auth_flows( + (_LOGGER, 10), + verbose_mode, + capture_dir, + f"Retrieving JWKS keys from endpoint: {jwks_uri}", + "get_jwks.txt", + f"JWKS Endpoint URL: {jwks_uri}", + mode="w", + header="", + is_request=True, + ) + async with self.http_session.get(jwks_uri) as response: await http_raise_for_status(response) + response_text = await response.text() + + await capture_auth_flows( + (_LOGGER, 10), + verbose_mode, + capture_dir, + f"JWKS response received: Status {response.status}", + "get_jwks.txt", + ( + f"Fetch JWKS Keys Status: {response.status}" + f"\n//Response Body:\n{response_text}" + ), + mode="a", + header="", + is_request=False, + ) + return await response.json() except HTTPClientError as e: _LOGGER.warning("Error fetching JWKS: %s", e) @@ -181,6 +261,28 @@ async def _validate_discovery_document(self, document): details={"endpoint": endpoint, "url": document[endpoint]}, ) + # OpenID Connect Discovery 1.0 §2.1 & Core 1.0 §3.1.3.7.2: Explicitly validate + # that the 'issuer' from discovery document exactly matches the discovery URL + # (normalized: scheme/host only, lowercase scheme, no path/query/fragment). + # Prevents issuer mismatch attacks or misconfigs. + def normalize_issuer(issuer_url: str) -> str: + """Normalize issuer URL per OIDC §8.1 (scheme/host only, lowercase scheme).""" + parsed = urllib.parse.urlparse(issuer_url.rstrip("/")) + return f"{parsed.scheme.lower()}://{parsed.netloc.lower()}" + + expected_issuer = normalize_issuer(self.discovery_url) + actual_issuer = normalize_issuer(document["issuer"]) + if expected_issuer != actual_issuer: + _LOGGER.warning( + "Error: Discovery issuer mismatch. Expected (normalized): %s, got: %s", + expected_issuer, + actual_issuer, + ) + raise OIDCDiscoveryInvalid( + type="issuer_mismatch", + details={"expected": expected_issuer, "actual": actual_issuer}, + ) + # Verify optional response_modes_supported if "response_modes_supported" in document: if "query" not in document["response_modes_supported"]: @@ -256,20 +358,26 @@ async def _validate_discovery_document(self, document): ) raise OIDCDiscoveryInvalid(type="missing_id_token_signing_alg_values") - # Verify that the requested id_token_signing_alg is supported + # Verify that the requested id_token_signing_alg is supported (WARN only, flexible) requested_alg = self.verification_context.get("id_token_signing_alg", None) - if requested_alg is not None and requested_alg not in signing_values: + signing_values = document.get("id_token_signing_alg_values_supported", None) + if signing_values is None: _LOGGER.warning( - "Error: Discovery document %s does not support requested " - "id_token_signing_alg '%s', only supports: %s", + "Error: Discovery document %s does not have " + "'id_token_signing_alg_values_supported' field", + self.discovery_url, + ) + raise OIDCDiscoveryInvalid(type="missing_id_token_signing_alg_values") + + if requested_alg is not None and requested_alg not in signing_values: + _LOGGER.warning( # WARN, not raise (flexible via compute_allowed_signing_algs) + "Discovery document %s does not support requested " + "id_token_signing_alg '%s', only supports: %s. Proceeding anyway.", self.discovery_url, requested_alg, signing_values, ) - raise OIDCDiscoveryInvalid( - type="does_not_support_id_token_signing_alg", - details={"requested": requested_alg, "supported": signing_values}, - ) + # raise ... # REMOVED: Now handled flexibly in _parse_id_token async def fetch_discovery_document(self): """Fetches discovery document.""" @@ -289,9 +397,6 @@ async def fetch_jwks(self, jwks_uri: str | None = None): class OIDCClient: """OIDC Client implementation for Python, including PKCE.""" - # Flows stores the state, code_verifier and nonce of all current flows. - flows = {} - # HTTP session to be used http_session: aiohttp.ClientSession = None @@ -309,6 +414,10 @@ def __init__( self.hass = hass self.discovery_url = discovery_url self.discovery_document = None + # Instance-level discovery caching with TTL (1h) for efficiency/freshness + # Prevents stale data on OP endpoint/JWKS rotations while minimizing fetches. + self.discovery_timestamp = None + self.discovery_ttl = 3600 # 1 hour self.client_id = client_id self.scope = scope @@ -316,9 +425,9 @@ def __init__( self.client_secret = kwargs.get("client_secret") # Default id_token_signing_alg to RS256 if not specified - self.id_token_signing_alg = kwargs.get("id_token_signing_alg") - if self.id_token_signing_alg is None: - self.id_token_signing_alg = DEFAULT_ID_TOKEN_SIGNING_ALGORITHM + self.id_token_signing_alg = kwargs.get( + "id_token_signing_alg", DEFAULT_ID_TOKEN_SIGNING_ALGORITHM + ) features = kwargs.get("features") claims = kwargs.get("claims") @@ -333,6 +442,34 @@ def __init__( self.admin_role = roles.get(ROLE_ADMINS, "admins") self.tls_verify = network.get(NETWORK_TLS_VERIFY, True) self.tls_ca_path = network.get(NETWORK_TLS_CA_PATH) + self.userinfo_fallback = network.get(NETWORK_USERINFO_FALLBACK, False) + + self.verbose_debug_mode = kwargs.get("enable_verbose_debug_mode", False) + if self.verbose_debug_mode: + _LOGGER.warning( + "VERBOSE_DEBUG_MODE is enabled so detailed token request and response " + + "logging is active. Do NOT leave this enabled in production!" + ) + self.capture_dir = ( + Path(self.hass.config.config_dir) + / "custom_components" + / "auth_oidc" + / "verbose_debug" + ) + self.capture_dir.mkdir(parents=True, exist_ok=True) + _LOGGER.info( + "The following scopes will be included in auth request: %s", self.scope + ) + if self.verbose_debug_mode: + _LOGGER.debug( + "Configured ID token signing algorithm: %s", + self.id_token_signing_alg or "none (will use OP discovery)", + ) + + # Flows stores the state, code_verifier and nonce of all current flows. + # Made instance-level to prevent collisions across multiple OIDCClient instances + # (e.g., multiple providers). Previously class-level caused state sharing/leaks. + self.flows = {} def __del__(self): """Cleanup the HTTP session.""" @@ -342,8 +479,8 @@ def __del__(self): _LOGGER.debug("Closing HTTP session") self.http_session.close() - def _base64url_encode(self, value: str) -> str: - """Uses base64url encoding on a given string""" + def _base64url_encode(self, value: bytes) -> str: + """Uses base64url encoding on a given byte string""" return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8") def _generate_random_url_string(self, length: int = 16) -> str: @@ -381,9 +518,61 @@ async def _make_token_request(self, token_endpoint, query_params): try: session = await self._get_http_session() + await capture_auth_flows( + (_LOGGER, 10), + self.verbose_debug_mode, + self.capture_dir, + f"Attempting Token request via Endpoint URL: {token_endpoint}", + "get_token.txt", + ( + f"Token Endpoint URL: {token_endpoint}\n//Query Parameters:" + f"\n{json.dumps(query_params, indent=2)}" + ), + mode="w", + header="", + is_request=True, + ) + async with session.post(token_endpoint, data=query_params) as response: await http_raise_for_status(response) - return await response.json() + response_text = await response.text() + + await capture_auth_flows( + (_LOGGER, 10), + self.verbose_debug_mode, + self.capture_dir, + f"Token response received: Status {response.status}", + "get_token.txt", + ( + f"Fetch Token Status: {response.status}" + f"\n//Response Body:\n{response_text}" + ), + mode="a", + header="", + is_request=False, + ) + + try: + parsed_json = json.loads(response_text) + if self.verbose_debug_mode: + _LOGGER.debug( + "Success! Token received from Endpoint: %s", token_endpoint + ) + return parsed_json + except json.JSONDecodeError as e: + await capture_auth_flows( + (_LOGGER, 10), + self.verbose_debug_mode, + self.capture_dir, + "Unhandled token response (not JSON)", + "unhandled_token_response.txt", + response_text, + mode="w", + header="", + is_request=False, + ) + _LOGGER.error("Unhandled Exception: Token Response is not json!") + raise OIDCTokenResponseInvalid("Token response not JSON") from e except HTTPClientError as e: if e.status == 400: _LOGGER.warning( @@ -404,16 +593,56 @@ async def _get_userinfo(self, userinfo_uri, access_token): session = await self._get_http_session() headers = {"Authorization": "Bearer " + access_token} + await capture_auth_flows( + (_LOGGER, 10), + self.verbose_debug_mode, + self.capture_dir, + f"Sending request to: {userinfo_uri} to collect Userinfo", + "get_userinfo.txt", + ( + f"Userinfo URL: {userinfo_uri}\n//Request Headers:" + f"\n{json.dumps(headers, indent=2)}", + ), + mode="w", + header="", + is_request=True, + ) + async with session.get(userinfo_uri, headers=headers) as response: await http_raise_for_status(response) - return await response.json() + response_text = await response.text() + + await capture_auth_flows( + (_LOGGER, 10), + self.verbose_debug_mode, + self.capture_dir, + f"Userinfo response received: Status {response.status}", + "get_userinfo.txt", + ( + f"Userinfo Response Status: {response.status}" + f"\n//Response Body:\n{response_text}" + ), + mode="a", + header="", + is_request=False, + ) + + return json.loads(response_text) except HTTPClientError as e: _LOGGER.warning("Error fetching userinfo: %s", e) raise OIDCUserinfoInvalid from e async def _fetch_discovery_document(self): - """Fetches discovery document.""" - if self.discovery_document is not None: + """Fetches discovery document if missing or expired (TTL=1h).""" + # Local import for TTL check + import time # pylint: disable=import-outside-toplevel + + now = time.time() + if ( + self.discovery_document is not None + and self.discovery_timestamp is not None + and (now - self.discovery_timestamp) < self.discovery_ttl + ): return self.discovery_document if self.discovery_class is None: @@ -424,25 +653,38 @@ async def _fetch_discovery_document(self): verification_context={ "id_token_signing_alg": self.id_token_signing_alg, }, + verbose_debug_mode=self.verbose_debug_mode, + capture_dir=self.capture_dir, ) self.discovery_document = await self.discovery_class.fetch_discovery_document() + self.discovery_timestamp = now return self.discovery_document async def _fetch_jwks(self, jwks_uri: str): """Fetches JWKS.""" return await self.discovery_class.fetch_jwks(jwks_uri) - async def _parse_id_token(self, id_token: str) -> Optional[dict]: + async def _parse_id_token( + self, id_token: str, access_token: Optional[str] = None + ) -> Optional[dict]: """Parses the ID token into a dict containing token contents.""" if self.discovery_document is None: self.discovery_document = await self._fetch_discovery_document() + # Flexible algorithm handling + allowed_algs = compute_allowed_signing_algs( + self.discovery_document, + self.id_token_signing_alg, + self.verbose_debug_mode, + _LOGGER, + ) + jwks_uri = self.discovery_document["jwks_uri"] jwks_data = await self._fetch_jwks(jwks_uri) try: - # Obtain the id_token header + # Obtain the (unverified) id_token header token_obj = jws.extract_compact(id_token.encode()) unverified_header = token_obj.protected if not unverified_header: @@ -451,14 +693,20 @@ async def _parse_id_token(self, id_token: str) -> Optional[dict]: # Obtain the signing algorithm from the header of the id_token alg = unverified_header.get("alg") - if alg != self.id_token_signing_alg: - # Verify that it matches our requested algorithm + if not alg: + _LOGGER.warning("JWT does not have alg") + return None + + if alg not in allowed_algs: _LOGGER.warning( - "ID Token received signed with the wrong algorithm: %s, expected %s", + "ID Token received signed with unsupported algorithm: %s (allowed: %s)", alg, - self.id_token_signing_alg, + allowed_algs, ) - raise OIDCIdTokenSigningAlgorithmInvalid() + raise OIDCIdTokenSigningAlgorithmInvalid + + if self.verbose_debug_mode: + _LOGGER.debug("ID token signed with algorithm '%s'", alg) # OpenID Connect Core 1.0 Section 3.1.3.7.8 # If the JWT alg Header Parameter uses a MAC based algorithm @@ -476,50 +724,139 @@ async def _parse_id_token(self, id_token: str) -> Optional[dict]: jwk_obj = jwk.import_key( { "kty": "oct", - "k": base64.urlsafe_b64encode( - self.client_secret.encode() - ).decode(), + "k": self._base64url_encode( + self.client_secret.encode("utf-8") + ), # RFC 7517 §4.2: base64url without padding "alg": alg, } ) + else: - # TODO: Deal with cases where kid is not specified (just take the first key?) - # Obtain the kid (Key ID) from the header of the id_token + # RFC 7515 (JWS) §4.1.11: "kid" (Key ID) is OPTIONAL but RECOMMENDED. + # If absent, select key via other means + # (e.g., try candidates until verification succeeds). + # Priority: 1. Exact "kid" match. 2. Matching key["alg"]. 3. All keys. + # OpenID Connect Core 1.0 §3.1.3.7: MUST validate signature using header "alg". + # RFC 7518 (JWK) §7.2: Inherit "alg" from header if missing in key. kid = unverified_header.get("kid") if not kid: - _LOGGER.warning("JWT does not have kid (Key ID)") - return None - - # Get the correct key - signing_key = None + if self.verbose_debug_mode: + _LOGGER.debug( + "JWT header lacks 'kid'; will try all JWKS candidates" + ) + else: + _LOGGER.warning( + "JWT does not have 'kid' (Key ID); trying all JWKS keys" + " (add 'kid' to provider config for efficiency)" + ) + + # Collect candidate keys from JWKS (jwks_data["keys"] is list of dicts) + candidates = [] + if kid: + # Priority 1: Exact kid match + matching_kid = next( + (key for key in jwks_data["keys"] if key.get("kid") == kid), + None, + ) + if matching_kid: + candidates.append(matching_kid) + if self.verbose_debug_mode: + _LOGGER.debug( + "Selected JWKS key by exact 'kid' match: %s", kid + ) + + # Priority 2-3: No kid or no match → add keys matching alg, then all (avoid dupes) for key in jwks_data["keys"]: - if key["kid"] == kid: - signing_key = key - break + if key.get("alg") == alg: + if key not in candidates: # Avoid dupes + candidates.append(key) + if self.verbose_debug_mode: + _LOGGER.debug( + "Added JWKS candidate by 'alg' match: %s (kid=%s)", + alg, + key.get("kid", "none"), + ) + elif ( + kid is None or key.get("kid") != kid + ) and key not in candidates: # Fallback: all non-dupe keys + candidates.append(key) + if self.verbose_debug_mode: + _LOGGER.debug( + "Added JWKS fallback candidate (kid=%s, alg=%s)", + key.get("kid", "none"), + key.get("alg", "none"), + ) + + if not candidates: + _LOGGER.warning( + "No candidate keys found in JWKS for alg '%s' (kid='%s')", + alg, + kid or "none", + ) + return None - if not signing_key: - _LOGGER.warning("Could not find matching key with kid: %s", kid) + # Try verification on each candidate until success (RFC 7515 compliant) + decoded_token = None + selected_key_info = None + for candidate_key in candidates: + try: + # If key lacks "alg", inherit from header (per JWK §7.2, optional) + key_dict = candidate_key.copy() + if "alg" not in key_dict: + key_dict["alg"] = alg + + jwk_obj = jwk.import_key(key_dict) + + # Attempt decode+verify (raises on sig fail/mismatch) + candidate_decoded = jwt.decode( + id_token, + jwk_obj, + # OpenID Connect Core 1.0 Section 3.1.3.7.6 + # The Client MUST validate the signature of all other ID Tokens + # according to JWS [JWS] using the algorithm specified in the JWT + # alg Header Parameter. + algorithms=[alg], + ) + decoded_token = candidate_decoded + selected_key_info = { + "kid": candidate_key.get("kid", "none"), + "alg": candidate_key.get("alg", alg), + "kty": candidate_key.get("kty"), + } + if self.verbose_debug_mode: + _LOGGER.debug( + "Signature verified successfully with JWKS key: %s", + selected_key_info, + ) + break # Success! Proceed + except joserfc_errors.JoseError as verify_err: + if self.verbose_debug_mode: + _LOGGER.debug( + "Key candidate failed verification (kid=%s): %s", + candidate_key.get("kid", "none"), + verify_err, + ) + continue # Try next + + if decoded_token is None: + _LOGGER.warning( + ( + "No JWKS key verified the ID token signature (alg='%s', " + "tried %d candidates; check JWKS rotation/provider config)" + ), + alg, + len(candidates), + ) return None - # If signing_key does not have alg, set it to the one passed in the token - if "alg" not in signing_key: - signing_key["alg"] = alg - - # Construct the JWK from the RSA key - jwk_obj = jwk.import_key(signing_key) - - # Decode the token, decode does not verify it - decoded_token = jwt.decode( - id_token, - jwk_obj, - # OpenID Connect Core 1.0 Section 3.1.3.7.6 - # The Client MUST validate the signature of all other ID Tokens - # according to JWS [JWS] using the algorithm specified in the JWT - # alg Header Parameter. - algorithms=[self.id_token_signing_alg], - ) + # Log successful key selection + if self.verbose_debug_mode: + _LOGGER.debug( + "Final selected key for verification: %s", selected_key_info + ) - # Create Claims Registry for validation + # Claims validation (post-signature verification) + # Create Claims Registry for validation (aud/iss/sub/exp/nbf/iat + leeway) id_token_validator = jwt.JWTClaimsRegistry( leeway=5, # OpenID Connect Core 1.0 Section 3.1.3.7.3 @@ -538,6 +875,37 @@ async def _parse_id_token(self, id_token: str) -> Optional[dict]: ) id_token_validator.validate(decoded_token.claims) + + # OpenID Connect Core 1.0 §3.1.3.6: Validate at_hash if access_token present + # Binds ID token to access_token (prevents tampering/replay). + # at_hash = base64url(SHA256(left half of access_token))[0:hash_len/2] + if access_token: # Pass access_token to method + try: + # Compute at_hash (OpenID Connect Core §3.1.3.6) + access_token_bytes = access_token.encode("utf-8") + hashed_access_token = hashlib.sha256(access_token_bytes).digest() + left_half_hash = hashed_access_token[ + : len(hashed_access_token) // 2 + ] + expected_at_hash = self._base64url_encode(left_half_hash) + + actual_at_hash = decoded_token.claims.get("at_hash") + if actual_at_hash != expected_at_hash: + _LOGGER.warning( + ( + "ID token at_hash mismatch! Expected: %s, " + "got: %s (access_token tampering?)" + ), + expected_at_hash, + actual_at_hash, + ) + return None + if self.verbose_debug_mode: + _LOGGER.debug("at_hash validated successfully") + except (UnicodeEncodeError, AttributeError) as e: + _LOGGER.warning("at_hash computation/validation failed: %s", e) + return None # Fail closed + return decoded_token.claims except joserfc_errors.JoseError as e: @@ -581,6 +949,12 @@ async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]: if not self.disable_pkce: query_params["code_challenge"] = code_challenge query_params["code_challenge_method"] = "S256" + else: + # Warn once per-flow if PKCE disabled (security risk for legacy OPs) + _LOGGER.warning( + "PKCE (RFC 7636) disabled via features.disable_rfc7636! " + "Authorization code interception risk increased. Only for legacy OPs." + ) url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}" return url @@ -588,14 +962,21 @@ async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]: _LOGGER.warning("Error generating authorization URL: %s", e) return None - async def parse_user_details(self, id_token: str, access_token: str) -> UserDetails: + async def parse_user_details( + self, id_token_claims: dict, access_token: str + ) -> UserDetails: """Parses the ID token and/or userinfo into user details.""" # Fetch userinfo if there is an userinfo_endpoint available # and use the data to supply the missing values in id_token discovery_document = await self._fetch_discovery_document() - if "userinfo_endpoint" in discovery_document: - userinfo_endpoint = discovery_document["userinfo_endpoint"] + userinfo_endpoint = discovery_document.get("userinfo_endpoint") + # Users may attempt fallback userinfo endpoint if OP doesn't advertise it + # Commonly /userinfo even if not in discovery document + if not userinfo_endpoint and self.userinfo_fallback: + userinfo_endpoint = f"{discovery_document['issuer'].rstrip('/')}/userinfo" + _LOGGER.info("Using userinfo fallback endpoint: %s", userinfo_endpoint) + if userinfo_endpoint: userinfo = await self._get_userinfo(userinfo_endpoint, access_token) # Replace missing claims in the id_token with their userinfo version @@ -604,15 +985,40 @@ async def parse_user_details(self, id_token: str, access_token: str) -> UserDeta self.display_name_claim, self.username_claim, ): - if claim not in id_token and claim in userinfo: - id_token[claim] = userinfo[claim] + if claim not in id_token_claims and claim in userinfo: + id_token_claims[claim] = userinfo[claim] # Get and parse groups (to check if it's an array) - groups = id_token.get(self.groups_claim, []) + groups = id_token_claims.get(self.groups_claim, []) if not isinstance(groups, list): _LOGGER.warning("Groups claim is not a list, using empty list instead.") groups = [] + # Extract case insensitive username and apply email + # stripping if configured to use 'email' claim. + # This converts full email (e.g., 'user@domain.com') + # to local-part (e.g., 'user') for username only. + # 1. Not all OP's support username / preferred_username + # claim, so email is often used, but + # this is not ideal for usernames in HA (even without + # username linking support **currently**). + # 2. Many RPs/OPs provide some level of claim matching + # / processing to increase flexibility. + username_raw = id_token_claims.get(self.username_claim) + username = username_raw + if ( + (self.username_claim.lower() in ["email", "e-mail"]) + and username_raw + and "@" in username_raw + ): + username = username_raw.split("@")[0] + if self.verbose_debug_mode: + _LOGGER.debug( + "Stripped email '%s' to username '%s' (local-part before '@')", + username_raw, + username, + ) + # Assign role if user has the required groups role = "invalid" if self.user_role in groups or self.user_role is None: @@ -622,18 +1028,22 @@ async def parse_user_details(self, id_token: str, access_token: str) -> UserDeta role = "system-admin" # Create a user details dict based on the contents of the id_token & userinfo + # Note: if user username claim is email, will be processed with local var 'username' above + # Other claims use originals from id_token_claims/userinfo merge. return { # Subject Identifier. A locally unique and never reassigned identifier within the # Issuer for the End-User, which is intended to be consumed by the Client # Only unique per issuer, so we combine it with the issuer and hash it. # This might allow multiple OIDC providers to be used with this integration. "sub": hashlib.sha256( - f"{discovery_document['issuer']}.{id_token.get('sub')}".encode("utf-8") + f"{discovery_document['issuer']}.{id_token_claims.get('sub')}".encode( + "utf-8" + ) ).hexdigest(), # Display name, configurable - "display_name": id_token.get(self.display_name_claim), - # Username, configurable - "username": id_token.get(self.username_claim), + "display_name": id_token_claims.get(self.display_name_claim), + # Username, configurable (uses processed 'username' var: email-stripped if applicable + "username": username, # Role "role": role, } @@ -644,11 +1054,10 @@ async def async_complete_token_flow( """Completes the OIDC token flow to obtain a user's details.""" try: - if state not in self.flows: + flow = self.flows.pop(state, None) + if flow is None: raise OIDCStateInvalid - flow = self.flows[state] - discovery_document = await self._fetch_discovery_document() token_endpoint = discovery_document["token_endpoint"] @@ -673,12 +1082,14 @@ async def async_complete_token_flow( token_endpoint, query_params ) - id_token = token_response.get("id_token") + id_token_str = token_response.get("id_token") + access_token = token_response.get("access_token") # Parse the id token to obtain the relevant details - id_token = await self._parse_id_token(id_token) - - if id_token is None: + id_token_claims = await self._parse_id_token( + id_token_str, access_token=access_token + ) + if id_token_claims is None: _LOGGER.warning("ID token could not be parsed!") return None @@ -686,19 +1097,18 @@ async def async_complete_token_flow( # If a nonce value was sent in the Authentication Request, # a nonce Claim MUST be present and its value checked to verify # that it is the same value as the one that was sent in the Authentication Request. - if id_token.get("nonce") != flow["nonce"]: + if id_token_claims.get("nonce") != flow["nonce"]: _LOGGER.warning("Nonce mismatch!") return None - access_token = token_response.get("access_token") - data = await self.parse_user_details(id_token, access_token) + data = await self.parse_user_details(id_token_claims, access_token) # Log which details were obtained for debugging # Also log the original subject identifier such that you can look it up in your provider _LOGGER.debug( "Obtained user details from OIDC provider: %s (issuer subject: %s)", data, - id_token.get("sub"), + id_token_claims.get("sub"), ) return data except OIDCClientException as e: