mybatis 拦截sql修改

需求:需要进行不停机数据库迁移,进行数据库双写,先将数据同时写入新老库(也可以在数据库层面进行主从复制,但是运维和dba无法配合,否定);
在原mapper接口中加入新注解,扫描该注解,获取完整sql,通过plusar发送消息,同步执行sql语句写入新数据库,进行同步;
mybatis拦截器
package com.zhaopin.zhiq.doublewrite;

import com.alibaba.fastjson.JSONObject;
import com.zhaopin.platzqaserver.pulsar.PulsarProducer;
import com.zhaopin.platzqaserver.utils.LogUtil;
import com.zhaopin.platzqaserver.utils.SpringUtils;
import com.zhaopin.zhiq.doublewrite.constant.DoubleWriteConstant;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.*;

import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;

//数据库双写拦截器
@Intercepts({@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})})
public class DoubleWriteInterceptor implements Interceptor {

    @Override
    public Object intercept(Invocation invocation) throws Throwable {

        Object result = invocation.proceed();
        if ((int) result == 0) {
            return result;
        }
        String sql = "";
        try {
            Object[] args = invocation.getArgs();
            MappedStatement ms = (MappedStatement) args[0];
          	//获取请求参数和调用的mapper方法
            if (hasMapperName(ms)) {
                Object parameter = args[1];
                sql = DoubleWriteUtil.getCompleteSql(ms.getConfiguration(), ms.getBoundSql(parameter));
                JSONObject message = new JSONObject();
                if (SqlCommandType.INSERT.equals(ms.getSqlCommandType())) {
                    //批量insert时的sql处理
                    if (parameter instanceof MapperMethod.ParamMap) {
                        MapperMethod.ParamMap parameterMap = (MapperMethod.ParamMap) parameter;
                        Object params = parameterMap.get("param1");
                        if (params instanceof ArrayList) {
                            ArrayList paramsList = (ArrayList) params;
                            List<Long> ids = (List<Long>) paramsList.stream().map(param -> Long.parseLong(String.valueOf(DoubleWriteUtil.getFieldValue(param, "id")))).collect(Collectors.toList());
                            sql = DoubleWriteUtil.createDoubleWriteBatchSql(sql, ids);
                        }
                    } else {
                        Object id = DoubleWriteUtil.getFieldValue(parameter, "id");
                        sql = DoubleWriteUtil.createDoubleWriteSimpleSql(sql, id);
                    }
                }
                message.put("sql", sql);
                message.put("sqlCommandType", ms.getSqlCommandType());
                PulsarProducer.send(DoubleWriteConstant.DB_DOUBLE_WRITE_TOPIC, message.toJSONString());
            }
        } catch (Exception e) {
            LogUtil.error("failed to double write error , sql " + sql, e);
        }
        return result;
    }
		//校验拦截的方法是否是需要被双写的sql
    private boolean hasMapperName(MappedStatement ms) {
        DoubleWritePrepareMapper doubleWritePrepareMapper = SpringUtils.getBean("doubleWritePrepareMapper", DoubleWritePrepareMapper.class);
        Set<String> mapperNames = doubleWritePrepareMapper.getMapperNames();
        String mapperName = ms.getId();
        mapperName = mapperName.substring(0, mapperName.lastIndexOf("."));
        return mapperNames.contains(mapperName);
    }

    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    @Override
    public void setProperties(Properties properties) {

    }

}

获取原始sql
package com.zhaopin.zhiq.doublewrite;

import com.zhaopin.platzqaserver.utils.LogUtil;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.type.TypeHandlerRegistry;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.text.DateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Locale;

/**
 * 获取原始sql修改
 */
public class DoubleWriteUtil {

    private DoubleWriteUtil(){}

    public static String createDoubleWriteSimpleSql(String oldSql, Object id) {
        oldSql = oldSql.replaceFirst("\\(","(id, ");

        if (oldSql.contains("values")) {
            return oldSql.replace("values (", "values ("+ id + ", ");
        }

        if (oldSql.contains("VALUES")) {
            return oldSql.replace("VALUES (", "VALUES ("+ id + ", ");
        }

        return "";
    }

    public static String createDoubleWriteBatchSql(String oldSql, List<Long> ids) {
        oldSql = oldSql.replaceFirst("\\(","(id, ");
        StringBuffer sql = new StringBuffer();

        String sqlHead;
        if (oldSql.contains("values")) {
            sqlHead = oldSql.substring(0, oldSql.indexOf("values"));
            oldSql = oldSql.substring(oldSql.indexOf("values"));
        } else {
            sqlHead = oldSql.substring(0, oldSql.indexOf("VALUES"));
            oldSql = oldSql.substring(oldSql.indexOf("VALUES"));
        }

        sql.append(sqlHead);
        for (int i = 0; i< ids.size(); i++) {
            Long id = ids.get(i);
            String middleSql = oldSql.replaceFirst("\\(", "(" + id + ", ");
            if (i < ids.size() - 1) {
                middleSql = middleSql.substring(0, middleSql.indexOf(")") + 1);
                oldSql = oldSql.substring(oldSql.indexOf(")") + 1);
            }
            sql.append(middleSql);
        }
        return sql.toString();
    }

