mybatis拦截器处理

1.自定义注释

package com.hsfw.backyard.biz.security.authority;

import java.lang.annotation.*;

/**
 * 数据权限过滤自定义注解
 *
 * @Description
 * @Author: liucq
 * @Date: 2018/12/14
 */
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface UserPermissionAop {
    String value() default "";
}

2.Utils

package com.hsfw.backyard.biz.security.authority.util;

import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;

public class ReflectUtil {
    /**
     * 利用反射获取指定对象的指定属性
     *
     * @param obj       目标对象
     * @param fieldName 目标属性
     * @return 目标属性的值
     */
    public static Object getFieldValue(Object obj, String fieldName) {
        Object result = null;
        Field field = ReflectUtil.getField(obj, fieldName);
        if (field != null) {
            field.setAccessible(true);
            try {
                result = field.get(obj);
            } catch (IllegalArgumentException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }
        return result;
    }

    /**
     * 利用反射获取指定对象里面的指定属性
     *
     * @param obj       目标对象
     * @param fieldName 目标属性
     * @return 目标字段
     */
    private static Field getField(Object obj, String fieldName) {
        Field field = null;
        for (Class<?> clazz = obj.getClass(); clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                field = clazz.getDeclaredField(fieldName);
                break;
            } catch (NoSuchFieldException e) {
                // 这里不用做处理,子类没有该字段可能对应的父类有,都没有就返回null。
            }
        }
        return field;
    }

