Mybatis之插件

mybatis插件开发

1、官网默认的拦截器实现方式-动态代理

那么就需要来明白mybatis插件开发中的原理,为什么要这么来进行书写。

而四大组件的加工都是需要通过configuration来进行构建的,可以看到在configuration对象中创建对应的四大对象的时候都是通过newXxx来进行构建的

最终都会在插件拦截器中来设置对应的插件:

executor = (Executor) interceptorChain.pluginAll(executor);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);

看一下对应的拦截器链条:

public class InterceptorChain {

  private final List<Interceptor> interceptors = new ArrayList<>();

  public Object pluginAll(Object target) {
    for (Interceptor interceptor : interceptors) {
      target = interceptor.plugin(target);
    }
    return target;
  }

  public void addInterceptor(Interceptor interceptor) {
    interceptors.add(interceptor);
  }

  public List<Interceptor> getInterceptors() {
    return Collections.unmodifiableList(interceptors);
  }

}

那么看一下对应的接口:

public interface Interceptor {
  // 调用插件中的方法来执行
  Object intercept(Invocation invocation) throws Throwable;
  
  // 生成插件的方式,可以来重写。但是都用默认的
  default Object plugin(Object target) {
    return Plugin.wrap(target, this);
  }

  default void setProperties(Properties properties) {
    // NOP
  }

}

看下对应的插件类中的方法:

public class Plugin implements InvocationHandler {

  private final Object target;
  private final Interceptor interceptor;
  private final Map<Class<?>, Set<Method>> signatureMap;

  private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
    this.target = target;
    this.interceptor = interceptor;
    this.signatureMap = signatureMap;
  }
  // 调用静态方法
  public static Object wrap(Object target, Interceptor interceptor) {
    // 得到class对应的多个方法
    Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
    // 获取得到对应的接口类型
    Class<?> type = target.getClass();
    Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
    // 创建出来对应的代理对象
    if (interfaces.length > 0) {
      return Proxy.newProxyInstance(
          type.getClassLoader(),
          interfaces,
          new Plugin(target, interceptor, signatureMap));
    }
    return target;
  }

  // 当真正的插件在执行的之后,会执行到这里
  @Override
  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
    try {
      // 取出来对应的方法来依次进行执行
      Set<Method> methods = signatureMap.get(method.getDeclaringClass());
      // 如果包含,那么就在这里来进行执行
      if (methods != null && methods.contains(method)) {
        return interceptor.intercept(new Invocation(target, method, args));
      }
      return method.invoke(target, args);
    } catch (Exception e) {
      throw ExceptionUtil.unwrapThrowable(e);
    }
  }

  private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
    // 首先判断插件是否实现了Interceptor这个类并且有Intercepts注解
    Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);   
    if (interceptsAnnotation == null) {
      throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
    }
    // 拿到注解中的值
    Signature[] sigs = interceptsAnnotation.value();
    Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
    // 循环遍历数组标签!这里可以看到Signature可以写多个Signature注解
    for (Signature sig : sigs) {
      // 一个类中的多个方法可以放入进去,如果没有,返回新的HashSet
      Set<Method> methods = MapUtil.computeIfAbsent(signatureMap, sig.type(), k -> new HashSet<>());
      try {
        // 获取得到方法和参数值确定对应的方法并加入到method中去
        Method method = sig.type().getMethod(sig.method(), sig.args());
        methods.add(method);
      } catch (NoSuchMethodException e) {
        throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
      }
    }
    return signatureMap;
  }

  private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
    Set<Class<?>> interfaces = new HashSet<>();
    while (type != null) {
      for (Class<?> c : type.getInterfaces()) {
        if (signatureMap.containsKey(c)) {
          interfaces.add(c);
        }
      }
      type = type.getSuperclass();
    }
    return interfaces.toArray(new Class<?>[0]);
  }

}

最终会执行到我们的interceptor.intercept方法中来,最终执行到我们自己的逻辑中来。

那么也就是说执行四大组件的时候,先让自定义的拦截器先来执行,然后再让原有的拦截器来进行执行。

2、模拟案例

2.1、生成Executor插件

来给query方法做一个拦截

