spring自定义session分布式session

spring实现自定义session、springboot实现自定义session、自定义sessionid的key、value、实现分布式会话

一、原始方案

自定义生成sessionid的值

修改tomcat 的org.apache.catalina.util.HttpServletRequest 包下的生成方法

/**
 * Generate and return a new session identifier.
 */
@Override
public String generateSessionId() {
    return generateSessionId(jvmRoute);
}

二、使用spring-session框架

maven

<dependency>
    <groupId>org.springframework.session</groupId>
    <artifactId>spring-session-core</artifactId>
</dependency>

<dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.session</groupId>
            <artifactId>spring-session-bom</artifactId>
            <version>Corn-SR2</version>
            <type>pom</type>
            <scope>import</scope>
        </dependency>
    </dependencies>
</dependencyManagement>

自定义生成

import org.springframework.session.MapSession;
import org.springframework.session.MapSessionRepository;
import org.springframework.session.Session;

import java.time.Duration;
import java.util.Map;

/**
 * @Author 绫小路
 * @Date 2021/3/10
 * @Description 继承 MapSessionRepository 表示将session存储到map中
 */
public class MySessionRepository extends MapSessionRepository {
  private Integer defaultMaxInactiveInterval;

  public MySessionRepository(Map<String, Session> sessions) {
    super(sessions);
  }

  public void setDefaultMaxInactiveInterval(int defaultMaxInactiveInterval) {
    this.defaultMaxInactiveInterval = defaultMaxInactiveInterval;
  }

  @Override
  public MapSession createSession() {
    //自定义生成id  解码即可看到 byte[] bytes = new BASE64Decoder().decodeBuffer("MTYxNTM1Nzg0OTI2NQ==");
    String id = String.valueOf(System.currentTimeMillis());
    MapSession result = new MapSession(id);

    if (this.defaultMaxInactiveInterval != null) {
      result.setMaxInactiveInterval(Duration.ofSeconds(this.defaultMaxInactiveInterval));
    }
    return result;
  }
}

配置

import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.session.Session;
import org.springframework.session.SessionRepository;
import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession;
import org.springframework.session.web.http.CookieSerializer;
import org.springframework.session.web.http.DefaultCookieSerializer;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @Author 绫小路
 * @Date 2021/3/10
 */
@EnableSpringHttpSession
@Configuration
public class MyConfig {
  public static Map<String, Session> sessions = new ConcurrentHashMap<>();

  @Bean
  public SessionRepository mySessionRepository() {
    return new MySessionRepository(sessions);
  }

  @Bean
  public CookieSerializer cookieSerializer() {
    //默认会将cookie进行 Base64 decode value
    DefaultCookieSerializer serializer = new DefaultCookieSerializer();
    serializer.setCookieName("JSESSIONID");
    serializer.setCookiePath("/");
    //允许跨域
    serializer.setDomainNamePattern("^.+?\\.(\\w+\\.[a-z]+)$");

    //cookie 的值不进行base64 编码
    serializer.setUseBase64Encoding(false);
    return serializer;
  }
}

application.properties

server.servlet.session.cookie.name=aa

效果
在这里插入图片描述

覆盖 CookieSerializer @Bean
在这里插入图片描述

三、通过包装请求会话进行高度自定义

原理是对请求会话进行包装自定义,能够高度支配会话,自由进行自定义开发。例如spring-session原理也是对请求会话进行包装,所以可以通过自定义进行对session的存储,例如存储到内存、redis、数据库、nosql等等。

首选实现HttpSession,并对它进行序列化,其中我添加了自定义id生成

package top.lingkang.testdemo;

import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

import javax.servlet.ServletContext;
import javax.servlet.http.HttpSession;
import javax.servlet.http.HttpSessionBindingEvent;
import javax.servlet.http.HttpSessionBindingListener;
import javax.servlet.http.HttpSessionContext;
import java.io.Serializable;
import java.util.*;

/**
 * @author lingkang
 * Created by 2022/1/24
 */
public class MyHttpSession implements HttpSession, Serializable {
    private static int nextId = 1;
    private String id;
    private final long creationTime;
    private int maxInactiveInterval;
    private long lastAccessedTime;
    private final ServletContext servletContext;
    private final Map<String, Object> attributes;
    private boolean invalid;
    private boolean isNew;

    public MyHttpSession(String id){
        this((ServletContext) null);
        this.id=id;
    }

    public MyHttpSession() {
        this((ServletContext) null);
    }

    public MyHttpSession(@Nullable ServletContext servletContext) {
        this(servletContext, (String) null);
    }

    public MyHttpSession(@Nullable ServletContext servletContext, @Nullable String id) {
        this.creationTime = System.currentTimeMillis();
        this.lastAccessedTime = System.currentTimeMillis();
        this.attributes = new LinkedHashMap();
        this.invalid = false;
        this.isNew = true;
        this.servletContext = null;
        this.id = id != null ? id : Integer.toString(nextId++);
    }

