MybatisPlus通过插件自定义逻辑处理SQL实现数据隔离

1.应用场景:

1.最近公司在做数据隔离,从原有的系统的组织架构扩展,每个组织只能看到自己的数据,数据库都是用的一套,所以在每个表都加上了DeptId绑定组织,由于每个业务单独去改sql实现数据隔离有点困难,遂找到此插件方案实现数据隔离字段的补充。

2.其它场景:sql打印,sql拦截,sql分页。mybatisplus自带相关插件,可自行官网查看

https://baomidou.com/pages/2976a3/

2.自定义插件实现

pom依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.1.3.RELEASE</version>
        <relativePath/> <!-- lookup parent from repository -->
    </parent>
    <groupId>cn.tellsea</groupId>
    <artifactId>springboot-mybatis-plus</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>springboot-mybatis-plus</name>
    <description>Demo project for Spring Boot</description>

    <properties>
        <java.version>1.8</java.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus-boot-starter</artifactId>
            <version>3.3.1</version>
        </dependency>
        <dependency>
            <groupId>com.baomidou</groupId>
            <artifactId>mybatis-plus</artifactId>
            <version>3.3.1</version>
        </dependency>
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
        </dependency>

        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.7</version>
        </dependency>

        <!-- druid数据源 -->
        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>druid-spring-boot-starter</artifactId>
            <version>1.1.14</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>1.2.17</version>
        </dependency>
    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
        </plugins>
    </build>

</project>

yml配置:

spring:
  datasource:
    type: com.alibaba.druid.pool.DruidDataSource
    driver-class-name: com.mysql.jdbc.Driver
    platform: mysql
    url: jdbc:mysql://127.0.0.1:3306/test?useUnicode=true&characterEncoding=UTF-8&serverTimezone=UTC
    username: root
    password: hekang
    initialSize: 5
    minIdle: 5
    maxActive: 20
    maxWait: 60000
    timeBetweenEvictionRunsMillis: 60000
    minEvictableIdleTimeMillis: 300000
    validationQuery: SELECT1FROMDUAL
    testWhileIdle: true
    testOnBorrow: false
    testOnReturn: false
    filters: stat,wall,log4j
    logSlowSql: true

mybatis-plus:
  configuration:
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl

# ��־
logging:
  level:
    cn.tellsea.skeleton.business.mapper: debug

项目结构

插件实现:

package cn.tellsea.config.impl;

import cn.tellsea.config.ExecutorPluginUtils;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.select.FromItem;
import net.sf.jsqlparser.statement.select.PlainSelect;
import net.sf.jsqlparser.statement.select.Select;
import net.sf.jsqlparser.statement.update.Update;
import org.apache.ibatis.cache.CacheKey;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

@Slf4j
@Intercepts({
        @Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
        @Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class})
})
public class PlatformInterceptor implements Interceptor {

    private static final String DATA_COLUMN_NAME = "dept_id";

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        String processSql = ExecutorPluginUtils.getSqlByInvocation(invocation);
        log.info("Sql替换前:{}", processSql);
        // 获取sql
        String sql2Reset = processSql;
        //忽略sql中包含on conflict的情况
        Statement statement = CCJSqlParserUtil.parse(processSql);

        MappedStatement mappedStatement = (MappedStatement) invocation.getArgs()[0];

        if (ExecutorPluginUtils.isPlatFormTag(mappedStatement)) {
            try {
                if (statement instanceof Update) {
                    Update updateStatement = (Update) statement;
                    Table table = updateStatement.getTable();
                    if (table != null) {
                        List<Column> columns = updateStatement.getColumns();
                        List<Expression> expressions = updateStatement.getExpressions();
                        columns.add(new Column(DATA_COLUMN_NAME));
                        expressions.add(CCJSqlParserUtil.parseExpression("1"));
                        updateStatement.setColumns(columns);
                        updateStatement.setExpressions(expressions);
                        sql2Reset = updateStatement.toString();
                    }

                }
                if (statement instanceof Insert) {
                    Insert insertStatement = (Insert) statement;
                    List<Column> columns = insertStatement.getColumns();
                    ExpressionList itemsList = (ExpressionList) insertStatement.getItemsList();
                    columns.add(new Column(DATA_COLUMN_NAME));
                    List<Expression> list = new ArrayList<>();
                    list.addAll(itemsList.getExpressions());
                    list.add(CCJSqlParserUtil.parseExpression("1"));
                    itemsList.setExpressions(list);
                    insertStatement.setItemsList(itemsList);
                    insertStatement.setColumns(columns);
                    sql2Reset = insertStatement.toString();
                }
                if (statement instanceof Select) {
                    Select selectStatement = (Select) statement;
                    PlainSelect plain = (PlainSelect) selectStatement.getSelectBody();
                    FromItem fromItem = plain.getFromItem();
                    StringBuffer whereSql = new StringBuffer();
                    //增加sql语句的逻辑部分处理
                    if (fromItem.getAlias() != null) {
                        whereSql.append(fromItem.getAlias().getName()).append(".dept_id = ").append(1);
                    } else {
                        whereSql.append("dept_id = ").append(1);
                    }
                    Expression where = plain.getWhere();
                    if (where == null) {
                        if (whereSql.length() > 0) {
                            Expression expression = CCJSqlParserUtil
                                    .parseCondExpression(whereSql.toString());
                            Expression whereExpression = (Expression) expression;
                            plain.setWhere(whereExpression);
                        }
                    } else {
                        if (whereSql.length() > 0) {
                            //where条件之前存在,需要重新进行拼接
                            whereSql.append(" and ( " + where.toString() + " )");
                        } else {
                            //新增片段不存在,使用之前的sql
                            whereSql.append(where.toString());
                        }
                        Expression expression = CCJSqlParserUtil
                                .parseCondExpression(whereSql.toString());
                        plain.setWhere(expression);
                    }
                    sql2Reset = selectStatement.toString();
                }

            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        log.info("sql替换后:{}", sql2Reset);
        // 替换sql
        ExecutorPluginUtils.resetSql2Invocation(invocation, sql2Reset);

        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }
}

工具类实现注解判断及相关插件处理

import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.mapping.SqlSource;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.factory.DefaultObjectFactory;
import org.apache.ibatis.reflection.wrapper.DefaultObjectWrapperFactory;

import java.lang.reflect.Method;
import java.sql.SQLException;

/**
 * TODO
 *
 * @author tengwang8
 * @version 1.0
 * @date 2022/3/2 9:07
 */
public class ExecutorPluginUtils {

