君临-行者无界

导航

一文彻底搞懂mybatis

mybatis执行原理

核心类

  • SqlSessionFactoryBuilder 
    主要用于读取mybatis配置文件,创建SqlSessionFactory对象
    
  • DefaultSqlSessionFactory
    SqlSessionFactory接口的默认实现类,持有Configuration,用于获取SqlSession
    
  • Configuration
    持有MapperRegistry,并且缓存stateMentId到MappedStatement的映射关系
    
  • MapperRegistry
    持有Configuration引用,并且缓存Class到具体MapperProxyFactory的映射
    
  • MapperProxyFactory
    创建mapper代理对象的工厂,构造方法中需传入对应mapper接口,并且缓存Method到MapperMethod的映射
    
  • MapperProxy
    实现InvocationHandler接口,invoke方法中调用sqlsession完成数据库操作
    
  • MapperMethod
    持有具体的sql语句和方法相关参数,执行对应的sql语句
    
  • DefaultSqlSession
    持有Configuration引用和Executor引用,调用Executor执行sql
    
  • Executor
    执行器,默认实现类是SimpleExecutor,在其内完成数据库连接的初始化,完成jdbc操作
    

初始化流程

  • SqlSessionFactoryBuilder加载mybatis配置,创建Configuration对象,完成MapperRegistry等的初始化工作
  • 调用build方法创建DefaultSqlSessionFactory。

调用流程

  • 调用SqlSessionFactory的openSession方法获取SqlSession对象

  • 调用SqlSession的getMapper方法获取代理对象,内部调用逻辑如下

    DefaultSqlSession 
        public <T> T getMapper(Class<T> type) {
            return configuration.<T>getMapper(type, this);
        }
    Configuration
      public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        return mapperRegistry.getMapper(type, sqlSession);
      }
    MapperRegistry
      public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        final MapperProxyFactory<T> mapperProxyFactory = (MapperProxyFactory<T>) knownMappers.get(type);
        if (mapperProxyFactory == null) {
          throw new BindingException("Type " + type + " is not known to the MapperRegistry.");
        }
        try {
          return mapperProxyFactory.newInstance(sqlSession);
        } catch (Exception e) {
          throw new BindingException("Error getting mapper instance. Cause: " + e, e);
        }
      }
    MapperProxyFactory
        public T newInstance(SqlSession sqlSession) {
        final MapperProxy<T> mapperProxy = new MapperProxy<T>(sqlSession, mapperInterface, methodCache);
        return newInstance(mapperProxy);
      }
      protected T newInstance(MapperProxy<T> mapperProxy) {
        return (T) Proxy.newProxyInstance(mapperInterface.getClassLoader(), new Class[] { mapperInterface }, 		mapperProxy);
      }  
    
  • 调用获取的代理对象的接口方法,具体调用逻辑如下

    MapperProxy
      public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
          if (Object.class.equals(method.getDeclaringClass())) {
            return method.invoke(this, args);
          } else if (isDefaultMethod(method)) {
            return invokeDefaultMethod(proxy, method, args);
          }
        } catch (Throwable t) {
          throw ExceptionUtil.unwrapThrowable(t);
        }
        final MapperMethod mapperMethod = cachedMapperMethod(method);
        return mapperMethod.execute(sqlSession, args);
      }
    MapperMethod
    public Object execute(SqlSession sqlSession, Object[] args) {
        Object result;
        switch (command.getType()) {
          case INSERT: {
        	Object param = method.convertArgsToSqlCommandParam(args);
            result = rowCountResult(sqlSession.insert(command.getName(), param));
            break;
          }
          //省略update、select等代码
          default:
            throw new BindingException("Unknown execution method for: " + command.getName());
        }
        if (result == null && method.getReturnType().isPrimitive() && !method.returnsVoid()) {
          throw new BindingException("Mapper method '" + command.getName() 
              + " attempted to return null from a method with a primitive return type (" + method.getReturnType() + ").");
        }
        return result;
      }
    DefaultSqlSession
      public int insert(String statement, Object parameter) {
        return update(statement, parameter);
      }
      public int update(String statement, Object parameter) {
        try {
          dirty = true;
          MappedStatement ms = configuration.getMappedStatement(statement);
          return executor.update(ms, wrapCollection(parameter));
        } catch (Exception e) {
          throw ExceptionFactory.wrapException("Error updating database.  Cause: " + e, e);
        } finally {
          ErrorContext.instance().reset();
        }
      }
    executor.update之后即为原始的jdbc逻辑  
    

手写一个简版的mybatis

定义SqlSessionFactoryBuilder、SqlSessionFactory及Configuration类

public class SqlSessionFactoryBuilder {

    public SqlSessionFactory build(String fileName) {

        InputStream inputStream = SqlSessionFactoryBuilder.class.getClassLoader().getResourceAsStream(fileName);
        return build(inputStream);
    }

    public SqlSessionFactory build(InputStream inputStream) {
        try {
            Configuration.PROPS.load(inputStream);
        } catch (IOException e) {
            e.printStackTrace();
        }
        return new DefaultSqlSessionFactory(new Configuration());
    }
}
public class DefaultSqlSessionFactory implements SqlSessionFactory {

    private final Configuration configuration;

    public DefaultSqlSessionFactory(Configuration configuration) {
        this.configuration = configuration;
        loadMappersInfo(Configuration.getProperty(Constant.MAPPER_LOCATION).replaceAll("\\.", "/"));
    }

    @Override
    public SqlSession openSession() {
        SqlSession session = new DefaultSqlSession(this.configuration);
        return session;
    }

    private void loadMappersInfo(String dirName) {
        URL resources = DefaultSqlSessionFactory.class.getClassLoader().getResource(dirName);
        File mappersDir = new File(resources.getFile());
        if (mappersDir.isDirectory()) {
            // 显示包下所有文件
            File[] mappers = mappersDir.listFiles();
            if (CommonUtil.isNotEmpty(mappers)) {
                for (File file : mappers) {
                    // 对文件夹继续递归
                    if (file.isDirectory()) {
                        loadMappersInfo(dirName + "/" + file.getName());

                    } else if (file.getName().endsWith(Constant.MAPPER_FILE_SUFFIX)) {

                        // 只对XML文件解析
                        XmlUtil.readMapperXml(file, this.configuration);
                    }

                }

            }
        }

    }

}
public class Configuration {
    /**
     * 配置项
     */
    public static Properties PROPS = new Properties();