    public long getCreationTime() {
        this.assertIsValid();
        return this.creationTime;
    }

    public String getId() {
        return this.id;
    }

    public String changeSessionId() {
        this.id = Integer.toString(nextId++);
        return this.id;
    }

    public void access() {
        this.lastAccessedTime = System.currentTimeMillis();
        this.isNew = false;
    }

    public long getLastAccessedTime() {
        this.assertIsValid();
        return this.lastAccessedTime;
    }

    public ServletContext getServletContext() {
        return this.servletContext;
    }

    public void setMaxInactiveInterval(int interval) {
        this.maxInactiveInterval = interval;
    }

    public int getMaxInactiveInterval() {
        return this.maxInactiveInterval;
    }

    public HttpSessionContext getSessionContext() {
        throw new UnsupportedOperationException("getSessionContext");
    }

    public Object getAttribute(String name) {
        this.assertIsValid();
        Assert.notNull(name, "Attribute name must not be null");
        return this.attributes.get(name);
    }

    public Object getValue(String name) {
        return this.getAttribute(name);
    }

    public Enumeration<String> getAttributeNames() {
        this.assertIsValid();
        return Collections.enumeration(new LinkedHashSet(this.attributes.keySet()));
    }

    public String[] getValueNames() {
        this.assertIsValid();
        return StringUtils.toStringArray(this.attributes.keySet());
    }

    public void setAttribute(String name, @Nullable Object value) {
        this.assertIsValid();
        Assert.notNull(name, "Attribute name must not be null");
        if (value != null) {
            Object oldValue = this.attributes.put(name, value);
            if (value != oldValue) {
                if (oldValue instanceof HttpSessionBindingListener) {
                    ((HttpSessionBindingListener) oldValue).valueUnbound(new HttpSessionBindingEvent(this, name, oldValue));
                }

                if (value instanceof HttpSessionBindingListener) {
                    ((HttpSessionBindingListener) value).valueBound(new HttpSessionBindingEvent(this, name, value));
                }
            }
        } else {
            this.removeAttribute(name);
        }

    }

    public void putValue(String name, Object value) {
        this.setAttribute(name, value);
    }

    public void removeAttribute(String name) {
        this.assertIsValid();
        Assert.notNull(name, "Attribute name must not be null");
        Object value = this.attributes.remove(name);
        if (value instanceof HttpSessionBindingListener) {
            ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value));
        }

    }

    public void removeValue(String name) {
        this.removeAttribute(name);
    }

    public void clearAttributes() {
        Iterator it = this.attributes.entrySet().iterator();

        while (it.hasNext()) {
            Map.Entry<String, Object> entry = (Map.Entry) it.next();
            String name = (String) entry.getKey();
            Object value = entry.getValue();
            it.remove();
            if (value instanceof HttpSessionBindingListener) {
                ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value));
            }
        }

    }

    public void invalidate() {
        this.assertIsValid();
        this.invalid = true;
        this.clearAttributes();
    }

    public boolean isInvalid() {
        return this.invalid;
    }

    private void assertIsValid() {
        Assert.state(!this.isInvalid(), "The session has already been invalidated");
    }

    public void setNew(boolean value) {
        this.isNew = value;
    }

    public boolean isNew() {
        this.assertIsValid();
        return this.isNew;
    }

    public Serializable serializeState() {
        HashMap<String, Serializable> state = new HashMap();
        Iterator it = this.attributes.entrySet().iterator();

        while (it.hasNext()) {
            Map.Entry<String, Object> entry = (Map.Entry) it.next();
            String name = (String) entry.getKey();
            Object value = entry.getValue();
            it.remove();
            if (value instanceof Serializable) {
                state.put(name, (Serializable) value);
            } else if (value instanceof HttpSessionBindingListener) {
                ((HttpSessionBindingListener) value).valueUnbound(new HttpSessionBindingEvent(this, name, value));
            }
        }

        return state;
    }

    public void deserializeState(Serializable state) {
        Assert.isTrue(state instanceof Map, "Serialized state needs to be of type [java.util.Map]");
        this.attributes.putAll((Map) state);
    }
}

再创建一个请求包装类MyServletRequestWrapper

package top.lingkang.testdemo;

import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpSession;

/**
 * @author lingkang
 * Created by 2022/1/24
 */
public class MyServletRequestWrapper extends HttpServletRequestWrapper {

    private HttpSession session;

    public MyServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public HttpSession getSession() {
        return session;
    }

    public void setSession(HttpSession session){
        this.session=session;
    }
}

最后通过拦截器进行包装类替换,注意,应该经该拦截器放在最前面。

package top.lingkang.testdemo;

import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;

/**
 * @author lingkang
 * Created by 2022/1/24
 */
