基于 SQL 解析的 JPA 多租户方案

概述#

最近在对一个使用 JPA 的老项目进行多租户改造,由于年代过于久远,陈年屎山让人实在不敢轻举妄动,最后只能选择一个改造成本最小的方案,那就是通过拦截器改 SQL,动态添加租户 ID 作为查询条件。
本篇文章用于记录笔者基于该方案解决此问题的踩坑和思考过程,部分代码与实际代码有所出入。如果希望直接获取可运行的代码,可以直接在 github 仓库获取。

1.SQL 拦截器#

由于 JPA 底层是基于 Hibernate 实现的,而 Hibernate 本身提供了 StatementInspector 接口用于实现 SQL 拦截。因此我们只需要在这个阶段对 SQL 进行解析,然后为需要按租户进行隔离的资源表动态的添加租户 ID 的过滤条件即可。
这里我们选择使用 JSqlParser 作为我们的 SQL 解析器。它社区还算活跃,文档详细,最重要的是,API 比较简单易懂。下文若不特意澄清,则所有与 SQL 解析相关的类都来自于它。

1.1.简单实现#

在最开始,我们写一个简单的实现来验证一下可行性。
假设,我们需要指定拦截针对表 t_resource 的查询语句,为其添加 tenant_id = xxx 作为查询条件,那么这个 SQL 拦截器需要做到:

  1. 将 SQL 解析为 Statement 对象,然后检查其是否为查询类 SQL;
  2. 获取 SQL 的 from 语句,并判断查询的表是否为我们要拦截的表;
  3. 解析 where 语句:
    • 若原本没有任何条件,则为其生成一个 where t.tenant_id = xxx 的条件;
    • 如果原本已经有条件了,则为其在最后拼接 and t.tenant_id = xxx 的条件;

这里我们针对这个需求给出一个简单的实现:

@Slf4j
public class TenantSQLInterceptor {

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @param table 要拦截器的租户表名
     * @param column 租户字段名
     * @param value 租户字段值
     * @return 处理后的SQL语句
     */
    public String handle(String sql, String table, String column, String value) {
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = doHandle(sql, table, column, value);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String doHandle(String sql, String table, String column, String value) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(statement -> doHandle(statement, table, column, value))
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
        }
    }

    private Statement doHandle(Statement statement, String table, String column, String value) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            // 目前只处理普通的 SQL 查询
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    if (Objects.equals(queryTable, table)) {
                        where = appendTenantCondition(plainSelect.getWhere(), fromItem, value, column);
                        plainSelect.setWhere(where);
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        // 生成一个 tenant_id = xxx 的条件
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }
}

测试一下:

public static void main(String[] args) {
    String sql = "select * from t_resource r where r.order = 1";
    TenantSQLInterceptor tenantSQLInterceptor = new TenantSQLInterceptor();
    String handledSql = tenantSQLInterceptor.handle(sql, "t_resource", "tenant_id", "1");
    System.out.println(handledSql); // = SELECT * FROM resource r WHERE r.t_resource = 1 AND r.tenant_id = '1'
}

虽然还非常简陋,不过这个拦截器已经能够初步实现我们想要的功能了,不过要投入实际场景,显然还需要做出“一点点”改进。

1.2.从上下文获取租户信息#

首先,真实的使用场景中,一个 SQL 可能会同时涉及到多张需要拦截器的表,并且每张表对应的租户 ID 仍然有可能不同,因此我们最好直接将相关的配置信息提取出来,改为通过一个上下文对象进行获取:

