【Mybatis】单独使用mybatis的SQL模板解析

前言

由于公司的项目历史设计问题坑多不见底,新项目没时间改,旧项目改不动。生产存在非常多的需要且只能通过数据库脚本改数据的违规操作。
每次开发到一半,一个工单就丢过来让去生产改数据,天天写脚本,这怎么受得了。
几个季度下来,忍无可忍,我一拍桌子,决定chi.....从头开发一个脚本执行工具,管理写过的脚本,脚本间可关联执行,可跨数据库。以解决临时写脚本,及脚本共享的问题。

设计上,脚本执行工具的脚本最早仅是占位符,如 "{参数名}",然后replaceAll,后面想了下,为啥不能像mybatis一样,或者说,为啥不直接借mybatis的xml模板解析功能来解析脚本呢。
————————————————————————————————

正文

通过走查代码,可以发现mybatis在XMLLanguageDriver类里的createSqlSource里实现SQL解析,并且支持使用<script></script>执行字符串模板。
OK,既然如此,就想办法按这里的代码执行即可。

直接修改XPathParser的话太过复杂,还得理内部逻辑,这里走模拟法

构建配置Configuration

首先需要一个mybatis的xml最小配置,这里写为字符串:

  String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n" //
            + "<!DOCTYPE configuration\r\n" //
            + " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n" //
            + " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"//
            + "<configuration>\r\n" //
            + "</configuration>";

然后根据mybatis的实现构建Configuration

点击查看代码
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
Configuration configuration = xmlConfigBuilder.parse();

有了Configuration就可以创建XPathParser来解析SQL了

点击查看代码
String script = "<script>\nSELECT COUNT(0) FROM TABLE_NAME where 1=1 <if test=\"param!=null\"> AND num = #{param}</if> \n</script>";
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), null);

Map<String, String> params = new HashMap<>();
param.put("param", "1");
BoundSql boundSql = source.getBoundSql(params);
String sql = boundSql.getSql();

结果

通过上述代码解析出来的SQL,若带了参数number,则为
SELECT COUNT(0) FROM TABLE_NAME where 1=1 AND num = ?

是带占位符的安全度高的预编译SQL,使用时需要构建PrepareStatement,然后通过如jdbc传入prepareStatement及手动set参数。

自动化prepareStatement参数设置

参照Mybatis的DefaultParameterHandler类的setParameters

创建PrepareStatement可以通过如springJdbc进行构建:
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());

然后调用DefaultParameterHandler的setParameters即可

获取完整SQL

参照mybatis-plus的PerformanceInterceptor类,该类可通过Statement获取SQL
需要的代码如下:

