2020import static org .mockito .BDDMockito .given ;
2121import static org .mockito .ArgumentMatchers .any ;
2222import static org .mockito .Mockito .mock ;
23+ import static org .mockito .Mockito .spy ;
2324import static org .mockito .Mockito .verify ;
2425import static org .mockito .Mockito .verifyZeroInteractions ;
2526import static org .mockito .Mockito .when ;
2627import static org .springframework .security .config .Customizer .withDefaults ;
28+ import static org .springframework .test .util .ReflectionTestUtils .getField ;
2729
2830import java .util .Arrays ;
2931import java .util .List ;
3537import org .junit .Before ;
3638import org .junit .Test ;
3739import org .junit .runner .RunWith ;
40+ import org .mockito .ArgumentCaptor ;
3841import org .mockito .Mock ;
3942import org .mockito .junit .MockitoJUnitRunner ;
4043
4144import org .springframework .security .core .Authentication ;
4245import org .springframework .security .oauth2 .client .registration .ReactiveClientRegistrationRepository ;
4346import org .springframework .security .oauth2 .client .web .server .ServerAuthorizationRequestRepository ;
47+ import org .springframework .security .oauth2 .client .web .server .authentication .OAuth2LoginAuthenticationWebFilter ;
4448import org .springframework .security .oauth2 .core .endpoint .OAuth2AuthorizationRequest ;
4549import org .springframework .security .oauth2 .core .endpoint .TestOAuth2AuthorizationRequests ;
4650import org .springframework .security .web .authentication .preauth .x509 .X509PrincipalExtractor ;
4751import org .springframework .security .web .server .authentication .ServerX509AuthenticationConverter ;
52+ import org .springframework .security .web .server .savedrequest .ServerRequestCache ;
53+ import org .springframework .security .web .server .savedrequest .WebSessionServerRequestCache ;
4854import reactor .core .publisher .Mono ;
4955import reactor .test .publisher .TestPublisher ;
5056
6470import org .springframework .security .web .server .csrf .CsrfServerLogoutHandler ;
6571import org .springframework .security .web .server .csrf .CsrfWebFilter ;
6672import org .springframework .security .web .server .csrf .ServerCsrfTokenRepository ;
67- import org .springframework .test .util .ReflectionTestUtils ;
6873import org .springframework .test .web .reactive .server .EntityExchangeResult ;
6974import org .springframework .test .web .reactive .server .FluxExchangeResult ;
7075import org .springframework .test .web .reactive .server .WebTestClient ;
@@ -200,7 +205,7 @@ public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
200205 .isNotPresent ();
201206
202207 Optional <ServerLogoutHandler > logoutHandler = getWebFilter (securityWebFilterChain , LogoutWebFilter .class )
203- .map (logoutWebFilter -> (ServerLogoutHandler ) ReflectionTestUtils . getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
208+ .map (logoutWebFilter -> (ServerLogoutHandler ) getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
204209
205210 assertThat (logoutHandler )
206211 .get ()
@@ -213,17 +218,17 @@ public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
213218
214219 assertThat (getWebFilter (securityWebFilterChain , CsrfWebFilter .class ))
215220 .get ()
216- .extracting (csrfWebFilter -> ReflectionTestUtils . getField (csrfWebFilter , "csrfTokenRepository" ))
221+ .extracting (csrfWebFilter -> getField (csrfWebFilter , "csrfTokenRepository" ))
217222 .isEqualTo (this .csrfTokenRepository );
218223
219224 Optional <ServerLogoutHandler > logoutHandler = getWebFilter (securityWebFilterChain , LogoutWebFilter .class )
220- .map (logoutWebFilter -> (ServerLogoutHandler ) ReflectionTestUtils . getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
225+ .map (logoutWebFilter -> (ServerLogoutHandler ) getField (logoutWebFilter , LogoutWebFilter .class , "logoutHandler" ));
221226
222227 assertThat (logoutHandler )
223228 .get ()
224229 .isExactlyInstanceOf (DelegatingServerLogoutHandler .class )
225230 .extracting (delegatingLogoutHandler ->
226- ((List <ServerLogoutHandler >) ReflectionTestUtils . getField (delegatingLogoutHandler , DelegatingServerLogoutHandler .class , "delegates" )).stream ()
231+ ((List <ServerLogoutHandler >) getField (delegatingLogoutHandler , DelegatingServerLogoutHandler .class , "delegates" )).stream ()
227232 .map (ServerLogoutHandler ::getClass )
228233 .collect (Collectors .toList ()))
229234 .isEqualTo (Arrays .asList (SecurityContextServerLogoutHandler .class , CsrfServerLogoutHandler .class ));
@@ -479,6 +484,33 @@ public void postWhenCustomCsrfTokenRepositoryThenUsed() {
479484 verify (customServerCsrfTokenRepository ).loadToken (any ());
480485 }
481486
487+ @ Test
488+ public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler () {
489+ ServerRequestCache requestCache = spy (new WebSessionServerRequestCache ());
490+ ReactiveClientRegistrationRepository clientRegistrationRepository = mock (ReactiveClientRegistrationRepository .class );
491+
492+ SecurityWebFilterChain securityFilterChain = this .http
493+ .oauth2Login ()
494+ .clientRegistrationRepository (clientRegistrationRepository )
495+ .and ()
496+ .authorizeExchange ().anyExchange ().authenticated ()
497+ .and ()
498+ .requestCache (c -> c .requestCache (requestCache ))
499+ .build ();
500+
501+ WebTestClient client = WebTestClientBuilder .bindToWebFilters (securityFilterChain ).build ();
502+ client .get ().uri ("/test" ).exchange ();
503+ ArgumentCaptor <ServerWebExchange > captor = ArgumentCaptor .forClass (ServerWebExchange .class );
504+ verify (requestCache ).saveRequest (captor .capture ());
505+ assertThat (captor .getValue ().getRequest ().getURI ().toString ()).isEqualTo ("/test" );
506+
507+
508+ OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
509+ getWebFilter (securityFilterChain , OAuth2LoginAuthenticationWebFilter .class ).get ();
510+ Object handler = getField (authenticationWebFilter , "authenticationSuccessHandler" );
511+ assertThat (getField (handler , "requestCache" )).isSameAs (requestCache );
512+ }
513+
482514 @ Test
483515 public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login () {
484516 ServerAuthorizationRequestRepository <OAuth2AuthorizationRequest > authorizationRequestRepository = mock (ServerAuthorizationRequestRepository .class );
@@ -503,7 +535,7 @@ public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
503535
504536 private boolean isX509Filter (WebFilter filter ) {
505537 try {
506- Object converter = ReflectionTestUtils . getField (filter , "authenticationConverter" );
538+ Object converter = getField (filter , "authenticationConverter" );
507539 return converter .getClass ().isAssignableFrom (ServerX509AuthenticationConverter .class );
508540 } catch (IllegalArgumentException e ) {
509541 // field doesn't exist
0 commit comments