@Slf4j
public class TenantSQLInterceptor {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }


    public String handle(String sql) {
        // 如果未设置租户信息,则直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = doHandle(sql);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String doHandle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
            return statements;
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
        }
    }

    private Statement doHandle(Statement statement) {
        if (!(statement instanceof Select)) {
            return statement;
        }
        try {
            SelectBody selectBody = ((Select) statement).getSelectBody();
            if (selectBody instanceof PlainSelect) {
                PlainSelect plainSelect = (PlainSelect) selectBody;
                FromItem fromItem = plainSelect.getFromItem();
                Expression where = plainSelect.getWhere();

                // 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
                if (fromItem instanceof Table) {
                    String queryTable = ((Table) fromItem).getName();
                    TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
                    String tenantColumn = tenantInfo.tablesWithTenantColumn.get(queryTable);
                    if (Objects.nonNull(tenantColumn)) {
                        plainSelect.setWhere(appendTenantCondition(where, fromItem, tenantInfo.tenantId, tenantColumn));
                    }
                }
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.3.复杂 SQL 的解析#

在实际场景中,尤其是涉及到手写 SQL 的场景中,SQL 往往比较复杂,比如:

  • 查询可能基于一张虚拟表,比如: select * from (selecrt from t1 where t1.id = xx) t2 这种情况。
  • 可能会存在关联查,比如: select * from t1 left join t2 on t2.id = t1.tid 这种情况。
  • 可能会涉及到子查询,比如:select * from t where t.id in (select t2.tid from t2 where t2.id = xxx) 这种情况。

除上述这几种情况外,我们还需要考虑各种组合的场景,比如 union 类型的联合查询,函数与子查询的嵌套,基于虚拟表的联查……等等。

1.3.1.改进方案#

虽然情况有很多种,不过值得高兴的是,我们还是有办法为其归纳出一个处理流程。简单的来说,就是检查所有可能存在嵌套查询的语句,进行递归解析:

  1. 第一步,先解析语句本身,如果是 union 这种联合查询,则将其拆分为多条单体 SQL 进行递归解析;
  2. 第二步,对于单条 SQL,解析其 select 的字段,如果存在函数或者子查询,则将每个字段其作为一个单体 SQL 进行递归解析;
  3. 第三步,解析其 from 语句,如果存在函数或者基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
  4. 第四步,解析 join 语句:
    1. 如果 join 的表本身是基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
    2. 如果 on 条件中存在函数或者子查询,则将其作为单体 SQL 进行递归解析;
  5. 第五步,解析 where 条件,如果存在函数或者基于子查询的条件字段,则将其作为一个单体 SQL 进行递归解析。

基于上述分析,我们需要对现有的代码做出一点调整:

  • 在 doHandle 方法中,我们需要判断 fromItem 的类型,如果是子查询,则需要进行递归处理。
  • 在 doHandle 方法后,我们需要新增一部分对 join 语句的处理,由于 join 语句同样由 from 和 where 两部分组成,因此此处的逻辑应当与正常的 select 差不多。
  • 在 appendTenantCondition 方法之前,我们需要增加对特殊条件的处理,对应每个条件,我们都需要检查是否存在可能的子查询,如果存则需要进行递归处理。

1.3.2.改进后的代码#

根据改进方案,我们再次调整代码:

@Slf4j
public class TenantSQLInterceptor {
    
    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }
    
    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new RuntimeException("SQL 解析失败", e);
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败", ex);
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查询
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 嵌套查询,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 联合查询,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查询,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName()
        }
    }

    /**
     * 处理插入语句
     *
     * @param insert 插入语句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 处理删除语句
     *
     * @param delete 删除语句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理更新语句
     *
     * @param update 更新语句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理查询语句
     *
     * @param plainSelect 查询语句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查询,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知类型的查询,直接报错
        else {
            log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName()
        }

        // 如果还存在关联查询
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理关联查询
     *
     * @param join 关联查询
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
     *
     * @param expression 条件表达式
     * @param table 表
     * @return 添加租户条件后的条件表达式
     */
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已经不可拆分的表达式,则直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查询,则需要对子查询进行递归处理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查询条件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 判断是否是已经是无法再拆分的基本表达式 <br/>
     * 比如:列名、常量、函数等
     *
     * @param expression 表达式
     * @return 是否是基本表达式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }

    /**
     * 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
     *
     * @param expression 原有的查询条件
     * @param table 指定的表
     * @return 查询条件
     */
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

现在,针对预期的复杂场景,我们再来测试一下:

public static void main(String[] args) {
    Map<String, String> tablesWithTenantColumn = Maps.newHashMap();
    tablesWithTenantColumn.put("t", "tenant_id");
    TenantInfo tenantInfo = new TenantInfo("1", tablesWithTenantColumn);
    TenantSQLInterceptor.setTenantInfo(tenantInfo);

    // 处理包含的复杂子查询的SQL
    String sql = "select * " +
        "from (select * from t where a = 1) t " +
        "left join (select * from t where b = 2) t2 on t.id = t2.id " +
        "where b in (select * from t where c = 2) and d = 3";
    TenantSQLInterceptor interceptor = new TenantSQLInterceptor();
    String handledSql = interceptor.handle(sql);
    System.out.println(handledSql);
    // 输出结果:
    // select * 
    // from (select * from t where a = 1 and tenant_id = '1') t 
    // left join (select * from t where b = 2 and tenant_id = '1') t2 on t.id = t2.id 
    // where b in (select * from t where c = 2 and tenant_id = '1') and d = 3
}