    /**
     * 获取sql语句
     * @param invocation
     * @return
     */
    public static String getSqlByInvocation(Invocation invocation) {
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        Object parameterObject = args[1];
        BoundSql boundSql = ms.getBoundSql(parameterObject);
        return boundSql.getSql();
    }

    /**
     * 包装sql后,重置到invocation中
     * @param invocation
     * @param sql
     * @throws SQLException
     */
    public static void resetSql2Invocation(Invocation invocation, String sql) throws SQLException {
        final Object[] args = invocation.getArgs();
        MappedStatement statement = (MappedStatement) args[0];
        Object parameterObject = args[1];
        BoundSql boundSql = statement.getBoundSql(parameterObject);
        MappedStatement newStatement = newMappedStatement(statement, new BoundSqlSqlSource(boundSql));
        MetaObject msObject =  MetaObject.forObject(newStatement, new DefaultObjectFactory(), new DefaultObjectWrapperFactory(),new DefaultReflectorFactory());
        msObject.setValue("sqlSource.boundSql.sql", sql);
        args[0] = newStatement;
    }


    private static MappedStatement newMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
        MappedStatement.Builder builder =
                new MappedStatement.Builder(ms.getConfiguration(), ms.getId(), newSqlSource, ms.getSqlCommandType());
        builder.resource(ms.getResource());
        builder.fetchSize(ms.getFetchSize());
        builder.statementType(ms.getStatementType());
        builder.keyGenerator(ms.getKeyGenerator());
        if (ms.getKeyProperties() != null && ms.getKeyProperties().length != 0) {
            StringBuilder keyProperties = new StringBuilder();
            for (String keyProperty : ms.getKeyProperties()) {
                keyProperties.append(keyProperty).append(",");
            }
            keyProperties.delete(keyProperties.length() - 1, keyProperties.length());
            builder.keyProperty(keyProperties.toString());
        }
        builder.timeout(ms.getTimeout());
        builder.parameterMap(ms.getParameterMap());
        builder.resultMaps(ms.getResultMaps());
        builder.resultSetType(ms.getResultSetType());
        builder.cache(ms.getCache());
        builder.flushCacheRequired(ms.isFlushCacheRequired());
        builder.useCache(ms.isUseCache());

        return builder.build();
    }

    /**
     * 是否标记为区域字段
     * @return
     */
    public static boolean isPlatFormTag( MappedStatement mappedStatement) throws ClassNotFoundException {
        String id = mappedStatement.getId();
        Class<?> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));

        //获取对应拦截方法名
        String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1);

        boolean ignore = false;

        for(Method method : classType.getDeclaredMethods()){
            if(method.isAnnotationPresent(PlatformTagIngore.class) && mName.equals(method.getName()) ) {
                ignore = true;
            }
        }

        if (classType.isAnnotationPresent(PlatformTag.class) && !ignore) {
            return true;
        }
        return false;
    }


    /**
     * 是否标记为区域字段
     * @return
     */
    public static boolean isAreaTagIngore( MappedStatement mappedStatement) throws ClassNotFoundException {
        String id = mappedStatement.getId();
        Class<?> classType = Class.forName(id.substring(0,mappedStatement.getId().lastIndexOf(".")));
        //获取对应拦截方法名
        String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1);
        boolean ignore = false;
        for(Method method : classType.getDeclaredMethods()){
            if(method.isAnnotationPresent(PlatformTagIngore.class) && mName.equals(method.getName()) ) {
                ignore = true;
            }
        }
        return ignore;
    }


    public static String getOperateType(Invocation invocation) {
        final Object[] args = invocation.getArgs();
        MappedStatement ms = (MappedStatement) args[0];
        SqlCommandType commondType = ms.getSqlCommandType();
        if (commondType.compareTo(SqlCommandType.SELECT) == 0) {
            return "select";
        }
        if (commondType.compareTo(SqlCommandType.INSERT) == 0) {
            return "insert";
        }
        if (commondType.compareTo(SqlCommandType.UPDATE) == 0) {
            return "update";
        }
        if (commondType.compareTo(SqlCommandType.DELETE) == 0) {
            return "delete";
        }
        return null;
    }
    //    定义一个内部辅助类,作用是包装sq
    static class BoundSqlSqlSource implements SqlSource {
        private BoundSql boundSql;
        public BoundSqlSqlSource(BoundSql boundSql) {
            this.boundSql = boundSql;
        }
        @Override
        public BoundSql getBoundSql(Object parameterObject) {
            return boundSql;
        }
    }



}


