Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* This class is a copy of {@link org.springframework.security.oauth2.client.web.server.DefaultServerOAuth2AuthorizationRequestResolver}
Expand Down Expand Up @@ -177,7 +181,7 @@ private Mono<OAuth2AuthorizationRequest> authorizationRequest(
builder = OAuth2AuthorizationRequest.authorizationCode();
Map<String, Object> additionalParameters = new HashMap<>();

addAttributesAndAdditionalParameters(clientRegistration, attributes, additionalParameters);
addAttributesAndAdditionalParameters(exchange, clientRegistration, attributes, additionalParameters);

builder.additionalParameters(additionalParameters);
// } else if (AuthorizationGrantType.IMPLICIT.equals(clientRegistration.getAuthorizationGrantType()))
Expand All @@ -199,6 +203,7 @@ private Mono<OAuth2AuthorizationRequest> authorizationRequest(
}

protected void addAttributesAndAdditionalParameters(
ServerWebExchange exchange,
ClientRegistration clientRegistration,
Map<String, Object> attributes,
Map<String, Object> additionalParameters) {
Expand All @@ -215,16 +220,137 @@ protected void addAttributesAndAdditionalParameters(
addPkceParameters(attributes, additionalParameters);
}
if (!commonConfig.getOauthAllowedDomains().isEmpty()) {
if (commonConfig.getOauthAllowedDomains().size() == 1) {
// Incase there's only 1 domain, we can do a further optimization to let the user select a specific one
// from the list
additionalParameters.put(
"hd", commonConfig.getOauthAllowedDomains().get(0));
List<String> allowedDomains = commonConfig.getOauthAllowedDomains();

if (allowedDomains.size() == 1) {
// Single domain case: use it directly
additionalParameters.put("hd", allowedDomains.get(0));
} else {
// Add multiple domains to the list of allowed domains
additionalParameters.put("hd", commonConfig.getOauthAllowedDomains());
// Multiple domains case: derive candidate domain from request context
String candidateDomain = deriveDomainFromRequest(exchange);

if (candidateDomain != null) {
// Domain was successfully derived and matched
additionalParameters.put("hd", candidateDomain);
log.debug("Using derived domain '{}' for hd parameter", candidateDomain);
} else {
// No domain could be derived or matched, fallback to first allowed domain
String fallbackDomain = allowedDomains.get(0);
additionalParameters.put("hd", fallbackDomain);
log.debug(
"No matching domain derived, using fallback domain '{}' for hd parameter", fallbackDomain);
}
}
}
}

/**
* Derives a candidate domain from the incoming request using existing tenant/domain logic.
* This method leverages the same mechanisms used elsewhere in the codebase for domain inference.
*
* @param exchange The ServerWebExchange containing the request
* @return The derived domain candidate, or null if no domain could be derived
*/
protected String deriveDomainFromRequest(ServerWebExchange exchange) {
try {
ServerHttpRequest request = exchange.getRequest();

// Extract host from request headers with fallback chain
String host = extractHostFromRequest(request);
if (host == null || host.isEmpty()) {
return null;
}

// Normalize host: strip port, lowercase, remove trailing dot
host = normalizeHost(host);

// Get and normalize allowed domains
List<String> allowedDomains = commonConfig.getOauthAllowedDomains();
if (allowedDomains == null || allowedDomains.isEmpty()) {
return null;
}

List<String> normalizedAllowed = allowedDomains.stream()
.filter(Objects::nonNull)
.map(d -> d.trim().toLowerCase(Locale.ROOT))
.filter(s -> !s.isEmpty())
.collect(Collectors.toList());

// Find the most specific domain match
return findBestDomainMatch(host, normalizedAllowed);

} catch (Exception e) {
log.debug("Error deriving domain from request", e);
return null;
}
}

/**
* Extracts host from request using fallback chain: X-Forwarded-Host -> URI host -> Host header
*/
private String extractHostFromRequest(ServerHttpRequest request) {
// Prefer X-Forwarded-Host header (for proxy environments)
String xfHost = request.getHeaders().getFirst("X-Forwarded-Host");
if (xfHost != null && !xfHost.isBlank()) {
// If comma-separated, take the first
int comma = xfHost.indexOf(',');
return (comma >= 0 ? xfHost.substring(0, comma) : xfHost).trim();
}

// Fallback to request URI host
if (request.getURI() != null && request.getURI().getHost() != null) {
return request.getURI().getHost();
}

// Final fallback to Host header
if (request.getHeaders().getHost() != null) {
return request.getHeaders().getHost().getHostString();
}

return null;
}

/**
* Normalizes host by removing port, converting to lowercase, and removing trailing dots
*/
private String normalizeHost(String host) {
if (host == null || host.isEmpty()) {
return host;
}

// Strip port
int colon = host.indexOf(':');
if (colon >= 0) {
host = host.substring(0, colon);
}

// Convert to lowercase
host = host.toLowerCase(Locale.ROOT);

// Remove trailing dot
if (host.endsWith(".")) {
host = host.substring(0, host.length() - 1);
}

return host;
}

/**
* Finds the most specific domain match using suffix matching
*/
private String findBestDomainMatch(String host, List<String> normalizedAllowed) {
String best = null;

for (String allowed : normalizedAllowed) {
if (host.equals(allowed) || host.endsWith("." + allowed)) {
// Prefer the most specific (longest) match
if (best == null || allowed.length() > best.length()) {
best = allowed;
}
}
}

return best;
}

/**
Expand Down
Loading
Loading