完美!

1.4.分离公共代码#

这个 SQL 拦截器已经可以完美满足我们的大部分需求了。现在功能已经实现,可以看看代码层面有什么可以优化的地方了。
我们再次分析一下上述代码,会注意到,上面的解析器其实干了两件事情:

  • 解析 SQL,并在递归获取不可再拆分的“根” SQL 后,替换其 where 条件。
  • 将 SQL 的 where 条件替换或追加上租户条件。

换而言之,第一步的逻辑似乎与“租户拦截”这个需求无关,它显然可以抽离为一个独立的组件以便后续复用。此外,我们现在实现的其实是一个行级别的租户拦截,如果我们日后需要表级别的租户拦截,最好也有办法基于它来实现。
综上考虑,这里我们将这个新组件根据其功能命名为 AbstractSqlHandler,并且为其添加一个 handleTable 抽象方法,使其具备拦截表名的能力:

/**
 * <p>SQL处理器,用于拦截SQL语句并修改其中的查询条件,
 * 该处理器支持处理嵌套查询、联合查询、关联查询等多种查询方式。
 *
 * @author huangchengxing
 * @see #handle
 * @see #handleCondition
 */
@Setter
@Slf4j
public abstract class AbstractSqlHandler {

    /**
     * 处理SQL语句
     *
     * @param sql SQL语句
     * @return 处理后的SQL语句
     */
    @Nullable
    public String handle(String sql) {
        Statements statements = parseStatements(sql);
        if (Objects.isNull(statements)) {
            return null;
        }
        List<Statement> statementList = statements.getStatements();
        if (CollUtil.isEmpty(statementList)) {
            return null;
        }
        return statements.getStatements().stream()
            .map(this::doHandle)
            .map(Statement::toString)
            .collect(Collectors.joining(";"));
    }

    @Nullable
    private Statements parseStatements(String sql) {
        Statements statements = null;
        try {
            statements = CCJSqlParserUtil.parseStatements(sql);
        } catch (JSQLParserException e) {
            log.error("SQL 解析失败: {}", sql, e);
            throw new RuntimeException("SQL 解析失败");
        }
        return statements;
    }

    private Statement doHandle(Statement statement) {
        try {
            if (statement instanceof Select) {
                processSelect(((Select) statement).getSelectBody());
            } else if (statement instanceof Update) {
                processUpdate((Update) statement);
            } else if (statement instanceof Delete) {
                processDelete((Delete) statement);
            } else if (statement instanceof Insert) {
                processInsert((Insert) statement);
            }
        } catch (Exception ex) {
            log.error("SQL 处理失败: {}", statement, ex);
            throw new RuntimeException("SQL 处理失败");
        }
        return statement;
    }

