Mybatis拦截器小运用-分页拦截器

这是一个仅支持MySQL的简单的Mybatis拦截器小运用。

原理

利用Mybatis的拦截器在sql执行之前把sql取出来,添加上分页语法,再把sql赋值回去。

  • 利用ThreadLocal在线程内传送 页数 和 页面大小参数,减少对原有代码的改动
  • 利用反射把修改后的sql 赋值回去

拦截器源码

@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class,Integer.class})})
public class MyPageInterceptor implements Interceptor {
    private static final Logger logger = LoggerFactory.getLogger(MyPageInterceptor.class);

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
        BoundSql boundSql = handler.getBoundSql();
        Object param = boundSql.getParameterObject();
        //取得原始sql
        String sql =boundSql.getSql();
        logger.info("拦截到 sql:{} 参数:{}",sql,param);
		
		//判断需不需要分页
        if(!PageHelper.isPage()){
            //不需要分页
            return invocation.proceed();
        }
        //获取原始sql查询总数
        int count = getCount(invocation,sql);
        PageHelper.setTotal(count);

        int pageNo=PageHelper.getPageNo();
        int pageSize=PageHelper.getPageSize();
        int currentPage=(pageNo-1)*pageSize;

		//组装新sql
        StringBuilder preSql=new StringBuilder(sql);
        preSql.append(" limit ").append(currentPage).append(",").append(pageSize);
        //新sql赋值回原对象
        ReflectUtil.setValueByFieldName(boundSql,"sql",preSql.toString());
        logger.info("分页后 sql:{} 参数:{}",boundSql.getSql(),param);
		//执行
        return invocation.proceed();

    }

    @Override
    public Object plugin(Object o) {
        return Plugin.wrap(o, this);
    }

    @Override
    public void setProperties(Properties properties) {
        logger.warn(properties.toString());
    }

    private Integer getCount(Invocation invocation,String sql) throws SQLException {
        RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
        Connection connection = (Connection) invocation.getArgs()[0];
        StringBuilder sb=new StringBuilder("select count(*) from (");
        sb.append(sql).append(") count");
        PreparedStatement ps = connection.prepareStatement(sb.toString());
        handler.getParameterHandler().setParameters(ps);
        ResultSet rs = ps.executeQuery();
        rs.next();
        Integer count=rs.getInt(1);
        rs.close();
        ps.close();
        return count;
    }

}

PageHelper源码

 
public class PageHelper {

    static ThreadLocal<PageInfo> pageInfo = new ThreadLocal<>();

    public static void startPage(Integer pageNo, Integer pageSize) {
        PageInfo p = new PageInfo();
        p.setPageNo(pageNo);
        p.setPageSize(pageSize);
        pageInfo.set(p);
    }

    public static PageInfo getPageInfo(Object data) throws ServiceException {
        PageInfo p = pageInfo.get();
        if (p == null) {
            throw new ServiceException("此线程不存在分页查询。");
        }
        p.setData(data);
        pageInfo.set(null);
        return p;
    }

    public static void setTotal(Integer count) {
        pageInfo.get().setTotal(count);
    }

    public static boolean isPage() {
        return pageInfo.get() == null ? false : true;
    }

    public static Integer getPageNo(){
        return pageInfo.get().getPageNo();
    }

    public static Integer getPageSize(){
        return pageInfo.get().getPageSize();
    }
}

反射工具源码

import java.lang.reflect.Field;

public class ReflectUtil {

    public static Object getFieldValue(Object target, String field) throws IllegalAccessException {
        Field f = getFieldByFieldName(target, field);
        if (f == null) {
            return null;
        }
        if (f.isAccessible()) {
            return f.get(target);
        }
        f.setAccessible(true);
        return f.get(target);

    }

    public static void setValueByFieldName(Object obj, String fieldName, Object value) throws SecurityException, NoSuchFieldException,
            IllegalArgumentException, IllegalAccessException {
        Field field = getFieldByFieldName(obj,fieldName);
        if (field.isAccessible()) {
            field.set(obj, value);
        } else {
            field.setAccessible(true);
            field.set(obj, value);
            field.setAccessible(false);
        }
    }

    public static Field getFieldByFieldName(Object obj, String fieldName) {
        Field f = null;
        while (true) {
            Class<?> clzz = obj.getClass();
            if (clzz == Object.class) {
                break;
            }
            try {
                f = clzz.getDeclaredField(fieldName);
                return f;
            } catch (NoSuchFieldException e) {
                e.printStackTrace();
            }
        }
        return f;
    }


}

posted @ 2019-12-12 08:56  A_yes  阅读(381)  评论(0编辑  收藏  举报