Mybatis之插件
mybatis插件开发
1、官网默认的拦截器实现方式-动态代理
那么就需要来明白mybatis插件开发中的原理,为什么要这么来进行书写。
而四大组件的加工都是需要通过configuration来进行构建的,可以看到在configuration对象中创建对应的四大对象的时候都是通过newXxx来进行构建的
最终都会在插件拦截器中来设置对应的插件:
executor = (Executor) interceptorChain.pluginAll(executor);
parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
看一下对应的拦截器链条:
public class InterceptorChain {
private final List<Interceptor> interceptors = new ArrayList<>();
public Object pluginAll(Object target) {
for (Interceptor interceptor : interceptors) {
target = interceptor.plugin(target);
}
return target;
}
public void addInterceptor(Interceptor interceptor) {
interceptors.add(interceptor);
}
public List<Interceptor> getInterceptors() {
return Collections.unmodifiableList(interceptors);
}
}
那么看一下对应的接口:
public interface Interceptor {
// 调用插件中的方法来执行
Object intercept(Invocation invocation) throws Throwable;
// 生成插件的方式,可以来重写。但是都用默认的
default Object plugin(Object target) {
return Plugin.wrap(target, this);
}
default void setProperties(Properties properties) {
// NOP
}
}
看下对应的插件类中的方法:
public class Plugin implements InvocationHandler {
private final Object target;
private final Interceptor interceptor;
private final Map<Class<?>, Set<Method>> signatureMap;
private Plugin(Object target, Interceptor interceptor, Map<Class<?>, Set<Method>> signatureMap) {
this.target = target;
this.interceptor = interceptor;
this.signatureMap = signatureMap;
}
// 调用静态方法
public static Object wrap(Object target, Interceptor interceptor) {
// 得到class对应的多个方法
Map<Class<?>, Set<Method>> signatureMap = getSignatureMap(interceptor);
// 获取得到对应的接口类型
Class<?> type = target.getClass();
Class<?>[] interfaces = getAllInterfaces(type, signatureMap);
// 创建出来对应的代理对象
if (interfaces.length > 0) {
return Proxy.newProxyInstance(
type.getClassLoader(),
interfaces,
new Plugin(target, interceptor, signatureMap));
}
return target;
}
// 当真正的插件在执行的之后,会执行到这里
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
try {
// 取出来对应的方法来依次进行执行
Set<Method> methods = signatureMap.get(method.getDeclaringClass());
// 如果包含,那么就在这里来进行执行
if (methods != null && methods.contains(method)) {
return interceptor.intercept(new Invocation(target, method, args));
}
return method.invoke(target, args);
} catch (Exception e) {
throw ExceptionUtil.unwrapThrowable(e);
}
}
private static Map<Class<?>, Set<Method>> getSignatureMap(Interceptor interceptor) {
// 首先判断插件是否实现了Interceptor这个类并且有Intercepts注解
Intercepts interceptsAnnotation = interceptor.getClass().getAnnotation(Intercepts.class);
if (interceptsAnnotation == null) {
throw new PluginException("No @Intercepts annotation was found in interceptor " + interceptor.getClass().getName());
}
// 拿到注解中的值
Signature[] sigs = interceptsAnnotation.value();
Map<Class<?>, Set<Method>> signatureMap = new HashMap<>();
// 循环遍历数组标签!这里可以看到Signature可以写多个Signature注解
for (Signature sig : sigs) {
// 一个类中的多个方法可以放入进去,如果没有,返回新的HashSet
Set<Method> methods = MapUtil.computeIfAbsent(signatureMap, sig.type(), k -> new HashSet<>());
try {
// 获取得到方法和参数值确定对应的方法并加入到method中去
Method method = sig.type().getMethod(sig.method(), sig.args());
methods.add(method);
} catch (NoSuchMethodException e) {
throw new PluginException("Could not find method on " + sig.type() + " named " + sig.method() + ". Cause: " + e, e);
}
}
return signatureMap;
}
private static Class<?>[] getAllInterfaces(Class<?> type, Map<Class<?>, Set<Method>> signatureMap) {
Set<Class<?>> interfaces = new HashSet<>();
while (type != null) {
for (Class<?> c : type.getInterfaces()) {
if (signatureMap.containsKey(c)) {
interfaces.add(c);
}
}
type = type.getSuperclass();
}
return interfaces.toArray(new Class<?>[0]);
}
}
最终会执行到我们的interceptor.intercept方法中来,最终执行到我们自己的逻辑中来。
那么也就是说执行四大组件的时候,先让自定义的拦截器先来执行,然后再让原有的拦截器来进行执行。
2、模拟案例
2.1、生成Executor插件
来给query方法做一个拦截
@Intercepts({@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
)})
public class PagingInterceptor implements Interceptor {....}
2.2、生成StatementHandler插件
@Intercepts({@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {.....}
可以看到在@Intercepts是可以写多个的方法签名的,表示代理的都有哪些方法。
注意:这里的动态代理的模式是可以来进行学习一下的。
3、分页插件
不想使用PageHelper来进行分页了,所以这里想要自定义一款插件来进行使用。
在开发中没有开启二级缓存和一级缓存,也就是禁用的状态,所以直接关闭。
那么在四大插件中选择的有两个插件对象:Executor和StatementHandler两个接口,那么应该来选择哪一个?
我觉得是两个都可以!
即使Executor中使用到了二级缓存,但是二级缓存在进行缓存之后重新查询的时候,将没有对应的total结果,相对来说,二级缓存的功能也不是那么健全。
既然将二级缓存和一级缓存都关闭掉了,那么对于缓存来说,就没有太大的影响!所以可以直接来进行使用即可。
3.0、功能分析
- 1、既然能够来到插件位置,那么肯定是对应的执行器;
- 2、这里应该来对执行的SQL的类型(select、update、delete、insert)来做一个校验操作
- 3、读取该SQL能够查询得到的总数;
- 4、在SQL后面来进行分页操作;
3.1、实现方式一
V1版本
利用StatementHandler来进行实现,对应的实现代码如下所示:
@Intercepts({@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
Object parameterObject = boundSql.getParameterObject();
Page page = null;
if (parameterObject != null) {
if (parameterObject instanceof Page) {
page = (Page) parameterObject;
} else if (parameterObject instanceof Map) {
page = (Page) ((Map<?, ?>) parameterObject).values().stream().filter(value -> value instanceof Page).findFirst().orElse(null);
}
}
if (page != null) {
String sql = boundSql.getSql();
String newSql = String.format("select count(*) from (%s) as tmp_page", sql);
Connection connection = (Connection) invocation.getArgs()[0];
PreparedStatement preparedStatement = connection.prepareStatement(newSql);
statementHandler.getParameterHandler().setParameters(preparedStatement);
ResultSet resultSet = preparedStatement.executeQuery();
int count = -1;
if (resultSet != null) {
// 一定要添加这个,不然在查询出来的时候会有错误:before start of result set
while (resultSet.next()) {
page.setTotal(resultSet.getLong(1));
break;
}
}
resultSet.close();
preparedStatement.close();
String newExecutorSql = String.format("%s limit %s offset %s", boundSql.getSql(), page.getSize(), page.getOffset());
// 修改boundSql中的值
SystemMetaObject.forObject(boundSql).setValue("sql", newExecutorSql);
// 然后来执行新的SQL,将结果集封装到Page中来
Object proceedResult = invocation.proceed();
return proceedResult;
}
return invocation.proceed();
}
}
但是感觉到这类的逻辑是不太严谨的。缺少了对于MappedStatement的处理和校验。那么这里可以手动的来进行处理。
没有考虑到并发场景的情况,那么将并发场景融合进来进行考虑
V2并发版本
利用ThreadLocal来进行实现这种方式:
/**
* Statement prepare(Connection connection, Integer transactionTimeout)
*
* @author liguang
* @date 2022/6/1 11:41
*/
@Intercepts({@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class})})
public class PageIntercepter implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
BoundSql boundSql = statementHandler.getBoundSql();
Object parameterObject = boundSql.getParameterObject();
Page page = null;
if (parameterObject != null) {
if (parameterObject instanceof Page) {
page = (Page) parameterObject;
} else if (parameterObject instanceof Map) {
page = (Page) ((Map<?, ?>) parameterObject).values().stream().filter(value -> value instanceof Page).findFirst().orElse(null);
}
}
if (page != null) {
String sql = boundSql.getSql();
String newSql = String.format("select count(*) from (%s) as tmp_page", sql);
Connection connection = (Connection) invocation.getArgs()[0];
PreparedStatement preparedStatement = connection.prepareStatement(newSql);
statementHandler.getParameterHandler().setParameters(preparedStatement);
ResultSet resultSet = preparedStatement.executeQuery();
int count = -1;
if (resultSet != null) {
// 一定要添加这个,不然在查询出来的时候会有错误:before start of result set
while (resultSet.next()) {
page.setTotal(resultSet.getLong(1));
break;
}
}
resultSet.close();
preparedStatement.close();
Map<String, Integer> page1 = PageHelper.getPage();
Integer index = page1.get("index");
Integer size = page1.get("size");
// String newExecutorSql = String.format("%s limit %s offset %s", boundSql.getSql(), page.getSize(), page.getOffset());
String newExecutorSql = String.format("%s limit %s offset %s", boundSql.getSql(), index, size);
// 修改boundSql中的值
SystemMetaObject.forObject(boundSql).setValue("sql", newExecutorSql);
// 然后来执行新的SQL,将结果集封装到Page中来
Object proceedResult = invocation.proceed();
return proceedResult;
}
return invocation.proceed();
}
}
3.2、实现方式二
利用Executor中的query来进行实现。
V1版本
@Intercepts({@Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}
), @Signature(
type = Executor.class,
method = "query",
args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}
)})
public class PagingInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {
System.out.println("--------------||||||||||----------------");
RowBounds rowBounds = (RowBounds)invocation.getArgs()[2];
Object[] args = invocation.getArgs();
if (rowBounds != null && rowBounds instanceof PagingRowBounds) {
System.out.println("是否使用我们的分页");
// 获取得到分页对象
PagingRowBounds pagingRowBounds = (PagingRowBounds)rowBounds;
// 获取得到执行器对象
Executor executor = (Executor)invocation.getTarget();
// 获取得到封装的映射语句对象
MappedStatement mappedStatement = (MappedStatement)invocation.getArgs()[0];
// 获取得到参数
Object parameterObject = invocation.getArgs()[1];
// 调用executor传入进来的第三个参数ResultHandler
ResultHandler resultHandler = (ResultHandler)invocation.getArgs()[3];
// 如果传入的参数的个数大于四个,那么应该有对应的缓存等条件,那么获取得到BoundSql。BoundSql中封装的有
BoundSql boundSql = invocation.getArgs().length > 4 ? (BoundSql)invocation.getArgs()[5] : mappedStatement.getBoundSql(parameterObject);
// 校验操作
this.validate(mappedStatement, boundSql);
// 统计多少条数据
Integer totalCount = this.executeCountSql(executor, mappedStatement, parameterObject, pagingRowBounds, resultHandler, boundSql);
// 执行真正的结构数据
List result = this.executePagingSql(executor, mappedStatement, parameterObject, pagingRowBounds, resultHandler, boundSql);
return new Page(totalCount, pagingRowBounds.getCurrPage(), pagingRowBounds.getPageSize(), result);
} else {
return invocation.proceed();
}
}
// 查询得到的最终数据
private Integer executeCountSql(Executor executor, MappedStatement mappedStatement, Object parameterObject, PagingRowBounds pagingRowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
BoundSql countBoundSql = this.createCountSql(mappedStatement, boundSql);
// 获取得到新的映射语句
MappedStatement countMappedStatement = this.createCountMappedStatement(mappedStatement);
// 创建缓存key,根据新的构建条件
CacheKey cacheKey = executor.createCacheKey(countMappedStatement, parameterObject, pagingRowBounds, countBoundSql);
return (Integer)executor.query(countMappedStatement, parameterObject, pagingRowBounds, resultHandler, cacheKey, countBoundSql).get(0);
}
private BoundSql createCountSql(MappedStatement mappedStatement, BoundSql originBoundSql) {
// select count(1) from (select * from dfkdsjkfjdsk wehre id != 666)
String countSql = "SELECT COUNT(1) FROM (" + originBoundSql.getSql() + ")";
// 创建新的SQL来进行执行统计!
return new BoundSql(mappedStatement.getConfiguration(), countSql, originBoundSql.getParameterMappings(), originBoundSql.getParameterObject());
}
// 重新创建MappedStatement,然后来创建新的MappedStatement用于查询count的值
// 这里主要是将这里的resultMap中的返回类型给修改成了int类型
private MappedStatement createCountMappedStatement(MappedStatement mappedStatement) {
List<ResultMap> resultMaps = new ArrayList();
ResultMap resultMap = (new Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), Integer.class, new ArrayList())).build();
resultMaps.add(resultMap);
return (new org.apache.ibatis.mapping.MappedStatement.Builder(mappedStatement.getConfiguration(), mappedStatement.getId(), mappedStatement.getSqlSource(), mappedStatement.getSqlCommandType())).statementType(mappedStatement.getStatementType()).cache(mappedStatement.getCache()).fetchSize(mappedStatement.getFetchSize()).flushCacheRequired(mappedStatement.isFlushCacheRequired()).resource(mappedStatement.getResource()).keyGenerator(mappedStatement.getKeyGenerator()).parameterMap(mappedStatement.getParameterMap()).timeout(mappedStatement.getTimeout()).resultOrdered(mappedStatement.isResultOrdered()).resultSetType(mappedStatement.getResultSetType()).useCache(mappedStatement.isUseCache()).resultMaps(resultMaps).build();
}
private List executePagingSql(Executor executor, MappedStatement mappedStatement, Object parameterObject, PagingRowBounds pagingRowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
BoundSql pagingBoundSql = this.createPagingSql(mappedStatement, boundSql, pagingRowBounds);
CacheKey cacheKey = executor.createCacheKey(mappedStatement, parameterObject, pagingRowBounds, pagingBoundSql);
return executor.query(mappedStatement, parameterObject, pagingRowBounds, resultHandler, cacheKey, pagingBoundSql);
}
private BoundSql createPagingSql(MappedStatement mappedStatement, BoundSql originBoundSql, PagingRowBounds pagingRowBounds) {
// Oracle用法
// String pagingSql = "SELECT * FROM (SELECT ROWNUM AS ROW_NUM, A.* FROM (" + originBoundSql.getSql() + ") A) WHERE ROW_NUM >= " + pagingRowBounds.getStartRowNum() + " AND ROW_NUM <= " + pagingRowBounds.getEndRowNum();
// MySQL用法
String pagingSql = originBoundSql.getSql() + " limit " +pagingRowBounds.getStartRowNum()+" , "+pagingRowBounds.getEndRowNum();
return new BoundSql(mappedStatement.getConfiguration(), pagingSql, originBoundSql.getParameterMappings(), originBoundSql.getParameterObject());
}
// 校验类型是否是以select开头的
private void validate(MappedStatement mappedStatement, BoundSql boundSql) {
if (SqlCommandType.SELECT != mappedStatement.getSqlCommandType()) {
throw new IllegalArgumentException("该sql语句不是select语句: " + boundSql.getSql());
}
}
}
这里需要注意的是:返回的还是一个List,只不过是我们自己来进行实现的结果。那么这里最终不会影响到结果集的封装。
也可以将结果集使用一个Page来进行封装,采用泛型的方式来进行操作。
这里完全可以将上面两个步骤合在一起:通过参数来进行判断是否有分页对象Page,然后来进行判断其中的值,利用ThreadLocal来进行设置,然后通过从ThreadLocal获取得到对应的值来进行判断是否有对应的值,如果有的话,那么是分页;如果没有或者是配置了负数,那么应该报错提示参数设置错误,或者是说不是分页的查询。
4、总结
1、使用动态代理的方式来操作对应的插件,需要注意的是,这里的返回值,就是自定义插件中的返回值。所以在不影响封装条件的前提下,可以自定义操作。如上面的分页插件中使用了Page继承list,来封装对应的对象;
2、如果需要修改的SQL,只是添加了静态文本,那么这里重写设置给BoundSQL即可;如果在BoundSql中修改了对应的参数或者是加载了可以通过修改来进行实现,这里就需要注意到对应的SQL和返回值的封装条件了。