    private void processSelect(SelectBody selectBody) {
        // 普通查询
        if (selectBody instanceof PlainSelect) {
            processSelect((PlainSelect) selectBody);
        }
        // 嵌套查询,比如 select xx from (select yy from t)
        else if (selectBody instanceof WithItem) {
            WithItem withItem = (WithItem) selectBody;
            if (withItem.getSelectBody() != null) {
                processSelect(withItem.getSelectBody());
            }
        }
        // 联合查询,比如 union
        else if (selectBody instanceof SetOperationList) {
            SetOperationList operationList = (SetOperationList) selectBody;
            if (CollUtil.isNotEmpty(operationList.getSelects())) {
                operationList.getSelects().forEach(this::processSelect);
            }
        }
        // 值查询,比如 select 1, 2, 3
        else if (selectBody instanceof ValuesStatement) {
            List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
            if (CollUtil.isNotEmpty(expressions)) {
                expressions.forEach(exp -> processCondition(exp, null));
            }
        } else {
            log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
            throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName());
        }
    }

    /**
     * 处理插入语句
     *
     * @param insert 插入语句
     */
    protected void processInsert(Insert insert) {
        // do nothing
    }

    /**
     * 处理删除语句
     *
     * @param delete 删除语句
     */
    protected void processDelete(Delete delete) {
        Table table = delete.getTable();
        delete.setWhere(processCondition(delete.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = delete.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理更新语句
     *
     * @param update 更新语句
     */
    protected void processUpdate(Update update) {
        Table table = update.getTable();
        update.setWhere(processCondition(update.getWhere(), table));
        // 如果还存在关联查询
        List<Join> joins = update.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理查询语句
     *
     * @param plainSelect 查询语句
     */
    protected void processSelect(PlainSelect plainSelect) {
        FromItem fromItem = plainSelect.getFromItem();
        // 如果是普通的表名
        if (fromItem instanceof Table) {
            Table fromTable = (Table) fromItem;
            plainSelect.setFromItem(handleTable(fromTable));
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
        }
        // 如果是子查询,比如 select * from (select xxx from yyy)
        else if (fromItem instanceof SubSelect) {
            SubSelect subSelect = (SubSelect) fromItem;
            if (subSelect.getSelectBody() != null) {
                processSelect(subSelect.getSelectBody());
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
        }
        // 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
        else if (fromItem instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
            plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
        }
        // 未知类型的查询,直接报错
        else {
            log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
            throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName());
        }

        // 如果还存在关联查询
        List<Join> joins = plainSelect.getJoins();
        if (CollUtil.isNotEmpty(joins)) {
            joins.forEach(this::processJoin);
        }
    }

    /**
     * 处理关联查询
     *
     * @param join 关联查询
     */
    protected void processJoin(Join join) {
        FromItem joinTable = join.getRightItem();
        if (joinTable instanceof Table) {
            Table table = (Table) joinTable;
            join.setRightItem(handleTable((Table) joinTable));
            join.setOnExpression(processCondition(join.getOnExpression(), table));
        }
        else if (joinTable instanceof SubSelect) {
            processSelect(((SubSelect) joinTable).getSelectBody());
        }
        else if (joinTable instanceof SpecialSubSelect) {
            SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
            if (specialSubSelect.getSubSelect() != null) {
                SubSelect subSelect = specialSubSelect.getSubSelect();
                if (subSelect.getSelectBody() != null) {
                    processSelect(subSelect.getSelectBody());
                }
            }
        }
        else {
            log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
            throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
        }
    }

    /**
     * <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
     *
     * @param expression 条件表达式
     * @param table 表
     * @return 添加租户条件后的条件表达式
     */
    @SuppressWarnings({"java:S6541", "java:S3776"})
    protected Expression processCondition(@Nullable Expression expression, FromItem table) {
        // 如果已经不可拆分的表达式,则直接返回
        if (isBasicExpression(expression)) {
            return expression;
        }
        // 如果是子查询,则需要对子查询进行递归处理
        else if (expression instanceof SubSelect) {
            processSelect(((SubSelect) expression).getSelectBody());
        }
        // 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof InExpression) {
            InExpression inExp = (InExpression) expression;
            ItemsList rightItems = inExp.getRightItemsList();
            if (rightItems instanceof SubSelect) {
                processSelect(((SubSelect) rightItems).getSelectBody());
            }
        }
        // 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
        else if (expression instanceof NotExpression) {
            NotExpression notExpression = (NotExpression) expression;
            processCondition(notExpression.getExpression(), table);
        }
        // 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
        else if (expression instanceof Parenthesis) {
            Parenthesis parenthesis = (Parenthesis) expression;
            Expression content = parenthesis.getExpression();
            processCondition(content, table);
        }
        // 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
        else if (expression instanceof BinaryExpression) {
            BinaryExpression binaryExpression = (BinaryExpression) expression;
            Expression left = binaryExpression.getLeftExpression();
            processCondition(left, table);
            Expression right = binaryExpression.getRightExpression();
            processCondition(right, table);
        }
        // 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
        else if (expression instanceof Function) {
            Function function = (Function) expression;
            ExpressionList parameters = function.getParameters();
            if (parameters != null) {
                parameters.getExpressions().forEach(param -> processCondition(param, table));
            }
        }
        // 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
        else if (expression instanceof WhenClause) {
            WhenClause whenClause = (WhenClause) expression;
            processCondition(whenClause.getWhenExpression(), table);
            processCondition(whenClause.getThenExpression(), table);
        }
        // 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
        else if (expression instanceof CaseExpression) {
            CaseExpression caseExpression = (CaseExpression) expression;
            processCondition(caseExpression.getSwitchExpression(), table);
            List<WhenClause> whenClauses = caseExpression.getWhenClauses();
            if (CollUtil.isNotEmpty(whenClauses)) {
                whenClauses.forEach(whenClause -> {
                    processCondition(whenClause.getWhenExpression(), table);
                    processCondition(whenClause.getThenExpression(), table);
                });
            }
            processCondition(caseExpression.getElseExpression(), table);
        }
        // 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof ExistsExpression) {
            Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
            if (existsExpression instanceof SubSelect) {
                processSelect(((SubSelect) existsExpression).getSelectBody());
            }
        }
        // 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
        else if (expression instanceof AllComparisonExpression) {
            AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
            processSelect(allComparisonExpression.getSubSelect().getSelectBody());
        }
        else if (expression instanceof AnyComparisonExpression) {
            AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
            processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
        }
        // 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
        else if (expression instanceof CastExpression) {
            CastExpression castExpression = (CastExpression) expression;
            processCondition(castExpression.getLeftExpression(), table);
        }

        // 拼接查询条件
        Expression appendCondition = handleCondition(expression, table);
        return Objects.isNull(appendCondition) ? expression : appendCondition;
    }

    /**
     * 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
     *
     * @param expression 原有的查询条件
     * @param table 指定的表
     * @return 查询条件
     */
    protected abstract Expression handleCondition(@Nullable Expression expression, FromItem table);
    
    /**
     * 返回一个表名,该表名将替换原有的表名
     *
     * @param table 表名
     * @return 处理后的表名
     */
    protected FromItem handleTable(Table table) {
        return table;
    }

    /**
     * 判断是否是已经是无法再拆分的基本表达式 <br/>
     * 比如:列名、常量、函数等
     *
     * @param expression 表达式
     * @return 是否是基本表达式
     */
    protected boolean isBasicExpression(@Nullable Expression expression) {
        return expression instanceof Column
            || expression instanceof LongValue
            || expression instanceof StringValue
            || expression instanceof DoubleValue
            || expression instanceof NullValue
            || expression instanceof TimeValue
            || expression instanceof TimestampValue
            || expression instanceof DateValue;
    }
}

�接着,对于原本的 SQL 拦截器,我们令其继承 AbstractSqlHandler,然后更换一个更合适的名字 LineLevelTenantSqlHandler

/**
 * SQL拦截器,用于为SQL语句添加租户条件。
 * 每次执行SQL时,将会检查当前线程上下文中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Slf4j
public class LineLevelTenantSqlHandler extends AbstractConditionSqlHandler {

    private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();

    /**
     * 设置租户信息
     *
     * @param tenantInfo 租户信息
     */
    public static void setTenantInfo(TenantInfo tenantInfo) {
        TENANT_INFO_CONTEXT.set(tenantInfo);
    }

    /**
     * 清除租户信息
     */
    public static void clearTenantInfo() {
        TENANT_INFO_CONTEXT.remove();
    }

    @Override
    public String handle(String sql) {
        // 如果未设置租户信息,则直接返回原始SQL
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        if (Objects.isNull(tenantInfo)) {
            return sql;
        }
        log.debug("租户拦截器拦截原始 SQL: {}", sql);
        String handledSql = super.handle(sql);
        log.info("租户拦截器拦截后 SQL: {}", handledSql);
        return Objects.isNull(handledSql) ? sql : handledSql;
    }

    @Override
    @Nullable
    protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
        TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
        // 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
        if (!(table instanceof Table)) {
            return null;
        }
        String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
        if (Objects.nonNull(tenantColumn)) {
            return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
        }
        return null;
    }

    private static Expression appendTenantCondition(
        @Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
        equalsTo.setRightExpression(new StringValue(tenantId));
        if (Objects.isNull(original)) {
            return equalsTo;
        }
        return original instanceof OrExpression ?
            new AndExpression(equalsTo, new Parenthesis(original)) :
            new AndExpression(original, equalsTo);
    }

    private static Column getColumnWithTableAlias(FromItem table, String column) {
        // 如果表存在别名,则字段应该变“表别名.字段名”的格式
        return Optional.ofNullable(table)
            .map(FromItem::getAlias)
            .map(alias -> alias.getName() + "." + column)
            .map(Column::new)
            .orElse(new Column(column));
    }

    /**
     * 租户信息
     */
    @RequiredArgsConstructor
    public static class TenantInfo {
        /**
         * 租户ID
         */
        private final String tenantId;
        /**
         * 要添加租户条件的表名称与对应的租户字段
         */
        private final Map<String, String> tablesWithTenantColumn;
    }
}