点击查看代码
    /**
     * COPY FROM {@link PerformanceInterceptor}
     */
    private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
    private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
    private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
    private Method oracleGetOriginalSqlMethod;
    private Method druidGetSQLMethod;

    /**
     * 获取原始SQL, COPY FROM {@link PerformanceInterceptor}
     */
    private String getOriginSql(PreparedStatement statement) {
        String originalSql = null;
        String stmtClassName = statement.getClass().getName();
        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
            try {
                if (druidGetSQLMethod == null) {
                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
                    druidGetSQLMethod = clazz.getMethod("getSql");
                }
                Object stmtSql = druidGetSQLMethod.invoke(statement);
                if (stmtSql instanceof String) {
                    originalSql = (String) stmtSql;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
            try {
                if (oracleGetOriginalSqlMethod != null) {
                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                    if (stmtSql instanceof String) {
                        originalSql = (String) stmtSql;
                    }
                } else {
                    Class<?> clazz = Class.forName(stmtClassName);
                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
                    if (oracleGetOriginalSqlMethod != null) {
                        // OraclePreparedStatementWrapper is not a public class, need set this.
                        oracleGetOriginalSqlMethod.setAccessible(true);
                        if (null != oracleGetOriginalSqlMethod) {
                            Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                            if (stmtSql instanceof String) {
                                originalSql = (String) stmtSql;
                            }
                        }
                    }
                }
            } catch (Exception e) {
                // ignore
            }
        }
        if (originalSql == null) {
            originalSql = statement.toString();
        }

        return originalSql;
    }

获取出来的SQL在mybatis中为com.mxxxx : SELECT....,大概是这样,需要截取:

点击查看代码
    /**
     * 获取sql语句开头部分, COPY FROM {@link PerformanceInterceptor}
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }

这样就获取出来完整的SQL了,可喜可贺。但是!

Oracle数据库使用该方法无效,打印出来还是预编译SQL

完整SQL第二方案

在第一个方案执行失败的情况下(可遍历字符串看有没有?),增加该方案:
参照:https://www.cnblogs.com/aipan/p/7237854.html
增加LoggableStatement,改造DefaultParameterHandler的setParameters(Copy出来作为新方法):

点击查看代码
    @SuppressWarnings({"unchecked", "rawtypes"})
    private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
        PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
        // 改造点
        LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        Object parameterObject = boundSql.getParameterObject();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ls, i + 1, value, jdbcType);
                    } catch (TypeException | SQLException e) {
                        throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }

        return ls;
    }

在获取完整SQL失败后,即可通过LoggableStatement来获取SQL:

if (!isCorrectGetSql(boundSql, originalSql)) {
   originalSql = statement.getQueryString();
}

————————————————————————————

上面提到项目的该类完整代码:

点击查看代码
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;

import org.apache.ibatis.builder.xml.XMLConfigBuilder;
import org.apache.ibatis.builder.xml.XMLMapperEntityResolver;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.parsing.XPathParser;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.xmltags.XMLScriptBuilder;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.TypeHandler;
import org.springframework.jdbc.core.JdbcTemplate;

import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;

/**
 * 使用Mybatis的XML解析来进行SQL构建的执行器
 */
public class MybatisTemplateSqlExcutor {

    /**
     * 脚本模板
     */
    private static final String SCRIPT_TEMPLATE = "<script>\n%s\n</script>";

    /**
     * MP环境配置
     */
    private Configuration configuration;

    /**
     * COPY FROM {@link PerformanceInterceptor}
     */
    private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
    private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
    private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
    private Method oracleGetOriginalSqlMethod;
    private Method druidGetSQLMethod;

    /**
     * 构造器,初始化MP环境配置
     */
    public MybatisTemplateSqlExcutor() {
        InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
        XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
        configuration = xmlConfigBuilder.parse();
    }

    public String parseSql(JdbcTemplate jdbcTemplate, String sqlTemplate, Map<String, Object> params) throws SQLException {
        String script = String.format(SCRIPT_TEMPLATE, sqlTemplate);
        XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
        SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), Map.class);
        BoundSql boundSql = source.getBoundSql(params);
        LoggableStatement statement = buildPreparedStatement(jdbcTemplate, boundSql);
        String originalSql = getOriginSql(statement.getPreparedStatement());
        int index = indexOfSqlStart(originalSql);
        if (index > 0) {
            originalSql = originalSql.substring(index);
        }
        if (!isCorrectGetSql(boundSql, originalSql)) {
            originalSql = statement.getQueryString();
        }
        return originalSql;
    }

    /**
     * 从MP复制过来的脚本解析方法
     */
    private SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
        XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
        return builder.parseScriptNode();
    }

    /**
     * 根据BoundSql组装PreparedStatement,用于获取实际SQL
     */
    @SuppressWarnings({"unchecked", "rawtypes"})
    private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
        PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
        LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        Object parameterObject = boundSql.getParameterObject();
        if (parameterMappings != null) {
            for (int i = 0; i < parameterMappings.size(); i++) {
                ParameterMapping parameterMapping = parameterMappings.get(i);
                if (parameterMapping.getMode() != ParameterMode.OUT) {
                    Object value;
                    String propertyName = parameterMapping.getProperty();
                    if (boundSql.hasAdditionalParameter(propertyName)) {
                        value = boundSql.getAdditionalParameter(propertyName);
                    } else if (parameterObject == null) {
                        value = null;
                    } else {
                        MetaObject metaObject = configuration.newMetaObject(parameterObject);
                        value = metaObject.getValue(propertyName);
                    }
                    TypeHandler typeHandler = parameterMapping.getTypeHandler();
                    JdbcType jdbcType = parameterMapping.getJdbcType();
                    if (value == null && jdbcType == null) {
                        jdbcType = configuration.getJdbcTypeForNull();
                    }
                    try {
                        typeHandler.setParameter(ls, i + 1, value, jdbcType);
                    } catch (TypeException | SQLException e) {
                        throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
                    }
                }
            }
        }

        return ls;
    }

    /**
     * 获取sql语句开头部分
     */
    private int indexOfSqlStart(String sql) {
        String upperCaseSql = sql.toUpperCase();
        Set<Integer> set = new HashSet<>();
        set.add(upperCaseSql.indexOf("SELECT "));
        set.add(upperCaseSql.indexOf("UPDATE "));
        set.add(upperCaseSql.indexOf("INSERT "));
        set.add(upperCaseSql.indexOf("DELETE "));
        set.remove(-1);
        if (CollectionUtils.isEmpty(set)) {
            return -1;
        }
        List<Integer> list = new ArrayList<>(set);
        list.sort(Comparator.naturalOrder());
        return list.get(0);
    }

    /**
     * 获取原始SQL, COPY FROM {@link PerformanceInterceptor}
     */
    private String getOriginSql(PreparedStatement statement) {
        String originalSql = null;
        String stmtClassName = statement.getClass().getName();
        if (DruidPooledPreparedStatement.equals(stmtClassName)) {
            try {
                if (druidGetSQLMethod == null) {
                    Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
                    druidGetSQLMethod = clazz.getMethod("getSql");
                }
                Object stmtSql = druidGetSQLMethod.invoke(statement);
                if (stmtSql instanceof String) {
                    originalSql = (String) stmtSql;
                }
            } catch (Exception e) {
                e.printStackTrace();
            }
        } else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
            try {
                if (oracleGetOriginalSqlMethod != null) {
                    Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                    if (stmtSql instanceof String) {
                        originalSql = (String) stmtSql;
                    }
                } else {
                    Class<?> clazz = Class.forName(stmtClassName);
                    oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
                    if (oracleGetOriginalSqlMethod != null) {
                        // OraclePreparedStatementWrapper is not a public class, need set this.
                        oracleGetOriginalSqlMethod.setAccessible(true);
                        if (null != oracleGetOriginalSqlMethod) {
                            Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
                            if (stmtSql instanceof String) {
                                originalSql = (String) stmtSql;
                            }
                        }
                    }
                }
            } catch (Exception e) {
                // ignore
            }
        }
        if (originalSql == null) {
            originalSql = statement.toString();
        }

        return originalSql;
    }

    /**
     * 获取此方法名的具体 Method
     *
     * @param clazz class 对象
     * @param methodName 方法名
     * @return 方法
     */
    public Method getMethodRegular(Class<?> clazz, String methodName) {
        if (Object.class.equals(clazz)) {
            return null;
        }
        for (Method method : clazz.getDeclaredMethods()) {
            if (method.getName().equals(methodName)) {
                return method;
            }
        }
        return getMethodRegular(clazz.getSuperclass(), methodName);
    }

    /**
     * 判断是否正确的获取了SQL
     */
    private boolean isCorrectGetSql(BoundSql boundSql, String originSql) {
        return countQuestionMark(boundSql.getSql()) > countQuestionMark(originSql);
    }

    /**
     * 统计占位符
     */
    private int countQuestionMark(String sql) {
        int result = 0;
        for (char c : sql.toCharArray())
            if (c == '?')
                result++;
        return result;
    }

    /**
     * 空MP配置模板,用于构建MP环境配置,(放这里是由于博客的编辑器识别问题,会导致高亮错误)
     */
    private static final String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n"
            + "<!DOCTYPE configuration\r\n"
            + " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n"
            + " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"
            + "<configuration>\r\n"
            + "</configuration>";
}

称之为Excutor是因为这其实是个子类,只提供parseSql

其他

若有更好的方法,请务必告诉我(๑•̀ㅂ•́)و✧

posted @ 2021-11-29 19:21  四方田春海  阅读(1484)  评论(0编辑  收藏  举报