🍉Spring Authorization Server (8) 授权服务的默认认证方式扩展

988 阅读10分钟

架构版本
Spring Boot 3.1
Spring Authorization Server 1.1.1
spring-cloud 2022.0.3
spring-cloud-alibaba 2022.0.0.0
完整代码👉watermelon-cloud

授权服务的默认认证方式

默认的/login ,且只有用户名和密码模式,就这个页面
image.png
密码我万一那天忘记了咋办,随便看看现在的谁家的网站,手机号+验证码没有,这也太low了,必须得加一个啊!

那先看看默认的 /login 它怎么玩的,我们就跟着它一样玩

授权服务默认登录

默认登录url/login,能找到对应的filterUsernamePasswordAuthenticationFilter,需要从这个地方先入手,点进去看看

UsernamePasswordAuthenticationFilter

public class UsernamePasswordAuthenticationFilter extends >AbstractAuthenticationProcessingFilter {

   public static final String SPRING_SECURITY_FORM_USERNAME_KEY = "username";

   public static final String SPRING_SECURITY_FORM_PASSWORD_KEY = "password";

   private static final AntPathRequestMatcher DEFAULT_ANT_PATH_REQUEST_MATCHER = new >AntPathRequestMatcher("/login",
           "POST");

   private String usernameParameter = SPRING_SECURITY_FORM_USERNAME_KEY;

   private String passwordParameter = SPRING_SECURITY_FORM_PASSWORD_KEY;

   private boolean postOnly = true;

   public UsernamePasswordAuthenticationFilter() {
       super(DEFAULT_ANT_PATH_REQUEST_MATCHER);
   }

   public UsernamePasswordAuthenticationFilter(AuthenticationManager authenticationManager) {
       super(DEFAULT_ANT_PATH_REQUEST_MATCHER, authenticationManager);
   }

   @Override
   public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
           throws AuthenticationException {
       if (this.postOnly && !request.getMethod().equals("POST")) {
           throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod());
       }
       String username = obtainUsername(request);
       username = (username != null) ? username.trim() : "";
       String password = obtainPassword(request);
       password = (password != null) ? password : "";
       UsernamePasswordAuthenticationToken authRequest = UsernamePasswordAuthenticationToken.unauthenticated(username,
               password);
       // Allow subclasses to set the "details" property
       setDetails(request, authRequest);
       return this.getAuthenticationManager().authenticate(authRequest);
   }
}

UsernamePasswordAuthenticationFilter 继承 AbstractAuthenticationProcessingFilter的 重写了 attemptAuthentication()方法
构造了UsernamePasswordAuthenticationToken对象
执行this.getAuthenticationManager().authenticate(authRequest);
UsernamePasswordAuthenticationToken 对于的DaoAuthenticationProvider 处理, 最后返回了Authentication 这个流程应该已经很熟悉了吧。

最后看看怎么走 UsernamePasswordAuthenticationFilter的父类是一个抽象类AbstractAuthenticationProcessingFilter

AbstractAuthenticationProcessingFilter

public abstract class AbstractAuthenticationProcessingFilter extends GenericFilterBean
       implements ApplicationEventPublisherAware, MessageSourceAware {

  
   // UsernamePasswordAuthenticationFilter 是 对应就是  CompositeSessionAuthenticationStrategy 因为 AbstractAuthenticationFilterConfigurer.config() 初始化了
   private SessionAuthenticationStrategy sessionStrategy = new NullAuthenticatedSessionStrategy();

   //
   private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();

   
   public abstract Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
           throws AuthenticationException, IOException, ServletException;
   
   
   private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
           throws IOException, ServletException {
       if (!requiresAuthentication(request, response)) {
           chain.doFilter(request, response);
           return;
       }
       try {
           Authentication authenticationResult = attemptAuthentication(request, response);
           if (authenticationResult == null) {
               // return immediately as subclass has indicated that it hasn't completed
               return;
           }
           this.sessionStrategy.onAuthentication(authenticationResult, request, response);
           // Authentication success
           if (this.continueChainBeforeSuccessfulAuthentication) {
               chain.doFilter(request, response);
           }
           successfulAuthentication(request, response, chain, authenticationResult);
       } catch (InternalAuthenticationServiceException failed) {
           this.logger.error("An internal error occurred while trying to authenticate the user.", failed);
           unsuccessfulAuthentication(request, response, failed);
       } catch (AuthenticationException ex) {
           // Authentication failed
           unsuccessfulAuthentication(request, response, ex);
       }
   }
}

Authentication authenticationResult = attemptAuthentication(request, response); 这行会执行子类的attemptAuthentication() 最后处理登录成功相关的了。\

那扩展肯定这个 attemptAuthentication() 方法是关键了

也同样继承AbstractAuthenticationProcessingFilter 去实现一个 Filter 这样不行,因为AbstractAuthenticationProcessingFilter 中有两个变量securityContextRepositorysessionStrategy 当是UsernamePasswordAuthenticationFilter 时 初始化的值就不一样了,这样就导致我们和 UsernamePasswordAuthenticationFilter 的后续处理成功流程就不一样,导致一直认证失败的。

后续的filter 注入的时候代码中会体现出来。

扩展

手机号+验证码

UserAuthenticationProcessingFilter

自定义一个UserAuthenticationProcessingFilterAbstractAuthenticationProcessingFilter类似