1.5.与 JPA 结合使用#

JPA 的默认实现 Hibernate 提供了 StatementInspector 接口,我们实现一个自定义的实现类,然后让基础上文实现好的租户解析器即可 LineLevelTenantSqlHandler

/**
 * SQL拦截器,用于为SQL语句添加租户条件。
 * 每次执行SQL时,将会检查当前线程中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
 *
 * @author huangchengxing
 */
@Slf4j
@RequiredArgsConstructor
public class HibernateLineLevelTenantStatementInspector
    extends LineLevelTenantSqlHandler implements StatementInspector {

    @Override
    public String inspect(String sql) {
        return handle(sql);
    }
}

同理,我们也可以结合 Mybatis 或其他的框架实现类似的效果。

2.租户拦截器#

显然,我们不可能无条件的拦截所有的查询,有些查询本身不需要进行拦截,而有些查询当访问者为管理员时也不需要拦截……总而言之,对应租户拦截,我们需要采用白名单而不是黑名单的方式,因此最好的实现方法就是搞一个切面,然后只对带有特定注解的方法的调用进行拦截。

2.1.注解类#

我们定义一个 @TenantOperation 注解,该注解可以被用于方法或者类上,当用于类上的时候等于类中所有的方法都应用拦截:

/**
 * 表明方法是一个租户操作方法,需要在相关的SQL中加入租户过滤条件
 *
 * @author huangchengxing
 * @see ContextTenantConditionSqlHandlerAdvisor
 */
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface TenantOperation {

    /**
     * 表配置
     *
     * @return 表配置
     */
    Tables[] value() default {};

    /**
     * 是否对当前方法与后续调用链不进行租户拦截
     *
     * @return boolean
     * @see Ignore
     */
    boolean ignore() default false;

    /**
     * 对当前方法与后续调用链不进行租户拦截
     */
    @TenantOperation(ignore = true) // 基于 Springt 合成注解机制的扩展注解
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
    @interface Ignore {}

    /**
     * 表配置
     */
    @Documented
    @Retention(RetentionPolicy.RUNTIME)
    @interface Tables {

        /**
         * 租户字段名,不指定时默认遵循配置文件中的字段名
         *
         * @return String
         */
        String column() default "";

        /**
         * 需要添加过滤条件的表名,不指定时默认遵循配置文件中的表名
         *
         * @return String
         */
        String[] tables() default {};
    }
}