    /**
     * 利用反射设置指定对象的指定属性为指定的值
     *
     * @param obj        目标对象
     * @param fieldName  目标属性
     * @param fieldValue 目标值
     */
    public static void setFieldValue(Object obj, String fieldName, String fieldValue) {
        Field field = ReflectUtil.getField(obj, fieldName);
        if (field != null) {
            try {
                field.setAccessible(true);
                field.set(obj, fieldValue);
            } catch (IllegalArgumentException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            } catch (IllegalAccessException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }
    }


    /**
     * 根据文件路径 获取反射对象并执行对应方法
     *
     * @author GaoYuan
     * @date 2018/4/17 上午9:51
     */
    public static Object reflectByPath(String path) {
        try {
            //获取类名
            String className = path.substring(0, path.lastIndexOf("."));
            //获取方法名
            String methodName = path.substring(path.lastIndexOf(".") + 1, path.length());
            // 获取字节码文件对象
            Class c = Class.forName(className);

            Constructor con = c.getConstructor();
            Object obj = con.newInstance();

            // public Method getMethod(String name,Class<?>... parameterTypes)
            // 第一个参数表示的方法名,第二个参数表示的是方法的参数的class类型
            Method method = c.getMethod(methodName);
            // 调用obj对象的 method 方法
            return method.invoke(obj);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }
}
package com.hsfw.backyard.biz.security.authority.util;//package com.foruo.sc.permission.example.util;

import com.hsfw.backyard.biz.security.authority.UserPermissionAop;
import org.apache.ibatis.mapping.MappedStatement;

import java.lang.reflect.Method;


public class PermissionUtils {

    public static UserPermissionAop getPermissionByDelegate(MappedStatement mappedStatement) {
        UserPermissionAop permissionAop = null;
        try {
            String id = mappedStatement.getId();
            String className = id.substring(0, id.lastIndexOf("."));
            String methodName = id.substring(id.lastIndexOf(".") + 1, id.length());
            final Class cls = Class.forName(className);
            final Method[] method = cls.getMethods();
            for (Method me : method) {
                if (me.getName().equals(methodName) && me.isAnnotationPresent(UserPermissionAop.class)) {
                    permissionAop = me.getAnnotation(UserPermissionAop.class);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return permissionAop;
    }
}

3.mybatis

package com.hsfw.backyard.biz.security.authority.mybatis;


import com.hsfw.backyard.biz.ContextHolder;
import com.hsfw.backyard.biz.model.sys.User;
import com.hsfw.backyard.biz.security.authority.UserPermissionAop;
import com.hsfw.backyard.biz.security.authority.util.PermissionUtils;
import com.hsfw.backyard.biz.security.authority.util.ReflectUtil;
import com.hsfw.backyard.dal.mapper.model.UserDataManageDO;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.sql.Connection;
import java.util.List;
import java.util.Properties;

@Intercepts({
        @Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})
})
@Component
public class PrepareInterceptor implements Interceptor {

    private static final Logger log = LoggerFactory.getLogger(PrepareInterceptor.class);

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }


    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (log.isInfoEnabled()) {
            log.info("进入 PrepareInterceptor 拦截器...");
        }
        if (invocation.getTarget() instanceof RoutingStatementHandler) {
            RoutingStatementHandler handler = (RoutingStatementHandler) invocation.getTarget();
            StatementHandler delegate = (StatementHandler) ReflectUtil.getFieldValue(handler, "delegate");
            //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性
            MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(delegate, "mappedStatement");
            BoundSql boundSql = delegate.getBoundSql();
            ReflectUtil.setFieldValue(boundSql, "sql", permissionSql(boundSql.getSql(), mappedStatement));
        }
        return invocation.proceed();
    }

    /**
     * 权限sql包装,以及是否需要包装
     *
     * @param sql
     * @return
     */
    protected String permissionSql(String sql, MappedStatement mappedStatement) {
        UserPermissionAop permissionAop = PermissionUtils.getPermissionByDelegate(mappedStatement);
        StringBuilder sbSql = new StringBuilder(sql);
        if (permissionAop != null) {
            String id = mappedStatement.getId();
            String methodName = id.substring(id.lastIndexOf(".") + 1, id.length());
            //方法过滤,待讨论
            if (methodName.equals("countByCondition") || methodName.equals("pageByCondition")) {
                return getAppendSql(sbSql, methodName);
            }
        }
        return sbSql.toString();
    }

    /**
     * sql拼接
     *
     * @param sbSql
     * @param methodName
     * @return
     */
    public String getAppendSql(StringBuilder sbSql, String methodName) {
        //当前用户信息
        User user = ContextHolder.user();
        List<UserDataManageDO> managedUserList = user.getUserDatalist();
        String findUserList = String.valueOf(user.getId());
        for (int i = 0; i < managedUserList.size(); i++) {
            findUserList.concat("," + String.valueOf(managedUserList.get(i).getManagedUserId()));
        }
        if (methodName.equals("countByCondition")) {
            sbSql = sbSql.append(" and operate_user_id in (" + findUserList + ")  ");
        } else {
            sbSql = new StringBuilder("select * from (").append(sbSql).append(" ) temp where temp.operate_user_id in (" + findUserList + ")  ");
        }
        return sbSql.toString();
    }


}
package com.hsfw.backyard.biz.security.authority.mybatis;


import com.hsfw.backyard.biz.security.authority.UserPermissionAop;
import com.hsfw.backyard.biz.security.authority.util.PermissionUtils;
import com.hsfw.backyard.biz.security.authority.util.ReflectUtil;
import org.apache.ibatis.executor.resultset.ResultSetHandler;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Properties;


@Intercepts({
        @Signature(type = ResultSetHandler.class, method = "handleResultSets", args = {Statement.class})
})
@Component
public class ResultInterceptor implements Interceptor {
    /**
     * 日志
     */
    private static final Logger log = LoggerFactory.getLogger(ResultInterceptor.class);

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        if (log.isInfoEnabled()) {
            log.info("进入 ResultInterceptor 拦截器...");
        }
        ResultSetHandler resultSetHandler1 = (ResultSetHandler) invocation.getTarget();
        //通过java反射获得mappedStatement属性值
        //可以获得mybatis里的resultype
        MappedStatement mappedStatement = (MappedStatement) ReflectUtil.getFieldValue(resultSetHandler1, "mappedStatement");
        //获取切面对象
        UserPermissionAop permissionAop = PermissionUtils.getPermissionByDelegate(mappedStatement);
        //执行请求方法,并将所得结果保存到result中
        Object result = invocation.proceed();
        if (permissionAop != null) {
            if (result instanceof ArrayList) {
                ArrayList resultList = (ArrayList) result;
                for (int i = 0; i < resultList.size(); i++) {
                    Object oi = resultList.get(i);
                    Class c = oi.getClass();
                    Class[] types = {String.class};
                    Method method = c.getMethod("setRegionCd", types);
                    // 调用obj对象的 method 方法
                    method.invoke(oi, "");
                    if (log.isInfoEnabled()) {
                        log.info("数据权限处理【过滤结果】...");
                    }
                }
            }
        }
        return result;
    }


}

 

posted @ 2019-01-02 14:34  可乐998  阅读(1199)  评论(0编辑  收藏  举报