注入拦截器到拦截器链

package cn.tellsea.config;

import cn.tellsea.config.impl.PlatformInterceptor;
import com.baomidou.mybatisplus.autoconfigure.ConfigurationCustomizer;
import com.baomidou.mybatisplus.core.MybatisConfiguration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.Properties;

@Configuration
public class MpConfig {

    //将插件加入到mybatis插件拦截链中
    @Bean
    public ConfigurationCustomizer configurationCustomizer() {
        return new ConfigurationCustomizer() {
            @Override
            public void customize(org.apache.ibatis.session.Configuration configuration) {
                //插件拦截链采用了责任链模式,执行顺序和加入连接链的顺序有关
                PlatformInterceptor myPlugin = new PlatformInterceptor();
                //设置参数,比如阈值等,可以在配置文件中配置,这里直接写死便于测试
                Properties properties = new Properties();
                //这里设置慢查询阈值为1毫秒,便于测试
                myPlugin.setProperties(properties);
                configuration.addInterceptor(myPlugin);
            }
        };
    }
}


Mapper标记注解

标记的需要拦截,对于不需要拦截的方法也可用注解进行放过

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTag {
}


@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface PlatformTagIngore {
}


实体Bean:

@Data
@TableName("test_content")
public class TestContent {
    /**
     * ID
     */
    @TableId(type = IdType.ASSIGN_ID)
    private Long id;


    /**
     * 数据内容
     */
    private String content;

    /**
     * 部门id--新增字段,懒得每个sql去加这个字段使用插件拦截补充sql具体字段
     */
    private Integer deptId;
}

Mapper:

package cn.tellsea.mapper;

import cn.tellsea.config.PlatformTag;
import cn.tellsea.config.PlatformTagIngore;
import cn.tellsea.entity.TestContent;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;

/**
 * 使用注解来标注需要处理的Mapper
 */
@PlatformTag
public interface TestContentMapper extends BaseMapper<TestContent> {

    /**
     * 使用注解来标注不需要插件处理的方法
     * @param id
     * @return
     */
    @PlatformTagIngore
    TestContent selectById(Long id);
}

其它控制层实现层

@Slf4j
@RestController
public class TestController {
    @Autowired
    private TestContentService testContentService;


    @GetMapping("/test1")
    public List<TestContent> query() {
        LambdaQueryWrapper<TestContent> wrapper = new LambdaQueryWrapper<>();
        wrapper.in(TestContent::getId, 1, 2, 3);
        List<TestContent> testContents = testContentService.getBaseMapper().selectList(wrapper);
        return testContents;
    }

    @GetMapping("/test2")
    public String add() {
        TestContent testContent = new TestContent();
        testContent.setContent(new Random().nextInt() + "自定义添加内容");
        testContent.setDeptId(1);
        testContent.setId(IdWorker.getId());
        int insert = testContentService.getBaseMapper().insert(testContent);
        log.info("插入成功:{}", testContent.getId());
        return "插入成功";
    }

    @GetMapping("/test3")
    public String update() {
        TestContent testContent = new TestContent();
        testContent.setContent(new Random().nextInt() + "自定义修改内容");
        testContent.setId(1L);
        int insert = testContentService.getBaseMapper().updateById(testContent);
        log.info("修改成功:{}", testContent.getId());
        return "修改成功";
    }

    @GetMapping("/test4")
    public TestContent queryIgnore() {
        TestContent testContent = testContentService.getBaseMapper().selectById(1);
        log.info("查询:{}", testContent.toString());
        return testContent;
    }
}


@Service
public class TestContentServiceImpl extends ServiceImpl<TestContentMapper, TestContent> implements TestContentService  {
}

public interface TestContentService extends IService<TestContent> {

}

测试结果:

查询时带上deptId

其它几个测试可自行查看,新增和修改主要是实现字段补充,查询主要是在sql后拼接新增字段

posted @ 2022-06-22 17:40  胡小华  阅读(869)  评论(0编辑  收藏  举报