此外,为了便于使用,注解还支持直接指定要拦截的表和字段,以便覆盖默认配置文件中的配置。

2.2.方法拦截器#

为了便于后续扩展,这里笔者没有基于 Aspect 注解,而是基于 Spring 的方法拦截器,自定义了切点来实现这个效果:

/**
 * 方法拦截器,用于拦截带有{@link TenantOperation}注解的方法,为涉及的查询语句添加租户过滤条件
 *
 * @author huangchengxing
 * @see LineLevelTenantOperationAdvisor
 */
@Slf4j
public class LineLevelTenantOperationAdvisor implements PointcutAdvisor, MethodInterceptor {

    private static final String INTERCEPT_REQUEST_ENTRY = "tenant";
    private static final TenantOpsInfo NULL = new TenantOpsInfo(null);
    private final Map<Method, TenantOpsInfo> tenantInfoCaches = new ConcurrentReferenceHashMap<>();
    private final TenantOpsInfo opsByDefault;

    public LineLevelTenantOperationAdvisor(Map<String, String> tableWithColumns) {
        this.opsByDefault = new TenantOpsInfo(tableWithColumns);
    }

    @Override
    public Object invoke(MethodInvocation methodInvocation) throws Throwable {
        // 从上下文获取租户ID
        String tenantId = Optional.ofNullable(RequestUserContext.getUser())
            .map(RequestUserContext.User::getUserId)
            .orElse(null);
        // 若没有上下文信息,则直接放行
        if (Objects.isNull(tenantId)) {
            return methodInvocation.proceed();
        }

        // 解析配置信息
        TenantOpsInfo info = resolveMethod(methodInvocation.getMethod());
        if (info == NULL) {
            return methodInvocation.proceed();
        }

        // 设置租户信息
        try {
            LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
            return methodInvocation.proceed();
        } finally {
            LineLevelTenantSqlHandler.clearTenantInfo();
        }
    }

