Mybatis自定义拦截器实现自动记录操作人信息

1.前言

Mybatis有四大核心对象,分别是Executor,StatementHandler,ParamterHandler,ResultSetHandler。

在很多时候,对表中的数据都需要记录插入时间,修改时间,插入人和修改人,若每次都在插入或修改代码中去设置这些信息,就显得有些冗余。那么此时可以通过Mybatis提供的拦截器加上我们自定义的拦截器实现对在需要记录的操作人信息sql执行前,自动补充这些信息,也就是所谓的对Mybatis的核心对象进行增强。这里只拦截Executor对象,给更新的sql语句动态的增加参数。

2.实战演练

下面的示例演示了对订单信息进行新增和修改:(SpringBoot版本:2.7.2,Mybatis版本:2.2.2,数据库:MySQL)

第一步:创建订单表

CREATE TABLE `b_order` (
  `order_id` int(11) NOT NULL,
  `order_no` varchar(100) DEFAULT NULL,
  `order_status` int(2) DEFAULT NULL,
  `insert_by` varchar(100) DEFAULT NULL,
  `update_by` varchar(100) DEFAULT NULL,
  `insert_time` datetime DEFAULT NULL,
  `update_time` datetime DEFAULT NULL,
  PRIMARY KEY (`order_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8;

第二步:创建基础类,包含基础的操作人信息

package com.zxh.test.entity;

import lombok.Data;
import lombok.Getter;
import lombok.Setter;

import java.util.Date;

@Setter
@Getter
public class BaseFieldDO {
    private String insertBy;
    private String updateBy;
    private Date insertTime;
    private Date updateTime;
}

第三步,创建订单类,继承基础类

package com.zxh.test.entity;

import lombok.Data;
import lombok.experimental.Accessors;

@Data
@Accessors(chain = true)
public class Order extends BaseFieldDO {
    private Long orderId;
    private String orderNo;
    private Integer orderStatus;

}

第四步,创建拦截器,实现Mybatis的拦截器

参数说明:

@Intercepts:标识该类是一个拦截器;
@Signature:指明自定义拦截器需要拦截哪一个类型,哪一个方法;
	type:对应四种类型中的一种;
	method:对应接口中的哪个方法;
	args:对应哪一个方法参数类型(因为可能存在重载方法); 

对于处理的逻辑,需要根据实际情况进行对应的修改。invocation.proceed()是拦截器是否放行,还通过反射的方式给对象属性设置值。

虽然看起来内容很多,但我一向喜欢把通用的方法给抽成方法,每个方法都实现单一的功能,那么具体的处理逻辑就不会太多。

package com.zxh.test.interceptor;

import com.zxh.test.entity.BaseFieldDO;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.util.*;

/**
 * @author zhongyushi
 */
@Component
@Intercepts({@Signature(method = "update", type = Executor.class, args = {MappedStatement.class, Object.class})})
@Slf4j
public class MybatisAuditDataInterceptor implements Interceptor {
    //MapperMethod.ParamMap类型
    private static final String[] MAPPER_METHOD_PARAM_MAP = {"record", "collection"};


    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object[] args = invocation.getArgs();
        SqlCommandType sqlCommandType = null;
        for (Object object : args) {
            // 从MappedStatement参数中获取到操作类型
            if (object instanceof MappedStatement) {
                MappedStatement ms = (MappedStatement) object;
                sqlCommandType = ms.getSqlCommandType();
                continue;
            }
            Integer mapContainsKeyIndex = -1;
            // 判断参数是否是BaseFieldDO类型
            // 一个参数
            if (this.setFields(sqlCommandType, object)) continue;

            if (object instanceof MapperMethod.ParamMap) {
                @SuppressWarnings("unchecked")
                MapperMethod.ParamMap<Object> parasMap = (MapperMethod.ParamMap<Object>) object;

                for (int i = 0; i < MAPPER_METHOD_PARAM_MAP.length; i++) {
                    if (parasMap.containsKey(MAPPER_METHOD_PARAM_MAP[i])) {
                        mapContainsKeyIndex = i;
                        break;
                    }
                }
                if (mapContainsKeyIndex == -1) continue;
                Object paraObject = parasMap.get(MAPPER_METHOD_PARAM_MAP[mapContainsKeyIndex]);
                if (mapContainsKeyIndex == 0) {
                    //兼容MyBatis的updateByExampleSelective(record, example);
                    if (this.setFields(sqlCommandType, paraObject)) continue;
                } else if (mapContainsKeyIndex == 1) {
                    //批量操作
                    if (paraObject instanceof ArrayList) {
                        @SuppressWarnings("unchecked")
                        ArrayList<Object> parasList = (ArrayList<Object>) paraObject;
                        //批量新增或修改
                        if (SqlCommandType.INSERT == sqlCommandType) {
                            if (this.setFieldsFromList(parasList)) continue;
                        }
                        if (SqlCommandType.UPDATE == sqlCommandType) {
                            if (this.setFieldsFromList(parasList)) continue;
                        }
                    }
                }
            }
        }
        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {
    }

    /**
     * @param sqlCommandType sql命令类型
     * @param paramObj       对象
     * @return
     */
    private Boolean setFields(SqlCommandType sqlCommandType, Object paramObj) {
        Boolean needContinue = false;
        if (paramObj instanceof BaseFieldDO) {
            if (SqlCommandType.INSERT == sqlCommandType) {
                this.setFields(paramObj, 1);
                needContinue = true;
            }
            if (SqlCommandType.UPDATE == sqlCommandType) {
                this.setFields(paramObj, 2);
                needContinue = true;
            }
        }
        return needContinue;

    }

    /**
     * 根据类型设置对象的属性
     *
     * @param parameter
     * @param type      1新增,2修改
     */
    private void setFields(Object parameter, Integer type) {
        //获取操作的用户名,按实际要求获取,这里只做演示
        String loginName = "zhangsan";
        Field[] fields = getAllFields(parameter);
        try {
            for (Field field : fields) {
                if (type == 1) {
                    this.setPropertyByInsert(field, parameter, loginName);
                } else if (type == 2) {
                    this.setPropertyByUpdate(field, parameter, loginName);
                }
            }
        } catch (Exception e) {
            log.error("failed to set fields, 原因:{}", e);
        }
    }

    /**
     * 循环设置集合对象的属性
     *
     * @param parasList
     */
    private Boolean setFieldsFromList(ArrayList<Object> parasList) {
        Boolean needContinue = false;
        if (parasList != null && parasList.size() > 0) {
            needContinue = true;
            parasList.stream().forEach(obj -> {
                if (obj instanceof BaseFieldDO) {
                    this.setFields(obj, 1);
                }
            });
        }
        return needContinue;
    }

    /**
     * 添加时设置属性值
     *
     * @param field
     * @param parameter
     * @param userName
     * @throws InvocationTargetException
     * @throws IllegalAccessException
     */
    private void setPropertyByInsert(Field field, Object parameter, String userName) throws IllegalAccessException {
        Date insertTime = new Date();
        // 注入创建人
        this.setFieldByName(field, "insertBy", parameter, userName);
        //注入创建时间
        this.setFieldByName(field, "insertTime", parameter, insertTime);
        this.setPropertyByUpdate(field, parameter, userName, insertTime);
    }

    /**
     * 修改时设置属性值
     *
     * @param field
     * @param parameter
     * @param userName
     * @throws IllegalAccessException
     */
    private void setPropertyByUpdate(Field field, Object parameter, String userName) throws IllegalAccessException {
        Date updateTime = new Date();
        this.setPropertyByUpdate(field, parameter, userName, updateTime);
    }

    /**
     * 修改时设置属性值
     *
     * @param field
     * @param parameter
     * @param userName
     * @param updateTime
     * @throws IllegalAccessException
     */
    private void setPropertyByUpdate(Field field, Object parameter, String userName, Date updateTime) throws IllegalAccessException {
        // 注入修改人
        this.setFieldByName(field, "updateBy", parameter, userName);
        //注入修改时间
        this.setFieldByName(field, "updateTime", parameter, updateTime);
    }

    /**
     * 通过反射给对象的属性赋值
     *
     * @param field
     * @param fieldName
     * @param parameter
     * @param fieldValue
     * @throws IllegalAccessException
     */
    public void setFieldByName(Field field, String fieldName, Object parameter, Object fieldValue) throws IllegalAccessException {
        if (fieldName.equals(field.getName())) {
            field.setAccessible(true);
            field.set(parameter, fieldValue);
            field.setAccessible(false);
        }
    }

    /**
     * 获取类的所有属性,包括父类
     *
     * @param object
     * @return
     */
    public Field[] getAllFields(Object object) {
        Class<?> clazz = object.getClass();
        List<Field> fieldList = new ArrayList<>();
        while (clazz != null) {
            fieldList.addAll(new ArrayList<>(Arrays.asList(clazz.getDeclaredFields())));
            clazz = clazz.getSuperclass();
        }
        Field[] fields = new Field[fieldList.size()];
        fieldList.toArray(fields);
        return fields;
    }

}

第五步,编写基本的sql,包括单个添加,根据主键修改,批量新增或修改

dao类如下图

其中xml如下,其他类在此略

<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="com.zxh.test.dao.OrderDao">

    <resultMap id="BaseResultMap" type="com.zxh.test.entity.Order">
        <!--@Table b_order-->
        <result property="orderId" column="order_id" jdbcType="INTEGER"/>
        <result property="orderNo" column="order_no" jdbcType="INTEGER"/>
        <result property="orderStatus" column="order_status" jdbcType="VARCHAR"/>
        <result property="insertBy" column="insert_by" jdbcType="INTEGER"/>
        <result property="updateBy" column="update_by" jdbcType="VARCHAR"/>
        <result property="insertTime" column="insert_time" jdbcType="TIMESTAMP"/>
        <result property="updateTime" column="update_time" jdbcType="TIMESTAMP"/>
    </resultMap>


    <!--新增所有列-->
    <insert id="insert" keyProperty="orderId" useGeneratedKeys="true">
        insert into b_order(order_id, order_no, order_status, insert_by, update_by, insert_time, update_time)
        values (#{orderId}, #{orderNo}, #{orderStatus}, #{insertBy}, #{updateBy}, #{insertTime}, #{updateTime})
    </insert>

    <!--通过主键修改数据-->
    <update id="update">
        update db2020.b_order
        <set>
            <if test="orderNo != null">
                order_no = #{orderNo},
            </if>
            <if test="orderStatus != null and orderStatus != ''">
                order_status = #{orderStatus},
            </if>
            <if test="insertBy != null">
                insert_by = #{insertBy},
            </if>
            <if test="updateBy != null and updateBy != ''">
                update_by = #{updateBy},
            </if>
            <if test="insertTime != null">
                insert_time = #{insertTime},
            </if>
            <if test="updateTime != null">
                update_time = #{updateTime},
            </if>
        </set>
        where order_id = #{orderId}
    </update>


    <insert id="insertBatch">
        insert into b_order(order_id,order_no, order_status, insert_by, update_by, insert_time, update_time) values
        <foreach collection="list" item="item" separator=",">
            (#{item.orderId},#{item.orderNo}, #{item.orderStatus}, #{item.insertBy}, #{item.updateBy},
            #{item.insertTime}, #{item.updateTime})
        </foreach>
    </insert>

    <update id="updateBatch">
        update b_order t
        <trim prefix="set" suffixOverrides=",">
            <trim prefix="order_no  = case" suffix="end ,">
                <foreach collection="list" item="item">
                    <if test="item.orderNo != null ">
                        when t.order_id = #{item.orderId} then #{item.orderNo}
                    </if>
                </foreach>
            </trim>
            <trim prefix="update_time  = case" suffix="end ,">
                <foreach collection="list" item="item">
                    <if test="item.updateTime != null ">
                        when t.order_id = #{item.orderId} then #{item.updateTime}
                    </if>
                </foreach>
            </trim>
        </trim>
        where order_id in
        <foreach collection="list" item="item" open="(" close=")" separator=",">
            #{item.orderId}
        </foreach>
    </update>

</mapper>

第六步,编写测试类进行测试

以上几种方式都可以将操作人信息自动记录到表中,无需手动进行人工干预。

posted @ 2022-12-08 18:06  钟小嘿  阅读(1090)  评论(0编辑  收藏  举报