@Component
public class ClusterSessionFilter implements Filter {
    private Map<String, MyHttpSession> sessionMap = new HashMap<>();

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        MyHttpSession myHttpSession = null;
        String cookieName = "custom-cookie-name";

        // 获取cookie
        String cookieValue = getCookieValue(cookieName, request.getCookies());
        if (cookieValue != null) {
            myHttpSession = sessionMap.get(cookieValue);
        }

        if (myHttpSession == null) {
            // 自定义生成一个唯一id
            String id = UUID.randomUUID().toString();
            // 生成了id需要添加cookie
            HttpServletResponse response = (HttpServletResponse) servletResponse;
            Cookie cookie = new Cookie(cookieName, id);
            cookie.setPath("/");
            response.addCookie(cookie);

            myHttpSession = new MyHttpSession(id);
        }

        // 包装类
        MyServletRequestWrapper myServletRequestWrapper = new MyServletRequestWrapper(request);
        myServletRequestWrapper.setSession(myHttpSession);

        System.out.println(myHttpSession.getId());

        filterChain.doFilter(myServletRequestWrapper, servletResponse);

        // 将会话存储到内存,也可以选择存储到redis等
        sessionMap.put(myServletRequestWrapper.getSession().getId(), (MyHttpSession) myServletRequestWrapper.getSession());
    }

    private String getCookieValue(String name, Cookie[] cookies) {
        if (cookies == null)
            return null;
        for (Cookie cookie : cookies) {
            if (name.equals(cookie.getName())) {
                return cookie.getValue();
            }
        }
        return null;
    }
}

需要注意的是,上面的替换方案并没有做session淘汰机制,因为存储在内存中,不做淘汰机制会造成内存溢出

如果将会话存储到redis可以这样,即分布式会话存储方案:

package top.lingkang.testdemo;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

/**
 * @author lingkang
 * Created by 2022/1/24
 */
@Component
public class ClusterSessionFilter implements Filter {
    @Autowired
    private RedisTemplate redisTemplate;

    @Override
    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) servletRequest;
        MyHttpSession myHttpSession = null;
        String cookieName = "custom-cookie-name";

        // 获取cookie
        String cookieValue = getCookieValue(cookieName, request.getCookies());
        if (cookieValue != null) {
            Object o = redisTemplate.opsForValue().get(cookieValue);
            if (o != null) {
                myHttpSession = (MyHttpSession) o;
            }
        }

        if (myHttpSession == null) {
            // 自定义生成一个唯一id
            String id = UUID.randomUUID().toString();
            // 生成了id需要添加cookie
            HttpServletResponse response = (HttpServletResponse) servletResponse;
            Cookie cookie = new Cookie(cookieName, id);
            cookie.setPath("/");
            response.addCookie(cookie);

            myHttpSession = new MyHttpSession(id);
        }

        // 包装类
        MyServletRequestWrapper myServletRequestWrapper = new MyServletRequestWrapper(request);
        myServletRequestWrapper.setSession(myHttpSession);

        filterChain.doFilter(myServletRequestWrapper, servletResponse);

        // 将会话存储到内存,也可以选择存储到redis等
        redisTemplate.opsForValue().set(myHttpSession.getId(), myServletRequestWrapper.getSession(),1800000, TimeUnit.MILLISECONDS);
    }

    private String getCookieValue(String name, Cookie[] cookies) {
        if (cookies == null)
            return null;
        for (Cookie cookie : cookies) {
            if (name.equals(cookie.getName())) {
                return cookie.getValue();
            }
        }
        return null;
    }

}

redis的过期机制相当于session淘汰机制,同时又引入了新问题,就是极限情况下的空读问题:get请求要执行3分钟,而session在1分钟后到期,等执行完get再更新会话时发现session被淘汰了。解决方案:获取会话前先预判一下session剩余时间,若session的剩余时间少于5分钟,则直接淘汰这个会话,让用户重新登录。合理的时间分配也很重要,存储在其他地方也要考虑这个极限问题
贴出普通访问:

@RestController
public class WebController {
    @Autowired
    private HttpServletRequest request;

    @GetMapping("")
    public Object index() {
        request.getSession().setAttribute("a", System.currentTimeMillis());
        return "create session";
    }

    @GetMapping("get")
    public Object get() {
        HttpSession session = request.getSession();
        return session.getAttribute("a");
    }
}

后记:通过上面的会话存储可以做分布式集群了,理论上单体应用集群的扩充上限为redis集群的读写上限,假设redis写并发10w/s,那么你的应用集群并发处理能达10w/s。
若对session进一步优化,除去每次更新最后访问,则为读多写少,理论上集群可以无限扩展。
若使用数据库存储可以使用序列化二进制存储。
基于最后一点原理,我开发了分布式会话框架:
https://gitee.com/lingkang_top/final-session

posted @ 2022-09-16 00:09  凌康  阅读(317)  评论(0编辑  收藏  举报