    /**
     * 获取完整sql
     * @param configuration
     * @param boundSql
     * @return
     */
    public static String getCompleteSql(Configuration configuration, BoundSql boundSql) {
        Object parameterObject = boundSql.getParameterObject();
        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
        //替换空格、换行、tab缩进等
        String sql = boundSql.getSql().replaceAll("[\\s]+", " ");
        if (parameterMappings.size() > 0 && parameterObject != null) {
            TypeHandlerRegistry typeHandlerRegistry = configuration.getTypeHandlerRegistry();
            if (typeHandlerRegistry.hasTypeHandler(parameterObject.getClass())) {
                sql = sql.replaceFirst("\\?", getParameterValue(parameterObject));
            } else {
                MetaObject metaObject = configuration.newMetaObject(parameterObject);
                for (ParameterMapping parameterMapping : parameterMappings) {
                    String propertyName = parameterMapping.getProperty();
                    if (metaObject.hasGetter(propertyName)) {
                        Object obj = metaObject.getValue(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    } else if (boundSql.hasAdditionalParameter(propertyName)) {
                        Object obj = boundSql.getAdditionalParameter(propertyName);
                        sql = sql.replaceFirst("\\?", getParameterValue(obj));
                    }
                }
            }
        }
        return sql;
    }

    private static String getParameterValue(Object obj) {
        String value;
        if (obj instanceof String) {
            value = "'" + obj + "'";
        } else if (obj instanceof Date) {
            DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.DEFAULT, DateFormat.DEFAULT, Locale.CHINA);
            value = "'" + formatter.format(new Date()) + "'";
        } else {
            if (obj != null) {
                value = obj.toString();
            } else {
                value = "";
            }
        }
        return value.replace("$", "\\$");
    }

    /**
     * 反射获取字段值
     * @return
     */
    public static Object getFieldValue(Object parameter, String fieldName) {
        try {
            List<MetaClass> metaClasses = new ArrayList<>();
            getFields(parameter.getClass(), metaClasses, fieldName);
            if (metaClasses.size() > 0) {
                String field = metaClasses.get(0).field;
                Method method = metaClasses.get(0).superClass.getMethod("get" + captureName(field));
                return method.invoke(parameter);
            }
        } catch (Exception e) {
            LogUtil.error("failed to get", e);
        }
        return null;
    }
		//参数可能从父类继承
    private static void getFields(Class<?> clazz, List<MetaClass> metaClasses, String fieldName) {
        Field[] declaredFields = clazz.getDeclaredFields();
        for (Field field : declaredFields) {
            field.setAccessible(true);
            if (fieldName.equals(field.getName())) {
                MetaClass metaClass = new MetaClass();
                metaClass.superClass = clazz;
                metaClass.field = fieldName;
                metaClass.paramType = field.getType();
                metaClasses.add(metaClass);
                findField = true;
                return;
            }
        }
        if (clazz.getSuperclass() != null) {
            getFields(clazz.getSuperclass(), metaClasses, fieldName);
        }
    }

    private static class MetaClass {
				//父类
        private Class<?> superClass;
				//参数
        private String field;
				//参数类型
        private Class<?> paramType;
    }
  	//字符串首字母大写
  	private static String captureName(String str) {
        // 进行字母的ascii编码前移,效率要高于截取字符串进行转换的操作
        char[] cs = str.toCharArray();
        cs[0] -= (cs[0] > 96 && cs[0] < 123) ? 32 : 0;
        return String.valueOf(cs);
    }

}

获取被双写注解修饰的Repository
package com.zhaopin.zhiq.doublewrite;

import com.zhaopin.platzqaserver.utils.LogUtil;
import com.zhaopin.zhiq.annotation.DoubleWrite;
import org.mybatis.spring.mapper.MapperFactoryBean;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanPostProcessor;
import org.springframework.stereotype.Component;

import java.lang.annotation.Annotation;
import java.util.*;

/**
 * 获取被双写注解修饰的Repository
 */
@Component
public class DoubleWritePrepareMapper implements BeanPostProcessor {

    /**
     * key: Repository 接口的Name
     * value: Repository 接口的Class对象
     */
    private Map<String, Class<?>> mappers;

    @Override
    public Object postProcessBeforeInitialization(Object bean, String beanName) throws BeansException {
        return bean;
    }

    @Override
    public Object postProcessAfterInitialization(Object bean, String beanName) throws BeansException {
        if (mappers == null) {
            mappers = new HashMap<>(32);
        }
        try {
            MapperFactoryBean mapper;
            if (bean instanceof MapperFactoryBean) {
                mapper = (MapperFactoryBean) bean;
                Class mapperInterface = mapper.getMapperInterface();
                Annotation annotation = mapperInterface.getDeclaredAnnotation(DoubleWrite.class);
                if (annotation != null) {
                    mappers.put(mapperInterface.getName(), mapperInterface);
                }
            }
        } catch (Exception e) {
            LogUtil.error("failed to initialize double read mappers : ", e);
        }
        return bean;
    }

    public Set<String> getMapperNames() {
        if (this.mappers == null) {
            return null;
        }
        return this.mappers.keySet();
    }
}
原有执行sql
@ZhiqUserDBRepository //原有数据源
@DoubleWrite //需要双写的数据源
public interface IdentityFavorRepository {
    @Insert("insert into zhiq_identity_favor (uiid, uid, favored_uiid, favored_uid) " +
            "values (#{uiid}, #{uid}, #{favoredUiid}, #{favoredUid})" +
            "ON CONFLICT (uiid, favored_uiid) DO NOTHING")
    @Options(useGeneratedKeys = true, keyProperty="id", keyColumn = "id")
    boolean insert(IdentityFavor identityFavor);
}
posted @ 2022-02-11 17:06  748573200000  阅读(1020)  评论(0编辑  收藏  举报