2021年12月14日复盘(Oracle Not In,Limit 1000)

1、遇到Oracle Not In 无效的问题,原因是Not In里面的子查询结果有空值,需要过滤掉

2、Oracle Limit 1000的问题,自己按照Mybatis Plus的租户拦截器做了修改

  1)、重点需要理解下表达树,这个刚好旁边大佬学历高,跟我普及了下二叉树用来做数学公式计算的原理

  2)、需要写递归,拆分左右节点类型的,比如Or或And,然后也要拆包括号表达式,最终处理In表达式,其他类型的直接返回

  3)、版本是3.4.2

 

import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringPool;
import com.baomidou.mybatisplus.extension.parser.JsqlParserSupport;
import com.baomidou.mybatisplus.extension.plugins.inner.InnerInterceptor;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.NoArgsConstructor;
import lombok.ToString;
import net.sf.jsqlparser.expression.BinaryExpression;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.NotExpression;
import net.sf.jsqlparser.expression.Parenthesis;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.expression.operators.relational.ItemsList;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.*;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.util.List;
import java.util.stream.Collectors;

/**
 * @author linjiabin
 * @since 1.0.0
 */
@Data
@NoArgsConstructor
@ToString(callSuper = true)
@EqualsAndHashCode(callSuper = true)
public class OracleLimit1000InnerInterceptor extends JsqlParserSupport implements InnerInterceptor {


    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
        if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
        if (SqlParserHelper.getSqlParserInfo(ms)) return;
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        try {
            String sql = mpBs.sql();
            if (needRebuild(sql)) {
                mpBs.sql(parserSingle(sql, null));
            }
        } catch (Exception e) {
            logger.error(e.getMessage(), e);
        }
    }

    private boolean needRebuild(String sql) {
        return StringUtils.isNotBlank(sql) && sql.toUpperCase().contains(" IN ");
    }

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        if (sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            if (InterceptorIgnoreHelper.willIgnoreTenantLine(ms.getId())) return;
            if (SqlParserHelper.getSqlParserInfo(ms)) return;
            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            try {
                String sql = mpBs.sql();
                if (needRebuild(sql)) {
                    String parserMulti = parserMulti(sql, null);
                    if (sql.endsWith(StringPool.SEMICOLON) && !parserMulti.endsWith(StringPool.SEMICOLON)) {
                        parserMulti += StringPool.SEMICOLON;
                    }
                    mpBs.sql(parserMulti);
                }
            } catch (Exception e) {
                logger.error(e.getMessage(), e);
            }
        }
    }

    @Override
    protected void processSelect(Select select, int index, String sql, Object obj) {
        processSelectBody(select.getSelectBody());
        List<WithItem> withItemsList = select.getWithItemsList();
        if (!CollectionUtils.isEmpty(withItemsList)) {
            withItemsList.forEach(this::processSelectBody);
        }
    }

    protected void processSelectBody(SelectBody selectBody) {
        if (selectBody == null) {
            return;
        }
        if (selectBody instanceof PlainSelect) {
            processPlainSelect((PlainSelect) selectBody);
        } else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            processSelectBody(withItem.getSelectBody());
        } else {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (operationList.getSelects() != null && !operationList.getSelects().isEmpty()) {
                operationList.getSelects().forEach(this::processSelectBody);
            }
        }
    }

    @Override
    protected void processInsert(Insert insert, int index, String sql, Object obj) {
        // no do anything at insert
    }

    /**
     * update 语句处理
     */
    @Override
    protected void processUpdate(Update update, int index, String sql, Object obj) {
        update.setWhere(this.andExpression(update.getWhere()));
    }

    /**
     * delete 语句处理
     */
    @Override
    protected void processDelete(Delete delete, int index, String sql, Object obj) {
        delete.setWhere(this.andExpression(delete.getWhere()));
    }

    /**
     * delete update select 语句 where 处理
     */
    protected Expression andExpression(Expression where) {
        // 遇到左右表达式类型的,继续递归
        if (where instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) where;
            Expression rightExpression = binaryExpression.getRightExpression();
            binaryExpression.setRightExpression(andExpression(rightExpression));
            Expression leftExpression = binaryExpression.getLeftExpression();
            binaryExpression.setLeftExpression(andExpression(leftExpression));
        }
        // 遇到括号类型的,拆包递归
        if (where instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) where;
            return new Parenthesis(andExpression(parenthesis.getExpression()));
        }
        // 遇到in表达式的时候,尝试拆分
        if (where instanceof InExpression) {
            return builderExpression((InExpression) where);
        }
        // 其他表达式直接返回
        return where;
    }

    /**
     * 处理 PlainSelect
     */
    protected void processPlainSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        Expression where = plainSelect.getWhere();
        processWhereSubSelect(where);
        if (fromItem instanceof Table) {
            plainSelect.setWhere(builderExpression(where));
        } else {
            processFromItem(fromItem);
        }
        List<Join> joins = plainSelect.getJoins();
        if (joins != null && !joins.isEmpty()) {
            joins.forEach(j -> {
                processJoin(j);
                processFromItem(j.getRightItem());
            });
        }
    }

    /**
     * 处理where条件内的子查询
     * <p>
     * 支持如下:
     * 1. in
     * 2. =
     * 3. >
     * 4. <
     * 5. >=
     * 6. <=
     * 7. <>
     * 8. EXISTS
     * 9. NOT EXISTS
     * <p>
     * 前提条件:
     * 1. 子查询必须放在小括号中
     * 2. 子查询一般放在比较操作符的右边
     *
     * @param where where 条件
     */
    protected void processWhereSubSelect(Expression where) {
        if (where == null) {
            return;
        }
        if (where instanceof FromItem) {
            processFromItem((FromItem) where);
            return;
        }
        if (where.toString().contains("SELECT")) {
            // 有子查询
            if (where instanceof BinaryExpression) {
                // 比较符号 , and , or , 等等
                BinaryExpression expression = (BinaryExpression) where;
                processWhereSubSelect(expression.getLeftExpression());
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof InExpression) {
                // in
                InExpression expression = (InExpression) where;
                ItemsList itemsList = expression.getRightItemsList();
                if (itemsList instanceof SubSelect) {
                    processSelectBody(((SubSelect) itemsList).getSelectBody());
                }
            } else if (where instanceof ExistsExpression) {
                // exists
                ExistsExpression expression = (ExistsExpression) where;
                processWhereSubSelect(expression.getRightExpression());
            } else if (where instanceof NotExpression) {
                // not exists
                NotExpression expression = (NotExpression) where;
                processWhereSubSelect(expression.getExpression());
            } else if (where instanceof Parenthesis) {
                Parenthesis expression = (Parenthesis) where;
                processWhereSubSelect(expression.getExpression());
            }
        }
    }

    /**
     * 处理子查询等
     */
    protected void processFromItem(FromItem fromItem) {
        if (fromItem instanceof SubJoin) {
            SubJoin subJoin = (SubJoin) fromItem;
            if (subJoin.getJoinList() != null) {
                subJoin.getJoinList().forEach(this::processJoin);
            }
            if (subJoin.getLeft() != null) {
                processFromItem(subJoin.getLeft());
            }
        } else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelectBody(subSelect.getSelectBody());
            }
        } else if (fromItem instanceof ValuesList) {
            logger.debug("Perform a subquery, if you do not give us feedback");
        } else if (fromItem instanceof LateralSubSelect) {
            LateralSubSelect lateralSubSelect = (LateralSubSelect) fromItem;
            if (lateralSubSelect.getSubSelect() != null) {
                SubSelect subSelect = lateralSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelectBody(subSelect.getSelectBody());
                }
            }
        }
    }

    /**
     * 处理联接语句
     */
    protected void processJoin(Join join) {
        if (join.getRightItem() instanceof Table) {
            join.setOnExpression(builderExpression(join.getOnExpression()));
        }
    }

    /**
     * 处理条件
     */
    protected Expression builderExpression(Expression currentExpression) {
        if (currentExpression == null) {
            return null;
        }
        return andExpression(currentExpression);
    }

    protected Expression builderExpression(InExpression inExpression) {
        Expression leftExpression = inExpression.getLeftExpression();
        ItemsList rightItemsList = inExpression.getRightItemsList();
        if (rightItemsList instanceof ExpressionList) {

            ExpressionList expressionList = (ExpressionList) rightItemsList;
            List<Expression> expressions = expressionList.getExpressions();
            int size = expressions.size();
            int limit = 1000;
            if (size > limit) {
                OrExpression root = new OrExpression();
                int step = size / limit + 1;
                root.setLeftExpression(new InExpression(leftExpression, new ExpressionList(expressions.subList(0, limit))));
                if (step == 2) {
                    int toIndex = getToIndex(size, limit);
                    root.setRightExpression(new InExpression(leftExpression, new ExpressionList(expressions.subList(limit, toIndex))));
                    return root;
                }
                OrExpression orExpression = new OrExpression();
                root.setRightExpression(orExpression);
                for (int i = 1; i < step; i++) {
                    List<Expression> segment = expressions.stream().skip((long) i * limit)
                            .limit(limit).collect(Collectors.toList());
                    if (i == step - 2) {
                        orExpression.setLeftExpression(new InExpression(leftExpression, new ExpressionList(segment)));
                        List<Expression> last = expressions.stream().skip((long) (i + 1) * limit).collect(Collectors.toList());
                        orExpression.setRightExpression(new InExpression(leftExpression, new ExpressionList(last)));
                        break;
                    } else {
                        OrExpression orExpression1 = new OrExpression();
                        orExpression.setLeftExpression(new InExpression(leftExpression, new ExpressionList(segment)));
                        orExpression.setRightExpression(orExpression1);
                        orExpression = orExpression1;
                    }
                }
                return new Parenthesis(root);
            }
        }
        return inExpression;
    }

    private int getToIndex(int size, int limit) {
        int toIndex = limit * 2;
        if (toIndex > size) {
            toIndex = size;
        }
        return toIndex;
    }


}

 

由于遇到了update 自动去除最后一个分号的问题,追加了一个判断和处理

 

并且将重组sql的方法用异常捕获包起来,避免因为重组sql导致出错而无法继续的情况

 

简单地过滤掉不需要解析的场景,目前仅判断是否包含in关键词

posted @ 2021-12-14 20:04  gabin  阅读(374)  评论(0编辑  收藏  举报