springboot项目中添加mybatis自定义sql拦截器

import cn.hutool.core.util.ReUtil;
import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.util.Arrays;
import java.util.List;
import java.util.Properties;

/**
 * @version 1.0
 * @date 2022/5/7 14:42
 * @since : JDK 11
 */
@Slf4j
@Intercepts(
        {
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
                @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        }
)
public class MyBatisInterceptor implements Interceptor {

    private static final List<String> METHOD_LIST = Arrays.asList("list", "query", "find", "count", "select");

    /**
     * 白名单表
     */
    private static final  List<String> WHITE_LIST_TABLE = Arrays.asList("");

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameter = args[1];
        RowBounds rowBounds = (RowBounds) args[2];
        ResultHandler resultHandler = (ResultHandler) args[3];
        Executor executor = (Executor) invocation.getTarget();
        CacheKey cacheKey;
        BoundSql boundSql;
        //由于逻辑关系,只会进入一次
        if (args.length == 4) {
            //4 个参数时
            boundSql = ms.getBoundSql(parameter);
            cacheKey = executor.createCacheKey(ms, parameter, rowBounds, boundSql);
        } else {
            //6 个参数时
            cacheKey = (CacheKey) args[4];
            boundSql = (BoundSql) args[5];
        }
        String origSql = boundSql.getSql().replaceAll("\\s", " ").toLowerCase();
        // 组装新的 sql
        String newSql = handler(ms.getId(), origSql);
        // 重新new一个查询语句对象
        BoundSql newBoundSql = new BoundSql(ms.getConfiguration(), newSql,
                boundSql.getParameterMappings(), boundSql.getParameterObject());
        // 把新的查询放到statement里
        MappedStatement newMs = newMappedStatement(ms, new BoundSqlSource(newBoundSql));
        for (ParameterMapping mapping : boundSql.getParameterMappings()) {
            String prop = mapping.getProperty();
            if (boundSql.hasAdditionalParameter(prop)) {
                newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
            }
        }

        args[0] = newMs;
        if (args.length == 6) {
            args[5] = newMs.getBoundSql(parameter);
        }
        return invocation.proceed();
    }

    private MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder = new
                MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length > 0) {
            builder.keyProperty(ms.getKeyProperties()[0]);
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());
        return builder.build();
    }

    @Override
    public Object plugin(Object target) {
        log.info("MysqlInterceptor plugin>>>>>>>{}", target);
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
        String dialect = properties.getProperty("dialect");
        log.info("mybatis intercept dialect:>>>>>>>{}", dialect);
    }

    /**
     * 定义一个内部辅助类,作用是包装 SQL
     */
    class BoundSqlSource implements SqlSource {
        private BoundSql boundSql;

        public BoundSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }

        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }

    }


    /**
     * 模块级处理查询sql拼装未删除字段
     */
    private String handler(String mapperId, String sql) {
        for (String m : METHOD_LIST) {
            if (mapperId.toLowerCase().contains(m)) {
                if (sql.contains("is_del")) {
                    return sql;
                } else {
                    String tableAlias;
                    String[] split;
                    if (sql.contains("join")) {
                        split = ReUtil.findAll("from(.*?)left\\s+join", sql, 1).get(0).trim().split("\\s");
                    } else {
                        split = ReUtil.findAll("from(.*?)where", sql, 1).get(0).trim().split("\\s");
                    }
                    if (split.length > 1) {
                        tableAlias = split[split.length - 1];
                    } else {
                        tableAlias = split[0];
                    }
                    if (sql.contains("limit")) {
                        int index = sql.indexOf("limit");
                        String sqlPrefix = StrUtil.sub(sql, 0, index);
                        String sqlSuffix;
                        if (sql.contains("where")) {
                            sqlSuffix = " and " + tableAlias + ".is_del = 0 " + StrUtil.sub(sql, index, sql.length());
                        } else {
                            sqlSuffix = " where " + tableAlias + ".is_del = 0 " + StrUtil.sub(sql, index, sql.length());
                        }
                        return sqlPrefix + sqlSuffix;
                    } else {
                        if (sql.contains("where")) {
                            return sql + " and " + tableAlias + ".is_del = 0 ";
                        } else {
                            return sql + " where " + tableAlias + ".is_del = 0 ";
                        }
                    }
                }
            }
        }
        return sql;
    }
}

import org.apache.ibatis.session.SqlSessionFactory;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * @author JHL
 * @version 1.0
 * @date 2022/5/7 14:41
 * @since : JDK 11
 */
@Component
public class MyBatisSqlInterceptorConfiguration implements ApplicationContextAware {

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        SqlSessionFactory sqlSessionFactory = applicationContext.getBean(SqlSessionFactory.class);
        sqlSessionFactory.getConfiguration().addInterceptor(new MyBatisInterceptor());
    }
}

参考:MyBatis 自定义 SQL 拦截器

posted @ 2022-05-24 16:21  黄河大道东  阅读(427)  评论(0编辑  收藏  举报