public abstract class UserAuthenticationProcessingFilter extends GenericFilterBean
     implements ApplicationEventPublisherAware, MessageSourceAware {

  private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
        .getContextHolderStrategy();

  protected ApplicationEventPublisher eventPublisher;

  protected AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();

  private AuthenticationManager authenticationManager;

  protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();

  private RememberMeServices rememberMeServices = new NullRememberMeServices();

  private RequestMatcher requiresAuthenticationRequestMatcher;

  private boolean continueChainBeforeSuccessfulAuthentication = false;

  private SessionAuthenticationStrategy sessionStrategy = new NullAuthenticatedSessionStrategy();

  private boolean allowSessionCreation = true;

  private AuthenticationSuccessHandler successHandler = new SavedRequestAwareAuthenticationSuccessHandler();

  private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler();

  private SecurityContextRepository securityContextRepository = new RequestAttributeSecurityContextRepository();


  protected UserAuthenticationProcessingFilter(String defaultFilterProcessesUrl) {
     setFilterProcessesUrl(defaultFilterProcessesUrl);
  }


  protected UserAuthenticationProcessingFilter(RequestMatcher requiresAuthenticationRequestMatcher) {
     Assert.notNull(requiresAuthenticationRequestMatcher, "requiresAuthenticationRequestMatcher cannot be null");
     this.requiresAuthenticationRequestMatcher = requiresAuthenticationRequestMatcher;
  }

  
  protected UserAuthenticationProcessingFilter(String defaultFilterProcessesUrl,
                                                AuthenticationManager authenticationManager) {
     setFilterProcessesUrl(defaultFilterProcessesUrl);
     setAuthenticationManager(authenticationManager);
  }

  
  protected UserAuthenticationProcessingFilter(RequestMatcher requiresAuthenticationRequestMatcher,
                                                AuthenticationManager authenticationManager) {
     setRequiresAuthenticationRequestMatcher(requiresAuthenticationRequestMatcher);
     setAuthenticationManager(authenticationManager);
  }

  @Override
  public void afterPropertiesSet() {
     Assert.notNull(this.authenticationManager, "authenticationManager must be specified");
  }


  @Override
  public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
        throws IOException, ServletException {
     doFilter((HttpServletRequest) request, (HttpServletResponse) response, chain);
  }

  private void doFilter(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
        throws IOException, ServletException {
     if (!requiresAuthentication(request, response)) {
        chain.doFilter(request, response);
        return;
     }
     try {
        Authentication authenticationResult = attemptAuthentication(request, response);
        if (authenticationResult == null) {
           // return immediately as subclass has indicated that it hasn't completed
           return;
        }
        this.sessionStrategy.onAuthentication(authenticationResult, request, response);
        // Authentication success
        if (this.continueChainBeforeSuccessfulAuthentication) {
           chain.doFilter(request, response);
        }
        successfulAuthentication(request, response, chain, authenticationResult);
     }
     catch (InternalAuthenticationServiceException failed) {
        this.logger.error("An internal error occurred while trying to authenticate the user.", failed);
        unsuccessfulAuthentication(request, response, failed);
     }
     catch (AuthenticationException ex) {
        // Authentication failed
        unsuccessfulAuthentication(request, response, ex);
     }
  }


  protected boolean requiresAuthentication(HttpServletRequest request, HttpServletResponse response) {
     if (this.requiresAuthenticationRequestMatcher.matches(request)) {
        return true;
     }
     if (this.logger.isTraceEnabled()) {
        this.logger
              .trace(LogMessage.format("Did not match request to %s", this.requiresAuthenticationRequestMatcher));
     }
     return false;
  }


  public abstract Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
        throws AuthenticationException, IOException, ServletException;


  protected void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
        Authentication authResult) throws IOException, ServletException {
     SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
     context.setAuthentication(authResult);
     this.securityContextHolderStrategy.setContext(context);
     this.securityContextRepository.saveContext(context, request, response);
     if (this.logger.isDebugEnabled()) {
        this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
     }
     this.rememberMeServices.loginSuccess(request, response, authResult);
     if (this.eventPublisher != null) {
        this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
     }
     this.successHandler.onAuthenticationSuccess(request, response, authResult);
  }


  protected void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
        AuthenticationException failed) throws IOException, ServletException {
     this.securityContextHolderStrategy.clearContext();
     this.logger.trace("Failed to process authentication request", failed);
     this.logger.trace("Cleared SecurityContextHolder");
     this.logger.trace("Handling authentication failure");
     this.rememberMeServices.loginFail(request, response);
     this.failureHandler.onAuthenticationFailure(request, response, failed);
  }

  protected AuthenticationManager getAuthenticationManager() {
     return this.authenticationManager;
  }

  public void setAuthenticationManager(AuthenticationManager authenticationManager) {
     this.authenticationManager = authenticationManager;
  }


  public void setFilterProcessesUrl(String filterProcessesUrl) {
     setRequiresAuthenticationRequestMatcher(new AntPathRequestMatcher(filterProcessesUrl));
  }

  public final void setRequiresAuthenticationRequestMatcher(RequestMatcher requestMatcher) {
     Assert.notNull(requestMatcher, "requestMatcher cannot be null");
     this.requiresAuthenticationRequestMatcher = requestMatcher;
  }

  public RememberMeServices getRememberMeServices() {
     return this.rememberMeServices;
  }

  public void setRememberMeServices(RememberMeServices rememberMeServices) {
     Assert.notNull(rememberMeServices, "rememberMeServices cannot be null");
     this.rememberMeServices = rememberMeServices;
  }


  public void setContinueChainBeforeSuccessfulAuthentication(boolean continueChainBeforeSuccessfulAuthentication) {
     this.continueChainBeforeSuccessfulAuthentication = continueChainBeforeSuccessfulAuthentication;
  }

  @Override
  public void setApplicationEventPublisher(ApplicationEventPublisher eventPublisher) {
     this.eventPublisher = eventPublisher;
  }

  public void setAuthenticationDetailsSource(
        AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource) {
     Assert.notNull(authenticationDetailsSource, "AuthenticationDetailsSource required");
     this.authenticationDetailsSource = authenticationDetailsSource;
  }

  @Override
  public void setMessageSource(MessageSource messageSource) {
     this.messages = new MessageSourceAccessor(messageSource);
  }

  protected boolean getAllowSessionCreation() {
     return this.allowSessionCreation;
  }

  public void setAllowSessionCreation(boolean allowSessionCreation) {
     this.allowSessionCreation = allowSessionCreation;
  }

  public void setSessionAuthenticationStrategy(SessionAuthenticationStrategy sessionStrategy) {
     this.sessionStrategy = sessionStrategy;
  }

  public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler successHandler) {
     Assert.notNull(successHandler, "successHandler cannot be null");
     this.successHandler = successHandler;
  }

  public void setAuthenticationFailureHandler(AuthenticationFailureHandler failureHandler) {
     Assert.notNull(failureHandler, "failureHandler cannot be null");
     this.failureHandler = failureHandler;
  }


  public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
     Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
     this.securityContextRepository = securityContextRepository;
  }


  public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
     Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
     this.securityContextHolderStrategy = securityContextHolderStrategy;
  }

  protected AuthenticationSuccessHandler getSuccessHandler() {
     return this.successHandler;
  }

  protected AuthenticationFailureHandler getFailureHandler() {
     return this.failureHandler;
  }

}

