分布式环境中spring cloud oauth2授权服务异常处理
环境
springboot 2.3.7
spring cloud 2.2.6
spring security 2.3.8
分布式部署多个spring security oauth2授权服务器实例,使用redis session同步会话
问题
客户端通过认证码模式获取令牌时会出现异常报错
分析
spring security oauth2 授权服务器默认使用InMemoryAuthorizationCodeServices 管理授权码,导致分布部署的多个授权服务没有同步授权码,负载均衡将获取令牌的请求发送到非登陆认证的服务器时将报错
解决
自定义RedisAuthorizationCodeServices,使用Redis集中管理分布式环境下的授权码
import org.apache.commons.lang3.SerializationUtils;
import org.psrframework.core.util.UUIDUtil;
import org.springframework.dao.DataAccessException;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.security.oauth2.common.exceptions.InvalidGrantException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.code.AuthorizationCodeServices;
import java.nio.charset.StandardCharsets;
public class RedisAuthorizationCodeServicesImpl implements AuthorizationCodeServices {
private final String REDIS_KEY_AUTH_CODE_PREFIX = "auth_code:";
private final RedisTemplate stringRedisTemplate;
public RedisAuthorizationCodeServicesImpl(RedisTemplate redisTemplate) {
this.stringRedisTemplate = redisTemplate;
}
@Override
public String createAuthorizationCode(OAuth2Authentication authentication) {
for (int i = 0; i < 10; i++) {
String code = UUIDUtil.generateTimebasedUUID().toString();
byte[] key = (REDIS_KEY_AUTH_CODE_PREFIX + code).getBytes(StandardCharsets.UTF_8);
byte[] data = SerializationUtils.serialize(authentication);
Boolean result = (Boolean) stringRedisTemplate.execute(
new RedisCallback<Boolean>() {
@Override
public Boolean doInRedis(RedisConnection connection) throws DataAccessException {
if (connection.setNX(key, data)) {
connection.expire(key, 1500);
return true;
}
return false;
}
}
);
if (Boolean.TRUE.equals(result)) {
return code;
}
}
return null;
}
@Override
public OAuth2Authentication consumeAuthorizationCode(String code) throws InvalidGrantException {
byte[] key = (REDIS_KEY_AUTH_CODE_PREFIX + code).getBytes(StandardCharsets.UTF_8);
byte[] data = (byte[]) stringRedisTemplate.execute(new RedisCallback() {
@Override
public byte[] doInRedis(RedisConnection connection) throws DataAccessException {
byte[] data = connection.get(key);
connection.del(key);
return data;
}
});
return SerializationUtils.deserialize(data);
}
}
注册RedisAuthorizationCodeServices
public class AuthorizationServerConfig extends AuthorizationServerConfigurerAdapter {
...
@Override
public void configure(AuthorizationServerEndpointsConfigurer endpoints) throws Exception {
...
endpoints.authorizationCodeServices(new RedisAuthorizationCodeServicesImpl(redisTemplate));
}
}
spring授权过程源码
- 客户端
org.springframework.security.oauth2.client.OAuth2RestTemplate
protected OAuth2AccessToken acquireAccessToken(OAuth2ClientContext oauth2Context)
throws UserRedirectRequiredException {
AccessTokenRequest accessTokenRequest = oauth2Context.getAccessTokenRequest();
if (accessTokenRequest == null) {
throw new AccessTokenRequiredException(
"No OAuth 2 security context has been established. Unable to access resource '"
+ this.resource.getId() + "'.", resource);
}
// Transfer the preserved state from the (longer lived) context to the current request.
String stateKey = accessTokenRequest.getStateKey();
// 如果请求中没有状态码,则从上下文中获取预设的状态码(如果存在长生命周期的上下文,例如已获得授权的会话)
if (stateKey != null) {
accessTokenRequest.setPreservedState(oauth2Context.removePreservedState(stateKey));
}
OAuth2AccessToken existingToken = oauth2Context.getAccessToken();
if (existingToken != null) {
accessTokenRequest.setExistingToken(existingToken);
}
OAuth2AccessToken accessToken = null;
accessToken = accessTokenProvider.obtainAccessToken(resource, accessTokenRequest);
if (accessToken == null || accessToken.getValue() == null) {
throw new IllegalStateException(
"Access token provider returned a null access token, which is illegal according to the contract.");
}
oauth2Context.setAccessToken(accessToken);
return accessToken;
}
org.springframework.security.oauth2.client.token.grant.code.AuthorizationCodeAccessTokenProvider
public OAuth2AccessToken obtainAccessToken(OAuth2ProtectedResourceDetails details, AccessTokenRequest request)
throws UserRedirectRequiredException, UserApprovalRequiredException, AccessDeniedException,
OAuth2AccessDeniedException {
AuthorizationCodeResourceDetails resource = (AuthorizationCodeResourceDetails) details;
// 如果请求中不存在授权码和状态码,则跳转到授权页面
if (request.getAuthorizationCode() == null) {
if (request.getStateKey() == null) {
throw getRedirectForAuthorization(resource, request);
}
obtainAuthorizationCode(resource, request);
}
// 根据授权码和状态码获取令牌
return retrieveToken(request, resource, getParametersForTokenRequest(resource, request),
getHeadersForTokenRequest(request));
}
protected OAuth2AccessToken retrieveToken(AccessTokenRequest request, OAuth2ProtectedResourceDetails resource,
MultiValueMap<String, String> form, HttpHeaders headers) throws OAuth2AccessDeniedException {
try {
// Prepare headers and form before going into rest template call in case the URI is affected by the result
authenticationHandler.authenticateTokenRequest(resource, form, headers);
// Opportunity to customize form and headers
tokenRequestEnhancer.enhance(request, resource, form, headers);
final AccessTokenRequest copy = request;
final ResponseExtractor<OAuth2AccessToken> delegate = getResponseExtractor();
ResponseExtractor<OAuth2AccessToken> extractor = new ResponseExtractor<OAuth2AccessToken>() {
@Override
public OAuth2AccessToken extractData(ClientHttpResponse response) throws IOException {
if (response.getHeaders().containsKey("Set-Cookie")) {
copy.setCookie(response.getHeaders().getFirst("Set-Cookie"));
}
return delegate.extractData(response);
}
};
return getRestTemplate().execute(getAccessTokenUri(resource, form), getHttpMethod(),
getRequestCallback(resource, form, headers), extractor , form.toSingleValueMap());
}
catch (OAuth2Exception oe) {
throw new OAuth2AccessDeniedException("Access token denied.", resource, oe);
}
catch (RestClientException rce) {
throw new OAuth2AccessDeniedException("Error requesting access token.", resource, rce);
}
}
- 认证服务器
org.springframework.security.oauth2.provider.code.AuthorizationCodeTokenGranter
@Override
protected OAuth2Authentication getOAuth2Authentication(ClientDetails client, TokenRequest tokenRequest) {
Map<String, String> parameters = tokenRequest.getRequestParameters();
String authorizationCode = parameters.get("code");
String redirectUri = parameters.get(OAuth2Utils.REDIRECT_URI);
if (authorizationCode == null) {
throw new InvalidRequestException("An authorization code must be supplied.");
}
// 从授权码服务中根据授权码获取认证信息
OAuth2Authentication storedAuth = authorizationCodeServices.consumeAuthorizationCode(authorizationCode);
if (storedAuth == null) {
throw new InvalidGrantException("Invalid authorization code: " + authorizationCode);
}
OAuth2Request pendingOAuth2Request = storedAuth.getOAuth2Request();
// https://jira.springsource.org/browse/SECOAUTH-333
// This might be null, if the authorization was done without the redirect_uri parameter
String redirectUriApprovalParameter = pendingOAuth2Request.getRequestParameters().get(
OAuth2Utils.REDIRECT_URI);
if ((redirectUri != null || redirectUriApprovalParameter != null)
&& !pendingOAuth2Request.getRedirectUri().equals(redirectUri)) {
throw new RedirectMismatchException("Redirect URI mismatch.");
}
String pendingClientId = pendingOAuth2Request.getClientId();
String clientId = tokenRequest.getClientId();
if (clientId != null && !clientId.equals(pendingClientId)) {
// just a sanity check.
throw new InvalidClientException("Client ID mismatch");
}
// Secret is not required in the authorization request, so it won't be available
// in the pendingAuthorizationRequest. We do want to check that a secret is provided
// in the token request, but that happens elsewhere.
Map<String, String> combinedParameters = new HashMap<String, String>(pendingOAuth2Request
.getRequestParameters());
// Combine the parameters adding the new ones last so they override if there are any clashes
combinedParameters.putAll(parameters);
// Make a new stored request with the combined parameters
OAuth2Request finalStoredOAuth2Request = pendingOAuth2Request.createOAuth2Request(combinedParameters);
Authentication userAuth = storedAuth.getUserAuthentication();
return new OAuth2Authentication(finalStoredOAuth2Request, userAuth);
}
org.springframework.security.oauth2.config.annotation.web.configurers.AuthorizationServerEndpointsConfigurer
// 默认使用内存授权码服务
private AuthorizationCodeServices authorizationCodeServices() {
if (authorizationCodeServices == null) {
authorizationCodeServices = new InMemoryAuthorizationCodeServices();
}
return authorizationCodeServices;
}
org.springframework.security.oauth2.provider.code.InMemoryAuthorizationCodeServices
public class InMemoryAuthorizationCodeServices extends RandomValueAuthorizationCodeServices {
// 授权码保存在本地内存中
protected final ConcurrentHashMap<String, OAuth2Authentication> authorizationCodeStore = new ConcurrentHashMap<String, OAuth2Authentication>();
@Override
protected void store(String code, OAuth2Authentication authentication) {
this.authorizationCodeStore.put(code, authentication);
}
@Override
public OAuth2Authentication remove(String code) {
OAuth2Authentication auth = this.authorizationCodeStore.remove(code);
return auth;
}
}