【Mybatis】单独使用mybatis的SQL模板解析
前言
由于公司的项目历史设计问题坑多不见底,新项目没时间改,旧项目改不动。生产存在非常多的需要且只能通过数据库脚本改数据的违规操作。
每次开发到一半,一个工单就丢过来让去生产改数据,天天写脚本,这怎么受得了。
几个季度下来,忍无可忍,我一拍桌子,决定chi.....从头开发一个脚本执行工具,管理写过的脚本,脚本间可关联执行,可跨数据库。以解决临时写脚本,及脚本共享的问题。
设计上,脚本执行工具的脚本最早仅是占位符,如 "{参数名}",然后replaceAll,后面想了下,为啥不能像mybatis一样,或者说,为啥不直接借mybatis的xml模板解析功能来解析脚本呢。
————————————————————————————————
正文
通过走查代码,可以发现mybatis在XMLLanguageDriver
类里的createSqlSource
里实现SQL解析,并且支持使用<script></script>
执行字符串模板。
OK,既然如此,就想办法按这里的代码执行即可。
直接修改XPathParser的话太过复杂,还得理内部逻辑,这里走模拟法
构建配置Configuration
首先需要一个mybatis的xml最小配置,这里写为字符串:
String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n" //
+ "<!DOCTYPE configuration\r\n" //
+ " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n" //
+ " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"//
+ "<configuration>\r\n" //
+ "</configuration>";
然后根据mybatis的实现构建Configuration
点击查看代码
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
Configuration configuration = xmlConfigBuilder.parse();
有了Configuration就可以创建XPathParser来解析SQL了
点击查看代码
String script = "<script>\nSELECT COUNT(0) FROM TABLE_NAME where 1=1 <if test=\"param!=null\"> AND num = #{param}</if> \n</script>";
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), null);
Map<String, String> params = new HashMap<>();
param.put("param", "1");
BoundSql boundSql = source.getBoundSql(params);
String sql = boundSql.getSql();
结果
通过上述代码解析出来的SQL,若带了参数number,则为
SELECT COUNT(0) FROM TABLE_NAME where 1=1 AND num = ?
是带占位符的安全度高的预编译SQL,使用时需要构建PrepareStatement
,然后通过如jdbc传入prepareStatement及手动set参数。
自动化prepareStatement参数设置
参照Mybatis的DefaultParameterHandler
类的setParameters
创建PrepareStatement
可以通过如springJdbc进行构建:
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
然后调用DefaultParameterHandler
的setParameters即可
获取完整SQL
参照mybatis-plus的PerformanceInterceptor
类,该类可通过Statement
获取SQL
需要的代码如下:
点击查看代码
/**
* COPY FROM {@link PerformanceInterceptor}
*/
private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
private Method oracleGetOriginalSqlMethod;
private Method druidGetSQLMethod;
/**
* 获取原始SQL, COPY FROM {@link PerformanceInterceptor}
*/
private String getOriginSql(PreparedStatement statement) {
String originalSql = null;
String stmtClassName = statement.getClass().getName();
if (DruidPooledPreparedStatement.equals(stmtClassName)) {
try {
if (druidGetSQLMethod == null) {
Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
druidGetSQLMethod = clazz.getMethod("getSql");
}
Object stmtSql = druidGetSQLMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} catch (Exception e) {
e.printStackTrace();
}
} else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
try {
if (oracleGetOriginalSqlMethod != null) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} else {
Class<?> clazz = Class.forName(stmtClassName);
oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
if (oracleGetOriginalSqlMethod != null) {
// OraclePreparedStatementWrapper is not a public class, need set this.
oracleGetOriginalSqlMethod.setAccessible(true);
if (null != oracleGetOriginalSqlMethod) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
}
}
}
} catch (Exception e) {
// ignore
}
}
if (originalSql == null) {
originalSql = statement.toString();
}
return originalSql;
}
获取出来的SQL在mybatis中为com.mxxxx : SELECT....,大概是这样,需要截取:
点击查看代码
/**
* 获取sql语句开头部分, COPY FROM {@link PerformanceInterceptor}
*/
private int indexOfSqlStart(String sql) {
String upperCaseSql = sql.toUpperCase();
Set<Integer> set = new HashSet<>();
set.add(upperCaseSql.indexOf("SELECT "));
set.add(upperCaseSql.indexOf("UPDATE "));
set.add(upperCaseSql.indexOf("INSERT "));
set.add(upperCaseSql.indexOf("DELETE "));
set.remove(-1);
if (CollectionUtils.isEmpty(set)) {
return -1;
}
List<Integer> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
return list.get(0);
}
这样就获取出来完整的SQL了,可喜可贺。但是!
Oracle
数据库使用该方法无效,打印出来还是预编译SQL
完整SQL第二方案
在第一个方案执行失败的情况下(可遍历字符串看有没有?
),增加该方案:
参照:https://www.cnblogs.com/aipan/p/7237854.html
增加LoggableStatement,改造DefaultParameterHandler
的setParameters(Copy出来作为新方法):
点击查看代码
@SuppressWarnings({"unchecked", "rawtypes"})
private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
// 改造点
LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) {
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
try {
typeHandler.setParameter(ls, i + 1, value, jdbcType);
} catch (TypeException | SQLException e) {
throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
return ls;
}
在获取完整SQL失败后,即可通过LoggableStatement来获取SQL:
if (!isCorrectGetSql(boundSql, originalSql)) {
originalSql = statement.getQueryString();
}
————————————————————————————
上面提到项目的该类完整代码:
点击查看代码
import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.nio.charset.StandardCharsets;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import org.apache.ibatis.builder.xml.XMLConfigBuilder;
import org.apache.ibatis.builder.xml.XMLMapperEntityResolver;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.ParameterMode;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.parsing.XNode;
import org.apache.ibatis.parsing.XPathParser;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.scripting.xmltags.XMLScriptBuilder;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.JdbcType;
import org.apache.ibatis.type.TypeException;
import org.apache.ibatis.type.TypeHandler;
import org.springframework.jdbc.core.JdbcTemplate;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
/**
* 使用Mybatis的XML解析来进行SQL构建的执行器
*/
public class MybatisTemplateSqlExcutor {
/**
* 脚本模板
*/
private static final String SCRIPT_TEMPLATE = "<script>\n%s\n</script>";
/**
* MP环境配置
*/
private Configuration configuration;
/**
* COPY FROM {@link PerformanceInterceptor}
*/
private static final String DruidPooledPreparedStatement = "com.alibaba.druid.pool.DruidPooledPreparedStatement";
private static final String T4CPreparedStatement = "oracle.jdbc.driver.T4CPreparedStatement";
private static final String OraclePreparedStatementWrapper = "oracle.jdbc.driver.OraclePreparedStatementWrapper";
private Method oracleGetOriginalSqlMethod;
private Method druidGetSQLMethod;
/**
* 构造器,初始化MP环境配置
*/
public MybatisTemplateSqlExcutor() {
InputStream inputStream = new ByteArrayInputStream(EMPTY_XML.getBytes(StandardCharsets.UTF_8));
XMLConfigBuilder xmlConfigBuilder = new XMLConfigBuilder(inputStream, null, null);
configuration = xmlConfigBuilder.parse();
}
public String parseSql(JdbcTemplate jdbcTemplate, String sqlTemplate, Map<String, Object> params) throws SQLException {
String script = String.format(SCRIPT_TEMPLATE, sqlTemplate);
XPathParser parser = new XPathParser(script, false, new Properties(), new XMLMapperEntityResolver());
SqlSource source = createSqlSource(configuration, parser.evalNode("/script"), Map.class);
BoundSql boundSql = source.getBoundSql(params);
LoggableStatement statement = buildPreparedStatement(jdbcTemplate, boundSql);
String originalSql = getOriginSql(statement.getPreparedStatement());
int index = indexOfSqlStart(originalSql);
if (index > 0) {
originalSql = originalSql.substring(index);
}
if (!isCorrectGetSql(boundSql, originalSql)) {
originalSql = statement.getQueryString();
}
return originalSql;
}
/**
* 从MP复制过来的脚本解析方法
*/
private SqlSource createSqlSource(Configuration configuration, XNode script, Class<?> parameterType) {
XMLScriptBuilder builder = new XMLScriptBuilder(configuration, script, parameterType);
return builder.parseScriptNode();
}
/**
* 根据BoundSql组装PreparedStatement,用于获取实际SQL
*/
@SuppressWarnings({"unchecked", "rawtypes"})
private LoggableStatement buildPreparedStatement(JdbcTemplate jdbcTemplate, BoundSql boundSql) throws SQLException {
PreparedStatement ps = jdbcTemplate.getDataSource().getConnection().prepareStatement(boundSql.getSql());
LoggableStatement ls = new LoggableStatement(ps, boundSql.getSql());
List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
Object parameterObject = boundSql.getParameterObject();
if (parameterMappings != null) {
for (int i = 0; i < parameterMappings.size(); i++) {
ParameterMapping parameterMapping = parameterMappings.get(i);
if (parameterMapping.getMode() != ParameterMode.OUT) {
Object value;
String propertyName = parameterMapping.getProperty();
if (boundSql.hasAdditionalParameter(propertyName)) {
value = boundSql.getAdditionalParameter(propertyName);
} else if (parameterObject == null) {
value = null;
} else {
MetaObject metaObject = configuration.newMetaObject(parameterObject);
value = metaObject.getValue(propertyName);
}
TypeHandler typeHandler = parameterMapping.getTypeHandler();
JdbcType jdbcType = parameterMapping.getJdbcType();
if (value == null && jdbcType == null) {
jdbcType = configuration.getJdbcTypeForNull();
}
try {
typeHandler.setParameter(ls, i + 1, value, jdbcType);
} catch (TypeException | SQLException e) {
throw new TypeException("Could not set parameters for mapping: " + parameterMapping + ". Cause: " + e, e);
}
}
}
}
return ls;
}
/**
* 获取sql语句开头部分
*/
private int indexOfSqlStart(String sql) {
String upperCaseSql = sql.toUpperCase();
Set<Integer> set = new HashSet<>();
set.add(upperCaseSql.indexOf("SELECT "));
set.add(upperCaseSql.indexOf("UPDATE "));
set.add(upperCaseSql.indexOf("INSERT "));
set.add(upperCaseSql.indexOf("DELETE "));
set.remove(-1);
if (CollectionUtils.isEmpty(set)) {
return -1;
}
List<Integer> list = new ArrayList<>(set);
list.sort(Comparator.naturalOrder());
return list.get(0);
}
/**
* 获取原始SQL, COPY FROM {@link PerformanceInterceptor}
*/
private String getOriginSql(PreparedStatement statement) {
String originalSql = null;
String stmtClassName = statement.getClass().getName();
if (DruidPooledPreparedStatement.equals(stmtClassName)) {
try {
if (druidGetSQLMethod == null) {
Class<?> clazz = Class.forName(DruidPooledPreparedStatement);
druidGetSQLMethod = clazz.getMethod("getSql");
}
Object stmtSql = druidGetSQLMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} catch (Exception e) {
e.printStackTrace();
}
} else if (T4CPreparedStatement.equals(stmtClassName) || OraclePreparedStatementWrapper.equals(stmtClassName)) {
try {
if (oracleGetOriginalSqlMethod != null) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
} else {
Class<?> clazz = Class.forName(stmtClassName);
oracleGetOriginalSqlMethod = getMethodRegular(clazz, "getOriginalSql");
if (oracleGetOriginalSqlMethod != null) {
// OraclePreparedStatementWrapper is not a public class, need set this.
oracleGetOriginalSqlMethod.setAccessible(true);
if (null != oracleGetOriginalSqlMethod) {
Object stmtSql = oracleGetOriginalSqlMethod.invoke(statement);
if (stmtSql instanceof String) {
originalSql = (String) stmtSql;
}
}
}
}
} catch (Exception e) {
// ignore
}
}
if (originalSql == null) {
originalSql = statement.toString();
}
return originalSql;
}
/**
* 获取此方法名的具体 Method
*
* @param clazz class 对象
* @param methodName 方法名
* @return 方法
*/
public Method getMethodRegular(Class<?> clazz, String methodName) {
if (Object.class.equals(clazz)) {
return null;
}
for (Method method : clazz.getDeclaredMethods()) {
if (method.getName().equals(methodName)) {
return method;
}
}
return getMethodRegular(clazz.getSuperclass(), methodName);
}
/**
* 判断是否正确的获取了SQL
*/
private boolean isCorrectGetSql(BoundSql boundSql, String originSql) {
return countQuestionMark(boundSql.getSql()) > countQuestionMark(originSql);
}
/**
* 统计占位符
*/
private int countQuestionMark(String sql) {
int result = 0;
for (char c : sql.toCharArray())
if (c == '?')
result++;
return result;
}
/**
* 空MP配置模板,用于构建MP环境配置,(放这里是由于博客的编辑器识别问题,会导致高亮错误)
*/
private static final String EMPTY_XML = "<?xml version=\"1.0\" encoding=\"UTF-8\" ?>\r\n"
+ "<!DOCTYPE configuration\r\n"
+ " PUBLIC \"-//mybatis.org//DTD Config 3.0//EN\"\r\n"
+ " \"http://mybatis.org/dtd/mybatis-3-config.dtd\">\r\n"
+ "<configuration>\r\n"
+ "</configuration>";
}
称之为Excutor是因为这其实是个子类,只提供parseSql
其他
若有更好的方法,请务必告诉我(๑•̀ㅂ•́)و✧