PhoneCaptchaAuthenticationToken

public class PhoneCaptchaAuthenticationToken extends AbstractAuthenticationToken {

   private final Object principal;

   private Object credentials;

   public PhoneCaptchaAuthenticationToken(Object principal, Object credentials) {
       super(null);
       this.principal = principal;
       this.credentials = credentials;
       setAuthenticated(false);
   }


   public PhoneCaptchaAuthenticationToken(Object principal, Object credentials,
                                              Collection<? extends GrantedAuthority> authorities) {
       super(authorities);
       this.principal = principal;
       this.credentials = credentials;
       super.setAuthenticated(true); // must use super, as we override
   }


   public static PhoneCaptchaAuthenticationToken unauthenticated(Object principal, Object credential) {
       return new PhoneCaptchaAuthenticationToken(principal, credential);
   }


   public static PhoneCaptchaAuthenticationToken authenticated(Object principal, Object credentials,
                                                                   Collection<? extends GrantedAuthority> authorities) {
       return new PhoneCaptchaAuthenticationToken(principal, credentials, authorities);
   }

   @Override
   public Object getCredentials() {
       return this.credentials;
   }

   @Override
   public Object getPrincipal() {
       return this.principal;
   }

   @Override
   public void setAuthenticated(boolean isAuthenticated) throws IllegalArgumentException {
       Assert.isTrue(!isAuthenticated,
               "Cannot set this token to trusted - use constructor which takes a GrantedAuthority list instead");
       super.setAuthenticated(false);
   }

   @Override
   public void eraseCredentials() {
       super.eraseCredentials();
       this.credentials = null;
   }
}

PhoneCaptchaAuthenticationProvider