    /**
     * mapper代理注册器
     */
    protected final MapperRegistry mapperRegistry = new MapperRegistry();

    /**
     * mapper文件的select/update语句的id和SQL语句属性
     **/
    protected final Map<String, MappedStatement> mappedStatements = new HashMap<>();

    /**
     * 获取字符型属性(默认值为空字符串)
     *
     * @param key
     * @return
     */
    public static String getProperty(String key) {
        return getProperty(key, "");
    }

    /**
     * 获取字符型属性(可指定默认值)
     *
     * @param key
     * @param defaultValue 默认值
     * @return
     */
    public static String getProperty(String key, String defaultValue) {
        return PROPS.containsKey(key) ? PROPS.getProperty(key) : defaultValue;
    }

    /**
     * addMapper
     */
    public <T> void addMapper(Class<T> type) {
        this.mapperRegistry.addMapper(type);
    }

    /**
     * getMapper
     */
    public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        return this.mapperRegistry.getMapper(type, sqlSession);
    }

    /**
     * addMappedStatement
     */
    public void addMappedStatement(String key, MappedStatement mappedStatement) {
        this.mappedStatements.put(key, mappedStatement);
    }

    /**
     * 获取MappedStatement
     */
    public MappedStatement getMappedStatement(String id) {
        return this.mappedStatements.get(id);
    }

}

定义Conguration中依赖的内置对象,其中MapperProxy是mybatis生成代理对象的关键

public class MapperRegistry {
    /**
     * the knownMappers
     */
    private final Map<Class<?>, MapperProxyFactory<?>> knownMappers = new HashMap<>();

    /**
     * 注册代理工厂
     *
     * @param type
     */
    public <T> void addMapper(Class<T> type) {
        this.knownMappers.put(type, new MapperProxyFactory<T>(type));
    }

    /**
     * 获取代理工厂实例
     *
     * @param type
     * @param sqlSession
     * @return
     */
    @SuppressWarnings("unchecked")
    public <T> T getMapper(Class<T> type, SqlSession sqlSession) {
        MapperProxyFactory<T> mapperProxyFactory = (MapperProxyFactory<T>) this.knownMappers.get(type);

        return mapperProxyFactory.newInstance(sqlSession);
    }
}
public class MapperProxyFactory<T> {

    private final Class<T> mapperInterface;

    /**
     * 初始化方法
     *
     * @param mapperInterface
     */
    public MapperProxyFactory(Class<T> mapperInterface) {
        this.mapperInterface = mapperInterface;
    }

    /**
     * 根据sqlSession创建一个代理
     *
     * @param sqlSession
     * @return
     * @see
     */
    public T newInstance(SqlSession sqlSession) {
        MapperProxy<T> mapperProxy = new MapperProxy<T>(sqlSession, this.mapperInterface);
        return newInstance(mapperProxy);
    }

    /**
     * 根据mapper代理返回实例
     *
     * @param mapperProxy
     * @return
     * @see
     */
    @SuppressWarnings("unchecked")
    protected T newInstance(MapperProxy<T> mapperProxy) {
        return (T) Proxy.newProxyInstance(this.mapperInterface.getClassLoader(), new Class[]{this.mapperInterface},
                mapperProxy);
    }
}
public class MapperProxy<T> implements InvocationHandler, Serializable {

    private static final long serialVersionUID = -7861758496991319661L;

    private final SqlSession sqlSession;

    private final Class<T> mapperInterface;

    /**
     * 构造方法
     *
     * @param sqlSession
     * @param mapperInterface
     */
    public MapperProxy(SqlSession sqlSession, Class<T> mapperInterface) {
        this.sqlSession = sqlSession;
        this.mapperInterface = mapperInterface;
    }

    /**
     * 真正的执行方法
     *
     * @param proxy
     * @param method
     * @param args
     * @return
     * @throws Throwable
     */
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        try {
            if (Object.class.equals(method.getDeclaringClass())) {
                return method.invoke(this, args);
            }
            return this.execute(method, args);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;

    }

    /**
     * 根据SQL指令执行对应操作
     *
     */
    private Object execute(Method method, Object[] args) {
        String statementId = this.mapperInterface.getName() + "." + method.getName();
        MappedStatement ms = this.sqlSession.getConfiguration().getMappedStatement(statementId);

        Object result = null;
        switch (ms.getSqlCommandType()) {
            case SELECT: {
                Class<?> returnType = method.getReturnType();
                // 如果返回的是list,应该调用查询多个结果的方法,否则只要查单条记录
                if (Collection.class.isAssignableFrom(returnType)) {
                    //ID为mapper类全名+方法名
                    result = sqlSession.selectList(statementId, args);
                } else {
                    result = sqlSession.selectOne(statementId, args);
                }
                break;
            }
            case UPDATE: {
                sqlSession.update(statementId, args);
                break;
            }
            default: {
                break;
            }
        }
        return result;
    }

}

定义Conguration中依赖的内置对象,其中MapperProxy是mybatis生成代理对象的关键

public class DefaultSqlSession implements SqlSession {

    private final Configuration configuration;

    private final Executor executor;

    /**
     * 默认构造方法
     *
     * @param configuration
     */
    public DefaultSqlSession(Configuration configuration) {
        this.configuration = configuration;
        this.executor = new SimpleExecutor(configuration);
    }

    /**
     * 查询带条记录
     */
    @Override
    public <T> T selectOne(String statementId, Object parameter) {
        List<T> results = this.<T>selectList(statementId, parameter);

        return CommonUtil.isNotEmpty(results) ? results.get(0) : null;
    }

    /**
     * 查询多条记录
     *
     * @param statementId ID为mapper类全名+方法名
     * @param parameter   参数列表
     * @return
     */
    @Override
    public <E> List<E> selectList(String statementId, Object parameter) {
        MappedStatement mappedStatement = this.configuration.getMappedStatement(statementId);
        return this.executor.<E>doQuery(mappedStatement, parameter);
    }

    /**
     * 更新
     *
     * @param statementId
     * @param parameter
     */
    @Override
    public void update(String statementId, Object parameter) {
        MappedStatement mappedStatement = this.configuration.getMappedStatement(statementId);
        this.executor.doUpdate(mappedStatement, parameter);
    }

    @Override
    public void insert(String statementId, Object parameter) {
        //TODO 待实现
    }

    /**
     * 获取Mapper
     */
    @Override
    public <T> T getMapper(Class<T> type) {
        return configuration.<T>getMapper(type, this);
    }

