JSqlParser+mybatis拦截器,实现返回替换后的sql
1.JSqlParser可用于解析sql语句,mybatis拦截器的实现网上有大部分的讲解,此处不做赘述,直接上结果。
@Intercepts( { @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class, Integer.class }) } ) @Order(1) @Component public class DataPermissionInterceptor implements Interceptor { private final static Logger logger = LoggerFactory.getLogger(DataPermissionInterceptor.class); @Override public Object intercept(Invocation invocation) throws Throwable { RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget(); StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate"); //从当前线程获取需要进行数据权限控制的业务 //TODO 此处后续替换为自己的权限控制,也可以去掉,所有人员均涉及权限控制,不存在白名单用户 DataPermission dataPermission = DPHelper.getLocalDataPermissions(); //判断有没有进行数据权限控制,是不是最高权限的管理员(这里指的是数据权限的白名单用户) if (dataPermission != null && dataPermission.getAdmin() == false && !dataPermission.getTables().isEmpty()) { BoundSql boundSql = delegate.getBoundSql(); //获取真实sql String sql = boundSql.getSql(); if(CCJSqlParserUtil.parse(sql) instanceof Select){ //获得方法类型 Select select = (Select) CCJSqlParserUtil.parse(sql); //此处调用accept方法,new SelectVisitorImpl()用于sql解析 select.getSelectBody().accept(new SelectVisitorImpl()); //修改sql ReflectUtil.setFieldValue(boundSql, "sql", select.toString()); } } return invocation.proceed(); } @Override public Object plugin(Object target) { return Plugin.wrap(target, this); } @Override public void setProperties(Properties properties) { }
2. SelectVisitorImpl类实现了SelectVisitor接口,重写了其中的visit()方法。
SelectVisitor接口内容如下:
package net.sf.jsqlparser.statement.select; public interface SelectVisitor { //PlainSelect 为普通select查询 void visit(PlainSelect plainSelect); void visit(SetOperationList setOpList); void visit(WithItem withItem); }
SelectVisitorImpl类内容如下:此处做了部分优化
package com.example.mybatissqls.visitor; import com.example.mybatissqls.sys.ColumnInfo; import com.example.mybatissqls.sys.TableInfo; import com.example.mybatissqls.utils.UserUtils; import net.sf.jsqlparser.JSQLParserException; import net.sf.jsqlparser.expression.Expression; import net.sf.jsqlparser.expression.Parenthesis; import net.sf.jsqlparser.expression.StringValue; import net.sf.jsqlparser.expression.operators.conditional.AndExpression; import net.sf.jsqlparser.parser.CCJSqlParserUtil; import net.sf.jsqlparser.schema.Column; import net.sf.jsqlparser.schema.Table; import net.sf.jsqlparser.statement.select.AllColumns; import net.sf.jsqlparser.statement.select.AllTableColumns; import net.sf.jsqlparser.statement.select.FromItem; import net.sf.jsqlparser.statement.select.Join; import net.sf.jsqlparser.statement.select.OrderByElement; import net.sf.jsqlparser.statement.select.PlainSelect; import net.sf.jsqlparser.statement.select.Select; import net.sf.jsqlparser.statement.select.SelectBody; import net.sf.jsqlparser.statement.select.SelectExpressionItem; import net.sf.jsqlparser.statement.select.SelectItem; import net.sf.jsqlparser.statement.select.SelectItemVisitor; import net.sf.jsqlparser.statement.select.SelectVisitor; import net.sf.jsqlparser.statement.select.SetOperationList; import net.sf.jsqlparser.statement.select.WithItem; import net.sf.jsqlparser.util.TablesNamesFinder; import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; public class SelectVisitorImpl implements SelectVisitor { private UserUtils userUtils = UserUtils.getBean(UserUtils.class); @Override public void visit(PlainSelect plainSelect) { // 访问 select 获取sql查询字段,将字段替换为有权限的字段 if (plainSelect.getSelectItems() != null) {
// 业务代码,获取表和字段的权限 List<TableInfo> tableInfo = userUtils.getTableInfo(); Map<String, TableInfo> tableMap = userUtils.getTableMap(tableInfo);
SelectItemVisitorImpl selectItemVisitor = new SelectItemVisitorImpl(); //替换之后的列字段 List<String> columnList = new ArrayList<>(); // 解释在下方 List<AllTableColumns> allTableColumns = new ArrayList<>(); List<SelectExpressionItem> selectExpressionItems = new ArrayList<>(); List<String> tableList = getTableList(plainSelect.toString()); if (tableList != null) { for (SelectItem item : plainSelect.getSelectItems()) { item.accept(selectItemVisitor); //设置可查询的列 if (selectItemVisitor.getEnhancedCondition() != null) { String aliasColumn = selectItemVisitor.getEnhancedCondition().toString(); //正则校验以'开头,以'结尾,存在替换 String regex = "'([\\s\\S]*?)'"; Matcher matcher = Pattern.compile(regex).matcher(aliasColumn); // 满足正则校验,剔除字符串的第一位和最后一位 if (matcher.find()){ aliasColumn = aliasColumn.substring(1,aliasColumn.length()-1); } //将字符串按照.进行分隔,前面为别名后面的字段名 String[] tableArr = aliasColumn.split("\\."); String table = tableList.size() ==1 ? tableList.get(0):tableArr[0]; if (aliasColumn.contains("*")) { TableInfo tableInfo1 = tableMap.get(table); if (tableInfo1 != null) { List<ColumnInfo> columnInfos = tableInfo1.getColumnInfos(); //设置权限且有权限的 List<ColumnInfo> collect = columnInfos.stream() .filter(e -> e.getIsSetPermission() && e .getIsHavePermission()).collect( Collectors.toList()); //不做权限设置的,通用的 List<ColumnInfo> collect1 = columnInfos.stream() .filter(e -> !e.getIsSetPermission()).collect( Collectors.toList()); for (ColumnInfo info : collect) { selectExpressionItems.add(new SelectExpressionItem(new Column(new Table(table),info.getColumnName()))); columnList.add(table+"."+info.getColumnName()); } for (ColumnInfo info : collect1) { selectExpressionItems.add(new SelectExpressionItem(new Column(new Table(table),info.getColumnName()))); columnList.add(table+"."+info.getColumnName()); } } else { allTableColumns.add(new AllTableColumns(new Table(table))); } } else { TableInfo tableInfo1 = tableMap.get(table); boolean flag = tableArr.length == 1; if (tableInfo1 != null) { List<ColumnInfo> columnInfos = tableInfo1.getColumnInfos(); //对比字段 Map<String, ColumnInfo> collect = columnInfos.stream() .collect(Collectors.toMap(ColumnInfo::getColumnName, Function.identity())); if (flag){ ColumnInfo columnInfo = collect.get(tableArr[0]); if ((columnInfo.getIsSetPermission() && columnInfo.getIsHavePermission()) || !columnInfo.getIsSetPermission()) { selectExpressionItems.add(new SelectExpressionItem(new Column(new Table(table),tableArr[0]))); } } else { ColumnInfo columnInfo = collect.get(tableArr[1]); if ((columnInfo.getIsSetPermission() && columnInfo.getIsHavePermission()) || !columnInfo.getIsSetPermission()) { selectExpressionItems.add(new SelectExpressionItem(new Column(new Table(table),tableArr[1]))); } } } } } } } // 替换字段 SelectItem[] list = new SelectItem[allTableColumns.size()+selectExpressionItems.size()]; plainSelect.setSelectItems(null); for (int i = 0 ; i < allTableColumns.size();i++) { list[i] = allTableColumns.get(i); plainSelect.addSelectItems(list[i]); } for(int j = 0 ; j < selectExpressionItems.size();j++){ list[j+allTableColumns.size()] = selectExpressionItems.get(j); plainSelect.addSelectItems(list[j+allTableColumns.size()]); } } // 访问from FromItem fromItem = plainSelect.getFromItem(); FromItemVisitorImpl fromItemVisitorImpl = new FromItemVisitorImpl(); fromItem.accept(fromItemVisitorImpl); // 访问join if (plainSelect.getJoins() != null) { for (Join join : plainSelect.getJoins()) { join.getRightItem().accept(fromItemVisitorImpl); } } // 访问where if (plainSelect.getWhere() != null) { plainSelect.getWhere().accept(new ExpressionVisitorImpl()); } //过滤增强的条件 if (fromItemVisitorImpl.getEnhancedCondition() != null) { if (plainSelect.getWhere() != null) { Expression expr = new Parenthesis(plainSelect.getWhere()); Expression enhancedCondition = new Parenthesis(fromItemVisitorImpl.getEnhancedCondition()); AndExpression and = new AndExpression(enhancedCondition, expr); plainSelect.setWhere(and); } else { plainSelect.setWhere(fromItemVisitorImpl.getEnhancedCondition()); } } // 访问 order by if (plainSelect.getOrderByElements() != null) { for (OrderByElement orderByElement : plainSelect .getOrderByElements()) { orderByElement.getExpression().accept( new ExpressionVisitorImpl()); } } // 访问group by having if (plainSelect.getHaving() != null) { plainSelect.getHaving().accept(new ExpressionVisitorImpl()); } } /** * 获取sql中的表名 * @param sql * @return */ private List<String> getTableList(String sql){ //获取table表名,替换字段 TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); try { return tablesNamesFinder .getTableList(CCJSqlParserUtil.parse(sql)); } catch (JSQLParserException e) { e.printStackTrace(); return null; } } @Override public void visit(SetOperationList setOpList) { for (SelectBody plainSelect : setOpList.getSelects()) { plainSelect.accept(new SelectVisitorImpl()); } } @Override public void visit(WithItem withItem) { // withItem.getSelectBody().accept(new SelectVisitorImpl()); } }
3. 从获取字段列的方法中可以了解到
plainSelect.getSelectItems() 获取所有字段集合,集合的类型为SelectItem,
SelectItemVisitorImpl selectItemVisitor = new SelectItemVisitorImpl(); List<AllTableColumns> allTableColumns = new ArrayList<>(); List<SelectExpressionItem> selectExpressionItems = new ArrayList<>(); List<String> tableList = getTableList(plainSelect.toString()); if (tableList != null) { for (SelectItem item : plainSelect.getSelectItems()) { item.accept(selectItemVisitor);
SelectItem接口只有一个accept方法,且参数为SelectItemVisitor。
package net.sf.jsqlparser.statement.select; public interface SelectItem { void accept(SelectItemVisitor var1); }
因此调用item.accept(selectItemVisitor);方法时新建SelectItemVisitorImpl实现类,新增属性,作为后续增强使用。
public class SelectItemVisitorImpl implements SelectItemVisitor { @Override public void visit(AllColumns allColumns) { //查询项为* 单表或者多表,不带别名的*走这里 enhancedCondition = new StringValue(allColumns.toString()); } @Override public void visit(AllTableColumns allTableColumns) { //带别名的*走这里 enhancedCondition = new StringValue(allTableColumns.toString()); } @Override public void visit(SelectExpressionItem selectExpressionItem) { //查询具体字段 enhancedCondition = selectExpressionItem.getExpression(); } // 声明增强条件 private Expression enhancedCondition; public Expression getEnhancedCondition() { return enhancedCondition; } }
SelectItemVisitorImpl具体内容如上所示,
替换sql中的字段时,创建集体的字段使用
selectExpressionItems.add(new SelectExpressionItem(new Column(new Table(table),info.getColumnName())));
创建对象.*时使用
allTableColumns.add(new AllTableColumns(new Table(table)));
最后创建SelectItem 数组,将原有sql查询列置为null,转换后的新增进去
SelectItem[] list = new SelectItem[allTableColumns.size()+selectExpressionItems.size()]; plainSelect.setSelectItems(null); for (int i = 0 ; i < allTableColumns.size();i++) { list[i] = allTableColumns.get(i); plainSelect.addSelectItems(list[i]); } for(int j = 0 ; j < selectExpressionItems.size();j++){ list[j+allTableColumns.size()] = selectExpressionItems.get(j); plainSelect.addSelectItems(list[j+allTableColumns.size()]); }
由此,解决了sql列替换的问题,但此处未解决直接显示列名,而没有匹配表名的情况,后续优化。