@Component
public class PhoneCaptchaAuthenticationProvider
  	implements AuthenticationProvider {

  protected final Log logger = LogFactory.getLog(getClass());

  protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();

  private UserCache userCache = new NullUserCache();

  private final boolean forcePrincipalAsString = false;

  protected boolean hideUserNotFoundExceptions = true;

  private UserDetailsChecker preAuthenticationChecks = new DefaultPreAuthenticationChecks();

  private UserDetailsChecker postAuthenticationChecks = new DefaultPostAuthenticationChecks();

  private GrantedAuthoritiesMapper authoritiesMapper = new NullAuthoritiesMapper();

  @Lazy
  @Autowired
  public UserDetailsService userDetailsService;


  @Override
  public Authentication authenticate(Authentication authentication) throws AuthenticationException {
  	Assert.isInstanceOf(PhoneCaptchaAuthenticationToken.class, authentication,
  			() -> this.messages.getMessage("PhoneCaptchaAuthenticationProvider.onlySupports",
  					"Only PhoneCaptchaAuthenticationToken is supported"));
  	String username = determineUsername(authentication);
  	boolean cacheWasUsed = true;
  	UserDetails user = this.userCache.getUserFromCache(username);
  	if (user == null) {
  		cacheWasUsed = false;
  		try {
  			user = retrieveUser(username, (PhoneCaptchaAuthenticationToken) authentication);
  		}
  		catch (UsernameNotFoundException ex) {
  			this.logger.debug("Failed to find phone '" + username + "'");
  			if (!this.hideUserNotFoundExceptions) {
  				throw ex;
  			}
  			throw new BadCredentialsException(this.messages
  					.getMessage("PhoneCaptchaAuthenticationProvider.badCredentials", "Bad credentials"));
  		}
  		Assert.notNull(user, "retrieveUser returned null - a violation of the interface contract");
  	}
  	try {
  		this.preAuthenticationChecks.check(user);
  	}
  	catch (AuthenticationException ex) {
  		if (!cacheWasUsed) {
  			throw ex;
  		}
  		// There was a problem, so try again after checking
  		// we're using latest data (i.e. not from the cache)
  		cacheWasUsed = false;
  		user = retrieveUser(username, (PhoneCaptchaAuthenticationToken) authentication);
  		this.preAuthenticationChecks.check(user);
  	}
  	this.postAuthenticationChecks.check(user);
  	if (!cacheWasUsed) {
  		this.userCache.putUserInCache(user);
  	}
  	Object principalToReturn = user;
  	if (this.forcePrincipalAsString) {
  		principalToReturn = user.getUsername();
  	}
  	Authentication successAuthentication = createSuccessAuthentication(principalToReturn, authentication, user);
  	return successAuthentication;
  }

  private String determineUsername(Authentication authentication) {
  	return (authentication.getPrincipal() == null) ? "NONE_PROVIDED" : authentication.getName();
  }


  protected Authentication createSuccessAuthentication(Object principal, Authentication authentication,
  		UserDetails user) {
  	PhoneCaptchaAuthenticationToken result = PhoneCaptchaAuthenticationToken.authenticated(principal,
  			authentication.getCredentials(), this.authoritiesMapper.mapAuthorities(user.getAuthorities()));
  	result.setDetails(authentication.getDetails());
  	this.logger.debug("Authenticated user");
  	return result;
  }



  public UserCache getUserCache() {
  	return this.userCache;
  }

  public boolean isForcePrincipalAsString() {
  	return this.forcePrincipalAsString;
  }

  public boolean isHideUserNotFoundExceptions() {
  	return this.hideUserNotFoundExceptions;
  }


  protected  UserDetails retrieveUser(String username, PhoneCaptchaAuthenticationToken authentication)
  		throws AuthenticationException{
  	try {
  		UserDetails loadedUser = this.getUserDetailsService().loadUserByUsername(username);
  		if (loadedUser == null) {
  			throw new InternalAuthenticationServiceException(
  					"UserDetailsService returned null, which is an interface contract violation");
  		}
  		return loadedUser;
  	}
  	catch (UsernameNotFoundException ex) {
  		throw ex;
  	}
  	catch (InternalAuthenticationServiceException ex) {
  		throw ex;
  	}
  	catch (Exception ex) {
  		throw new InternalAuthenticationServiceException(ex.getMessage(), ex);
  	}

  }


  protected UserDetailsService getUserDetailsService() {
  	return this.userDetailsService;
  }


  public void setUserCache(UserCache userCache) {
  	this.userCache = userCache;
  }

  @Override
  public boolean supports(Class<?> authentication) {
  	return (PhoneCaptchaAuthenticationToken.class.isAssignableFrom(authentication));
  }

  protected UserDetailsChecker getPreAuthenticationChecks() {
  	return this.preAuthenticationChecks;
  }

  /**
   * Sets the policy will be used to verify the status of the loaded
   * <tt>UserDetails</tt> <em>before</em> validation of the credentials takes place.
   * @param preAuthenticationChecks strategy to be invoked prior to authentication.
   */
  public void setPreAuthenticationChecks(UserDetailsChecker preAuthenticationChecks) {
  	this.preAuthenticationChecks = preAuthenticationChecks;
  }

  protected UserDetailsChecker getPostAuthenticationChecks() {
  	return this.postAuthenticationChecks;
  }

  public void setPostAuthenticationChecks(UserDetailsChecker postAuthenticationChecks) {
  	this.postAuthenticationChecks = postAuthenticationChecks;
  }

  public void setAuthoritiesMapper(GrantedAuthoritiesMapper authoritiesMapper) {
  	this.authoritiesMapper = authoritiesMapper;
  }


  private class DefaultPreAuthenticationChecks implements UserDetailsChecker {

  	@Override
  	public void check(UserDetails user) {
  		if (!user.isAccountNonLocked()) {
  			PhoneCaptchaAuthenticationProvider.this.logger
  					.debug("Failed to authenticate since user account is locked");
  			throw new LockedException(PhoneCaptchaAuthenticationProvider.this.messages
  					.getMessage("AbstractUserDetailsAuthenticationProvider.locked", "User account is locked"));
  		}
  		if (!user.isEnabled()) {
  			PhoneCaptchaAuthenticationProvider.this.logger
  					.debug("Failed to authenticate since user account is disabled");
  			throw new DisabledException(PhoneCaptchaAuthenticationProvider.this.messages
  					.getMessage("AbstractUserDetailsAuthenticationProvider.disabled", "User is disabled"));
  		}
  		if (!user.isAccountNonExpired()) {
  			PhoneCaptchaAuthenticationProvider.this.logger
  					.debug("Failed to authenticate since user account has expired");
  			throw new AccountExpiredException(PhoneCaptchaAuthenticationProvider.this.messages
  					.getMessage("AbstractUserDetailsAuthenticationProvider.expired", "User account has expired"));
  		}
  	}

  }

  private class DefaultPostAuthenticationChecks implements UserDetailsChecker {

  	@Override
  	public void check(UserDetails user) {
  		if (!user.isCredentialsNonExpired()) {
  			logger.debug("Failed to authenticate since user account credentials have expired");
  			throw new CredentialsExpiredException(PhoneCaptchaAuthenticationProvider.this.messages
  					.getMessage("PhoneCaptchaAuthenticationProvider.credentialsExpired",
  							"User credentials have expired"));
  		}
  	}

  }

}