    private TenantOpsInfo resolveMethod(Method method) {
        return tenantInfoCaches.computeIfAbsent(method, m -> {
            // 从方法上或类上获取注解
            TenantOperation annotation = Optional.ofNullable(AnnotatedElementUtils.findMergedAnnotation(method, TenantOperation.class))
                .orElse(AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(), TenantOperation.class));
            if (Objects.isNull(annotation)) {
                return NULL;
            }
            // 若注解未指定column和tables,则使用默认值
            TenantOperation.Tables[] tables = annotation.value();
            if (ArrayUtil.isEmpty(tables)) {
                return opsByDefault;
            }
            // 若指定了column和tables,则使用指定值
            Map<String, String> tableWithColumns = new HashMap<>(tables.length);
            for (TenantOperation.Tables table : tables) {
                String column = table.column();
                for (String tableName : table.tables()) {
                    tableWithColumns.put(tableName, column);
                }
            }
            return new TenantOpsInfo(tableWithColumns);
        });
    }

    @RequiredArgsConstructor
    private static class TenantOpsInfo {
        private final Map<String, String> tablesWithTenantColumn;
        public LineLevelTenantSqlHandler.TenantInfo getTenantInfo(String tenantId) {
            return new LineLevelTenantSqlHandler.TenantInfo(tenantId, tablesWithTenantColumn);
        }
    }

    @Override
    public @NonNull Pointcut getPointcut() {
        return TenantOperationPointcut.INSTANCE;
    }

    @Override
    public @NonNull Advice getAdvice() {
        return this;
    }

    @Override
    public boolean isPerInstance() {
        return false;
    }

    // 自定义切点,拦截带有 @TenantOperation 注解的方法,或声明类上带有 @TenantOperation 注解的全部方法
    private static class TenantOperationPointcut extends StaticMethodMatcher implements Pointcut {
        public static final TenantOperationPointcut INSTANCE = new TenantOperationPointcut();
        @Override
        public @NonNull ClassFilter getClassFilter() {
            return ClassFilter.TRUE;
        }
        @Override
        public @NonNull MethodMatcher getMethodMatcher() {
            return this;
        }
        @Override
        public boolean matches(@NonNull Method method, @NonNull Class<?> type) {
            return AnnotatedElementUtils.isAnnotated(method, TenantOperation.class)
                || AnnotatedElementUtils.isAnnotated(type, TenantOperation.class);
        }
    }
}

2.3.上下文传递问题#

如上文,我们选择使用方法拦截器在方法执行前设置租户信息,在方法执行后清空租户信息,这种做法在当同一条调用链上,同上触发了多次拦截时就会出现问题:

// 设置租户信息
try {
    LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
    return methodInvocation.proceed();
} finally {
    LineLevelTenantSqlHandler.clearTenantInfo();
}

举个例子,假如我们存在如下的调用:

@TenantOperation(
    @Tables(column = "userId")
)
public void method1() {
    // do something
    method2();
    method3()
    // do something
}

@TenantOperation(ignore = true) // 该方法不需要进行租户拦截
public void method1() {
    // do something
}

@TenantOperation(
    @Tables(column = "tenantId")
)
public void method1() {
    // do something
}

如果我们假设每个方法都能被正确的拦截,那么按原有的代码,当执行了 method2 以后,由于直接清空了上下文,最终会导致后续的调用都没有办法正确获取到租户信息。同理,
这个问题与 Spring 的事务传播有点异曲同工,我们的解决方案也类似,那就引入“挂起”这个概念。简单的来说,如果有一个被拦截的方法触发了上下文租户信息的更新,纳那么:

  • 如果上下文已经存在租户信息,说明当前方法只是调用链中的一个环节,那么就需要先将其挂起,先放入当前方法配置的租户信息,等到执行结束后,再将旧的租户信息放回上下文;
  • 如果上下文中没有存在租户信息,说明当前方法已经是调用链的源头,那么当执行完毕后,可以直接请上下文清空。

对此,我们参照 Spring 的做法,稍微调整一下这部分代码即可:

