基于 SQL 解析的 JPA 多租户方案
概述#
最近在对一个使用 JPA 的老项目进行多租户改造,由于年代过于久远,陈年屎山让人实在不敢轻举妄动,最后只能选择一个改造成本最小的方案,那就是通过拦截器改 SQL,动态添加租户 ID 作为查询条件。
本篇文章用于记录笔者基于该方案解决此问题的踩坑和思考过程,部分代码与实际代码有所出入。如果希望直接获取可运行的代码,可以直接在 github 仓库获取。
1.SQL 拦截器#
由于 JPA 底层是基于 Hibernate 实现的,而 Hibernate 本身提供了 StatementInspector
接口用于实现 SQL 拦截。因此我们只需要在这个阶段对 SQL 进行解析,然后为需要按租户进行隔离的资源表动态的添加租户 ID 的过滤条件即可。
这里我们选择使用 JSqlParser 作为我们的 SQL 解析器。它社区还算活跃,文档详细,最重要的是,API 比较简单易懂。下文若不特意澄清,则所有与 SQL 解析相关的类都来自于它。
1.1.简单实现#
在最开始,我们写一个简单的实现来验证一下可行性。
假设,我们需要指定拦截针对表 t_resource
的查询语句,为其添加 tenant_id = xxx
作为查询条件,那么这个 SQL 拦截器需要做到:
- 将 SQL 解析为
Statement
对象,然后检查其是否为查询类 SQL; - 获取 SQL 的
from
语句,并判断查询的表是否为我们要拦截的表; - 解析
where
语句:- 若原本没有任何条件,则为其生成一个
where t.tenant_id = xxx
的条件; - 如果原本已经有条件了,则为其在最后拼接
and t.tenant_id = xxx
的条件;
- 若原本没有任何条件,则为其生成一个
这里我们针对这个需求给出一个简单的实现:
@Slf4j
public class TenantSQLInterceptor {
/**
* 处理SQL语句
*
* @param sql SQL语句
* @param table 要拦截器的租户表名
* @param column 租户字段名
* @param value 租户字段值
* @return 处理后的SQL语句
*/
public String handle(String sql, String table, String column, String value) {
log.debug("租户拦截器拦截原始 SQL: {}", sql);
String handledSql = doHandle(sql, table, column, value);
log.info("租户拦截器拦截后 SQL: {}", handledSql);
return Objects.isNull(handledSql) ? sql : handledSql;
}
/**
* 处理SQL语句
*
* @param sql SQL语句
* @return 处理后的SQL语句
*/
@Nullable
public String doHandle(String sql, String table, String column, String value) {
Statements statements = parseStatements(sql);
if (Objects.isNull(statements)) {
return null;
}
List<Statement> statementList = statements.getStatements();
if (CollUtil.isEmpty(statementList)) {
return null;
}
return statements.getStatements().stream()
.map(statement -> doHandle(statement, table, column, value))
.map(Statement::toString)
.collect(Collectors.joining(";"));
}
@Nullable
private Statements parseStatements(String sql) {
Statements statements = null;
try {
statements = CCJSqlParserUtil.parseStatements(sql);
return statements;
} catch (JSQLParserException e) {
log.error("SQL 解析失败: {}", sql, e);
throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
}
}
private Statement doHandle(Statement statement, String table, String column, String value) {
if (!(statement instanceof Select)) {
return statement;
}
try {
SelectBody selectBody = ((Select) statement).getSelectBody();
// 目前只处理普通的 SQL 查询
if (selectBody instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectBody;
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
// 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
if (fromItem instanceof Table) {
String queryTable = ((Table) fromItem).getName();
if (Objects.equals(queryTable, table)) {
where = appendTenantCondition(plainSelect.getWhere(), fromItem, value, column);
plainSelect.setWhere(where);
}
}
}
} catch (Exception ex) {
log.error("SQL 处理失败: {}", statement, ex);
throw new RuntimeException("SQL 处理失败", ex);
}
return statement;
}
private static Expression appendTenantCondition(
@Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
// 生成一个 tenant_id = xxx 的条件
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
equalsTo.setRightExpression(new StringValue(tenantId));
if (Objects.isNull(original)) {
return equalsTo;
}
return original instanceof OrExpression ?
new AndExpression(equalsTo, new Parenthesis(original)) :
new AndExpression(original, equalsTo);
}
private static Column getColumnWithTableAlias(FromItem table, String column) {
// 如果表存在别名,则字段应该变“表别名.字段名”的格式
return Optional.ofNullable(table)
.map(FromItem::getAlias)
.map(alias -> alias.getName() + "." + column)
.map(Column::new)
.orElse(new Column(column));
}
}
测试一下:
public static void main(String[] args) {
String sql = "select * from t_resource r where r.order = 1";
TenantSQLInterceptor tenantSQLInterceptor = new TenantSQLInterceptor();
String handledSql = tenantSQLInterceptor.handle(sql, "t_resource", "tenant_id", "1");
System.out.println(handledSql); // = SELECT * FROM resource r WHERE r.t_resource = 1 AND r.tenant_id = '1'
}
虽然还非常简陋,不过这个拦截器已经能够初步实现我们想要的功能了,不过要投入实际场景,显然还需要做出“一点点”改进。
1.2.从上下文获取租户信息#
首先,真实的使用场景中,一个 SQL 可能会同时涉及到多张需要拦截器的表,并且每张表对应的租户 ID 仍然有可能不同,因此我们最好直接将相关的配置信息提取出来,改为通过一个上下文对象进行获取:
@Slf4j
public class TenantSQLInterceptor {
private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();
/**
* 设置租户信息
*
* @param tenantInfo 租户信息
*/
public static void setTenantInfo(TenantInfo tenantInfo) {
TENANT_INFO_CONTEXT.set(tenantInfo);
}
/**
* 清除租户信息
*/
public static void clearTenantInfo() {
TENANT_INFO_CONTEXT.remove();
}
public String handle(String sql) {
// 如果未设置租户信息,则直接返回原始SQL
TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
if (Objects.isNull(tenantInfo)) {
return sql;
}
log.debug("租户拦截器拦截原始 SQL: {}", sql);
String handledSql = doHandle(sql);
log.info("租户拦截器拦截后 SQL: {}", handledSql);
return Objects.isNull(handledSql) ? sql : handledSql;
}
/**
* 处理SQL语句
*
* @param sql SQL语句
* @return 处理后的SQL语句
*/
@Nullable
public String doHandle(String sql) {
Statements statements = parseStatements(sql);
if (Objects.isNull(statements)) {
return null;
}
List<Statement> statementList = statements.getStatements();
if (CollUtil.isEmpty(statementList)) {
return null;
}
return statements.getStatements().stream()
.map(this::doHandle)
.map(Statement::toString)
.collect(Collectors.joining(";"));
}
@Nullable
private Statements parseStatements(String sql) {
Statements statements = null;
try {
statements = CCJSqlParserUtil.parseStatements(sql);
return statements;
} catch (JSQLParserException e) {
log.error("SQL 解析失败: {}", sql, e);
throw new CloudPluginException(ResultCodingEnum.SchedulingError, "SQL 解析失败");
}
}
private Statement doHandle(Statement statement) {
if (!(statement instanceof Select)) {
return statement;
}
try {
SelectBody selectBody = ((Select) statement).getSelectBody();
if (selectBody instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectBody;
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
// 如果查询的表即为要拦截的租户表,则为查询条件添加租户条件
if (fromItem instanceof Table) {
String queryTable = ((Table) fromItem).getName();
TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
String tenantColumn = tenantInfo.tablesWithTenantColumn.get(queryTable);
if (Objects.nonNull(tenantColumn)) {
plainSelect.setWhere(appendTenantCondition(where, fromItem, tenantInfo.tenantId, tenantColumn));
}
}
}
} catch (Exception ex) {
log.error("SQL 处理失败: {}", statement, ex);
throw new RuntimeException("SQL 处理失败", ex);
}
return statement;
}
private static Expression appendTenantCondition(
@Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
equalsTo.setRightExpression(new StringValue(tenantId));
if (Objects.isNull(original)) {
return equalsTo;
}
return original instanceof OrExpression ?
new AndExpression(equalsTo, new Parenthesis(original)) :
new AndExpression(original, equalsTo);
}
private static Column getColumnWithTableAlias(FromItem table, String column) {
// 如果表存在别名,则字段应该变“表别名.字段名”的格式
return Optional.ofNullable(table)
.map(FromItem::getAlias)
.map(alias -> alias.getName() + "." + column)
.map(Column::new)
.orElse(new Column(column));
}
/**
* 租户信息
*/
@RequiredArgsConstructor
public static class TenantInfo {
/**
* 租户ID
*/
private final String tenantId;
/**
* 要添加租户条件的表名称与对应的租户字段
*/
private final Map<String, String> tablesWithTenantColumn;
}
}
1.3.复杂 SQL 的解析#
在实际场景中,尤其是涉及到手写 SQL 的场景中,SQL 往往比较复杂,比如:
- 查询可能基于一张虚拟表,比如:
select * from (selecrt from t1 where t1.id = xx) t2
这种情况。 - 可能会存在关联查,比如:
select * from t1 left join t2 on t2.id = t1.tid
这种情况。 - 可能会涉及到子查询,比如:
select * from t where t.id in (select t2.tid from t2 where t2.id = xxx)
这种情况。
除上述这几种情况外,我们还需要考虑各种组合的场景,比如 union 类型的联合查询,函数与子查询的嵌套,基于虚拟表的联查……等等。
1.3.1.改进方案#
虽然情况有很多种,不过值得高兴的是,我们还是有办法为其归纳出一个处理流程。简单的来说,就是检查所有可能存在嵌套查询的语句,进行递归解析:
- 第一步,先解析语句本身,如果是 union 这种联合查询,则将其拆分为多条单体 SQL 进行递归解析;
- 第二步,对于单条 SQL,解析其 select 的字段,如果存在函数或者子查询,则将每个字段其作为一个单体 SQL 进行递归解析;
- 第三步,解析其 from 语句,如果存在函数或者基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
- 第四步,解析 join 语句:
- 如果 join 的表本身是基于子查询的临时表,则将子查询作为一个单体 SQL 进行递归解析;
- 如果 on 条件中存在函数或者子查询,则将其作为单体 SQL 进行递归解析;
- 第五步,解析 where 条件,如果存在函数或者基于子查询的条件字段,则将其作为一个单体 SQL 进行递归解析。
基于上述分析,我们需要对现有的代码做出一点调整:
- 在 doHandle 方法中,我们需要判断 fromItem 的类型,如果是子查询,则需要进行递归处理。
- 在 doHandle 方法后,我们需要新增一部分对 join 语句的处理,由于 join 语句同样由 from 和 where 两部分组成,因此此处的逻辑应当与正常的 select 差不多。
- 在 appendTenantCondition 方法之前,我们需要增加对特殊条件的处理,对应每个条件,我们都需要检查是否存在可能的子查询,如果存则需要进行递归处理。
1.3.2.改进后的代码#
根据改进方案,我们再次调整代码:
@Slf4j
public class TenantSQLInterceptor {
private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();
/**
* 设置租户信息
*
* @param tenantInfo 租户信息
*/
public static void setTenantInfo(TenantInfo tenantInfo) {
TENANT_INFO_CONTEXT.set(tenantInfo);
}
/**
* 清除租户信息
*/
public static void clearTenantInfo() {
TENANT_INFO_CONTEXT.remove();
}
/**
* 处理SQL语句
*
* @param sql SQL语句
* @return 处理后的SQL语句
*/
@Nullable
public String handle(String sql) {
Statements statements = parseStatements(sql);
if (Objects.isNull(statements)) {
return null;
}
List<Statement> statementList = statements.getStatements();
if (CollUtil.isEmpty(statementList)) {
return null;
}
return statements.getStatements().stream()
.map(this::doHandle)
.map(Statement::toString)
.collect(Collectors.joining(";"));
}
@Nullable
private Statements parseStatements(String sql) {
Statements statements = null;
try {
statements = CCJSqlParserUtil.parseStatements(sql);
} catch (JSQLParserException e) {
log.error("SQL 解析失败: {}", sql, e);
throw new RuntimeException("SQL 解析失败", e);
}
return statements;
}
private Statement doHandle(Statement statement) {
try {
if (statement instanceof Select) {
processSelect(((Select) statement).getSelectBody());
} else if (statement instanceof Update) {
processUpdate((Update) statement);
} else if (statement instanceof Delete) {
processDelete((Delete) statement);
} else if (statement instanceof Insert) {
processInsert((Insert) statement);
}
} catch (Exception ex) {
log.error("SQL 处理失败: {}", statement, ex);
throw new RuntimeException("SQL 处理失败", ex);
}
return statement;
}
private void processSelect(SelectBody selectBody) {
// 普通查询
if (selectBody instanceof PlainSelect) {
processSelect((PlainSelect) selectBody);
}
// 嵌套查询,比如 select xx from (select yy from t)
else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
if (withItem.getSelectBody() != null) {
processSelect(withItem.getSelectBody());
}
}
// 联合查询,比如 union
else if (selectBody instanceof SetOperationList) {
SetOperationList operationList = (SetOperationList) selectBody;
if (CollUtil.isNotEmpty(operationList.getSelects())) {
operationList.getSelects().forEach(this::processSelect);
}
}
// 值查询,比如 select 1, 2, 3
else if (selectBody instanceof ValuesStatement) {
List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
if (CollUtil.isNotEmpty(expressions)) {
expressions.forEach(exp -> processCondition(exp, null));
}
} else {
log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName()
}
}
/**
* 处理插入语句
*
* @param insert 插入语句
*/
protected void processInsert(Insert insert) {
// do nothing
}
/**
* 处理删除语句
*
* @param delete 删除语句
*/
protected void processDelete(Delete delete) {
Table table = delete.getTable();
delete.setWhere(processCondition(delete.getWhere(), table));
// 如果还存在关联查询
List<Join> joins = delete.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理更新语句
*
* @param update 更新语句
*/
protected void processUpdate(Update update) {
Table table = update.getTable();
update.setWhere(processCondition(update.getWhere(), table));
// 如果还存在关联查询
List<Join> joins = update.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理查询语句
*
* @param plainSelect 查询语句
*/
protected void processSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
// 如果是普通的表名
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
}
// 如果是子查询,比如 select * from (select xxx from yyy)
else if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
}
// 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
else if (fromItem instanceof SpecialSubSelect) {
SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
if (specialSubSelect.getSubSelect() != null) {
SubSelect subSelect = specialSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
}
plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
}
// 未知类型的查询,直接报错
else {
log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName()
}
// 如果还存在关联查询
List<Join> joins = plainSelect.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理关联查询
*
* @param join 关联查询
*/
protected void processJoin(Join join) {
FromItem joinTable = join.getRightItem();
if (joinTable instanceof Table) {
Table table = (Table) joinTable;
join.setOnExpression(processCondition(join.getOnExpression(), table));
}
else if (joinTable instanceof SubSelect) {
processSelect(((SubSelect) joinTable).getSelectBody());
}
else if (joinTable instanceof SpecialSubSelect) {
SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
if (specialSubSelect.getSubSelect() != null) {
SubSelect subSelect = specialSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
}
}
else {
log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
}
}
/**
* <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
*
* @param expression 条件表达式
* @param table 表
* @return 添加租户条件后的条件表达式
*/
protected Expression processCondition(@Nullable Expression expression, FromItem table) {
// 如果已经不可拆分的表达式,则直接返回
if (isBasicExpression(expression)) {
return expression;
}
// 如果是子查询,则需要对子查询进行递归处理
else if (expression instanceof SubSelect) {
processSelect(((SubSelect) expression).getSelectBody());
}
// 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof InExpression) {
InExpression inExp = (InExpression) expression;
ItemsList rightItems = inExp.getRightItemsList();
if (rightItems instanceof SubSelect) {
processSelect(((SubSelect) rightItems).getSelectBody());
}
}
// 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
else if (expression instanceof NotExpression) {
NotExpression notExpression = (NotExpression) expression;
processCondition(notExpression.getExpression(), table);
}
// 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
else if (expression instanceof Parenthesis) {
Parenthesis parenthesis = (Parenthesis) expression;
Expression content = parenthesis.getExpression();
processCondition(content, table);
}
// 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
else if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression;
Expression left = binaryExpression.getLeftExpression();
processCondition(left, table);
Expression right = binaryExpression.getRightExpression();
processCondition(right, table);
}
// 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
else if (expression instanceof Function) {
Function function = (Function) expression;
ExpressionList parameters = function.getParameters();
if (parameters != null) {
parameters.getExpressions().forEach(param -> processCondition(param, table));
}
}
// 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
else if (expression instanceof WhenClause) {
WhenClause whenClause = (WhenClause) expression;
processCondition(whenClause.getWhenExpression(), table);
processCondition(whenClause.getThenExpression(), table);
}
// 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
else if (expression instanceof CaseExpression) {
CaseExpression caseExpression = (CaseExpression) expression;
processCondition(caseExpression.getSwitchExpression(), table);
List<WhenClause> whenClauses = caseExpression.getWhenClauses();
if (CollUtil.isNotEmpty(whenClauses)) {
whenClauses.forEach(whenClause -> {
processCondition(whenClause.getWhenExpression(), table);
processCondition(whenClause.getThenExpression(), table);
});
}
processCondition(caseExpression.getElseExpression(), table);
}
// 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof ExistsExpression) {
Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
if (existsExpression instanceof SubSelect) {
processSelect(((SubSelect) existsExpression).getSelectBody());
}
}
// 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof AllComparisonExpression) {
AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
processSelect(allComparisonExpression.getSubSelect().getSelectBody());
}
else if (expression instanceof AnyComparisonExpression) {
AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
}
// 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
else if (expression instanceof CastExpression) {
CastExpression castExpression = (CastExpression) expression;
processCondition(castExpression.getLeftExpression(), table);
}
// 拼接查询条件
Expression appendCondition = handleCondition(expression, table);
return Objects.isNull(appendCondition) ? expression : appendCondition;
}
/**
* 判断是否是已经是无法再拆分的基本表达式 <br/>
* 比如:列名、常量、函数等
*
* @param expression 表达式
* @return 是否是基本表达式
*/
protected boolean isBasicExpression(@Nullable Expression expression) {
return expression instanceof Column
|| expression instanceof LongValue
|| expression instanceof StringValue
|| expression instanceof DoubleValue
|| expression instanceof NullValue
|| expression instanceof TimeValue
|| expression instanceof TimestampValue
|| expression instanceof DateValue;
}
/**
* 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
*
* @param expression 原有的查询条件
* @param table 指定的表
* @return 查询条件
*/
@Nullable
protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
// 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
if (!(table instanceof Table)) {
return null;
}
String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
if (Objects.nonNull(tenantColumn)) {
return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
}
return null;
}
private static Expression appendTenantCondition(
@Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
equalsTo.setRightExpression(new StringValue(tenantId));
if (Objects.isNull(original)) {
return equalsTo;
}
return original instanceof OrExpression ?
new AndExpression(equalsTo, new Parenthesis(original)) :
new AndExpression(original, equalsTo);
}
private static Column getColumnWithTableAlias(FromItem table, String column) {
// 如果表存在别名,则字段应该变“表别名.字段名”的格式
return Optional.ofNullable(table)
.map(FromItem::getAlias)
.map(alias -> alias.getName() + "." + column)
.map(Column::new)
.orElse(new Column(column));
}
/**
* 租户信息
*/
@RequiredArgsConstructor
public static class TenantInfo {
/**
* 租户ID
*/
private final String tenantId;
/**
* 要添加租户条件的表名称与对应的租户字段
*/
private final Map<String, String> tablesWithTenantColumn;
}
}
现在,针对预期的复杂场景,我们再来测试一下:
public static void main(String[] args) {
Map<String, String> tablesWithTenantColumn = Maps.newHashMap();
tablesWithTenantColumn.put("t", "tenant_id");
TenantInfo tenantInfo = new TenantInfo("1", tablesWithTenantColumn);
TenantSQLInterceptor.setTenantInfo(tenantInfo);
// 处理包含的复杂子查询的SQL
String sql = "select * " +
"from (select * from t where a = 1) t " +
"left join (select * from t where b = 2) t2 on t.id = t2.id " +
"where b in (select * from t where c = 2) and d = 3";
TenantSQLInterceptor interceptor = new TenantSQLInterceptor();
String handledSql = interceptor.handle(sql);
System.out.println(handledSql);
// 输出结果:
// select *
// from (select * from t where a = 1 and tenant_id = '1') t
// left join (select * from t where b = 2 and tenant_id = '1') t2 on t.id = t2.id
// where b in (select * from t where c = 2 and tenant_id = '1') and d = 3
}
完美!
1.4.分离公共代码#
这个 SQL 拦截器已经可以完美满足我们的大部分需求了。现在功能已经实现,可以看看代码层面有什么可以优化的地方了。
我们再次分析一下上述代码,会注意到,上面的解析器其实干了两件事情:
- 解析 SQL,并在递归获取不可再拆分的“根” SQL 后,替换其 where 条件。
- 将 SQL 的 where 条件替换或追加上租户条件。
换而言之,第一步的逻辑似乎与“租户拦截”这个需求无关,它显然可以抽离为一个独立的组件以便后续复用。此外,我们现在实现的其实是一个行级别的租户拦截,如果我们日后需要表级别的租户拦截,最好也有办法基于它来实现。
综上考虑,这里我们将这个新组件根据其功能命名为 AbstractSqlHandler
,并且为其添加一个 handleTable
抽象方法,使其具备拦截表名的能力:
/**
* <p>SQL处理器,用于拦截SQL语句并修改其中的查询条件,
* 该处理器支持处理嵌套查询、联合查询、关联查询等多种查询方式。
*
* @author huangchengxing
* @see #handle
* @see #handleCondition
*/
@Setter
@Slf4j
public abstract class AbstractSqlHandler {
/**
* 处理SQL语句
*
* @param sql SQL语句
* @return 处理后的SQL语句
*/
@Nullable
public String handle(String sql) {
Statements statements = parseStatements(sql);
if (Objects.isNull(statements)) {
return null;
}
List<Statement> statementList = statements.getStatements();
if (CollUtil.isEmpty(statementList)) {
return null;
}
return statements.getStatements().stream()
.map(this::doHandle)
.map(Statement::toString)
.collect(Collectors.joining(";"));
}
@Nullable
private Statements parseStatements(String sql) {
Statements statements = null;
try {
statements = CCJSqlParserUtil.parseStatements(sql);
} catch (JSQLParserException e) {
log.error("SQL 解析失败: {}", sql, e);
throw new RuntimeException("SQL 解析失败");
}
return statements;
}
private Statement doHandle(Statement statement) {
try {
if (statement instanceof Select) {
processSelect(((Select) statement).getSelectBody());
} else if (statement instanceof Update) {
processUpdate((Update) statement);
} else if (statement instanceof Delete) {
processDelete((Delete) statement);
} else if (statement instanceof Insert) {
processInsert((Insert) statement);
}
} catch (Exception ex) {
log.error("SQL 处理失败: {}", statement, ex);
throw new RuntimeException("SQL 处理失败");
}
return statement;
}
private void processSelect(SelectBody selectBody) {
// 普通查询
if (selectBody instanceof PlainSelect) {
processSelect((PlainSelect) selectBody);
}
// 嵌套查询,比如 select xx from (select yy from t)
else if (selectBody instanceof WithItem) {
WithItem withItem = (WithItem) selectBody;
if (withItem.getSelectBody() != null) {
processSelect(withItem.getSelectBody());
}
}
// 联合查询,比如 union
else if (selectBody instanceof SetOperationList) {
SetOperationList operationList = (SetOperationList) selectBody;
if (CollUtil.isNotEmpty(operationList.getSelects())) {
operationList.getSelects().forEach(this::processSelect);
}
}
// 值查询,比如 select 1, 2, 3
else if (selectBody instanceof ValuesStatement) {
List<Expression> expressions = ((ValuesStatement) selectBody).getExpressions();
if (CollUtil.isNotEmpty(expressions)) {
expressions.forEach(exp -> processCondition(exp, null));
}
} else {
log.error("无法解析的 select 语句:{}({})", selectBody, selectBody.getClass());
throw new RuntimeException("不支持的查询语句:" + selectBody.getClass().getName());
}
}
/**
* 处理插入语句
*
* @param insert 插入语句
*/
protected void processInsert(Insert insert) {
// do nothing
}
/**
* 处理删除语句
*
* @param delete 删除语句
*/
protected void processDelete(Delete delete) {
Table table = delete.getTable();
delete.setWhere(processCondition(delete.getWhere(), table));
// 如果还存在关联查询
List<Join> joins = delete.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理更新语句
*
* @param update 更新语句
*/
protected void processUpdate(Update update) {
Table table = update.getTable();
update.setWhere(processCondition(update.getWhere(), table));
// 如果还存在关联查询
List<Join> joins = update.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理查询语句
*
* @param plainSelect 查询语句
*/
protected void processSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
// 如果是普通的表名
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
plainSelect.setFromItem(handleTable(fromTable));
plainSelect.setWhere(processCondition(plainSelect.getWhere(), fromTable));
}
// 如果是子查询,比如 select * from (select xxx from yyy)
else if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
plainSelect.setWhere(processCondition(plainSelect.getWhere(), subSelect));
}
// 如果是带有特殊函数的子查询,比如 lateral (select sum(*) from yyy)
else if (fromItem instanceof SpecialSubSelect) {
SpecialSubSelect specialSubSelect = (SpecialSubSelect) fromItem;
if (specialSubSelect.getSubSelect() != null) {
SubSelect subSelect = specialSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
}
plainSelect.setWhere(processCondition(plainSelect.getWhere(), specialSubSelect));
}
// 未知类型的查询,直接报错
else {
log.error("无法解析的 from 语句:{}({})", fromItem, fromItem.getClass());
throw new RuntimeException("不支持的查询语句:" + fromItem.getClass().getName());
}
// 如果还存在关联查询
List<Join> joins = plainSelect.getJoins();
if (CollUtil.isNotEmpty(joins)) {
joins.forEach(this::processJoin);
}
}
/**
* 处理关联查询
*
* @param join 关联查询
*/
protected void processJoin(Join join) {
FromItem joinTable = join.getRightItem();
if (joinTable instanceof Table) {
Table table = (Table) joinTable;
join.setRightItem(handleTable((Table) joinTable));
join.setOnExpression(processCondition(join.getOnExpression(), table));
}
else if (joinTable instanceof SubSelect) {
processSelect(((SubSelect) joinTable).getSelectBody());
}
else if (joinTable instanceof SpecialSubSelect) {
SpecialSubSelect specialSubSelect = (SpecialSubSelect) joinTable;
if (specialSubSelect.getSubSelect() != null) {
SubSelect subSelect = specialSubSelect.getSubSelect();
if (subSelect.getSelectBody() != null) {
processSelect(subSelect.getSelectBody());
}
}
}
else {
log.error("无法解析的 join 语句:{}({})", joinTable, joinTable.getClass());
throw new RuntimeException("不支持的查询语句:" + joinTable.getClass().getName());
}
}
/**
* <p>获取添加了租户条件的查询条件,若条件中存在子查询,则也会为子查询添加租户条件。
*
* @param expression 条件表达式
* @param table 表
* @return 添加租户条件后的条件表达式
*/
@SuppressWarnings({"java:S6541", "java:S3776"})
protected Expression processCondition(@Nullable Expression expression, FromItem table) {
// 如果已经不可拆分的表达式,则直接返回
if (isBasicExpression(expression)) {
return expression;
}
// 如果是子查询,则需要对子查询进行递归处理
else if (expression instanceof SubSelect) {
processSelect(((SubSelect) expression).getSelectBody());
}
// 如果是 in 条件,比如:xxx in (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof InExpression) {
InExpression inExp = (InExpression) expression;
ItemsList rightItems = inExp.getRightItemsList();
if (rightItems instanceof SubSelect) {
processSelect(((SubSelect) rightItems).getSelectBody());
}
}
// 如果是 not 或者 != 条件,则需要对里面的条件进行递归处理
else if (expression instanceof NotExpression) {
NotExpression notExpression = (NotExpression) expression;
processCondition(notExpression.getExpression(), table);
}
// 如果是 (xxx != xxx),则需要对括号里面的表达式进行递归处理
else if (expression instanceof Parenthesis) {
Parenthesis parenthesis = (Parenthesis) expression;
Expression content = parenthesis.getExpression();
processCondition(content, table);
}
// 如果是二元表达式,比如:xx = xx,xx > xx,则需要对左右两边的表达式进行递归处理
else if (expression instanceof BinaryExpression) {
BinaryExpression binaryExpression = (BinaryExpression) expression;
Expression left = binaryExpression.getLeftExpression();
processCondition(left, table);
Expression right = binaryExpression.getRightExpression();
processCondition(right, table);
}
// 如果是函数,比如:if(xx, xx) ,则需要对函数的参数进行递归处理
else if (expression instanceof Function) {
Function function = (Function) expression;
ExpressionList parameters = function.getParameters();
if (parameters != null) {
parameters.getExpressions().forEach(param -> processCondition(param, table));
}
}
// 如果是 case when 语句,则需要对 when 和 then 两个条件进行递归处理
else if (expression instanceof WhenClause) {
WhenClause whenClause = (WhenClause) expression;
processCondition(whenClause.getWhenExpression(), table);
processCondition(whenClause.getThenExpression(), table);
}
// 如果是 case 语句,则需要对 switch、when、then、else 四个条件进行递归处理
else if (expression instanceof CaseExpression) {
CaseExpression caseExpression = (CaseExpression) expression;
processCondition(caseExpression.getSwitchExpression(), table);
List<WhenClause> whenClauses = caseExpression.getWhenClauses();
if (CollUtil.isNotEmpty(whenClauses)) {
whenClauses.forEach(whenClause -> {
processCondition(whenClause.getWhenExpression(), table);
processCondition(whenClause.getThenExpression(), table);
});
}
processCondition(caseExpression.getElseExpression(), table);
}
// 如果是 exists 语句,比如:exists (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof ExistsExpression) {
Expression existsExpression = ((ExistsExpression) expression).getRightExpression();
if (existsExpression instanceof SubSelect) {
processSelect(((SubSelect) existsExpression).getSelectBody());
}
}
// 如果是 all 或者 any 语句,比如:xx > all (select xx from yy……),则需要对子查询进行递归处理
else if (expression instanceof AllComparisonExpression) {
AllComparisonExpression allComparisonExpression = (AllComparisonExpression) expression;
processSelect(allComparisonExpression.getSubSelect().getSelectBody());
}
else if (expression instanceof AnyComparisonExpression) {
AnyComparisonExpression anyComparisonExpression = (AnyComparisonExpression) expression;
processSelect(anyComparisonExpression.getSubSelect().getSelectBody());
}
// 如果是 cast 语句,比如:cast(xx as xx),则需要对子查询进行递归处理
else if (expression instanceof CastExpression) {
CastExpression castExpression = (CastExpression) expression;
processCondition(castExpression.getLeftExpression(), table);
}
// 拼接查询条件
Expression appendCondition = handleCondition(expression, table);
return Objects.isNull(appendCondition) ? expression : appendCondition;
}
/**
* 返回一个查询条件,该查询条件将替换{@code table}原有的{@code where}条件
*
* @param expression 原有的查询条件
* @param table 指定的表
* @return 查询条件
*/
protected abstract Expression handleCondition(@Nullable Expression expression, FromItem table);
/**
* 返回一个表名,该表名将替换原有的表名
*
* @param table 表名
* @return 处理后的表名
*/
protected FromItem handleTable(Table table) {
return table;
}
/**
* 判断是否是已经是无法再拆分的基本表达式 <br/>
* 比如:列名、常量、函数等
*
* @param expression 表达式
* @return 是否是基本表达式
*/
protected boolean isBasicExpression(@Nullable Expression expression) {
return expression instanceof Column
|| expression instanceof LongValue
|| expression instanceof StringValue
|| expression instanceof DoubleValue
|| expression instanceof NullValue
|| expression instanceof TimeValue
|| expression instanceof TimestampValue
|| expression instanceof DateValue;
}
}
�接着,对于原本的 SQL 拦截器,我们令其继承 AbstractSqlHandler
,然后更换一个更合适的名字 LineLevelTenantSqlHandler
:
/**
* SQL拦截器,用于为SQL语句添加租户条件。
* 每次执行SQL时,将会检查当前线程上下文中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
*
* @author huangchengxing
* @see ContextTenantConditionSqlHandlerAdvisor
*/
@Slf4j
public class LineLevelTenantSqlHandler extends AbstractConditionSqlHandler {
private static final ThreadLocal<TenantInfo> TENANT_INFO_CONTEXT = new TransmittableThreadLocal<>();
/**
* 设置租户信息
*
* @param tenantInfo 租户信息
*/
public static void setTenantInfo(TenantInfo tenantInfo) {
TENANT_INFO_CONTEXT.set(tenantInfo);
}
/**
* 清除租户信息
*/
public static void clearTenantInfo() {
TENANT_INFO_CONTEXT.remove();
}
@Override
public String handle(String sql) {
// 如果未设置租户信息,则直接返回原始SQL
TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
if (Objects.isNull(tenantInfo)) {
return sql;
}
log.debug("租户拦截器拦截原始 SQL: {}", sql);
String handledSql = super.handle(sql);
log.info("租户拦截器拦截后 SQL: {}", handledSql);
return Objects.isNull(handledSql) ? sql : handledSql;
}
@Override
@Nullable
protected Expression handleCondition(@Nullable Expression expression, FromItem table) {
TenantInfo tenantInfo = TENANT_INFO_CONTEXT.get();
// 如果是一个标准表名,且改表名在租户表列表中,则为查询条件添加租户条件
if (!(table instanceof Table)) {
return null;
}
String tenantColumn = tenantInfo.tablesWithTenantColumn.get(((Table) table).getName());
if (Objects.nonNull(tenantColumn)) {
return appendTenantCondition(expression, table, tenantInfo.tenantId, tenantColumn);
}
return null;
}
private static Expression appendTenantCondition(
@Nullable Expression original, FromItem table, String tenantId, String tenantColumn) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(getColumnWithTableAlias(table, tenantColumn));
equalsTo.setRightExpression(new StringValue(tenantId));
if (Objects.isNull(original)) {
return equalsTo;
}
return original instanceof OrExpression ?
new AndExpression(equalsTo, new Parenthesis(original)) :
new AndExpression(original, equalsTo);
}
private static Column getColumnWithTableAlias(FromItem table, String column) {
// 如果表存在别名,则字段应该变“表别名.字段名”的格式
return Optional.ofNullable(table)
.map(FromItem::getAlias)
.map(alias -> alias.getName() + "." + column)
.map(Column::new)
.orElse(new Column(column));
}
/**
* 租户信息
*/
@RequiredArgsConstructor
public static class TenantInfo {
/**
* 租户ID
*/
private final String tenantId;
/**
* 要添加租户条件的表名称与对应的租户字段
*/
private final Map<String, String> tablesWithTenantColumn;
}
}
1.5.与 JPA 结合使用#
JPA 的默认实现 Hibernate 提供了 StatementInspector
接口,我们实现一个自定义的实现类,然后让基础上文实现好的租户解析器即可 LineLevelTenantSqlHandler
:
/**
* SQL拦截器,用于为SQL语句添加租户条件。
* 每次执行SQL时,将会检查当前线程中是否存在租户信息,如果存在,则会为查询语句添加租户条件,否则直接略过。
*
* @author huangchengxing
*/
@Slf4j
@RequiredArgsConstructor
public class HibernateLineLevelTenantStatementInspector
extends LineLevelTenantSqlHandler implements StatementInspector {
@Override
public String inspect(String sql) {
return handle(sql);
}
}
同理,我们也可以结合 Mybatis 或其他的框架实现类似的效果。
2.租户拦截器#
显然,我们不可能无条件的拦截所有的查询,有些查询本身不需要进行拦截,而有些查询当访问者为管理员时也不需要拦截……总而言之,对应租户拦截,我们需要采用白名单而不是黑名单的方式,因此最好的实现方法就是搞一个切面,然后只对带有特定注解的方法的调用进行拦截。
2.1.注解类#
我们定义一个 @TenantOperation
注解,该注解可以被用于方法或者类上,当用于类上的时候等于类中所有的方法都应用拦截:
/**
* 表明方法是一个租户操作方法,需要在相关的SQL中加入租户过滤条件
*
* @author huangchengxing
* @see ContextTenantConditionSqlHandlerAdvisor
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.TYPE})
public @interface TenantOperation {
/**
* 表配置
*
* @return 表配置
*/
Tables[] value() default {};
/**
* 是否对当前方法与后续调用链不进行租户拦截
*
* @return boolean
* @see Ignore
*/
boolean ignore() default false;
/**
* 对当前方法与后续调用链不进行租户拦截
*/
@TenantOperation(ignore = true) // 基于 Springt 合成注解机制的扩展注解
@Documented
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.ANNOTATION_TYPE})
@interface Ignore {}
/**
* 表配置
*/
@Documented
@Retention(RetentionPolicy.RUNTIME)
@interface Tables {
/**
* 租户字段名,不指定时默认遵循配置文件中的字段名
*
* @return String
*/
String column() default "";
/**
* 需要添加过滤条件的表名,不指定时默认遵循配置文件中的表名
*
* @return String
*/
String[] tables() default {};
}
}
此外,为了便于使用,注解还支持直接指定要拦截的表和字段,以便覆盖默认配置文件中的配置。
2.2.方法拦截器#
为了便于后续扩展,这里笔者没有基于 Aspect 注解,而是基于 Spring 的方法拦截器,自定义了切点来实现这个效果:
/**
* 方法拦截器,用于拦截带有{@link TenantOperation}注解的方法,为涉及的查询语句添加租户过滤条件
*
* @author huangchengxing
* @see LineLevelTenantOperationAdvisor
*/
@Slf4j
public class LineLevelTenantOperationAdvisor implements PointcutAdvisor, MethodInterceptor {
private static final String INTERCEPT_REQUEST_ENTRY = "tenant";
private static final TenantOpsInfo NULL = new TenantOpsInfo(null);
private final Map<Method, TenantOpsInfo> tenantInfoCaches = new ConcurrentReferenceHashMap<>();
private final TenantOpsInfo opsByDefault;
public LineLevelTenantOperationAdvisor(Map<String, String> tableWithColumns) {
this.opsByDefault = new TenantOpsInfo(tableWithColumns);
}
@Override
public Object invoke(MethodInvocation methodInvocation) throws Throwable {
// 从上下文获取租户ID
String tenantId = Optional.ofNullable(RequestUserContext.getUser())
.map(RequestUserContext.User::getUserId)
.orElse(null);
// 若没有上下文信息,则直接放行
if (Objects.isNull(tenantId)) {
return methodInvocation.proceed();
}
// 解析配置信息
TenantOpsInfo info = resolveMethod(methodInvocation.getMethod());
if (info == NULL) {
return methodInvocation.proceed();
}
// 设置租户信息
try {
LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
return methodInvocation.proceed();
} finally {
LineLevelTenantSqlHandler.clearTenantInfo();
}
}
private TenantOpsInfo resolveMethod(Method method) {
return tenantInfoCaches.computeIfAbsent(method, m -> {
// 从方法上或类上获取注解
TenantOperation annotation = Optional.ofNullable(AnnotatedElementUtils.findMergedAnnotation(method, TenantOperation.class))
.orElse(AnnotatedElementUtils.findMergedAnnotation(method.getDeclaringClass(), TenantOperation.class));
if (Objects.isNull(annotation)) {
return NULL;
}
// 若注解未指定column和tables,则使用默认值
TenantOperation.Tables[] tables = annotation.value();
if (ArrayUtil.isEmpty(tables)) {
return opsByDefault;
}
// 若指定了column和tables,则使用指定值
Map<String, String> tableWithColumns = new HashMap<>(tables.length);
for (TenantOperation.Tables table : tables) {
String column = table.column();
for (String tableName : table.tables()) {
tableWithColumns.put(tableName, column);
}
}
return new TenantOpsInfo(tableWithColumns);
});
}
@RequiredArgsConstructor
private static class TenantOpsInfo {
private final Map<String, String> tablesWithTenantColumn;
public LineLevelTenantSqlHandler.TenantInfo getTenantInfo(String tenantId) {
return new LineLevelTenantSqlHandler.TenantInfo(tenantId, tablesWithTenantColumn);
}
}
@Override
public @NonNull Pointcut getPointcut() {
return TenantOperationPointcut.INSTANCE;
}
@Override
public @NonNull Advice getAdvice() {
return this;
}
@Override
public boolean isPerInstance() {
return false;
}
// 自定义切点,拦截带有 @TenantOperation 注解的方法,或声明类上带有 @TenantOperation 注解的全部方法
private static class TenantOperationPointcut extends StaticMethodMatcher implements Pointcut {
public static final TenantOperationPointcut INSTANCE = new TenantOperationPointcut();
@Override
public @NonNull ClassFilter getClassFilter() {
return ClassFilter.TRUE;
}
@Override
public @NonNull MethodMatcher getMethodMatcher() {
return this;
}
@Override
public boolean matches(@NonNull Method method, @NonNull Class<?> type) {
return AnnotatedElementUtils.isAnnotated(method, TenantOperation.class)
|| AnnotatedElementUtils.isAnnotated(type, TenantOperation.class);
}
}
}
2.3.上下文传递问题#
如上文,我们选择使用方法拦截器在方法执行前设置租户信息,在方法执行后清空租户信息,这种做法在当同一条调用链上,同上触发了多次拦截时就会出现问题:
// 设置租户信息
try {
LineLevelTenantSqlHandler.setTenantInfo(info.getTenantInfo(tenantId));
return methodInvocation.proceed();
} finally {
LineLevelTenantSqlHandler.clearTenantInfo();
}
举个例子,假如我们存在如下的调用:
@TenantOperation(
@Tables(column = "userId")
)
public void method1() {
// do something
method2();
method3()
// do something
}
@TenantOperation(ignore = true) // 该方法不需要进行租户拦截
public void method1() {
// do something
}
@TenantOperation(
@Tables(column = "tenantId")
)
public void method1() {
// do something
}
如果我们假设每个方法都能被正确的拦截,那么按原有的代码,当执行了 method2 以后,由于直接清空了上下文,最终会导致后续的调用都没有办法正确获取到租户信息。同理,
这个问题与 Spring 的事务传播有点异曲同工,我们的解决方案也类似,那就引入“挂起”这个概念。简单的来说,如果有一个被拦截的方法触发了上下文租户信息的更新,纳那么:
- 如果上下文已经存在租户信息,说明当前方法只是调用链中的一个环节,那么就需要先将其挂起,先放入当前方法配置的租户信息,等到执行结束后,再将旧的租户信息放回上下文;
- 如果上下文中没有存在租户信息,说明当前方法已经是调用链的源头,那么当执行完毕后,可以直接请上下文清空。
对此,我们参照 Spring 的做法,稍微调整一下这部分代码即可:
// 暂时挂起上一层级方法设置的租户信息
LineLevelTenantSqlHandler.TenantInfo previous = LineLevelTenantSqlHandler.getTenantInfo();
// 若当前方法设置了忽略租户信息,则清空上下文,否则设置当前租户信息
if (info.isIgnore()) {
LineLevelTenantSqlHandler.clearTenantInfo();
} else {
LineLevelTenantSqlHandler.TenantInfo current = new LineLevelTenantSqlHandler.TenantInfo(tenantId, info.getTablesWithTenantColumn());
LineLevelTenantSqlHandler.setTenantInfo(current);
}
try {
return methodInvocation.proceed();
} finally {
// 恢复挂起的租户信息
if (Objects.nonNull(previous)) {
LineLevelTenantSqlHandler.setTenantInfo(previous);
}
// 若之前没有租户信息,则清空上下文
else {
LineLevelTenantSqlHandler.clearTenantInfo();
}
}
3.使用#
3.1.配置类#
首先,我们先定义一个配置类以在项目中启用上述组件:
/**
* <p>租户拦截器配置,启用后可以为指定的查询方法添加租户过滤条件。 <br/>
* 可通过配置文件进行配置:<br/>
* <pre>
* # JPA 启用租户 SQL 拦截器
* spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
* # 启用租户拦截器
* tenant.interceptor.enabled=true
* # 需要拦截的表
* tenant.interceptor.tables[0].column = tenant_id
* tenant.interceptor.tables[0].tableNames = table1, table2
* </pre>
*
* @author huangchengxing
*/
@Slf4j
@ConditionalOnProperty(prefix = TenantInterceptorConfig.Properties.CONFIG_PREFIX, name = "enabled", havingValue = "true")
@EnableConfigurationProperties(TenantInterceptorConfig.Properties.class)
@Configuration
public class TenantInterceptorConfig {
@Bean
public LineLevelTenantOperationAdvisor lineLevelTenantOperationAdvisor(Properties properties) {
log.info("启用租户拦截器,需要拦截的表:{}", properties.getTables());
Map<String, String> tableWithColumns = new HashMap<>(16);
properties.getTables().forEach(ts -> ts.getTableNames().forEach(t -> {
Assert.isFalse(tableWithColumns.containsKey(t), "同一张表具备只允许具备一个租户字段:{}", t);
tableWithColumns.put(t, ts.getColumn());
}));
return new LineLevelTenantOperationAdvisor(tableWithColumns);
}
/**
* @author huangchengxing
*/
@ConfigurationProperties(prefix = Properties.CONFIG_PREFIX)
@Data
public static class Properties {
public static final String CONFIG_PREFIX = "tenant.interceptor";
/**
* 表配置
*/
private List<Tables> tables = new ArrayList<>();
@Data
public static class Tables {
/**
* 租户字段名
*/
private String column;
/**
* 需要拦截的表名
*/
private Set<String> tableNames;
}
}
}
3.2.配置文件#
随后在配置文件中启用拦截器,并配置好要拦截的表:
# 启用租户 SQL 拦截器
spring.jpa.properties.hibernate.session_factory.statement_inspector=io.github.createsequence.wheel.spring.tenant.HibernateTenantStatementInspector
# 启用租户拦截器
tenant.interceptor.enabled=true
# 拦截 t_user, t_resource, t_assest 表中的 tenant_id 字段
tenant.interceptor.tables[0].column=tenant_id
tenant.interceptor.tables[0].table-names=t_user, t_resource, t_assest
3.3.添加注解#
最后,我们只要在对应的类或者方法上添加 @TenantOperation
即可:
@TenantOperation // 默认所有方法都要应用拦截
@RestController
public class ResourceController {
// @TenantOperation 因为类上已经加了,所以方法上可以不用加
@GetMapping
public List<Resource> listResource1(List<Integer> ids) {
// do something
}
@TenantOperation.Ingore // 该方法不进行拦截
@GetMapping
public List<Resource> listResource2(List<Integer> ids) {
// do something
}
}
作者:Createsequence
出处:https://www.cnblogs.com/Createsequence/p/18093433
版权:本作品采用「署名-非商业性使用-相同方式共享 4.0 国际」许可协议进行许可。
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步
· DeepSeek 开源周回顾「GitHub 热点速览」
· 物流快递公司核心技术能力-地址解析分单基础技术分享
· .NET 10首个预览版发布:重大改进与新特性概览!
· AI与.NET技术实操系列(二):开始使用ML.NET
· 单线程的Redis速度为什么快?