UserAuthenticationFilter

自定义 UserAuthenticationFilter继承UserAuthenticationProcessingFilter,与UsernamePasswordAuthenticationFilter类似

public class UserAuthenticationFilter extends UserAuthenticationProcessingFilter {

  public static final String SPRING_SECURITY_FORM_USERNAME_KEY = "username";

  public static final String SPRING_SECURITY_FORM_PASSWORD_KEY = "password";

  public static final String SPRING_SECURITY_FORM_CODE_KEY = "code";

  private  String usernameParameter = SPRING_SECURITY_FORM_USERNAME_KEY;

  private String passwordParameter = SPRING_SECURITY_FORM_PASSWORD_KEY;

  private final String codeParameter = SPRING_SECURITY_FORM_CODE_KEY;

  private boolean postOnly = true;

  private static final AntPathRequestMatcher DEFAULT_ANT_PATH_REQUEST_MATCHER = new AntPathRequestMatcher("/sso-login",
  		"POST");

  public UserAuthenticationFilter() {
  	super(DEFAULT_ANT_PATH_REQUEST_MATCHER);
  }

  public UserAuthenticationFilter(AuthenticationManager authenticationManager) {
  	super(DEFAULT_ANT_PATH_REQUEST_MATCHER, authenticationManager);
  }

  @Override
  public Authentication attemptAuthentication(HttpServletRequest request, HttpServletResponse response)
  		throws AuthenticationException {
  	if (this.postOnly && !request.getMethod().equals("POST")) {
  		throw new AuthenticationServiceException("Authentication method not supported: " + request.getMethod());
  	}
  	String username = obtainUsername(request);
  	username = (username != null) ? username.trim() : "";
  	String password = obtainPassword(request);
  	if(StringUtils.hasText(password)){
  		password = (password != null) ? password : "";
  		UsernamePasswordAuthenticationToken authRequest = UsernamePasswordAuthenticationToken.unauthenticated(username,
  				password);
  		// Allow subclasses to set the "details" property
  		setDetails(request, authRequest);
  		return this.getAuthenticationManager().authenticate(authRequest);
  	}else {
  		String code = obtainCode(request);
  		assert code != null;
  		if (!code.equals("000000")) {
  			throw new UsernameNotFoundException("验证码错误!");
  		}
  		PhoneCaptchaAuthenticationToken authRequest = PhoneCaptchaAuthenticationToken.unauthenticated(username, code);
  		// Allow subclasses to set the "details" property
  		this.setDetails(request, authRequest);
  		return this.getAuthenticationManager().authenticate(authRequest);
  	}

  }

  @Nullable
  protected String obtainPassword(HttpServletRequest request) {
  	return request.getParameter(this.passwordParameter);
  }

  @Nullable
  protected String obtainUsername(HttpServletRequest request) {
  	return request.getParameter(this.usernameParameter);
  }

  @Nullable
  protected String obtainCode(HttpServletRequest request) {
  	return request.getParameter(this.codeParameter);
  }

  protected void setDetails(HttpServletRequest request, UsernamePasswordAuthenticationToken authRequest) {
  	authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
  }

  protected void setDetails(HttpServletRequest request, PhoneCaptchaAuthenticationToken authRequest) {
  	authRequest.setDetails(this.authenticationDetailsSource.buildDetails(request));
  }

  public void setPasswordParameter(String passwordParameter) {
  	this.passwordParameter = passwordParameter;
  }

  public void setPostOnly(boolean postOnly) {
  	this.postOnly = postOnly;
  }

  public final String getUsernameParameter() {
  	return this.usernameParameter;
  }

}

以上都扩展好了

UserAuthenticationProcessingFilter
UserAuthenticationFilter
PhoneCaptchaAuthenticationToken
PhoneCaptchaAuthenticationProvider
如何让以上定义的Filter在SecurityFilterChain 生效呢

AbstractAuthenticationProcessingFilter的找到了一个地方被 AbstractAuthenticationFilterConfigurer

public abstract class AbstractAuthenticationFilterConfigurer<B extends HttpSecurityBuilder<B>, T extends AbstractAuthenticationFilterConfigurer<B, T, F>, F extends AbstractAuthenticationProcessingFilter>
        extends AbstractHttpConfigurer<T, B> {
    @Override
    public void configure(B http) throws Exception {
        PortMapper portMapper = http.getSharedObject(PortMapper.class);
        if (portMapper != null) {
            this.authenticationEntryPoint.setPortMapper(portMapper);
        }
        RequestCache requestCache = http.getSharedObject(RequestCache.class);
        if (requestCache != null) {
            this.defaultSuccessHandler.setRequestCache(requestCache);
        }
        this.authFilter.setAuthenticationManager(http.getSharedObject(AuthenticationManager.class));
        this.authFilter.setAuthenticationSuccessHandler(this.successHandler);
        this.authFilter.setAuthenticationFailureHandler(this.failureHandler);
        if (this.authenticationDetailsSource != null) {
            this.authFilter.setAuthenticationDetailsSource(this.authenticationDetailsSource);
        }
        SessionAuthenticationStrategy sessionAuthenticationStrategy = http
                .getSharedObject(SessionAuthenticationStrategy.class);
        if (sessionAuthenticationStrategy != null) {
            this.authFilter.setSessionAuthenticationStrategy(sessionAuthenticationStrategy);
        }
        RememberMeServices rememberMeServices = http.getSharedObject(RememberMeServices.class);
        if (rememberMeServices != null) {
            this.authFilter.setRememberMeServices(rememberMeServices);
        }
        SecurityContextConfigurer securityContextConfigurer = http.getConfigurer(SecurityContextConfigurer.class);
        if (securityContextConfigurer != null && securityContextConfigurer.isRequireExplicitSave()) {
            SecurityContextRepository securityContextRepository = securityContextConfigurer
                    .getSecurityContextRepository();
            this.authFilter.setSecurityContextRepository(securityContextRepository);
        }
        this.authFilter.setSecurityContextHolderStrategy(getSecurityContextHolderStrategy());
        F filter = postProcess(this.authFilter);
        http.addFilter(filter);
    }
}

