shiro权限控制(二):分布式架构中shiro的实现
前言:
前段时间在搭建公司游戏框架安全验证的时候,就想到之前web最火的shiro框架,虽然后面实践发现在netty中不太适用,最后自己模仿shiro写了一个缩减版的,但是中间花费两天时间弄出来的shiro可不能白费,这里给大家出个简单的教程说明吧。
shiro的基本介绍这里就不再说了,可以自行翻阅博主之前写的shiro教程,这篇文章主要说明分布式架构下shiro的session共享问题。
一、原理描述
无论分布式、还是集群下,项目都需要获取登录用户的信息,而不可能做的就是让客户在每个系统或者每个模块中反复登录,也不存在让客户端存载用户信息给服务端,这是很常识的问题
而单机模式下,我们用shiro做了登录验证,他的主要方式就是在第一次登陆的时候,把我们设置的用户信息保存在cache(内存)中和自带的ehcahe(缓存管理器)中,然后给客户端一个cookie,在每次客户端访问时获取cookie值,从而得到用户信息。
好了,那么逻辑就清楚了,分布式架构下,要与多系统共享用户信息,其实就是共享shiro保存的cache。
要在多项目中共享,内存是不可能的了,ehcache对分布式支持不太好,或者说根本不支持。那么剩下只能是我么熟悉的mysql,redis,mongdb啥的数据库了。这么一对比,不用我说大家也明白了,最适合的无疑是redis了,速度快,主从啥的。
二、流程描述
查看源码我们可以知道,cacheManager最终会被set到sessionDAO中,所以我们要自己写sessionDAO。有两个类去操作保存的,那么我们只需要重写,实现这两个类,然后在注册的时候声明即可。
1.shiroCache:cache类,可以自己写一个定时消除的MAP存放更好,文章结尾我会给出map的代码。而这里的代码我是放在redis的。
package com.result.shiro.distributed; import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Set; import org.apache.shiro.cache.Cache; import org.apache.shiro.cache.CacheException; import com.result.redis.RedisKey; import com.result.redis.RedisUtil; import com.result.tools.KyroUtil; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月8日 下午9:33:23 * cache共享 */ @SuppressWarnings("unchecked") public class ShiroCache<K, V> implements Cache<K, V> { private static final String REDIS_SHIRO_CACHE = RedisKey.CACHEKEY; private String cacheKey; private long globExpire = 30; @SuppressWarnings("rawtypes") public ShiroCache(String name) { this.cacheKey = REDIS_SHIRO_CACHE + name + ":"; } @Override public V get(K key) throws CacheException { Object obj = RedisUtil.get(KyroUtil.serialization(getCacheKey(key))); if(obj==null){ return null; } return (V) KyroUtil.deserialization((String)obj); } @Override public V put(K key, V value) throws CacheException { V old = get(key); RedisUtil.setex(KyroUtil.serialization(getCacheKey(key)), 18000, KyroUtil.serialization(value)); return old; } @Override public V remove(K key) throws CacheException { V old = get(key); RedisUtil.del(KyroUtil.serialization(getCacheKey(key))); return old; } @Override public void clear() throws CacheException { for(String key : (Set<String>)keys()){ RedisUtil.del(key); } } @Override public int size() { return keys().size(); } @Override public Set<K> keys() { return (Set<K>) RedisUtil.keys(KyroUtil.serialization(getCacheKey("*"))); } @Override public Collection<V> values() { Set<K> set = keys(); List<V> list = new ArrayList<>(); for (K s : set) { list.add(get(s)); } return list; } private K getCacheKey(Object k) { return (K) (this.cacheKey + k); } }
2.session操作类:这里用来把用户信息存放在redis中共享的。
package com.result.shiro.distributed; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月6日 上午10:12:42 * redis实现共享session */ import java.io.Serializable; import java.util.Collection; import java.util.HashSet; import java.util.Set; import org.apache.shiro.session.Session; import org.apache.shiro.session.UnknownSessionException; import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import com.result.redis.RedisKey; import com.result.redis.RedisUtil; import com.result.tools.KyroUtil; import com.result.tools.SerializationUtil; public class RedisSessionDao extends EnterpriseCacheSessionDAO { private static Logger logger = LoggerFactory.getLogger(RedisSessionDao.class); @Override public void update(Session session) throws UnknownSessionException { this.saveSession(session); } /** * 删除session */ @Override public void delete(Session session) { if (session == null || session.getId() == null) { logger.error("==========session或sessionI 不存在"); return; } RedisUtil.del(KyroUtil.serialization(RedisKey.SESSIONKEY + session.getId())); } /** * 获取存活的sessions */ @Override public Collection<Session> getActiveSessions() { Set<Session> sessions = new HashSet<>(); Set<String> keys = RedisUtil.keys(KyroUtil.serialization(RedisKey.SESSIONKEY + "*")); for(String key:keys){ sessions.add((Session)KyroUtil.deserialization((String)RedisUtil.get(key))); } return sessions; } /** * 创建session */ @Override protected Serializable doCreate(Session session) { Serializable sessionId = this.generateSessionId(session); this.assignSessionId(session, sessionId); this.saveSession(session); return sessionId; } /** * 获取session */ @Override protected Session doReadSession(Serializable sessionId) { if(sessionId == null){ logger.error("==========session id 不存在"); return null; } Object obj = RedisUtil.get(KyroUtil.serialization(RedisKey.SESSIONKEY + sessionId)); if(obj==null){ return null; } Session s = (Session)KyroUtil.deserialization((String)obj); return s; } /** * 保存session并存储过期时间 * @param session * @throws UnknownSessionException */ public static void saveSession(String sessionId,Object obj) throws UnknownSessionException{ if (obj == null) { logger.error("要存入的session为空"); return; } //设置过期时间 int expireTime = 1800; RedisUtil.setex(sessionId,expireTime,SerializationUtil.serializeToString(obj)); } } 然后还有一个类也是必要的 package com.result.shiro.distributed; import org.apache.shiro.cache.Cache; import org.apache.shiro.cache.CacheException; import org.apache.shiro.cache.CacheManager; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月8日 下午9:32:41 * 类说明 */ public class RedisCacheManager implements CacheManager { @Override public <K, V> Cache<K, V> getCache(String name) throws CacheException { return new ShiroCache<K, V>(name); } }
三:辅助类说明
用户信息的session存放在redis中肯定是需要序列化的,然而用json这种可读性太强的东西安全性显得极低,而且长度太大,浪费存储空间和IO。所以需要找其他的序列化工具。
常规的好用的序列化工具有kyro,protobuff,这些是性能极高而且序列化之后长度极小的序列化工具,其中protobuf支持跨语言。不过这些在之后的文章再和大家介绍去了,因为~!!session不支持这两种操作(因为上面两个类中操作的session实际是一个接口)。
那么序列化用的什么,emmmm~一个很原生的东西,测试效率也挺高的,和protobuf差不太多。下面贴出的代码实际就是上面类中kyroUtils中的方法,因为shiro分布式在项目中被废掉了,我也没去改名字了。大家自己看仔细点就可以了。
被注释掉的代码是kyro的序列化工具。
package com.result.tools; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月6日 下午2:22:14 * Kryo工具类 */ public class KyroUtil { private static Logger logger = LoggerFactory.getLogger(KyroUtil.class); //private static KryoPool pool; //原本打算使用kyro序列化session,后来发现kyro对session序列化不支持,反序列后得不到value。 这种out序列化测试性能消耗时间更短,但是长度变大4倍意思,待优化 // static{ // KryoFactory factory = new KryoFactory() { // public Kryo create() { // Kryo kryo = new Kryo(); // kryo.setReferences(false); // //把shiroSession的结构注册到Kryo注册器里面,提高序列化/反序列化效率 // kryo.register(Session.class, new JavaSerializer()); // kryo.register(String.class, new JavaSerializer()); // kryo.register(User.class, new JavaSerializer()); // kryo.setInstantiatorStrategy(new StdInstantiatorStrategy()); // return kryo; // } // }; // pool = new KryoPool.Builder(factory).build(); // logger.info("KryoPool初始化成功===================================="); // } /** * 对象编码 * @param value * @return */ public static String serialization(Object value) { // String str =""; // try { // Kryo kryo = pool.borrow(); // ByteArrayOutputStream baos = new ByteArrayOutputStream(); // Output output = new Output(baos); // kryo.writeClassAndObject(output, value); // output.flush(); // output.close(); // byte[] b = baos.toByteArray(); // baos.flush(); // baos.close(); // str = new String(b, "ISO8859-1"); // } catch (IOException e) { // e.printStackTrace(); // } // return str; // ByteArrayOutputStream bos = null; ObjectOutputStream oos = null; try { bos = new ByteArrayOutputStream(); oos = new ObjectOutputStream(bos); oos.writeObject(value); return new String(bos.toByteArray(), "ISO8859-1"); } catch (Exception e) { throw new RuntimeException("serialize session error", e); } finally { try { oos.close(); bos.close(); } catch (IOException e) { e.printStackTrace(); } } // return new String(new Base64().encode(b)); } /** * 对象解码 * @param <T> * @param <T> * @param obj * @param clazz * @return */ public static Object deserialization(String obj) { // try { // Kryo kryo = pool.borrow(); // ByteArrayInputStream bais; // bais = new ByteArrayInputStream(obj.getBytes("ISO8859-1")); // //new Base64().decode(obj)); // Input input = new Input(bais); // return kryo.readClassAndObject(input); // } catch (UnsupportedEncodingException e) { // // TODO Auto-generated catch block // e.printStackTrace(); // } // return null; ByteArrayInputStream bis = null; ObjectInputStream ois = null; try { bis = new ByteArrayInputStream(obj.getBytes("ISO8859-1")); ois = new ObjectInputStream(bis); return ois.readObject(); } catch (Exception e) { throw new RuntimeException("deserialize session error", e); } finally { try { ois.close(); bis.close(); } catch (IOException e) { e.printStackTrace(); } } } }
四、注册
好了,该重写的都重写了,那么最后一步就是整合spring的时候我们要告诉spring,我们要用的是我们重写过的sessiondao了。
我这里用的是代码的方式,因为某些原因在写框架的时候不太好用xml去整合。
反正原理都差不多,大家看看就明白了:
package com.business.shiro; import java.util.LinkedHashMap; import java.util.Map; import org.apache.shiro.authc.credential.HashedCredentialsMatcher; import org.apache.shiro.cache.CacheManager; import org.apache.shiro.cache.ehcache.EhCacheManager; import org.apache.shiro.codec.Base64; import org.apache.shiro.realm.AuthorizingRealm; import org.apache.shiro.session.mgt.ExecutorServiceSessionValidationScheduler; import org.apache.shiro.session.mgt.eis.EnterpriseCacheSessionDAO; import org.apache.shiro.spring.LifecycleBeanPostProcessor; import org.apache.shiro.spring.security.interceptor.AuthorizationAttributeSourceAdvisor; import org.apache.shiro.spring.web.ShiroFilterFactoryBean; import org.apache.shiro.web.mgt.CookieRememberMeManager; import org.apache.shiro.web.mgt.DefaultWebSecurityManager; import org.apache.shiro.web.servlet.SimpleCookie; import org.apache.shiro.web.session.mgt.DefaultWebSessionManager; import org.springframework.aop.framework.autoproxy.DefaultAdvisorAutoProxyCreator; import org.springframework.beans.factory.config.MethodInvokingFactoryBean; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.DependsOn; import com.result.shiro.distributed.RedisCacheManager; import com.result.shiro.distributed.RedisSessionDao; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月8日 下午8:29:12 * 类说明 */ @Configuration public class ShiroConfiguration { private static Map<String, String> filterChainDefinitionMap = new LinkedHashMap<String, String>(); @Bean(name = "cacheShiroManager") public CacheManager getCacheManage() { return new RedisCacheManager(); } @Bean(name = "lifecycleBeanPostProcessor") public LifecycleBeanPostProcessor getLifecycleBeanPostProcessor() { return new LifecycleBeanPostProcessor(); } @Bean(name = "sessionValidationScheduler") public ExecutorServiceSessionValidationScheduler getExecutorServiceSessionValidationScheduler() { ExecutorServiceSessionValidationScheduler scheduler = new ExecutorServiceSessionValidationScheduler(); scheduler.setInterval(900000); return scheduler; } @Bean(name = "hashedCredentialsMatcher") public HashedCredentialsMatcher getHashedCredentialsMatcher() { HashedCredentialsMatcher credentialsMatcher = new HashedCredentialsMatcher(); credentialsMatcher.setHashAlgorithmName("MD5"); credentialsMatcher.setHashIterations(1); credentialsMatcher.setStoredCredentialsHexEncoded(true); return credentialsMatcher; } @Bean(name = "sessionIdCookie") public SimpleCookie getSessionIdCookie() { SimpleCookie cookie = new SimpleCookie("sid"); cookie.setHttpOnly(true); cookie.setMaxAge(-1); return cookie; } @Bean(name = "rememberMeCookie") public SimpleCookie getRememberMeCookie() { SimpleCookie simpleCookie = new SimpleCookie("rememberMe"); simpleCookie.setHttpOnly(true); simpleCookie.setMaxAge(2592000); return simpleCookie; } @Bean public CookieRememberMeManager getRememberManager(){ CookieRememberMeManager meManager = new CookieRememberMeManager(); meManager.setCipherKey(Base64.decode("4AvVhmFLUs0KTA3Kprsdag==")); meManager.setCookie(getRememberMeCookie()); return meManager; } @Bean(name = "sessionManager") public DefaultWebSessionManager getSessionManage() { DefaultWebSessionManager sessionManager = new DefaultWebSessionManager(); sessionManager.setGlobalSessionTimeout(1800000); sessionManager.setSessionValidationScheduler(getExecutorServiceSessionValidationScheduler()); sessionManager.setSessionValidationSchedulerEnabled(true); sessionManager.setDeleteInvalidSessions(true); sessionManager.setSessionIdCookieEnabled(true); sessionManager.setSessionIdCookie(getSessionIdCookie()); RedisSessionDao cacheSessionDAO = new RedisSessionDao(); cacheSessionDAO.setCacheManager(getCacheManage()); sessionManager.setSessionDAO(cacheSessionDAO); // -----可以添加session 创建、删除的监听器 return sessionManager; } @Bean(name = "myRealm") public AuthorizingRealm getShiroRealm() { MyRealm realm = new MyRealm(); // realm.setName("shiro_auth_cache"); // realm.setAuthenticationCache(getCacheManage().getCache(realm.getName())); // realm.setAuthenticationTokenClass(UserAuthenticationToken.class); return realm; } @Bean(name = "securityManager") public DefaultWebSecurityManager getSecurityManager() { DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager(); securityManager.setCacheManager(getCacheManage()); securityManager.setSessionManager(getSessionManage()); securityManager.setRememberMeManager(getRememberManager()); securityManager.setRealm(getShiroRealm()); return securityManager; } @Bean public MethodInvokingFactoryBean getMethodInvokingFactoryBean(){ MethodInvokingFactoryBean factoryBean = new MethodInvokingFactoryBean(); factoryBean.setStaticMethod("org.apache.shiro.SecurityUtils.setSecurityManager"); factoryBean.setArguments(new Object[]{getSecurityManager()}); return factoryBean; } @Bean @DependsOn("lifecycleBeanPostProcessor") public DefaultAdvisorAutoProxyCreator getAutoProxyCreator(){ DefaultAdvisorAutoProxyCreator creator = new DefaultAdvisorAutoProxyCreator(); creator.setProxyTargetClass(true); return creator; } @Bean public AuthorizationAttributeSourceAdvisor getAuthorizationAttributeSourceAdvisor(){ AuthorizationAttributeSourceAdvisor advisor = new AuthorizationAttributeSourceAdvisor(); advisor.setSecurityManager(getSecurityManager()); return advisor; } /** * @return */ @Bean(name = "shiroFilter") public ShiroFilterFactoryBean getShiroFilterFactoryBean(){ ShiroFilterFactoryBean factoryBean = new ShiroFilterFactoryBean(); factoryBean.setSecurityManager(getSecurityManager()); factoryBean.setLoginUrl("/toLogin"); factoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap); return factoryBean; } }
优化:伪定时消除map,最好配合quartz清楚,不然内存中MAP如果不访问就不消除,容易累计。
package com.result.security; import java.util.Collection; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import java.util.Set; import com.result.NettyGoConstant; /** * @author 作者 huangxinyu * @version 创建时间:2018年1月29日 上午10:31:50 类说明 */ public class ExpiryMap<K, V> extends HashMap<K, V> { private static final long serialVersionUID = 1L; /** * default expiry time 2m */ private long EXPIRY = NettyGoConstant.LOGINSESSIONTIMEOUT; private HashMap<K, Long> expiryMap = new HashMap<>(); public ExpiryMap() { super(); } public ExpiryMap(long defaultExpiryTime) { this(1 << 4, defaultExpiryTime); } public ExpiryMap(int initialCapacity, long defaultExpiryTime) { super(initialCapacity); this.EXPIRY = defaultExpiryTime; } public V put(K key, V value) { expiryMap.put(key, System.currentTimeMillis() + EXPIRY); return super.put(key, value); } public boolean containsKey(Object key) { return !checkExpiry(key, true) && super.containsKey(key); } /** * @param key * @param value * @param expiryTime * 键值对有效期 毫秒 * @return */ public V put(K key, V value, long expiryTime) { expiryMap.put(key, System.currentTimeMillis() + expiryTime); return super.put(key, value); } public int size() { return entrySet().size(); } public boolean isEmpty() { return entrySet().size() == 0; } public boolean containsValue(Object value) { if (value == null) return Boolean.FALSE; Set<java.util.Map.Entry<K, V>> set = super.entrySet(); Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator(); while (iterator.hasNext()) { java.util.Map.Entry<K, V> entry = iterator.next(); if (value.equals(entry.getValue())) { if (checkExpiry(entry.getKey(), false)) { iterator.remove(); return Boolean.FALSE; } else return Boolean.TRUE; } } return Boolean.FALSE; } public Collection<V> values() { Collection<V> values = super.values(); if (values == null || values.size() < 1) return values; Iterator<V> iterator = values.iterator(); while (iterator.hasNext()) { V next = iterator.next(); if (!containsValue(next)) iterator.remove(); } return values; } public V get(Object key) { if (key == null) return null; if (checkExpiry(key, true)) return null; return super.get(key); } /** * * @Description: 是否过期 * @param key * @return null:不存在或key为null -1:过期 存在且没过期返回value 因为过期的不是实时删除,所以稍微有点作用 */ public Object isInvalid(Object key) { if (key == null) return null; if (!expiryMap.containsKey(key)) { return null; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis() > expiryTime; if (flag) { super.remove(key); expiryMap.remove(key); return -1; } return super.get(key); } public void putAll(Map<? extends K, ? extends V> m) { for (Map.Entry<? extends K, ? extends V> e : m.entrySet()) expiryMap.put(e.getKey(), System.currentTimeMillis() + EXPIRY); super.putAll(m); } public Set<Map.Entry<K, V>> entrySet() { Set<java.util.Map.Entry<K, V>> set = super.entrySet(); Iterator<java.util.Map.Entry<K, V>> iterator = set.iterator(); while (iterator.hasNext()) { java.util.Map.Entry<K, V> entry = iterator.next(); if (checkExpiry(entry.getKey(), false)) iterator.remove(); } return set; } /** * * @Description: 是否过期 * @author: qd-ankang * @date: 2016-11-24 下午4:05:02 * @param expiryTime * true 过期 * @param isRemoveSuper * true super删除 * @return */ private boolean checkExpiry(Object key, boolean isRemoveSuper) { if (!expiryMap.containsKey(key)) { return Boolean.FALSE; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis() > expiryTime; if (flag) { if (isRemoveSuper) super.remove(key); expiryMap.remove(key); } return flag; } /** * 删除 * @param key */ public void del(Object key){ super.remove(key); expiryMap.remove(key); } public static void main(String[] args) throws InterruptedException { ExpiryMap<String, String> map = new ExpiryMap<>(10); map.put("test", "ankang"); map.put("test1", "ankang"); map.put("test2", "ankang", 3000); System.out.println("test1" + map.get("test")); Thread.sleep(1000); System.out.println("isInvalid:" + map.isInvalid("test")); System.out.println("size:" + map.size()); System.out.println("size:" + ((HashMap<String, String>) map).size()); for (Map.Entry<String, String> m : map.entrySet()) { System.out.println("isInvalid:" + map.isInvalid(m.getKey())); map.containsKey(m.getKey()); System.out.println("key:" + m.getKey() + " value:" + m.getValue()); } System.out.println("test1" + map.get("test")); } /** * 是否超过过期的一半时间 * @param key * @return */ public boolean isHalfExpiryTime(Object key ){ if (!expiryMap.containsKey(key)) { return false; } long expiryTime = expiryMap.get(key); boolean flag = System.currentTimeMillis()-(expiryTime-NettyGoConstant.LOGINSESSIONTIMEOUT)>=NettyGoConstant.LOGINSESSIONTIMEOUT/2; return flag; } }