    /**
     * getConfiguration
     *
     * @return
     */
    @Override
    public Configuration getConfiguration() {
        return this.configuration;
    }

}
public class SimpleExecutor implements Executor {
    /**
     * 数据库连接
     */
    private static Connection connection;

    static {
        initConnect();
    }

    private Configuration conf;

    /**
     * 初始化方法
     *
     * @param configuration
     */
    public SimpleExecutor(Configuration configuration) {
        this.conf = configuration;
    }

    /**
     * 静态初始化数据库连接
     *
     * @return
     */
    private static void initConnect() {

        String driver = Configuration.getProperty(Constant.DB_DRIVER_CONF);
        String url = Configuration.getProperty(Constant.DB_URL_CONF);
        String username = Configuration.getProperty(Constant.DB_USERNAME_CONF);
        String password = Configuration.getProperty(Constant.db_PASSWORD);

        try {
            Class.forName(driver);
            connection = DriverManager.getConnection(url, username, password);
            System.out.println("数据库连接成功");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 根据参数查询数据库
     *
     * @param ms
     * @param parameter
     * @return
     */
    @Override
    public <E> List<E> doQuery(MappedStatement ms, Object parameter) {

        try {
            //1.获取数据库连接
            Connection connection = getConnect();

            //2.获取MappedStatement信息,里面有sql信息
            MappedStatement mappedStatement = conf.getMappedStatement(ms.getSqlId());

            //3.实例化StatementHandler对象,
            StatementHandler statementHandler = new SimpleStatementHandler(mappedStatement);

            //4.通过StatementHandler和connection获取PreparedStatement
            PreparedStatement preparedStatement = statementHandler.prepare(connection);

            //5.实例化ParameterHandler,将SQL语句中?参数化
            ParameterHandler parameterHandler = new DefaultParameterHandler(parameter);
            parameterHandler.setParameters(preparedStatement);

            //6.执行SQL,得到结果集ResultSet
            ResultSet resultSet = statementHandler.query(preparedStatement);

            //7.实例化ResultSetHandler,通过反射将ResultSet中结果设置到目标resultType对象中
            ResultSetHandler resultSetHandler = new DefaultResultSetHandler(mappedStatement);
            return resultSetHandler.handleResultSets(resultSet);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * doUpdate
     *
     * @param ms
     * @param parameter
     */
    @Override
    public void doUpdate(MappedStatement ms, Object parameter) {
        try {
            //1.获取数据库连接
            Connection connection = getConnect();

            //2.获取MappedStatement信息,里面有sql信息
            MappedStatement mappedStatement = conf.getMappedStatement(ms.getSqlId());

            //3.实例化StatementHandler对象,
            StatementHandler statementHandler = new SimpleStatementHandler(mappedStatement);

            //4.通过StatementHandler和connection获取PreparedStatement
            PreparedStatement preparedStatement = statementHandler.prepare(connection);

            //5.实例化ParameterHandler,将SQL语句中?参数化
            ParameterHandler parameterHandler = new DefaultParameterHandler(parameter);
            parameterHandler.setParameters(preparedStatement);

            //6.执行SQL
            statementHandler.update(preparedStatement);

        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * getConnect
     *
     * @return
     * @throws SQLException
     */
    public Connection getConnect() throws SQLException {
        if (null != connection) {
            return connection;
        } else {
            throw new SQLException("无法连接数据库,请检查配置");
        }
    }

}

定义Executor中依赖的StatementHandler(sql预处理)、ParameterHandler(设置参数)、ResultSetHandler(结果映射)

public class SimpleStatementHandler implements StatementHandler {
    /**
     * #{}正则匹配
     */
    private static Pattern param_pattern = Pattern.compile("#\\{([^\\{\\}]*)\\}");

    private MappedStatement mappedStatement;

    /**
     * 默认构造方法
     *
     * @param mappedStatement
     */
    public SimpleStatementHandler(MappedStatement mappedStatement) {
        this.mappedStatement = mappedStatement;
    }

    /**
     * 将SQL语句中的#{}替换为?,源码中是在SqlSourceBuilder类中解析的
     *
     * @param source
     * @return
     */
    private static String parseSymbol(String source) {
        source = source.trim();
        Matcher matcher = param_pattern.matcher(source);
        return matcher.replaceAll("?");
    }

    @Override
    public PreparedStatement prepare(Connection paramConnection)
            throws SQLException {
        String originalSql = mappedStatement.getSql();

        if (CommonUtil.isNotEmpty(originalSql)) {
            // 替换#{},预处理,防止SQL注入
            return paramConnection.prepareStatement(parseSymbol(originalSql));
        } else {
            throw new SQLException("original sql is null.");
        }
    }

    /**
     * query
     *
     * @param preparedStatement
     * @return
     * @throws SQLException
     */
    @Override
    public ResultSet query(PreparedStatement preparedStatement)
            throws SQLException {
        return preparedStatement.executeQuery();
    }

    /**
     * update
     *
     * @param preparedStatement
     * @throws SQLException
     */
    @Override
    public void update(PreparedStatement preparedStatement)
            throws SQLException {
        preparedStatement.executeUpdate();
    }

}
public class DefaultParameterHandler implements ParameterHandler {

    private Object parameter;

    public DefaultParameterHandler(Object parameter) {
        this.parameter = parameter;
    }

    /**
     * 将SQL参数设置到PreparedStatement中
     */
    @Override
    public void setParameters(PreparedStatement ps) {

        try {

            if (null != parameter) {
                if (parameter.getClass().isArray()) {
                    Object[] params = (Object[]) parameter;
                    for (int i = 0; i < params.length; i++) {
                        //Mapper保证传入参数类型匹配,这里就不做类型转换了
                        ps.setObject(i + 1, params[i]);
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

}
public class DefaultResultSetHandler implements ResultSetHandler {

    private final MappedStatement mappedStatement;

    /**
     * @param mappedStatement
     */
    public DefaultResultSetHandler(MappedStatement mappedStatement) {
        this.mappedStatement = mappedStatement;
    }

    /**
     * 处理查询结果,通过反射设置到返回的实体类
     */
    @SuppressWarnings("unchecked")
    @Override
    public <E> List<E> handleResultSets(ResultSet resultSet) {
        try {

            List<E> result = new ArrayList<>();

            if (null == resultSet) {
                return null;
            }

            while (resultSet.next()) {
                // 通过反射实例化返回类
                Class<?> entityClass = Class.forName(mappedStatement.getResultType());
                E entity = (E) entityClass.newInstance();
                Field[] declaredFields = entityClass.getDeclaredFields();

                for (Field field : declaredFields) {
                    // 对成员变量赋值
                    field.setAccessible(true);
                    Class<?> fieldType = field.getType();

                    // 目前只实现了string和int转换
                    if (String.class.equals(fieldType)) {
                        field.set(entity, resultSet.getString(field.getName()));
                    } else if (int.class.equals(fieldType) || Integer.class.equals(fieldType)) {
                        field.set(entity, resultSet.getInt(field.getName()));
                    } else {
                        // 其他类型自己转换,这里就直接设置了
                        field.set(entity, resultSet.getObject(field.getName()));
                    }
                }

                result.add(entity);
            }

            return result;
        } catch (Exception e) {
            e.printStackTrace();
        }
        return null;
    }

}

相关的依赖及工具类就不贴出了,源码会放在文章末尾

mybatis高级特性

插件机制

mybatis采用责任链模式,通过动态代理组织多个插件,通过插件改变默认的sql的行为,myabtis允许通过插件来拦截四大对象:Executor、ParameterHandler、ResultSetHandler以及StatementHandler。

相关源码

//创建参数处理器
  public ParameterHandler newParameterHandler(MappedStatement mappedStatement, Object parameterObject, BoundSql boundSql) {
    //创建ParameterHandler
    ParameterHandler parameterHandler = mappedStatement.getLang().createParameterHandler(mappedStatement, parameterObject, boundSql);
    //插件在这里插入
    parameterHandler = (ParameterHandler) interceptorChain.pluginAll(parameterHandler);
    return parameterHandler;
  }

  //创建结果集处理器
  public ResultSetHandler newResultSetHandler(Executor executor, MappedStatement mappedStatement, RowBounds rowBounds, ParameterHandler parameterHandler,
      ResultHandler resultHandler, BoundSql boundSql) {
    //创建DefaultResultSetHandler(稍老一点的版本3.1是创建NestedResultSetHandler或者FastResultSetHandler)
    ResultSetHandler resultSetHandler = new DefaultResultSetHandler(executor, mappedStatement, parameterHandler, resultHandler, boundSql, rowBounds);
    //插件在这里插入
    resultSetHandler = (ResultSetHandler) interceptorChain.pluginAll(resultSetHandler);
    return resultSetHandler;
  }

  //创建语句处理器
  public StatementHandler newStatementHandler(Executor executor, MappedStatement mappedStatement, Object parameterObject, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) {
    //创建路由选择语句处理器
    StatementHandler statementHandler = new RoutingStatementHandler(executor, mappedStatement, parameterObject, rowBounds, resultHandler, boundSql);
    //插件在这里插入
    statementHandler = (StatementHandler) interceptorChain.pluginAll(statementHandler);
    return statementHandler;
  }

  public Executor newExecutor(Transaction transaction) {
    return newExecutor(transaction, defaultExecutorType);
  }

  //产生执行器
  public Executor newExecutor(Transaction transaction, ExecutorType executorType) {
    //判断使用的执行器类型
    executorType = executorType == null ? defaultExecutorType : executorType;
    //这句再做一下保护,囧,防止粗心大意的人将defaultExecutorType设成null?
    executorType = executorType == null ? ExecutorType.SIMPLE : executorType;
    Executor executor;
    //然后就是简单的3个分支,产生3种执行器BatchExecutor/ReuseExecutor/SimpleExecutor
    if (ExecutorType.BATCH == executorType) {
      executor = new BatchExecutor(this, transaction);
    } else if (ExecutorType.REUSE == executorType) {
      executor = new ReuseExecutor(this, transaction);
    } else {
      executor = new SimpleExecutor(this, transaction);
    }
    //如果要求缓存,生成另一种CachingExecutor(默认就是有缓存),装饰者模式,所以默认都是返回CachingExecutor
    if (cacheEnabled) {
      executor = new CachingExecutor(executor);
    }
    //此处调用插件,通过插件可以改变Executor行为
    executor = (Executor) interceptorChain.pluginAll(executor);
    return executor;
  }

每一个拦截器对目标类都进行一次代理

     /**
     *@target
     *@return 层层代理后的对象
     */
    public Object pluginAll(Object target) {
        //循环调用每个Interceptor.plugin方法
        for (Interceptor interceptor : interceptors) {
            target = interceptor.plugin(target);
        }
        return target;
    }

Interceptor 接口说明

public interface Interceptor {

  /**
   * 执行拦截逻辑的方法
   *
   * @param invocation 调用信息
   * @return 调用结果
   * @throws Throwable 异常
   */
  Object intercept(Invocation invocation) throws Throwable;

  /**
   * 代理类
   *
   * @param target
   * @return
   */
  Object plugin(Object target);

  /**
   * 根据配置来初始化 Interceptor 方法
   * @param properties
   */
  void setProperties(Properties properties);

}

注解拦截器并签名

@Intercepts(@Signature(
        type = StatementHandler.class,  //要拦截四大对象的类型
        method = "prepare",             //拦截对象中哪个方法
        args = {Connection.class, Integer.class}   //需要传入的参数
))

手写分页插件

基于ThreadLocal传递分页参数,拦截StatementHandler

@Intercepts(@Signature(
        type = StatementHandler.class,
        method = "prepare",
        args = {Connection.class, Integer.class}
))
public class PagePlugin implements Interceptor {
    // 插件的核心业务
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        /**
         * 1、拿到原始的sql语句
         * 2、修改原始sql,增加分页  select * from t_user limit 0,3
         * 3、执行jdbc去查询总数
         */
        // 从invocation拿到我们StatementHandler对象
        StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
        // 拿到原始的sql语句
        BoundSql boundSql = statementHandler.getBoundSql();
        String sql = boundSql.getSql();

        // statementHandler 转成 metaObject
        MetaObject metaObject = SystemMetaObject.forObject(statementHandler);

        // spring context.getBean("userBean")
        MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
        // 获取mapper接口中的方法名称  selectUserByPage
        String mapperMethodName = mappedStatement.getId();
        if (mapperMethodName.matches(".*ByPage")) {
            Page page = PageUtil.getPaingParam();
            //  select * from user;
            String countSql = "select count(0) from (" + sql + ") a";
            System.out.println("查询总数的sql : " + countSql);

            // 执行jdbc操作
            Connection connection = (Connection) invocation.getArgs()[0];
            PreparedStatement countStatement = connection.prepareStatement(countSql);
            ParameterHandler parameterHandler = (ParameterHandler) metaObject.getValue("delegate.parameterHandler");
            parameterHandler.setParameters(countStatement);
            ResultSet rs = countStatement.executeQuery();
            if (rs.next()) {
                page.setTotalNumber(rs.getInt(1));
            }
            rs.close();
            countStatement.close();

            // 改造sql limit
            String pageSql = this.generaterPageSql(sql, page);
            System.out.println("分页sql:" + pageSql);

            //将改造后的sql设置回去
            metaObject.setValue("delegate.boundSql.sql", pageSql);

        }
        // 把执行流程交给mybatis
        return invocation.proceed();
    }

    // 把自定义的插件加入到mybatis中去执行
    @Override
    public Object plugin(Object target) {
        return Plugin.wrap(target, this);
    }

    // 设置属性
    @Override
    public void setProperties(Properties properties) {

    }

    // 根据原始sql 生成 带limit sql
    public String generaterPageSql(String sql, Page page) {

        StringBuffer sb = new StringBuffer();
        sb.append(sql);
        sb.append(" limit " + page.getStartIndex() + " , " + page.getTotalSelect());
        return sb.toString();
    }

}
Data
@NoArgsConstructor
public class Page {
    public Page(int currentPage,int pageSize){
        this.currentPage=currentPage;
        this.pageSize=pageSize;
    }

    private int totalNumber;// 当前表中总条目数量
    private int currentPage;// 当前页的位置

    private int totalPage;	// 总页数
    private int pageSize = 3;// 页面大小

    private int startIndex;	// 检索的起始位置
    private int totalSelect;// 检索的总数目

    public void setTotalNumber(int totalNumber) {
        this.totalNumber = totalNumber;
        // 计算
        this.count();
    }

    public void count() {
        int totalPageTemp = this.totalNumber / this.pageSize;
        int plus = (this.totalNumber % this.pageSize) == 0 ? 0 : 1;
        totalPageTemp = totalPageTemp + plus;
        if (totalPageTemp <= 0) {
            totalPageTemp = 1;
        }
        this.totalPage = totalPageTemp;// 总页数

        if (this.totalPage < this.currentPage) {
            this.currentPage = this.totalPage;
        }
        if (this.currentPage < 1) {
            this.currentPage = 1;
        }
        this.startIndex = (this.currentPage - 1) * this.pageSize;// 起始位置等于之前所有页面输乘以页面大小
        this.totalSelect = this.pageSize;// 检索数量等于页面大小
    }
}
@Data
public class PageResponse<T> {
    private int totalNumber;
    private int currentPage;
    private int totalPage;
    private int pageSize = 3;
    private T data;

}
public class PageUtil {
    private static final ThreadLocal<Page> LOCAL_PAGE = new ThreadLocal<Page>();

    public static void setPagingParam(int offset, int limit) {
        Page page = new Page(offset, limit);
        LOCAL_PAGE.set(page);
    }

    public static void removePagingParam() {
        LOCAL_PAGE.remove();
    }

    public static Page getPaingParam() {
        return LOCAL_PAGE.get();
    }

}

手写插件实现读写分离

基于spring动态数据源和Theadlocal,拦截Executor

@Intercepts({// mybatis 执行流程
        @Signature(type = Executor.class, method = "update", args = { MappedStatement.class, Object.class }),
        @Signature(type = Executor.class, method = "query", args = { MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class })
})
@Slf4j
public class DynamicPlugin implements Interceptor {
    private static final Map<String, String> cacheMap = new ConcurrentHashMap<>();

    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        Object[] objects = invocation.getArgs();
        MappedStatement ms = (MappedStatement) objects[0];

        String dynamicDataSource = null;

        if ((dynamicDataSource = cacheMap.get(ms.getId())) == null) {
            // 读方法
            if (ms.getSqlCommandType().equals(SqlCommandType.SELECT)) { // select * from user;    update insert
                // !selectKey 为自增id查询主键(SELECT LAST_INSERT_ID() )方法,使用主库
                if (ms.getId().contains(SelectKeyGenerator.SELECT_KEY_SUFFIX)) {
                    dynamicDataSource = "write";
                } else {
                    // 负载均衡,针对多个读库
                    dynamicDataSource = "read";
                }
            } else {
                dynamicDataSource = "write";
            }

            log.info("方法[{"+ms.getId()+"}] 使用了 [{"+dynamicDataSource+"}] 数据源, SqlCommandType [{"+ms.getSqlCommandType().name()+"}]..");
            // 把id(方法名)和数据源存入map,下次命中后就直接执行
            cacheMap.put(ms.getId(), dynamicDataSource);
        }
        // 设置当前线程使用的数据源
        DynamicDataSourceHolder.putDataSource(dynamicDataSource);

        return invocation.proceed();
    }

    @Override
    public Object plugin(Object target) {
        if (target instanceof Executor) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }

    @Override
    public void setProperties(Properties properties) {
    }
}
public final class DynamicDataSourceHolder {

    // 使用ThreadLocal记录当前线程的数据源key
    private static final ThreadLocal<String> holder = new ThreadLocal<String>();

    public static void putDataSource(String name){
        holder.set(name);
    }

    public static String getDataSource(){
        return holder.get();
    }

    /**
     * 清理数据源
     */
    public static void clearDataSource() {
        holder.remove();
    }

}
public class DynamicDataSource extends AbstractRoutingDataSource {

    @Override
    protected Object determineCurrentLookupKey() {
      return  DynamicDataSourceHolder.getDataSource();
    }


}
@Configuration
public class DataSourceConfig {
    
    @Value("${spring.datasource.db01.jdbcUrl}")
    private String db01Url;
    @Value("${spring.datasource.db01.username}")
    private String db01Username;
    @Value("${spring.datasource.db01.password}")
    private String db01Password;
    @Value("${spring.datasource.db01.driverClassName}")
    private String db01DiverClassName;

    @Bean("dataSource01")
    public DataSource dataSource01(){
        HikariDataSource dataSource01 = new HikariDataSource();
        dataSource01.setJdbcUrl(db01Url);
        dataSource01.setDriverClassName(db01DiverClassName);
        dataSource01.setUsername(db01Username);
        dataSource01.setPassword(db01Password);
        return dataSource01;
    }

    @Value("${spring.datasource.db02.jdbcUrl}")
    private String db02Url;
    @Value("${spring.datasource.db02.username}")
    private String db02Username;
    @Value("${spring.datasource.db02.password}")
    private String db02Password;
    @Value("${spring.datasource.db02.driverClassName}")
    private String db02DiverClassName;

    @Bean("dataSource02")
    public DataSource dataSource02(){
        HikariDataSource dataSource02 = new HikariDataSource();
        dataSource02.setJdbcUrl(db02Url);
        dataSource02.setDriverClassName(db02DiverClassName);
        dataSource02.setUsername(db02Username);
        dataSource02.setPassword(db02Password);
        return dataSource02;
    }
    @Bean("multipleDataSource")
    public DataSource multipleDataSource(@Qualifier("dataSource01") DataSource dataSource01,
                                         @Qualifier("dataSource02") DataSource dataSource02) {
        Map<Object, Object> datasources = new HashMap<Object, Object>();
        datasources.put("write", dataSource01);
        datasources.put("read", dataSource02);
        DynamicDataSource multipleDataSource = new DynamicDataSource();
        multipleDataSource.setDefaultTargetDataSource(dataSource01);
        multipleDataSource.setTargetDataSources(datasources);
        return multipleDataSource;
    }

}
public class DynamicDataSource extends AbstractRoutingDataSource {

    @Override
    protected Object determineCurrentLookupKey() {
      return  DynamicDataSourceHolder.getDataSource();
    }
}
public class DynamicDataSourceTransactionManager extends DataSourceTransactionManager {
    private static final long serialVersionUID = 1L;

    public DynamicDataSourceTransactionManager(DataSource dataSource){
        super(dataSource);
    }

    /**
     * 只读事务到读库,读写事务到写库
     *
     * @param transaction
     * @param definition
     */
    @Override
    protected void doBegin(Object transaction, TransactionDefinition definition) {

        // 设置数据源
        boolean readOnly = definition.isReadOnly();
        if (readOnly) {
            DynamicDataSourceHolder.putDataSource("read");
        } else {
            DynamicDataSourceHolder.putDataSource("write");
        }
        super.doBegin(transaction, definition);
    }

    /**
     * 清理本地线程的数据源
     *
     * @param transaction
     */
    @Override
    protected void doCleanupAfterCompletion(Object transaction) {
        super.doCleanupAfterCompletion(transaction);
        DynamicDataSourceHolder.clearDataSource();
    }
}
Configuration
@MapperScan("com.example.dao")
@EnableTransactionManagement
public class MybatisConfig implements TransactionManagementConfigurer {

    private static String mybatisConfigPath = "mybatis-config.xml";

    @Autowired
    @Qualifier("multipleDataSource")
    private DataSource multipleDataSource;

    @Bean("sqlSessionFactoryBean")
    public SqlSessionFactory sqlSessionFactoryBean() throws Exception {
        SqlSessionFactoryBean bean = new SqlSessionFactoryBean();
        bean.setDataSource(multipleDataSource);
        bean.setTypeAliasesPackage("com.example.entity");
        bean.setConfigLocation(new ClassPathResource(mybatisConfigPath));
        ResourcePatternResolver resolver = new PathMatchingResourcePatternResolver();
        bean.setMapperLocations(resolver.getResources("classpath*:mapper/*.xml"));
        return bean.getObject();

    }

    public PlatformTransactionManager annotationDrivenTransactionManager() {
        return new DynamicDataSourceTransactionManager(multipleDataSource);
    }
}

mybatis二级缓存

mybatis默认开启一级缓存,一级缓存是sqlsession级别的,所以在实际场景中并没有什么用,Mybatis二级缓存默认关闭,使用方式如下

1、在全局配置文件中加入

 <settings>
	<setting name="cacheEnabled" value="true" />
 </settings>

2、在使用二级缓存的mapper.xml中加入

<mapper namespace="com.study.mybatis.mapper.UserMapper">
	<!--开启本mapper的namespace下的二级缓存-->
	<cache eviction="LRU" flushInterval="100000" readOnly="true" size="1024"/>
</mapper>
    <!--eviction:代表的是缓存回收策略,目前MyBatis提供以下策略。
        (1) LRU,最近最少使用的,一处最长时间不用的对象
        (2) FIFO,先进先出,按对象进入缓存的顺序来移除他们
        (3) SOFT,软引用,移除基于垃圾回收器状态和软引用规则的对象
        (4) WEAK,弱引用,更积极的移除基于垃圾收集器状态和弱引用规则的对象。这里采用的是LRU,移除最长时间不用的对形象
        flushInterval:刷新间隔时间,单位为毫秒,这里配置的是100秒刷新,如果你不配置它,那么当
        SQL被执行的时候才会去刷新缓存。
        size:引用数目,一个正整数,代表缓存最多可以存储多少个对象,不宜设置过大。设置过大会导致内存溢出。
        这里配置的是1024个对象
        readOnly:只读,意味着缓存数据只能读取而不能修改,这样设置的好处是我们可以快速读取缓存,缺点是我们没有办法修改缓存,他的默认值是false,不允许我们修改
    -->

3、分布式应用,可以引入myabtis-redis相关依赖,实现基于redis的分布式缓存

< cache type="org.mybatis.caches.redis.RedisCache" />

4、也可以自定义缓存,myabtis为我们预留了Cache接口

mybatis自定义类型转换器

通常用于特殊字段的统一转换、敏感字段加密等,使用方式如下

public class MyTypeHandler implements TypeHandler {

    //private static String KEY = "123456";

    /**
     * 通过preparedStatement对象设置参数,将T类型的数据存入数据库。
     *
     * @param ps
     * @param i
     * @param parameter
     * @param jdbcType
     * @throws SQLException
     */
    @Override
    public void setParameter(PreparedStatement ps, int i, Object parameter, JdbcType jdbcType) throws SQLException {
        try {
            String encrypt = EncryptUtil.encode(((String) parameter).getBytes());
            ps.setString(i, encrypt);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    // 通过列名或者下标来获取结果数据,也可以通过CallableStatement获取数据。
    @Override
    public Object getResult(ResultSet rs, String columnName) throws SQLException {
        String result = rs.getString(columnName);
        if (result != null && result != "") {
            try {
                return EncryptUtil.decode(result.getBytes());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    @Override
    public Object getResult(ResultSet rs, int columnIndex) throws SQLException {
        String result = rs.getString(columnIndex);
        if (result != null && result != "") {
            try {
                return EncryptUtil.decode(result.getBytes());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return result;
    }

    @Override
    public Object getResult(CallableStatement cs, int columnIndex) throws SQLException {
        String result = cs.getString(columnIndex);
        if (result != null && result != "") {
            try {
                return EncryptUtil.decode(result.getBytes());
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
        return result;
    }
}
public class EncryptUtil {

    //base64 解码
    public static String decode(byte[] bytes) {
        return new String(Base64.decodeBase64(bytes));
    }

    //base64 编码
    public static String encode(byte[] bytes) {
        return new String(Base64.encodeBase64(bytes));
    }
}

mybatis配置文件中引入类型转换器

  <plugins>
        <plugin interceptor="com.example.plugin.PagePlugin" >
            <property name="type" value="mysql"/>
        </plugin>
  </plugins>

在需要使用的字段中指定类型转换器

  <resultMap id="resultListUser" type="com.example.entity.User" >
	<result column="password" property="password" typeHandler="com.example.typehandler.MyTypeHandler" />
  </resultMap>

  <update id="updateUser" parameterType="com.example.entity.User">
	UPDATE user userName=#{userName typeHandler="com.example.typehandler.MyTypeHandler"} WHERE id=#{id}
  </update>

spring与mybatis整合原理

整合过程如下

1、引入依赖

<dependency>
  <groupId>org.mybatis</groupId>
  <artifactId>mybatis-spring</artifactId>
  <version>1.2.2</version>
</dependency>

2、spring配置文件中引入如下配置

<bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
    <!-- 加载数据源 -->
    <property name="dataSource" ref="dataSource"/>
    <property name="mapperLocations" value="classpath*:mappers/*Mapper.xml"/>
</bean>
 
<bean class="org.mybatis.spring.mapper.MapperScannerConfigurer">
    <!-- 指定扫描的包,如果存在多个包使用(逗号,)分割 -->
    <property name="basePackage" value="com.test.bean"/>
    <property name="sqlSessionFactoryBeanName" value="sqlSessionFactory"/>
</bean>

可以看到核心类有两个MapperScannerConfigurer和SqlSessionFactoryBean。

MapperScannerConfigurer

MapperScannerConfigurer类声明如下

public class MapperScannerConfigurer implements BeanDefinitionRegistryPostProcessor, InitializingBean, ApplicationContextAware, BeanNameAware {
	//代码省略
}

其中BeanDefinitionRegistryPostProcessor的声明如下

public interface BeanDefinitionRegistryPostProcessor extends BeanFactoryPostProcessor {
	void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException;

}

这个类中主要的方法就是这个postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry)方法

public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) {
    if (this.processPropertyPlaceHolders) {
      processPropertyPlaceHolders();
    }

    ClassPathMapperScanner scanner = new ClassPathMapperScanner(registry);
    scanner.setAddToConfig(this.addToConfig);
    scanner.setAnnotationClass(this.annotationClass);
    scanner.setMarkerInterface(this.markerInterface);
    scanner.setSqlSessionFactory(this.sqlSessionFactory);
    scanner.setSqlSessionTemplate(this.sqlSessionTemplate);
    scanner.setSqlSessionFactoryBeanName(this.sqlSessionFactoryBeanName);
    scanner.setSqlSessionTemplateBeanName(this.sqlSessionTemplateBeanName);
    scanner.setResourceLoader(this.applicationContext);
    scanner.setBeanNameGenerator(this.nameGenerator);
    scanner.registerFilters();
    scanner.scan(StringUtils.tokenizeToStringArray(this.basePackage, ConfigurableApplicationContext.CONFIG_LOCATION_DELIMITERS));
  }

这个方法中创建了一个spring-mybatis.jar包中的ClassPathMapperScanner扫描器,这个扫描器继承了spring的ClassPathBeanDefinitionScanner,这个扫描器的主要的作用有以下几个:

1、扫描basePackage包下面所有的class类

2、第二将所有的class类封装成为spring的ScannedGenericBeanDefinition sbd对象

3、第三过滤sbd对象,只接受接口类

4、第四完成sbd对象属性的设置,比如设置sqlSessionFactory、BeanClass(注意放入的是MapperFactoryBean的class对象,后续会用到)等,具体代码如

ClassPathMapperScanner
private void processBeanDefinitions(Set<BeanDefinitionHolder> beanDefinitions) {
    GenericBeanDefinition definition;
    for (BeanDefinitionHolder holder : beanDefinitions) {
      definition = (GenericBeanDefinition) holder.getBeanDefinition();
      definition.getConstructorArgumentValues().addGenericArgumentValue(definition.getBeanClassName()); 
      definition.setBeanClass(this.mapperFactoryBean.getClass());
      definition.getPropertyValues().add("addToConfig", this.addToConfig);
      boolean explicitFactoryUsed = false;
      if (StringUtils.hasText(this.sqlSessionFactoryBeanName)) {
        definition.getPropertyValues().add("sqlSessionFactory", new RuntimeBeanReference(this.sqlSessionFactoryBeanName));
        explicitFactoryUsed = true;
      } else if (this.sqlSessionFactory != null) {
        definition.getPropertyValues().add("sqlSessionFactory", this.sqlSessionFactory);
        explicitFactoryUsed = true;
      }

      if (StringUtils.hasText(this.sqlSessionTemplateBeanName)) {
        if (explicitFactoryUsed) {
          logger.warn("Cannot use both: sqlSessionTemplate and sqlSessionFactory together. sqlSessionFactory is ignored.");
        }
        definition.getPropertyValues().add("sqlSessionTemplate", new RuntimeBeanReference(this.sqlSessionTemplateBeanName));
        explicitFactoryUsed = true;
      } else if (this.sqlSessionTemplate != null) {
        if (explicitFactoryUsed) {
          logger.warn("Cannot use both: sqlSessionTemplate and sqlSessionFactory together. sqlSessionFactory is ignored.");
        }
        definition.getPropertyValues().add("sqlSessionTemplate", this.sqlSessionTemplate);
        explicitFactoryUsed = true;
      }

      if (!explicitFactoryUsed) {
        if (logger.isDebugEnabled()) {
          logger.debug("Enabling autowire by type for MapperFactoryBean with name '" + holder.getBeanName() + "'.");
        }
        definition.setAutowireMode(AbstractBeanDefinition.AUTOWIRE_BY_TYPE);
      }
    }
  }

5、第五将过滤出来的sbd对象通过这个BeanDefinitionRegistry registry注册器注册到DefaultListableBeanFactory中,这个registry就是方法postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry)中的参数。

以上就是实例化MapperScannerConfigurer类的主要工作,总结起来就是扫描basePackage包下所有的mapper接口类,并将mapper接口类封装成为BeanDefinition对象,注册到spring的BeanFactory容器中。

SqlSessionFactoryBean

类的声明

public class SqlSessionFactoryBean implements FactoryBean<SqlSessionFactory>, InitializingBean, ApplicationListener<ApplicationEvent> {
    //省略代码
}

在这个bean被创建的过程中,首先被调用的方法是afterPropertiesSet,这个方法是接口InitializingBean中的方法。

public void afterPropertiesSet() throws Exception {
    Assert.notNull(this.dataSource, "Property 'dataSource' is required");
    Assert.notNull(this.sqlSessionFactoryBuilder, "Property 'sqlSessionFactoryBuilder' is required");
    this.sqlSessionFactory = this.buildSqlSessionFactory();
}

在这个buildSqlSessionFactory()方法中,做了很多事情,总结起来,就是创建了几个对象,依次是mybatis的核心类Configuration、spring和mybatis集成的事物工厂类SpringManagedTransactionFactory、mybatis的Environment类、mybatis的DefaultSqlSessionFactory类,同时还完成了对mybatis的xml文件解析,并将解析结果封装在Configuration类中。

protected SqlSessionFactory buildSqlSessionFactory() throws IOException {
    XMLConfigBuilder xmlConfigBuilder = null;
    Configuration configuration;
    configuration = new Configuration();
        configuration.setVariables(this.configurationProperties);
    //省略代码
    if (this.transactionFactory == null) {
        this.transactionFactory = new SpringManagedTransactionFactory();
    }
 
    Environment environment = new Environment(this.environment, this.transactionFactory, this.dataSource);
    configuration.setEnvironment(environment);
    
    if (!ObjectUtils.isEmpty(this.mapperLocations)) {
        Resource[] arr$ = this.mapperLocations;
        len$ = arr$.length;
        for(i$ = 0; i$ < len$; ++i$) {
            Resource mapperLocation = arr$[i$];
            if (mapperLocation != null) {
                try {
                    XMLMapperBuilder xmlMapperBuilder = new XMLMapperBuilder(mapperLocation.getInputStream(), configuration, mapperLocation.toString(), configuration.getSqlFragments());
                    xmlMapperBuilder.parse();
                } catch (Exception var20) {
                    throw new NestedIOException("Failed to parse mapping resource: '" + mapperLocation + "'", var20);
                } finally {
                    ErrorContext.instance().reset();
                }
            }
        }
    }
 
    return this.sqlSessionFactoryBuilder.build(configuration);
}

sqlSessionFactoryBuilder.build(configuration)做的事情很简单,直接创建了DefaultSqlSessionFactory,其中还持有Configuration对象的引用。

SqlSessionFactoryBean实现了FactoryBean接口,因此在初始化时会被调用getObject方法。

public SqlSessionFactory getObject() throws Exception {
    if (this.sqlSessionFactory == null) {
        this.afterPropertiesSet();
    }
    return this.sqlSessionFactory;
}

从内容上看,是直接返回了一个SqlSessionFactory类的实例,如果为空则调用afterPropertiesSet去初始化sqlSessionFactory对象。

实例化过程

以下面代码为例子

@Service
public class UserServiceImpl implements IUserService {
	@Resource
	private UserMapper UserMapper;
 
	public UserServiceImpl(){
	}
}

spring在初始化的过程中,会去创建UserServiceImpl类,创建完成之后,会进行属性赋值,UserMapper这个mybatis接口就是UserServiceImpl的一个属性,首先根据这个mapper的名字从spring的BeanFactory中获取它的BeanDefinition,再从BeanDefinition中获取BeanClass,UserMapper对应的BeanClass就是MapperFactoryBean,这个在上面分析的内容中提到过。

public class MapperFactoryBean<T> extends SqlSessionDaoSupport implements FactoryBean<T> {

  private Class<T> mapperInterface;

  private boolean addToConfig = true;

  public MapperFactoryBean() {
    //intentionally empty 
  }
  
  public MapperFactoryBean(Class<T> mapperInterface) {
    this.mapperInterface = mapperInterface;
  }

  @Override
  protected void checkDaoConfig() {
    super.checkDaoConfig();

    notNull(this.mapperInterface, "Property 'mapperInterface' is required");

    Configuration configuration = getSqlSession().getConfiguration();
    if (this.addToConfig && !configuration.hasMapper(this.mapperInterface)) {
      try {
        configuration.addMapper(this.mapperInterface);
      } catch (Exception e) {
        logger.error("Error while adding the mapper '" + this.mapperInterface + "' to configuration.", e);
        throw new IllegalArgumentException(e);
      } finally {
        ErrorContext.instance().reset();
      }
    }
  }

  @Override
  public T getObject() throws Exception {
    return getSqlSession().getMapper(this.mapperInterface);
  }

  @Override
  public Class<T> getObjectType() {
    return this.mapperInterface;
  }

  @Override
  public boolean isSingleton() {
    return true;
  }

  public void setMapperInterface(Class<T> mapperInterface) {
    this.mapperInterface = mapperInterface;
  }

  public Class<T> getMapperInterface() {
    return mapperInterface;
  }

  public void setAddToConfig(boolean addToConfig) {
    this.addToConfig = addToConfig;
  }

  public boolean isAddToConfig() {
    return addToConfig;
  }
}

接着就是创建MapperFactoryBean对象了,创建完成之后,就需要对属性进行赋值,这是在创建UserMapper所对应BeanDefinition对象的时候决定的,回顾上面创建MapperScannerConfigurer对象的那部分内容就知道。其中有一个属性就是SqlSessionFactoryBean,然后就是触发SqlSessionFactoryBean的创建过程。

MapperFactoryBean对象的属性设置完成之后,就调用它的getObject()方法,后续的逻辑就是上面提到的mybatis创建代理对象的过程。

相关源码

https://gitee.com/junlin1991/mybatis-learn.git

posted on 2022-01-17 20:13  请叫我西毒  阅读(109)  评论(0编辑  收藏  举报