在里面就看到 SessionAuthenticationStrategySecurityContextRepository 初始化与AbstractAuthenticationProcessingFilter 中为 UsernamePasswordAuthenticationFilter 时的一致,那么UserAuthenticationProcessingFilter 中的SessionAuthenticationStrategySecurityContextRepository 初始化也用这样的方式进行注入吧。

DefaultSecurityConfig中添加到SecurityFilterChain

@EnableWebSecurity
@Configuration(proxyBeanMethods = false)
public class DefaultSecurityConfig {


  @Autowired
  private PhoneCaptchaAuthenticationProvider phoneCaptchaAuthenticationProvider;



  // 过滤器链
  @Bean
  public SecurityFilterChain defaultSecurityFilterChain(HttpSecurity http) throws Exception {
  	UserAuthenticationFilter userAuthenticationFilter = new UserAuthenticationFilter();
  	http
  			.addFilterAt(userAuthenticationFilter, UsernamePasswordAuthenticationFilter.class)
  			.authorizeHttpRequests(authorize ->//① 配置鉴权的
  					authorize
  							.requestMatchers(
  									"/assets/**",
  									"/webjars/**",
  									AuthorizationServerConfigurationConsent.LOGIN_PAGE_URL,
  									"/oauth2/**",
  									"/oauth2/token"
  							).permitAll() //② 忽略鉴权的url
  							.anyRequest().authenticated()//③ 排除忽略的其他url就需要鉴权了
  			)
  			.csrf(AbstractHttpConfigurer::disable)
  			.authenticationProvider(phoneCaptchaAuthenticationProvider)
  			.formLogin(login->
  					login.loginPage(AuthorizationServerConfigurationConsent.LOGIN_PAGE_URL)
  							.defaultSuccessUrl("/test") // 登录成功后的跳转路径
  			)
  			.oauth2Login(oauth2Login->
  					oauth2Login.loginPage(AuthorizationServerConfigurationConsent.LOGIN_PAGE_URL)
  			);

  	DefaultSecurityFilterChain build = http.build();
  	userAuthenticationFilter.setAuthenticationManager(http.getSharedObject(AuthenticationManager.class));
  	userAuthenticationFilter.setSessionAuthenticationStrategy(http.getSharedObject(SessionAuthenticationStrategy.class));
  	SecurityContextRepository securityContextRepository =new DelegatingSecurityContextRepository(
  			new RequestAttributeSecurityContextRepository(), new HttpSessionSecurityContextRepository());
  	userAuthenticationFilter.setSecurityContextRepository(securityContextRepository);

  	return build;
  }
}

sso-login.html

<!DOCTYPE html>
<html lang="en" xmlns="http://www.w3.org/1999/xhtml" xmlns:th="https://www.thymeleaf.org">
<head>
   <meta charset="utf-8" />
   <meta name="viewport" content="width=device-width, initial-scale=1">
   <title>Spring Authorization Server sample</title>
   <link rel="stylesheet" href="/webjars/bootstrap/css/bootstrap.css" th:href="@{/webjars/bootstrap/css/bootstrap.css}" />
   <link rel="stylesheet" href="/assets/css/signin.css" th:href="@{/assets/css/signin.css}" />
</head>
<body>
<div class="container">
   <form class="form-signin w-100 m-auto" method="post" th:action="@{/sso-login}">
       <div th:if="${param.error}" class="alert alert-danger" role="alert">
           Invalid username or password.
       </div>
       <div th:if="${param.logout}" class="alert alert-success" role="alert">
           You have been logged out.
       </div>
       <h1 class="h3 mb-3 fw-normal">Please sign in</h1>
       <div class="form-floating">
           <input type="text" id="username" name="username" class="form-control" required autofocus>
           <label for="username">Username</label>
       </div>
       <div class="form-floating">
           <input type="text" id="code" name="code" class="form-control" required>
           <label for="code">code</label>
       </div>
       <div>
           <button class="w-100 btn btn-lg btn-primary btn-block" type="submit">Sign in</button>
           <a class="w-100 btn btn-light btn-block bg-white" href="/oauth2/authorization/gitee" role="link" style="margin-top: 10px">
               <img src="/assets/img/gitee.png" th:src="@{/assets/img/gitee.png}" width="20" style="margin-right: 5px;" alt="Sign in with Gitee">
               Sign in with Gitee
           </a>
           <a class="w-100 btn btn-light btn-block bg-white" href="/oauth2/authorization/github-idp" role="link" style="margin-top: 10px">
               <img src="/assets/img/github.png" th:src="@{/assets/img/github.png}" width="24" style="margin-right: 5px;" alt="Sign in with Github">
               Sign in with Github
           </a>
       </div>
   </form>