// 暂时挂起上一层级方法设置的租户信息
LineLevelTenantSqlHandler.TenantInfo previous = LineLevelTenantSqlHandler.getTenantInfo();
// 若当前方法设置了忽略租户信息,则清空上下文,否则设置当前租户信息
if (info.isIgnore()) {
    LineLevelTenantSqlHandler.clearTenantInfo();
} else {
    LineLevelTenantSqlHandler.TenantInfo current = new LineLevelTenantSqlHandler.TenantInfo(tenantId, info.getTablesWithTenantColumn());
    LineLevelTenantSqlHandler.setTenantInfo(current);
}
try {
    return methodInvocation.proceed();
} finally {
    // 恢复挂起的租户信息
    if (Objects.nonNull(previous)) {
        LineLevelTenantSqlHandler.setTenantInfo(previous);
    }
    // 若之前没有租户信息,则清空上下文
    else {
        LineLevelTenantSqlHandler.clearTenantInfo();
    }
}

3.使用#

3.1.配置类#

首先,我们先定义一个配置类以在项目中启用上述组件:

/**
 * <p>租户拦截器配置,启用后可以为指定的查询方法添加租户过滤条件。 <br/>
 * 可通过配置文件进行配置:<br/>
 * <pre>
 * # JPA 启用租户 SQL 拦截器
 * spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
 * # 启用租户拦截器
 * tenant.interceptor.enabled=true
 * # 需要拦截的表
 * tenant.interceptor.tables[0].column = tenant_id
 * tenant.interceptor.tables[0].tableNames = table1, table2
 * </pre>
 *
 * @author huangchengxing
 */
@Slf4j
@ConditionalOnProperty(prefix = TenantInterceptorConfig.Properties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@EnableConfigurationProperties(TenantInterceptorConfig.Properties.class)
@Configuration
public class TenantInterceptorConfig {

    @Bean
    public LineLevelTenantOperationAdvisor lineLevelTenantOperationAdvisor(Properties properties) {
        log.info("启用租户拦截器,需要拦截的表:{}", properties.getTables());
        Map<String, String> tableWithColumns = new HashMap<>(16);
        properties.getTables().forEach(ts -> ts.getTableNames().forEach(t -> {
            Assert.isFalse(tableWithColumns.containsKey(t), "同一张表具备只允许具备一个租户字段:{}", t);
            tableWithColumns.put(t, ts.getColumn());
        }));
        return new LineLevelTenantOperationAdvisor(tableWithColumns);
    }

    /**
     * @author huangchengxing
     */
    @ConfigurationProperties(prefix = Properties.CONFIG_PREFIX)
    @Data
    public static class Properties {

        public static final String CONFIG_PREFIX = "tenant.interceptor";

        /**
         * 表配置
         */
        private List<Tables> tables = new ArrayList<>();

        @Data
        public static class Tables {

            /**
             * 租户字段名
             */
            private String column;

            /**
             * 需要拦截的表名
             */
            private Set<String> tableNames;
        }
    }
}

3.2.配置文件#

随后在配置文件中启用拦截器,并配置好要拦截的表:

# 启用租户 SQL 拦截器
spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
# 启用租户拦截器
tenant.interceptor.enabled=true
# 拦截 t_user, t_resource, t_assest 表中的 tenant_id 字段
tenant.interceptor.tables[0].column=tenant_id
tenant.interceptor.tables[0].table-names=t_user, t_resource, t_assest

3.3.添加注解#

最后,我们只要在对应的类或者方法上添加 @TenantOperation 即可:

@TenantOperation // 默认所有方法都要应用拦截
@RestController
public class ResourceController {

    // @TenantOperation 因为类上已经加了,所以方法上可以不用加
    @GetMapping
    public List<Resource> listResource1(List<Integer> ids) {
        // do something
    }

    @TenantOperation.Ingore // 该方法不进行拦截
    @GetMapping
    public List<Resource> listResource2(List<Integer> ids) {
        // do something
    }
}

作者:Createsequence

出处:https://www.cnblogs.com/Createsequence/p/18093433

版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。

posted @   Createsequence  阅读(998)  评论(1编辑  收藏  举报
相关博文:
阅读排行:
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?
more_horiz
keyboard_arrow_up dark_mode palette
选择主题
menu
点击右上角即可分享
微信分享提示