@Intercepts({@Signature(
    type = Executor.class,
    method = "query",
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
), @Signature(
    type = Executor.class,
    method = "query",
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
)})
public class PagingInterceptor implements Interceptor {....}

2.2、生成StatementHandler插件

@Intercepts({@Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {.....}

可以看到在@Intercepts是可以写多个的方法签名的,表示代理的都有哪些方法。

注意:这里的动态代理的模式是可以来进行学习一下的。

3、分页插件

不想使用PageHelper来进行分页了,所以这里想要自定义一款插件来进行使用。

在开发中没有开启二级缓存和一级缓存,也就是禁用的状态,所以直接关闭。

那么在四大插件中选择的有两个插件对象:Executor和StatementHandler两个接口,那么应该来选择哪一个?

我觉得是两个都可以!

即使Executor中使用到了二级缓存,但是二级缓存在进行缓存之后重新查询的时候,将没有对应的total结果,相对来说,二级缓存的功能也不是那么健全。

既然将二级缓存和一级缓存都关闭掉了,那么对于缓存来说,就没有太大的影响!所以可以直接来进行使用即可。

3.0、功能分析

  • 1、既然能够来到插件位置,那么肯定是对应的执行器;
  • 2、这里应该来对执行的SQL的类型(select、update、delete、insert)来做一个校验操作
  • 3、读取该SQL能够查询得到的总数;
  • 4、在SQL后面来进行分页操作;

3.1、实现方式一

V1版本

利用StatementHandler来进行实现,对应的实现代码如下所示:

@Intercepts({@Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        Object parameterObject = boundSql.getParameterObject();
        Page page = null;
        if (parameterObject != null) {
            if (parameterObject instanceof Page) {
                page = (Page) parameterObject;
            } else if (parameterObject instanceof Map) {
                page = (Page) ((Map<?, ?>) parameterObject).values().stream().filter(value -> value instanceof Page).findFirst().orElse(null);
            }
        }
        if (page != null) {
            String sql = boundSql.getSql();
            String newSql = String.format("select count(*) from (%s) as tmp_page", sql);
            Connection connection = (Connection) invocation.getArgs()[0];
            PreparedStatement preparedStatement = connection.prepareStatement(newSql);
            statementHandler.getParameterHandler().setParameters(preparedStatement);
            ResultSet resultSet = preparedStatement.executeQuery();
            int count = -1;
            if (resultSet != null) {
                // 一定要添加这个,不然在查询出来的时候会有错误:before start of result set
                while (resultSet.next()) {
                    page.setTotal(resultSet.getLong(1));
                    break;
                }
            }
            resultSet.close();
            preparedStatement.close();
            String newExecutorSql = String.format("%s  limit %s  offset %s", boundSql.getSql(), page.getSize(), page.getOffset());
            // 修改boundSql中的值
            SystemMetaObject.forObject(boundSql).setValue("sql", newExecutorSql);
            // 然后来执行新的SQL,将结果集封装到Page中来
            Object proceedResult = invocation.proceed();
            return proceedResult;
        }
        return invocation.proceed();
    }
}

但是感觉到这类的逻辑是不太严谨的。缺少了对于MappedStatement的处理和校验。那么这里可以手动的来进行处理。

没有考虑到并发场景的情况,那么将并发场景融合进来进行考虑

V2并发版本

利用ThreadLocal来进行实现这种方式:

/**
 * Statement prepare(Connection connection, Integer transactionTimeout)
 *
 * @author liguang
 * @date 2022/6/1 11:41
 */
@Intercepts({@Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        BoundSql boundSql = statementHandler.getBoundSql();
        Object parameterObject = boundSql.getParameterObject();
        Page page = null;
        if (parameterObject != null) {
            if (parameterObject instanceof Page) {
                page = (Page) parameterObject;
            } else if (parameterObject instanceof Map) {
                page = (Page) ((Map<?, ?>) parameterObject).values().stream().filter(value -> value instanceof Page).findFirst().orElse(null);
            }
        }
        if (page != null) {
            String sql = boundSql.getSql();
            String newSql = String.format("select count(*) from (%s) as tmp_page", sql);
            Connection connection = (Connection) invocation.getArgs()[0];
            PreparedStatement preparedStatement = connection.prepareStatement(newSql);
            statementHandler.getParameterHandler().setParameters(preparedStatement);
            ResultSet resultSet = preparedStatement.executeQuery();
            int count = -1;
            if (resultSet != null) {
                // 一定要添加这个,不然在查询出来的时候会有错误:before start of result set
                while (resultSet.next()) {
                    page.setTotal(resultSet.getLong(1));
                    break;
                }
            }
            resultSet.close();
            preparedStatement.close();
            Map<String, Integer> page1 = PageHelper.getPage();
            Integer index = page1.get("index");
            Integer size = page1.get("size");
//            String newExecutorSql = String.format("%s  limit %s  offset %s", boundSql.getSql(), page.getSize(), page.getOffset());
            String newExecutorSql = String.format("%s  limit %s  offset %s", boundSql.getSql(), index, size);
            // 修改boundSql中的值
            SystemMetaObject.forObject(boundSql).setValue("sql", newExecutorSql);
            // 然后来执行新的SQL,将结果集封装到Page中来
            Object proceedResult = invocation.proceed();
            return proceedResult;
        }
        return invocation.proceed();
    }
}

3.2、实现方式二

利用Executor中的query来进行实现。

V1版本

@Intercepts({@Signature(
    type = Executor.class,
    method = "query",
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
), @Signature(
    type = Executor.class,
    method = "query",
    args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
)})
public class PagingInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        System.out.println("--------------||||||||||----------------");
        RowBounds rowBounds = (RowBounds)invocation.getArgs()[2];
        Object[] args = invocation.getArgs();
        if (rowBounds != null && rowBounds instanceof PagingRowBounds) {
            System.out.println("是否使用我们的分页");
            // 获取得到分页对象
            PagingRowBounds pagingRowBounds = (PagingRowBounds)rowBounds;
            // 获取得到执行器对象
            Executor executor = (Executor)invocation.getTarget();
            // 获取得到封装的映射语句对象
            MappedStatement mappedStatement = (MappedStatement)invocation.getArgs()[0];
            // 获取得到参数
            Object parameterObject = invocation.getArgs()[1];
            // 调用executor传入进来的第三个参数ResultHandler
            ResultHandler resultHandler = (ResultHandler)invocation.getArgs()[3];
            // 如果传入的参数的个数大于四个,那么应该有对应的缓存等条件,那么获取得到BoundSql。BoundSql中封装的有
            BoundSql boundSql = invocation.getArgs().length > 4 ? (BoundSql)invocation.getArgs()[5] : mappedStatement.getBoundSql(parameterObject);
            // 校验操作
            this.validate(mappedStatement, boundSql);
            // 统计多少条数据
            Integer totalCount = this.executeCountSql(executor, mappedStatement, parameterObject, pagingRowBounds, resultHandler, boundSql);
            // 执行真正的结构数据
            List result = this.executePagingSql(executor, mappedStatement, parameterObject, pagingRowBounds, resultHandler, boundSql);
            return new Page(totalCount, pagingRowBounds.getCurrPage(), pagingRowBounds.getPageSize(), result);
        } else {
            return invocation.proceed();
        }
    }

    // 查询得到的最终数据
    private Integer executeCountSql(Executor executor, MappedStatement mappedStatement, Object parameterObject, PagingRowBounds pagingRowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        BoundSql countBoundSql = this.createCountSql(mappedStatement, boundSql);
        // 获取得到新的映射语句
        MappedStatement countMappedStatement = this.createCountMappedStatement(mappedStatement);
        // 创建缓存key,根据新的构建条件
        CacheKey cacheKey = executor.createCacheKey(countMappedStatement, parameterObject, pagingRowBounds, countBoundSql);
        return (Integer)executor.query(countMappedStatement, parameterObject, pagingRowBounds, resultHandler, cacheKey, countBoundSql).get(0);
    }

    private BoundSql createCountSql(MappedStatement mappedStatement, BoundSql originBoundSql) {
        // select count(1) from (select * from dfkdsjkfjdsk wehre id != 666)
        String countSql = "SELECT COUNT(1) FROM (" + originBoundSql.getSql() + ")";
        // 创建新的SQL来进行执行统计!
        return new BoundSql(mappedStatement.getConfiguration(), countSql, originBoundSql.getParameterMappings(), originBoundSql.getParameterObject());
    }

    // 重新创建MappedStatement,然后来创建新的MappedStatement用于查询count的值
    // 这里主要是将这里的resultMap中的返回类型给修改成了int类型
    private MappedStatement createCountMappedStatement(MappedStatement mappedStatement) {
        List<ResultMap> resultMaps = new ArrayList();
        ResultMap resultMap = (new Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), Integer.class, new ArrayList())).build();
        resultMaps.add(resultMap);
        return (new org.apache.ibatis.mapping.MappedStatement.Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), mappedStatement.getSqlSource(), mappedStatement.getSqlCommandType())).statementType(mappedStatement.getStatementType()).cache(mappedStatement.getCache()).fetchSize(mappedStatement.getFetchSize()).flushCacheRequired(mappedStatement.isFlushCacheRequired()).resource(mappedStatement.getResource()).keyGenerator(mappedStatement.getKeyGenerator()).parameterMap(mappedStatement.getParameterMap()).timeout(mappedStatement.getTimeout()).resultOrdered(mappedStatement.isResultOrdered()).resultSetType(mappedStatement.getResultSetType()).useCache(mappedStatement.isUseCache()).resultMaps(resultMaps).build();
    }

    private List executePagingSql(Executor executor, MappedStatement mappedStatement, Object parameterObject, PagingRowBounds pagingRowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        BoundSql pagingBoundSql = this.createPagingSql(mappedStatement, boundSql, pagingRowBounds);
        CacheKey cacheKey = executor.createCacheKey(mappedStatement, parameterObject, pagingRowBounds, pagingBoundSql);
        return executor.query(mappedStatement, parameterObject, pagingRowBounds, resultHandler, cacheKey, pagingBoundSql);
    }

    private BoundSql createPagingSql(MappedStatement mappedStatement, BoundSql originBoundSql, PagingRowBounds pagingRowBounds) {
        // Oracle用法
        // String pagingSql = "SELECT * FROM (SELECT ROWNUM AS ROW_NUM, A.* FROM (" + originBoundSql.getSql() + ") A) WHERE ROW_NUM >= " + pagingRowBounds.getStartRowNum() + " AND ROW_NUM <= " + pagingRowBounds.getEndRowNum();
        // MySQL用法
        String pagingSql = originBoundSql.getSql() +  " limit " +pagingRowBounds.getStartRowNum()+" , "+pagingRowBounds.getEndRowNum();
        return new BoundSql(mappedStatement.getConfiguration(), pagingSql, originBoundSql.getParameterMappings(), originBoundSql.getParameterObject());
    }

    // 校验类型是否是以select开头的
    private void validate(MappedStatement mappedStatement, BoundSql boundSql) {
        if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()) {
            throw new IllegalArgumentException("该sql语句不是select语句: " + boundSql.getSql());
        }
    }
}

这里需要注意的是:返回的还是一个List,只不过是我们自己来进行实现的结果。那么这里最终不会影响到结果集的封装。

也可以将结果集使用一个Page来进行封装,采用泛型的方式来进行操作。

这里完全可以将上面两个步骤合在一起:通过参数来进行判断是否有分页对象Page,然后来进行判断其中的值,利用ThreadLocal来进行设置,然后通过从ThreadLocal获取得到对应的值来进行判断是否有对应的值,如果有的话,那么是分页;如果没有或者是配置了负数,那么应该报错提示参数设置错误,或者是说不是分页的查询。

4、总结

1、使用动态代理的方式来操作对应的插件,需要注意的是,这里的返回值,就是自定义插件中的返回值。所以在不影响封装条件的前提下,可以自定义操作。如上面的分页插件中使用了Page继承list,来封装对应的对象;

2、如果需要修改的SQL,只是添加了静态文本,那么这里重写设置给BoundSQL即可;如果在BoundSql中修改了对应的参数或者是加载了可以通过修改来进行实现,这里就需要注意到对应的SQL和返回值的封装条件了。

posted @ 2022-06-02 09:59  写的代码很烂  阅读(1005)  评论(0编辑  收藏  举报