</div>
</body>
</html>

img_8_1.png

img_8_2.png

扩展完成🤓!

用spring-authorization-server的目的一定是要非常明确,就是为了不再去重复造轮子完成oauth2的认证流程,前端代码都不用写的,spring-boot-starter-oauth2-client 里面都去做了拿code去换token的整套流程,所以一定好好理解为什么要用spring-authorization-server的目的,别跑偏了。

如果出现以下异常

java.lang.IllegalArgumentException: The class with com.watermelon.authorization.filter.PhoneCaptchaAuthenticationToken \
and name of com.watermelon.authorization.filter.PhoneCaptchaAuthenticationToken is not in the allowlist. \
If you believe this class is safe to deserialize, please provide an explicit mapping using Jackson annotations or by providing a Mixin.\
If the serialization is only done by a trusted source, you can also enable default typing. \
See https://github.com/spring-projects/spring-security/issues/4370 for details \
  at org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService$OAuth2AuthorizationRowMapper.parseMap(JdbcOAuth2AuthorizationService.java:517) 

我们将OAuth2AuthorizationService 实现替换为redis存储即可

@Component
public class RedisOAuth2AuthorizationServiceImpl implements OAuth2AuthorizationService {


   private final static String AUTHORIZATION_TYPE = "authorization_type";

   private final static String OAUTH2_PARAMETER_NAME_ID = "id";

   private final static Long TIMEOUT = 600L;


   @Resource
   private RedisTemplate<String, Object> redisTemplate;

   @Override
   public void save(OAuth2Authorization authorization) {
       Assert.notNull(authorization, "authorization cannot be null");
       redisTemplate.setValueSerializer(RedisSerializer.java());
       redisTemplate.opsForValue().set(buildAuthorizationKey(OAUTH2_PARAMETER_NAME_ID, authorization.getId()), authorization, TIMEOUT, TimeUnit.SECONDS);
       if (isState(authorization)) {
           String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
           String isStateKey = buildAuthorizationKey(OAuth2ParameterNames.STATE, state);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isStateKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
       if (isAuthorizationCode(authorization)) {
           OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
                   authorization.getToken(OAuth2AuthorizationCode.class);
           String tokenValue = authorizationCode.getToken().getTokenValue();
           String isAuthorizationCodeKey = buildAuthorizationKey(OAuth2ParameterNames.CODE, tokenValue);
           Instant expiresAt = authorizationCode.getToken().getExpiresAt();//过期时间
           Instant issuedAt = authorizationCode.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isAuthorizationCodeKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
       if (isAccessToken(authorization)) {
           OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
                   authorization.getToken(OAuth2AccessToken.class);
           String tokenValue = accessToken.getToken().getTokenValue();
           String isAccessTokenKey = buildAuthorizationKey(OAuth2ParameterNames.ACCESS_TOKEN, tokenValue);
           Instant expiresAt = accessToken.getToken().getExpiresAt();//过期时间
           Instant issuedAt = accessToken.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isAccessTokenKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
       if (isRefreshToken(authorization)) {
           OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
                   authorization.getToken(OAuth2RefreshToken.class);
           String tokenValue = refreshToken.getToken().getTokenValue();
           String isRefreshTokenKey = buildAuthorizationKey(OAuth2ParameterNames.REFRESH_TOKEN, tokenValue);
           Instant expiresAt = refreshToken.getToken().getExpiresAt();//过期时间
           Instant issuedAt = refreshToken.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isRefreshTokenKey, authorization, TIMEOUT, TimeUnit.SECONDS);

       }
       if (isIdToken(authorization)) {
           OAuth2Authorization.Token<OidcIdToken> idToken =
                   authorization.getToken(OidcIdToken.class);
           String tokenValue = idToken.getToken().getTokenValue();
           String isIdTokenKey = buildAuthorizationKey(OidcParameterNames.ID_TOKEN, tokenValue);
           Instant expiresAt = idToken.getToken().getExpiresAt();//过期时间
           Instant issuedAt = idToken.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isIdTokenKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
       if (isDeviceCode(authorization)) {
           OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode =
                   authorization.getToken(OAuth2DeviceCode.class);

           String tokenValue = deviceCode.getToken().getTokenValue();
           String isDeviceCodeKey = buildAuthorizationKey(OAuth2ParameterNames.DEVICE_CODE, tokenValue);
           Instant expiresAt = deviceCode.getToken().getExpiresAt();//过期时间
           Instant issuedAt = deviceCode.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isDeviceCodeKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
       if (isUserCode(authorization)) {
           OAuth2Authorization.Token<OAuth2UserCode> userCode =
                   authorization.getToken(OAuth2UserCode.class);
           String tokenValue = userCode.getToken().getTokenValue();
           String isUserCodeKey = buildAuthorizationKey(OAuth2ParameterNames.USER_CODE, tokenValue);
           Instant expiresAt = userCode.getToken().getExpiresAt();//过期时间
           Instant issuedAt = userCode.getToken().getIssuedAt();//发放token的时间
           Date expiresAtDate = Date.from(expiresAt);
           Date issuedAtDate = Date.from(issuedAt);
           redisTemplate.setValueSerializer(RedisSerializer.java());
           redisTemplate.opsForValue().set(isUserCodeKey, authorization, TIMEOUT, TimeUnit.SECONDS);
       }
   }

   @Override
   public void remove(OAuth2Authorization authorization) {
       List<String> keys = new ArrayList<>();
       String idKey = buildAuthorizationKey(OAUTH2_PARAMETER_NAME_ID, authorization.getId());
       keys.add(idKey);
       if (isState(authorization)) {
           String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
           String isStateKey = buildAuthorizationKey(OAuth2ParameterNames.STATE, state);
           keys.add(isStateKey);
       }
       if (isAuthorizationCode(authorization)) {
           OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
                   authorization.getToken(OAuth2AuthorizationCode.class);
           String tokenValue = authorizationCode.getToken().getTokenValue();
           String isAuthorizationCodeKey = buildAuthorizationKey(OAuth2ParameterNames.CODE, tokenValue);
           keys.add(isAuthorizationCodeKey);
       }
       if (isAccessToken(authorization)) {
           OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
                   authorization.getToken(OAuth2AccessToken.class);
           String tokenValue = accessToken.getToken().getTokenValue();
           String isAccessTokenKey = buildAuthorizationKey(OAuth2ParameterNames.ACCESS_TOKEN, tokenValue);
           keys.add(isAccessTokenKey);
       }
       if (isRefreshToken(authorization)) {
           OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
                   authorization.getToken(OAuth2RefreshToken.class);
           String tokenValue = refreshToken.getToken().getTokenValue();
           String isRefreshTokenKey = buildAuthorizationKey(OAuth2ParameterNames.REFRESH_TOKEN, tokenValue);
           keys.add(isRefreshTokenKey);
       }
       if (isIdToken(authorization)) {
           OAuth2Authorization.Token<OidcIdToken> idToken =
                   authorization.getToken(OidcIdToken.class);
           String tokenValue = idToken.getToken().getTokenValue();
           String isIdTokenKey = buildAuthorizationKey(OidcParameterNames.ID_TOKEN, tokenValue);
           keys.add(isIdTokenKey);
       }
       if (isDeviceCode(authorization)) {
           OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode =
                   authorization.getToken(OAuth2DeviceCode.class);

           String tokenValue = deviceCode.getToken().getTokenValue();
           String isDeviceCodeKey = buildAuthorizationKey(OAuth2ParameterNames.DEVICE_CODE, tokenValue);
           keys.add(isDeviceCodeKey);
       }
       if (isUserCode(authorization)) {
           OAuth2Authorization.Token<OAuth2UserCode> userCode =
                   authorization.getToken(OAuth2UserCode.class);
           String tokenValue = userCode.getToken().getTokenValue();
           String isUserCodeKey = buildAuthorizationKey(OAuth2ParameterNames.USER_CODE, tokenValue);
           keys.add(isUserCodeKey);
       }
       redisTemplate.delete(keys);
   }

   @Override
   public OAuth2Authorization findById(String id) {
       return (OAuth2Authorization) Optional.ofNullable(redisTemplate.opsForValue().get(buildAuthorizationKey(OAUTH2_PARAMETER_NAME_ID, id))).orElse(null);
   }

   @Override
   public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
       Assert.hasText(token, "token cannot be empty");
       Assert.notNull(tokenType, "tokenType cannot be empty");
       redisTemplate.setValueSerializer(RedisSerializer.java());
       return (OAuth2Authorization) redisTemplate.opsForValue().get(buildAuthorizationKey(tokenType.getValue(), token));
   }


   private boolean isState(OAuth2Authorization authorization) {
       return Objects.nonNull(authorization.getAttribute(OAuth2ParameterNames.STATE));
   }


   private boolean isAuthorizationCode(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
               authorization.getToken(OAuth2AuthorizationCode.class);
       return Objects.nonNull(authorizationCode);
   }


   private boolean isAccessToken(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
               authorization.getToken(OAuth2AccessToken.class);
       return Objects.nonNull(accessToken) && Objects.nonNull(accessToken.getToken().getTokenType());
   }

   private boolean isRefreshToken(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
               authorization.getToken(OAuth2RefreshToken.class);
       return Objects.nonNull(refreshToken) && Objects.nonNull(refreshToken.getToken().getTokenValue());
   }

   private boolean isIdToken(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OidcIdToken> idToken =
               authorization.getToken(OidcIdToken.class);
       return Objects.nonNull(idToken) && Objects.nonNull(idToken.getToken().getTokenValue());
   }

   private boolean isDeviceCode(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OAuth2DeviceCode> deviceCode =
               authorization.getToken(OAuth2DeviceCode.class);
       return Objects.nonNull(deviceCode) && Objects.nonNull(deviceCode.getToken().getTokenValue());
   }

   private boolean isUserCode(OAuth2Authorization authorization) {
       OAuth2Authorization.Token<OAuth2UserCode> userCode =
               authorization.getToken(OAuth2UserCode.class);
       return Objects.nonNull(userCode) && Objects.nonNull(userCode.getToken().getTokenValue());
   }


   /**
    * redis key 构建
    *
    * @param type  授权类型
    * @param value 授权值
    * @return
    */
   private String buildAuthorizationKey(String type, String value) {
       return AUTHORIZATION_TYPE.concat("::").concat(type).concat("::").concat(value);
   }
}