手写开源ORM框架介绍

手写开源ORM框架介绍

简介

前段时间利用空闲时间,参照mybatis的基本思路手写了一个ORM框架。一直没有时间去补充相应的文档,现在正好抽时间去整理下。通过思路历程和代码注释,一方面重温下知识,另一方面准备后期去完善这个框架。

传统JDBC连接

参照传统的JDBC连接数据库过程如下,框架所做的事情就是把这些步骤进行封装。

// 1. 注册 JDBC 驱动
Class.forName(JDBC_DRIVER);
// 2. 打开链接
conn = DriverManager.getConnection(DB_URL,USER,PASS);
// 3. 实例化statement
stmt = conn.prepareStatement(sql);
// 4. 填充数据
stmt.setString(1,id);
// 5. 执行Sql连接
ResultSet rs = stmt.executeQuery();
// 6. 结果查询
while(rs.next()){
    // 7. 通过字段检索
    int id  = rs.getInt("id");
    String name = rs.getString("name");
    String url = rs.getString("url");
}

思路历程

上述是执行Java连接数据库通用语句,ORM框架就是对象关系映射框架。我们需要做的重点是步骤4和7,将对象字段填充至数据库,将数据库的内容取出作为对象返回。了解了基本的方法后,开始动手写一个框架了。

数据库驱动

开始,需要构建一个驱动注册类,这个类主要用来对数据库驱动进行构造。详见代码块1

代码块1

数据库驱动类(com.simple.ibatis.dirver)

/**
 * @Author  xiabing
 * @Desc    驱动注册中心,对数据库驱动进行注册的地方
 **/
public class DriverRegister {

    /*
     * mysql的driver类
     * */
    private static final String MYSQLDRIVER = "com.mysql.jdbc.Driver";

    /*
     * 构建driver缓存,存储已经注册了的driver的类型
     * */
    private static final Map<String,Driver> registerDrivers = new ConcurrentHashMap<>();

    /*
    *   初始化,此处是将DriverManager中已经注册了驱动放入自己缓存中。当然,你也可以在这个方法内注册
        常见的数据库驱动,这样后续就可以直接使用,不用自己手动注册了。
    * */
    static {
        Enumeration<Driver> driverEnumeration = DriverManager.getDrivers();
        while (driverEnumeration.hasMoreElements()){
            Driver driver = driverEnumeration.nextElement();
            registerDrivers.put(driver.getClass().getName(),driver);
        }
    }

    /*
     *  加载mysql驱动,此个方法可以写在静态代码块内,代表项目启动时即注册mysql驱动
     * */
    public void loadMySql(){
       if(! registerDrivers.containsKey(MYSQLDRIVER)){
           loadDriver(MYSQLDRIVER);
       }
    }

    /*
     *  加载数据库驱动通用方法,并注册到registerDrivers缓存中,此处是注册数据库驱动的方法
     * */
    public void loadDriver(String driverName){

        Class<?> driverType;
        try {
            // 注册驱动,返回驱动类
            driverType = Class.forName(driverName);
            // 将驱动实例放入驱动缓存中
            registerDrivers.put(driverType.getName(),(Driver)driverType.newInstance());
        }catch (ClassNotFoundException e){
            throw new RuntimeException(e);
        }catch (Exception e){
            throw new RuntimeException(e);
        }
    }
}

驱动注册后,需要建立数据库连接。但是框架如果一个请求建一个连接,那木势必耗费大量的资源。所以该框架引入池化的概念,对连接进行池化管理。详情见代码2

代码块2

数据源类(com.simple.ibatis.datasource)

PoolConnection

/**
 * @Author  xiabing
 * @Desc    连接代理类,是对Connection的一个封装。除了提供基本的连接外,还想记录
            这个连接的连接时间,因为有的连接如果一直连接不释放,那木我可以通过查看
            这个连接已连接的时间,如果超时了,我可以主动释放。
 **/
public class PoolConnection {
    
    // 真实的数据库连接
    public Connection connection;
    
    // 数据开始连接时间
    private Long CheckOutTime;
    
    // 连接的hashCode
    private int hashCode = 0;

    public PoolConnection(Connection connection) {
        this.connection = connection;
        this.hashCode = connection.hashCode();
    }

    public Long getCheckOutTime() {
        return CheckOutTime;
    }

    public void setCheckOutTime(Long checkOutTime) {
        CheckOutTime = checkOutTime;
    }
    
    // 判断两个PoolConnection对象是否相等,其实是判断其中真实连接的hashCode是否相等
    @Override
    public boolean equals(Object obj) {
        if(obj instanceof  PoolConnection){
            return connection.hashCode() ==
                    ((PoolConnection) obj).connection.hashCode();
        }else if(obj instanceof Connection){
            return obj.hashCode() == hashCode;
        }else {
            return false;
        }
    }

    @Override
    public int hashCode() {
        return hashCode;
    }
}

NormalDataSource

/**
 * @Author  xiabing
 * @Desc    普通数据源,这个数据源是用来产生数据库连接的。来一个请求就会建立一个数据库连接,没有池化的概念
 **/
public class NormalDataSource implements DataSource{
    
    // 驱动名称
    private String driverClassName;
    
    // 数据库连接URL
    private String url;
    
    // 数据库用户名
    private String userName;
    
    // 数据库密码
    private String passWord;

    // 驱动注册中心
    private DriverRegister driverRegister = new DriverRegister();

    public NormalDataSource(String driverClassName, String url, String userName, String passWord) {
        // 初始化时将驱动进行注册
        this.driverRegister.loadDriver(driverClassName);
        this.driverClassName = driverClassName;
        this.url = url;
        this.userName = userName;
        this.passWord = passWord;
    }
    
    // 获取数据库连接
    @Override
    public Connection getConnection() throws SQLException {
        return DriverManager.getConnection(url,userName,passWord);

    }
    
    // 移除数据库连接,此方法没有真正移除,只是将连接中未提交的事务进行回滚操作。
    public void removeConnection(Connection connection) throws SQLException{
        if(!connection.getAutoCommit()){
            connection.rollback();
        }
    }
    
    // 获取数据库连接
    @Override
    public Connection getConnection(String username, String password) throws SQLException {
        return DriverManager.getConnection(url,username,password);
    }
    
    // 后续方法因为没有用到,所有没有进行重写了
    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return null;
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return false;
    }

    @Override
    public PrintWriter getLogWriter() throws SQLException {
        return null;
    }

    @Override
    public void setLogWriter(PrintWriter out) throws SQLException {

    }

    @Override
    public void setLoginTimeout(int seconds) throws SQLException {

    }

    @Override
    public int getLoginTimeout() throws SQLException {
        return 0;
    }

    @Override
    public Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return null;
    }

}

PoolDataSource

/**
 * @Author  xiabing
 * @Desc    池化线程池,对连接进行管理
 **/
public class PoolDataSource implements DataSource{

    private Integer maxActiveConnectCount = 10; // 最大活跃线程数,此处可根据实际情况自行设置

    private Integer maxIdleConnectCount = 10; // 最大空闲线程数,此处可根据实际情况自行设置

    private Long maxConnectTime = 30*1000L; // 连接最长使用时间,在自定义连接中配置了连接开始时间,在这里来判断该连接是否超时

    private Integer waitTime = 2000; // 线程wait等待时间

    private NormalDataSource normalDataSource; // 使用normalDataSource来产生连接

    private Queue<PoolConnection> activeConList = new LinkedList<>(); // 存放活跃连接的队列

    private Queue<PoolConnection> idleConList = new LinkedList<>(); // 存放空闲连接的队列

    public PoolDataSource(String driverClassName, String url, String userName, String passWord) {

        this(driverClassName,url,userName,passWord,10,10);

    }

    public PoolDataSource(String driverClassName, String url, String userName, String passWord,Integer maxActiveConnectCount,Integer maxIdleConnectCount) {
        // 初始化normalDataSource,因为normalDataSource已经封装了新建连接的方法
        this.normalDataSource = new NormalDataSource(driverClassName,url,userName,passWord);
        this.maxActiveConnectCount = maxActiveConnectCount;
        this.maxIdleConnectCount = maxIdleConnectCount;

    }
    /**
     * @Desc 获取连接时先从空闲连接列表中获取。若没有,则判断现在活跃连接是否已超过设置的最大活跃连接数,没超过,new一个
     *        若超过,则判断第一个连接是否已超时,若超时,则移除掉在新建。若未超时,则wait()等待。
     **/
    @Override
    public Connection getConnection(){
        Connection connection = null;
        try {
            connection =  doGetPoolConnection().connection;
        }catch (SQLException e){
            throw new RuntimeException(e);
        }
        return connection;
    }

    public void removeConnection(Connection connection){
        PoolConnection poolConnection = new PoolConnection(connection);
        doRemovePoolConnection(poolConnection);
    }
    
    private PoolConnection doGetPoolConnection() throws SQLException{
        PoolConnection connection = null;
        while (connection == null){
            // 加锁
            synchronized (this){
                // 判断是否有空闲连接
                if(idleConList.size() < 1){ 
                    // 判断活跃连接数是否已经超过预设的最大活跃连接数
                    if(activeConList.size() < maxActiveConnectCount){
                        // 如果还可以新建连接,就新建一个连接
                        connection = new PoolConnection(normalDataSource.getConnection());
                    }else {
                        // 走到这一步,说明没有空闲连接,并且活跃连接已经满了,则只能看哪些连接超时,判断第一个连接是否超时
                        PoolConnection poolConnection = activeConList.peek();
                        // 这就是我为啥要给数据库连接加连接时间了
                        if(System.currentTimeMillis() - poolConnection.getCheckOutTime() > maxConnectTime){
                            // 若第一个连接已经超时了,移除第一个活跃连接
                            PoolConnection timeOutConnect = activeConList.poll();
                            if(!timeOutConnect.connection.getAutoCommit()){
                                // 如果该连接设的是非自动提交,则对事物进行回滚操作
                                timeOutConnect.connection.rollback();
                            }
                            // 置为空,让垃圾收集器去收集
                            timeOutConnect.connection.close();
                            timeOutConnect = null;
                            // 新建一个连接
                            connection = new PoolConnection(normalDataSource.getConnection());
                        }else {
                            // 走到这一步,代表所有连接都没有超时,只能等了
                            try{
                                // 于是等一会
                                this.wait(waitTime);
                            }catch (InterruptedException e){
                                // ignore错误,并退出
                                break;  
                            }
                        }
                    }
                }else {
                    // 这一步代表空闲连接队列里面有连接,则直接取出
                    connection = idleConList.poll();
                }
                if(connection != null){
                    // 设置这个连接的连接时间
                    connection.setCheckOutTime(System.currentTimeMillis());
                    // 然后放入活跃连接队列中
                    activeConList.add(connection);
                }
            }
        }

        return connection;
    }
    
    // 移除连接
    private void doRemovePoolConnection(PoolConnection connection){
        // 加锁
        synchronized (this){
            // 移除连接,这里需要重写PoolConnection的equal方法,因为只要真实连接相同,就可以代表这个封装的连接相等了
            activeConList.remove(connection);
            if(idleConList.size() < maxIdleConnectCount){
                // 加入空闲连接队列中
                idleConList.add(connection);
            }
            // 通知其他等待线程,让他们继续获取连接
            this.notifyAll();
        }
    }

    @Override
    public Connection getConnection(String username, String password) throws SQLException {
        return getConnection();
    }
    
    // 后面的方法没有用到就没有重写了
    @Override
    public <T> T unwrap(Class<T> iface) throws SQLException {
        return null;
    }

    @Override
    public boolean isWrapperFor(Class<?> iface) throws SQLException {
        return false;
    }

    @Override
    public PrintWriter getLogWriter() throws SQLException {
        return null;
    }

    @Override
    public void setLogWriter(PrintWriter out) throws SQLException {

    }

    @Override
    public void setLoginTimeout(int seconds) throws SQLException {

    }

    @Override
    public int getLoginTimeout() throws SQLException {
        return 0;
    }

    @Override
    public Logger getParentLogger() throws SQLFeatureNotSupportedException {
        return null;
    }
}

使用数据库连接池获取连接后,接下来就是关键了,生成SQL预处理语句,并进行数据填充。这让我想到了mybatis使用方法。

@Select("select * from users")
List<User> getUsers();

上面是mybatis基于注解的配置方式。除了注解,还支持xml。于是框架准备先使用基于注解的方式进行代码配置。首先,需要知道这些接口类的类名。于是提供一个配置,根据指定包名获取包下所有类的类名详情见代码3

代码块3

方法类(com.simple.ibatis.util)

PackageUtil

/**
 * @Author  xiabing
 * @Desc    获取指定包下所有类的名称
 **/
public class PackageUtil {

    private static final String CLASS_SUFFIX = ".class";
    private static final String CLASS_FILE_PREFIX = "classes"  + File.separator;
    private static final String PACKAGE_SEPARATOR = ".";

    /**
     * 查找包下的所有类的名字
     * @param packageName
     * @param showChildPackageFlag 是否需要显示子包内容
     * @return List集合,内容为类的全名
     */
    public static List<String> getClazzName(String packageName, boolean showChildPackageFlag){
        List<String> classNames = new ArrayList<>();
        String suffixPath = packageName.replaceAll("\\.", "/");
        // 获取类加载器
        ClassLoader loader = Thread.currentThread().getContextClassLoader();
        try{
            Enumeration<URL> urls = loader.getResources(suffixPath);
            while(urls.hasMoreElements()) {
                URL url = urls.nextElement();
                if(url != null){
                    if ("file".equals(url.getProtocol())) {
                        String path = url.getPath();
                        classNames.addAll(getAllClassName(new File(path),showChildPackageFlag));
                    }
                }
            }
        }catch (IOException e){
            throw new RuntimeException("load resource is error , resource is "+packageName);
        }
        return classNames;
    }

    
    // 获取所有类的名称
    private static List<String> getAllClassName(File file,boolean flag){

        List<String> classNames = new ArrayList<>();

        if(!file.exists()){
            return classNames;
        }
        if(file.isFile()){
            String path = file.getPath();
            if(path.endsWith(CLASS_SUFFIX)){
                path = path.replace(CLASS_SUFFIX,"");
                String clazzName = path.substring(path.indexOf(CLASS_FILE_PREFIX) + CLASS_FILE_PREFIX.length()).replace(File.separator, PACKAGE_SEPARATOR);
                classNames.add(clazzName);
            }
        }else {
            File[] listFiles = file.listFiles();
            if(listFiles != null && listFiles.length > 0){
                for (File f : listFiles){
                    if(flag) {
                        classNames.addAll(getAllClassName(f, flag));
                    }else {
                        if(f.isFile()){
                            String path = f.getPath();
                            if(path.endsWith(CLASS_SUFFIX)) {
                                path = path.replace(CLASS_SUFFIX, "");
                                String clazzName = path.substring(path.indexOf(CLASS_FILE_PREFIX) + CLASS_FILE_PREFIX.length()).replace(File.separator,PACKAGE_SEPARATOR);
                                classNames.add(clazzName);
                            }
                        }
                    }
                }
            }
        }
        return classNames;
    }
}

在知道接口类存放在哪些类后,我需要解析这些类中的各个方法,包括每个方法的注解,入参,和返回值。同时,我需要使用一个全局配置类,用于存放我解析后生成的对象,并且一些数据源配置信息也可以放在这个全局类中。详情见代码4

代码块4

核心类(com.simple.ibatis.core)

Config

/**
 * @Author  xiabing
 * @Desc    框架的核心配置类
 **/
public class Config {

    // 数据源
    private PoolDataSource dataSource;

    // mapper包地址(即待解析的包名),后续改为List<String>,能同时加载多个mapper包 todo
    private String daoSource;

    // mapper核心文件(存放类解析后的对象)
    private MapperCore mapperCore;

    // 是否启用事务
    private boolean openTransaction;

    // 是否开启缓存
    private boolean openCache;

    public Config(String mapperSource,PoolDataSource dataSource){
        this.dataSource = dataSource;
        this.daoSource = mapperSource;
        this.mapperCore = new MapperCore(this);
    }
    
    // 生成SQL语句执行器
    public SimpleExecutor getExecutor(){
        return new SimpleExecutor(this,this.getDataSource(),openTransaction,openCache);
    }
    
    // 后文都是基本的get,set方法
    public PoolDataSource getDataSource() {
        return dataSource;
    }

    public void setDataSource(PoolDataSource dataSource) {
        this.dataSource = dataSource;
    }

    public String getDaoSource() {
        return daoSource;
    }

    public void setDaoSource(String daoSource) {
        this.daoSource = daoSource;
    }

    public MapperCore getMapperCore() {
        return mapperCore;
    }

    public void setMapperCore(MapperCore mapperCore) {
        this.mapperCore = mapperCore;
    }

    public boolean isOpenTransaction() {
        return openTransaction;
    }

    public void setOpenTransaction(boolean openTransaction) {
        this.openTransaction = openTransaction;
    }

    public boolean isOpenCache() {
        return openCache;
    }

    public void setOpenCache(boolean openCache) {
        this.openCache = openCache;
    }

}

SqlSource

/**
 * @Author  xiabing5
 * @Desc    将sql语句拆分为sql部分和参数部分,因为我们需要使用java对象对数据库预处理对象进行注入。具体看下面例子
 
 * @example select * from users where id = {user.id} and name = {user.name}
 *            -> sql = select * from users where id = ? and name = ?
 *            -> paramResult = {user.id,user.name}
 **/
public class SqlSource {
    /**sql语句,待输入字段替换成?*/
    private String sql;
    /**待输入字段*/
    private List<String> param;
    /**select update insert delete*/
    private Integer executeType;

    public SqlSource(String sql){
        this.param = new ArrayList<>();
        this.sql = sqlInject(this.param,sql);
    }
    
    /**sql注入,对要注入的属性,必须使用{} 包裹,并替换为? 见例子 */
    private String sqlInject(List<String> paramResult, String sql){

        String labelPrefix = "{";

        String labelSuffix = "}";
        // 将语句的所有的{}全部都解析成?,并将{xxx}中的xxx提取出来。
        while (sql.indexOf(labelPrefix) > 0 && sql.indexOf(labelSuffix) > 0){
            String sqlParamName = sql.substring(sql.indexOf(labelPrefix),sql.indexOf(labelSuffix)+1);
            sql = sql.replace(sqlParamName,"?");
            paramResult.add(sqlParamName.replace("{","").replace("}",""));
        }
        return sql;
    }

    public String getSql() {
        return sql;
    }

    public void setSql(String sql) {
        this.sql = sql;
    }

    public List<String> getParam() {
        return param;
    }

    public void setParam(List<String> param) {
        this.param = param;
    }

    public Integer getExecuteType() {
        return executeType;
    }

    public void setExecuteType(Integer executeType) {
        this.executeType = executeType;
    }
}

MapperCore

/**
 * @Author  xiabing5
 * @Desc    mapper方法解析核心类,此处是将接口中的方法进行解析的关键类
 **/
public class MapperCore {
    
    // select方法
    private static final Integer SELECT_TYPE = 1;
    // update方法
    private static final Integer UPDATE_TYPE = 2;
    // delete方法
    private static final Integer DELETE_TYPE = 3;
    // insert方法
    private static final Integer INSERT_TYPE = 4;

    /**mapper文件解析类缓存,解析后都放在里面,避免重复对一个mapper文件进行解析*/
    private static Map<String,MethodDetails> cacheMethodDetails = new ConcurrentHashMap<>();

    /**
     * 全局配置
     * */
    private Config config;


    public MapperCore(Config config){
        this.config = config;
        load(config.getDaoSource());
    }

    /*
     * 加载,解析指定包名下的类
     * */
    private void load(String source){
        /**加载mapper包下的文件*/
        List<String> clazzNames = PackageUtil.getClazzName(source,true);
        try{
            for(String clazz: clazzNames){
                Class<?> nowClazz = java.lang.Class.forName(clazz);
                // 不是接口跳过,只能解析接口
                if(!nowClazz.isInterface()){
                    continue;
                }

                /**接口上没有@Dao跳过。我们对接口类上加了@Dao后,就代表要解析*/
                boolean skip = false;
                Annotation[] annotations = nowClazz.getDeclaredAnnotations();
                for(Annotation annotation:annotations){
                    // 该接口中的注释中是否有带@Dao
                    if(annotation instanceof Dao) {
                        skip = true;
                        break;
                    }
                }
                if(!skip){
                    continue;
                }
                // 调用反射接口,获取所有接口中的方法
                Method[] methods = nowClazz.getDeclaredMethods();
                for( Method method : methods){
                    // 解析方法详情
                    MethodDetails methodDetails = handleParameter(method);
                    // 解析@SELECT()等注解中的SQL内容
                    methodDetails.setSqlSource(handleAnnotation(method));
                    // 解析完成后放入缓存
                    cacheMethodDetails.put(generateStatementId(method),methodDetails);
                }
            }
        }catch (ClassNotFoundException e){
            throw new RuntimeException(" class load error,class is not exist");
        }
    }

    /**
     * 获得方法详情,从方法缓存池中获取,因为每个方法事先已经完成解析
     * */
    public MethodDetails getMethodDetails(Method method){
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId);
        }
        return new MethodDetails();
    }

    /**
     * 获得方法对应的sql语句
     * */
    public SqlSource getStatement(Method method){
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).getSqlSource();
        }
        throw new RuntimeException(method + " is not sql");
    }

    /**
     * 获得方法对应的参数名
     * */
    public List<String> getParameterName(Method method){
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).getParameterNames();
        }
        return new ArrayList<>();
    }

    /**
     *  获取方法返回类型
     * */
    public Class getReturnType(Method method){
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).getReturnType();
        }
        return null;
    }

    /**
     * 获得方法对应的参数类型
     * */
    public Class<?>[] getParameterType(Method method) {
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).getParameterTypes();
        }
        return new Class<?>[]{};
    }

    /**
     * 获得方法是SELECT UPDATE DELETE INSERT
     * */
    public Integer getMethodType(Method method){
       String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).getSqlSource().getExecuteType();
        }
        return null;
    }


    /**
     * 获得方法是否返回集合类型list
     * */
    public boolean getHasSet(Method method){
        String statementId = generateStatementId(method);
        if(cacheMethodDetails.containsKey(statementId)){
            return cacheMethodDetails.get(statementId).isHasSet();
        }
        return false;
    }

    /**
     * 解析方法内的注解
     * */
    private MethodDetails handleParameter(Method method){
        
        MethodDetails methodDetails = new MethodDetails();
        
        // 获取方法输入参数数量
        int parameterCount = method.getParameterCount();
        
        // 获取方法输入各参数类型
        Class<?>[] parameterTypes = method.getParameterTypes();
        
        // 获取方法输入各参数名称
        List<String> parameterNames = new ArrayList<>();

        Parameter[] params = method.getParameters();
        for(Parameter parameter:params){
            parameterNames.add(parameter.getName());
        }

        /*
         * 获得方法参数的注解值替代默认值,如果使用了@Param注解,则默认使用注解中的值作为参数名
         * */
        for(int i = 0; i < parameterCount; i++){
            parameterNames.set(i,getParamNameFromAnnotation(method,i,parameterNames.get(i)));
        }

        methodDetails.setParameterTypes(parameterTypes);
        methodDetails.setParameterNames(parameterNames);

        /** 获取方法返回类型*/
        Type methodReturnType = method.getGenericReturnType();
        Class<?> methodReturnClass = method.getReturnType();
        if(methodReturnType instanceof ParameterizedType){
            /** 返回是集合类 目前仅支持List  todo*/
            if(!List.class.equals(methodReturnClass)){
                throw new RuntimeException("now ibatis only support list");
            }
            Type type = ((ParameterizedType) methodReturnType).getActualTypeArguments()[0];
            /** 设置 返回集合中的对象类型 */
            methodDetails.setReturnType((Class<?>) type);   
            /** 标注该方法返回类型是集合类型 */
            methodDetails.setHasSet(true);
        }else {
            /** 不是集合类型,就直接设置返回类型 */
            methodDetails.setReturnType(methodReturnClass);
            /** 标注该方法返回类型不是集合类型 */
            methodDetails.setHasSet(false);
        }

        return methodDetails;
    }

    /**
     * 解析@select,@update等注解,获取SQL语句,并封装成对象
     * */
    private SqlSource handleAnnotation(Method method){
        SqlSource sqlSource = null;
        String sql = null;
        /** 获取方法上所有注解 */
        Annotation[]  annotations = method.getDeclaredAnnotations();
        for(Annotation annotation : annotations){
            /** 如果有@select注解 */
            if(Select.class.isInstance(annotation)){
                Select selectAnnotation = (Select)annotation;
                sql = selectAnnotation.value();
                /** 语句封装成sqlSource对象 */
                sqlSource = new SqlSource(sql);
                /** 设置执行语句类型 */
                sqlSource.setExecuteType(SELECT_TYPE);
                break;
            }else if(Update.class.isInstance(annotation)){
                Update updateAnnotation = (Update)annotation;
                sql = updateAnnotation.value();
                sqlSource = new SqlSource(sql);
                sqlSource.setExecuteType(UPDATE_TYPE);
                break;
            }else if(Delete.class.isInstance(annotation)){
                Delete deleteAnnotation = (Delete) annotation;
                sql = deleteAnnotation.value();
                sqlSource = new SqlSource(sql);
                sqlSource.setExecuteType(DELETE_TYPE);
                break;
            }else if(Insert.class.isInstance(annotation)){
                Insert insertAnnotation = (Insert) annotation;
                sql = insertAnnotation.value();
                sqlSource = new SqlSource(sql);
                sqlSource.setExecuteType(INSERT_TYPE);
                break;
            }
        }
        if(sqlSource == null){
            throw  new RuntimeException("method annotation not null");
        }
        return sqlSource;
    }

    /**
     * 获取@Param注解内容
     * */
    private String getParamNameFromAnnotation(Method method, int i, String paramName) {
        final Object[] paramAnnos = method.getParameterAnnotations()[i];
        for (Object paramAnno : paramAnnos) {
            if (paramAnno instanceof Param) {
                paramName = ((Param) paramAnno).value();
            }
        }
        return paramName;
    }

    /**
     * 生成唯一的statementId
     * */
    private static String generateStatementId(Method method){
        return method.getDeclaringClass().getName() + "." + method.getName();
    }
    
    /**
     * 每个mapper方法的封装类
     * */
    public static class MethodDetails{
        /**方法返回类型,若是集合,则代表集合的对象类,目前集合类仅支持返回List  */
        private Class<?> returnType;

        /**方法返回类型是否是集合*/
        private boolean HasSet;

        /**执行类型,SELECT,UPDATE,DELETE,INSERT*/
        private Integer executeType;

        /**方法输入参数类型集合*/
        private Class<?>[] parameterTypes;

        /**方法输入参数名集合*/
        private List<String> parameterNames;

        /**sql语句集合*/
        private SqlSource sqlSource;

        public Class<?> getReturnType() {
            return returnType;
        }

        public void setReturnType(Class<?> returnType) {
            this.returnType = returnType;
        }

        public boolean isHasSet() {
            return HasSet;
        }

        public void setHasSet(boolean hasSet) {
            HasSet = hasSet;
        }

        public Integer getExecuteType() {
            return executeType;
        }

        public void setExecuteType(Integer executeType) {
            this.executeType = executeType;
        }

        public Class<?>[] getParameterTypes() {
            return parameterTypes;
        }

        public void setParameterTypes(Class<?>[] parameterTypes) {
            this.parameterTypes = parameterTypes;
        }

        public List<String> getParameterNames() {
            return parameterNames;
        }

        public void setParameterNames(List<String> parameterNames) {
            this.parameterNames = parameterNames;
        }

        public SqlSource getSqlSource() {
            return sqlSource;
        }

        public void setSqlSource(SqlSource sqlSource) {
            this.sqlSource = sqlSource;
        }
    }
}

SQL语句已经解析完毕,后面开始执行SQL语句。在之前,我想给执行SQL语句时加点功能,如提供select缓存,事务功能,详情见代码块5,代码块6

代码块5

缓存类(com.simple.ibatis.cache)

Cache

/**
 * @author xiabing
 * @description: 自定义缓存接口,主要提供增删改查功能
 */
public interface Cache {

    /**放入缓存*/
    void putCache(String key,Object val);

    /**获取缓存*/
    Object getCache(String key);

    /**清空缓存*/
    void cleanCache();

    /**获取缓存健数量*/
    int getSize();

    /**移除key的缓存*/
    void removeCache(String key);
}

SimpleCache

/**
 * @author xiabing
 * @description: 缓存的简单实现
 */
public class SimpleCache implements Cache{
    
    /**用HashMap来实现缓存的增删改查*/
    private static Map<String,Object> map = new HashMap<>();


    @Override
    public void putCache(String key, Object val) {
        map.put(key,val);
    }

    @Override
    public Object getCache(String key) {
        return map.get(key);
    }

    @Override
    public void cleanCache() {
        map.clear();
    }

    @Override
    public int getSize() {
        return map.size();
    }

    @Override
    public void removeCache(String key) {
        map.remove(key);
    }
}

LruCache

/**
 * @author xiabing
 * @description: 缓存包装类没,使缓存具有LRU功能
 */
public class LruCache implements Cache{
    
    /**缓存大小*/
    private static Integer cacheSize = 2000;
    
    /**HashMap负载因子,可以自己指定*/
    private static Float loadFactory = 0.75F;
    
    /**真实缓存*/
    private Cache trueCache;
    
    /**使用LinkedHashMap来实现Lru功能*/
    private Map<String,Object> linkedCache;
    
    /**待移除的元素*/
    private static Map.Entry removeEntry;

    public LruCache(Cache trueCache){
        this(cacheSize,loadFactory,trueCache);
    }

    public LruCache(Integer cacheSize, Float loadFactory, Cache trueCache) {
        this.trueCache = trueCache;
        /**初始化LinkedHashMap,并重写removeEldestEntry() 方法,不了解的可以看我之前的博客*/
        this.linkedCache = new LinkedHashMap<String, Object>(cacheSize,loadFactory,true){
            @Override
            protected boolean removeEldestEntry(Map.Entry eldest) {
                // 当缓存容量超过设置的容量时,返回true,并记录待删除的对象
                if(getSize() >  cacheSize){
                    removeEntry = eldest;
                    return true;
                }
                return false;
            }
        };
    }

    // 放入缓存,当待删除的元素不为空时,就在真实缓存中删除该元素
    @Override
    public void putCache(String key, Object val) {
        // 真实缓存中放入
        this.trueCache.putCache(key,val);
        // linkedHashMap中放入,此处放入会调用重写的removeEldestEntry()方法
        this.linkedCache.put(key,val);
        // 当removeEldestEntry()执行后,如果有待删除元素,就开始执行删除操作
        if(removeEntry != null){
            removeCache((String)removeEntry.getKey());
            removeEntry = null;
        }
    }

    @Override
    public Object getCache(String key) {
        // 因为调用linkedHashMap中的get()方法会对链表进行再一次排序。见之前关于缓存的博客
        linkedCache.get(key);
        return trueCache.getCache(key);
    }

    @Override
    public void cleanCache() {
        trueCache.cleanCache();
        linkedCache.clear();
    }

    @Override
    public int getSize() {
        return trueCache.getSize();
    }

    @Override
    public void removeCache(String key) {
        trueCache.removeCache(key);
    }
}

代码块6

事务类(com.simple.ibatis.transaction)

Transaction

/**
 * @Author  xiabing
 * @Desc    封装事务功能
 **/
public interface Transaction {

    /**获取数据库连接*/
    Connection getConnection() throws SQLException;
    
    /**事务提交*/
    void commit() throws SQLException;
    
    /**事务回滚*/
    void rollback() throws SQLException;
    
    /**关闭连接*/
    void close() throws SQLException;
}

SimpleTransaction

/**
 * @Author  xiabing
 * @Desc    事务的简单实现,封装了数据库的事务功能
 **/
public class SimpleTransaction implements Transaction{

    private Connection connection; // 数据库连接
    private PoolDataSource dataSource; // 数据源
    private Integer level = Connection.TRANSACTION_REPEATABLE_READ; // 事务隔离级别
    private Boolean autoCommmit = true; // 是否自动提交

    public SimpleTransaction(PoolDataSource dataSource){
        this(dataSource,null,null);
    }

    public SimpleTransaction(PoolDataSource dataSource, Integer level, Boolean autoCommmit) {
        this.dataSource = dataSource;
        if(level != null){
            this.level = level;
        }
        if(autoCommmit != null){
            this.autoCommmit = autoCommmit;
        }
    }
    
    /**获取数据库连接,并设置该连接的事务隔离级别,是否自动提交*/
    @Override
    public Connection getConnection() throws SQLException{
        if(connection != null){
            return connection;
        }
        this.connection = dataSource.getConnection();

        this.connection.setAutoCommit(autoCommmit);

        this.connection.setTransactionIsolation(level);

        return this.connection;
    }

    @Override
    public void commit() throws SQLException{
        if(this.connection != null){
            this.connection.commit();
        }
    }

    @Override
    public void rollback() throws SQLException{
        if(this.connection != null){
            this.connection.rollback();
        }
    }
    
    /**关闭连接*/
    @Override
    public void close() throws SQLException{
        /**如果该连接设置自动连接为false,则在关闭前进行事务回滚一下*/
        if(autoCommmit != null && !autoCommmit){
            connection.rollback();
        }

        if(connection != null){
            dataSource.removeConnection(connection);
        }

        this.connection = null;
    }
}

TransactionFactory

/**
 * @Author  xiabing
 * @Desc    事务工厂,封装获取事务的功能
 **/
public class TransactionFactory {

    public static SimpleTransaction newTransaction(PoolDataSource poolDataSource, Integer level, Boolean autoCommmit){
        return new SimpleTransaction(poolDataSource,level,autoCommmit);
    }
    
}

代码真正执行SQL的地方,见代码块7

代码块7

语句执行类(com.simple.ibatis.execute)

Executor

/**
 * @Author  xiabing
 * @Desc    执行接口类,后续执行器需要实现此接口
 **/
public interface Executor {

    /**获取Mapper接口代理类,使用了动态代理技术,见实现类分析*/
    <T> T getMapper(Class<T> type);
    
    void commit() throws SQLException;

    void rollback() throws SQLException;

    void close() throws SQLException;
}

SimpleExecutor

/**
 * @Author  xiabing
 * @Desc    简单执行器类,对于Sql的执行和处理都依赖此类
 **/
public class SimpleExecutor implements Executor{

    /**全局配置类*/
    public Config config;
    
    /**mapper核心类,解析mapper接口类*/
    public MapperCore mapperCore;
    
    /**数据库源*/
    public PoolDataSource poolDataSource;
    
    /**事务*/
    public Transaction transaction;
    
    /**缓存*/
    public Cache cache;

    public SimpleExecutor(Config config,PoolDataSource poolDataSource){
        this(config,poolDataSource,false,false);
    }

    public SimpleExecutor(Config config,PoolDataSource poolDataSource,boolean openTransaction,boolean openCache){
        this.config = config;
        this.mapperCore = config.getMapperCore();
        this.poolDataSource = poolDataSource;
        if(openCache){
            // 设置缓存策略为LRU
            this.cache = new LruCache(new SimpleCache());
        }
        if(openTransaction){
            // 关闭自动提交
            this.transaction = TransactionFactory.newTransaction(poolDataSource,Connection.TRANSACTION_REPEATABLE_READ,false);
        }else {
            // 开启自动提交
            this.transaction = TransactionFactory.newTransaction(poolDataSource,null,null);
        }
    }
    
    /**
    * @Desc   获取mapper接口的代理类
              为什么要使用代理类,因为我们写mapper接口类只写了接口,却没有提供它的实现。于是
              使用动态代理机制对这些接口进行代理,在代理类中实现sql执行的方法。此处是参照
              mybatis的设计。
    **/
    @Override
    public <T> T getMapper(Class<T> type){
        /**MapperProxy代理类分析见详情8.1*/
        MapperProxy mapperProxy = new MapperProxy(type,this);
        return (T)mapperProxy.initialization();
    }
    
    /*select 语句执行流程分析**/
    public <E> List<E> select(Method method,Object[] args) throws Exception{
    
        /**PreparedStatementHandle 生成器见9.1 */
        PreparedStatementHandle preparedStatementHandle = new PreparedStatementHandle(mapperCore,transaction,method,args);
        PreparedStatement preparedStatement = preparedStatementHandle.generateStatement();
        
        /**执行查询接口,此处是真实调用sql语句的地方 */
        preparedStatement.executeQuery();
        ResultSet resultSet = preparedStatement.getResultSet();
        
        /**查询方法的返回参数值,若是void类型,就不用返回任务东西 */
        Class returnClass = mapperCore.getReturnType(method);
        if(returnClass == null || void.class.equals(returnClass)){
            return null;
        }else {
            /**ResultSetHandle 结果处理器见9.2 */
            ResultSetHandle resultSetHandle = new ResultSetHandle(returnClass,resultSet);
            return resultSetHandle.handle();
        }
    }
    
    /*update 语句执行流程分析,update,delete,insert都是调用的这个方法**/
    public int update(Method method,Object[] args)throws SQLException{
        PreparedStatementHandle preparedStatementHandle = null;
        PreparedStatement preparedStatement = null;
        Integer count = null;
        try{
            /**PreparedStatementHandle 生成器见9.1 */
            preparedStatementHandle = new PreparedStatementHandle(mapperCore,transaction,method,args);
            preparedStatement = preparedStatementHandle.generateStatement();
            
            /**返回受影响的行数 */
            count =  preparedStatement.executeUpdate();
        }finally {
            if(preparedStatement != null){
                preparedStatement.close();
            }
        }
        return count;
    }
    
    /**后续方法直接调用transaction相关方法*/
    @Override
    public void commit() throws SQLException{
        transaction.commit();
    }

    @Override
    public void rollback() throws SQLException{
        transaction.rollback();
    }

    @Override
    public void close() throws SQLException{
        transaction.close();
    }
}

因为我们mapper接口类只有接口,没有实现。但为什么可以调用这些接口方法?此处参照mybatis源码,自己通过动态代理技术实现接口的动态代理,在代理方法里写我们自己执行Sql的逻辑。代理类见代码块8

代码块8

mapper接口代理类(com.simple.ibatis.execute)

MapperProxy

/**
 * @Author  xiabing
 * @Desc    mapper接口代理类
 **/
public class MapperProxy<T> implements InvocationHandler{
    
    /**要代理的mapper接口类*/
    private Class<T> interfaces;
    
    /**具体的执行器,执行mapper中的一个方法相当于执行一条sql语句*/
    private SimpleExecutor executor;

    public MapperProxy(Class<T> interfaces,SimpleExecutor executor) {
        this.interfaces = interfaces;
        this.executor = executor;
    }
    
    /**反射方法,该方法自定义需要代理的逻辑。不了解可以去看下JAVA动态代理技术*/
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        Object result = null;
        /**object方法直接代理*/
        if(Object.class.equals(method.getDeclaringClass())){
            result = method.invoke(this,args);
        }else {
            // 获取方法是select还是非select方法
            Integer methodType = executor.mapperCore.getMethodType(method);

            if(methodType == null){
                throw new RuntimeException("method is normal sql method");
            }
            if(methodType == 1){
                /**调用执行器的select方法,返回集合类*/
                List<Object> list =  executor.select(method,args);
                result = list;
                /**查看接口的返回是List,还是单个对象*/
                if(!executor.mapperCore.getHasSet(method)){
                    if(list.size() == 0){
                        result = null;
                    }else {
                        /**单个对象就直接取第一个*/
                        result = list.get(0);
                    }
                }

            }else{
                /**返回受影响的行数*/
                Integer count = executor.update(method,args);
                result = count;
            }
        }
        return result;
    }
    
    /**构建代理类*/
    public T initialization(){
        return (T)Proxy.newProxyInstance(interfaces.getClassLoader(),new Class[] { interfaces },this);
    }
}

代码块7中执行SQL其实包含3个过程,生成SQL预处理语句,执行SQL,将SQL结果转为JAVA对象。详情见代码块9

代码块9

各生成器类(com.simple.ibatis.statement)

PreparedStatementHandle

/**
 * @Author  xiabing
 * @Desc    PreparedStatement生成器
 **/
public class PreparedStatementHandle {
    /**
     * 全局核心mapper解析类
     */
    private MapperCore mapperCore;

    /**
     * 待执行的方法
     */
    private Method method;

    private Transaction transaction;

    private Connection connection;

    /**
     * 方法输入参数
     */
    private Object[] args;

    public PreparedStatementHandle(MapperCore mapperCore, Transaction transaction,Method method, Object[] args)throws SQLException {
        this.mapperCore = mapperCore;
        this.method = method;
        this.transaction = transaction;
        this.args = args;
        connection = transaction.getConnection();
    }

    /**
     * @Author  xiabing5
     * @Desc    参数处理核心方法  todo
     **/
    public PreparedStatement generateStatement() throws SQLException{
        // 获取已经解析方法的sqlSource类,已经将待注入字段变为?,后续直接填充就好
        SqlSource sqlSource = mapperCore.getStatement(method);
        // 调用connection方法生成预处理语句
        PreparedStatement preparedStatement = connection.prepareStatement(sqlSource.getSql());
        
        // 获取方法输入参数的类型
        Class<?>[] clazzes = mapperCore.getParameterType(method);
        List<String> paramNames = mapperCore.getParameterName(method);
        List<String> params = sqlSource.getParam();
        // 详见typeInject方法,注入SQL参数
        preparedStatement = typeInject(preparedStatement,clazzes,paramNames,params,args);
        return preparedStatement;
    }

    /**
     * @Author  xiabing
     * @Desc    preparedStatement构建,将参数注入到SQL语句中
     * @Param  preparedStatement 待构建的preparedStatement
     * @Param  clazzes 该方法中参数类型数组
     * @Param  paramNames 该方法中参数名称列表,若有@Param注解,则为此注解的值,默认为类名首字母小写
     * @Param  params 待注入的参数名,如user.name或普通类型如name
     * @Param  args 真实参数值
     **/
    private PreparedStatement typeInject(PreparedStatement preparedStatement,Class<?>[] clazzes,List<String> paramNames,List<String> params,Object[] args)throws SQLException{

        for(int i = 0; i < paramNames.size(); i++){
            // 第i个参数名称
            String paramName = paramNames.get(i);
            // 第i个参数类型
            Class type = clazzes[i];
            if(String.class.equals(type)){
                // 原始SQL中需要注入的参数名中是否有此参数名称
                // example: select * from users where id = {id} and name = {name}  则{id}中的id和name就是如下的params里面内容
                int injectIndex = params.indexOf(paramName);
                /**此处是判断sql中是否有待注入的名称({name})和方法内输入对象名(name)相同,若相同,则直接注入*/
                if(injectIndex >= 0){
                    preparedStatement.setString(injectIndex+1,(String)args[i]);
                }
            }else if(Integer.class.equals(type) || int.class.equals(type)){
                int injectIndex = params.indexOf(paramName);
                if(injectIndex >= 0){
                    preparedStatement.setInt(injectIndex+1,(Integer)args[i]);
                }
            }else if(Float.class.equals(type) || float.class.equals(type)){
                int injectIndex = params.indexOf(paramName);
                if(injectIndex >= 0){
                    preparedStatement.setFloat(injectIndex+1,(Float)args[i]);
                }
            }else {
                /** 若待注入的是对象。example:
                 @SELECT(select * from users where id = {user.id} and name = {user.name})
                 List<User> getUser(User user)
                */
                // 对象工厂,获取对象包装实例,见代码块10
                ObjectWrapper objectWrapper = ObjectWrapperFactory.getInstance(args[i]);
                for(int j = 0; j < params.size(); j++){
                    /**此处是判断对象的属性 如user.name,需要先获取user对象,在调用getName方法获取值*/
                    if((params.get(j).indexOf(paramName)) >= 0 ){
                        try{
                            String paramProperties = params.get(j).substring(params.get(j).indexOf(".")+1);
                            Object object = objectWrapper.getVal(paramProperties);
                            Class childClazz = object.getClass();
                            if(String.class.equals(childClazz)){
                                preparedStatement.setString(j+1,(String)object);
                            }else if(Integer.class.equals(childClazz) || int.class.equals(childClazz)){
                                preparedStatement.setInt(j+1,(Integer)object);
                            }else if(Float.class.equals(childClazz) || float.class.equals(childClazz)){
                                preparedStatement.setFloat(j+1,(Float)object);
                            }else {
                                /**目前不支持对象中包含对象,如dept.user.name  todo*/
                                throw new RuntimeException("now not support object contain object");
                            }
                        }catch (Exception e){
                            throw new RuntimeException(e.getMessage());
                        }
                    }
                }
            }
        }
        return preparedStatement;
    }

    public void closeConnection() throws SQLException{
        transaction.close();
    }
}

ResultSetHandle

/**
 * @Author  xiabing
 * @Desc    ResultSet结果处理器,执行之后返回的json对象,转为java对象的过程
 **/
public class ResultSetHandle {

    /**转换的目标类型*/
    Class<?> typeReturn;

    /**待转换的ResultSet*/
    ResultSet resultSet;

    Boolean hasSet;

    public ResultSetHandle(Class<?> typeReturn,ResultSet resultSet){
        this.resultSet = resultSet;
        this.typeReturn = typeReturn;
    }

    /**
     * ResultSet处理方法,目前仅支持String,int,Float,不支持属性是集合类 todo
     * */
    public <T> List<T> handle() throws Exception{

        List<T> res = new ArrayList<>(resultSet.getRow());
        Object object = null;
        ObjectWrapper objectWrapper = null;
        Set<ClazzWrapper.FiledExpand> filedExpands = null;
        // 返回类型若不是基本数据类型
        if(!TypeUtil.isBaseType(typeReturn)){
            // 生成对象
            object = generateObj(typeReturn);
            // 将对象封装成包装类
            objectWrapper = ObjectWrapperFactory.getInstance(object);

            /** 获取对象属性 */
            filedExpands = objectWrapper.getMapperFiledExpands();
        }

        while (resultSet.next()){
            /** 若返回是基础数据类型,则直接将结果放入List中并返回 */
            if(String.class.equals(typeReturn)){
                String val = resultSet.getString(1);
                if(val != null){
                    res.add((T)val);
                }
            }else if(Integer.class.equals(typeReturn) || int.class.equals(typeReturn)){
                Integer val = resultSet.getInt(1);
                if(val != null){
                    res.add((T)val);
                }
            }else if(Float.class.equals(typeReturn) || float.class.equals(typeReturn)){
                Float val = resultSet.getFloat(1);
                if(val != null){
                    res.add((T)val);
                }
            }else { // 若返回的是对象(如User这种)
                // 查找对象属性,一个个注入
                for(ClazzWrapper.FiledExpand filedExpand:filedExpands){
                    // 如果对象属性是String类型,例如User.name是String类型
                    if(String.class.equals(filedExpand.getType())){
                        // resultSet中获取该属性
                        String val = resultSet.getString(filedExpand.getPropertiesName());
                        if(val != null){
                            // 填充到对象包装类中
                            objectWrapper.setVal(filedExpand.getPropertiesName(),val);
                        }
                    }else if(Integer.class.equals(filedExpand.getType()) || int.class.equals(filedExpand.getType())){
                        Integer val = resultSet.getInt(filedExpand.getPropertiesName());
                        if(val != null){
                            objectWrapper.setVal(filedExpand.getPropertiesName(),val);
                        }
                    }else if(Float.class.equals(filedExpand.getType()) || float.class.equals(filedExpand.getType())){
                        Float val = resultSet.getFloat(filedExpand.getPropertiesName());
                        if(val != null){
                            objectWrapper.setVal(filedExpand.getPropertiesName(),val);
                        }
                    }else {
                        continue;
                    }
                }
                // 后续将对象包装类转为真实对象,放入List中并返回。对象包装类见代码10
                res.add((T)objectWrapper.getRealObject());
            }
        }
        return res;
    }

    // 根据类型,根据反射,实例化对象
    private Object generateObj(Class<?> clazz) throws Exception{
        Constructor[] constructors  = clazz.getConstructors();
        Constructor usedConstructor = null;
        // 获取无参构造器,若对象没有无参构造方法,则失败
        for(Constructor constructor:constructors){
            if(constructor.getParameterCount() == 0){
                usedConstructor = constructor;
                break;
            }
        }
        if(constructors == null) {
            throw new RuntimeException(typeReturn + " is not empty constructor");
        }
        // 利用反射生成实例
        return usedConstructor.newInstance();
    }
}

代码块9中大量用到了对象增强,类增强。使用这些原因是 在编写框架代码时,因为并不知道每个属性名,不可能把每个get或者set方法写死,只能通过将对象封装成包装类。然后利用反射,要填充某个对象,就调用setVal(属性名),获取对象类似。详情见代码块10

代码块10

类增强(com.simple.ibatis.reflect)

ClazzWrapper

/**
 * @Author  xiabing
 * @Desc    clazz解析类
 **/
public class ClazzWrapper {
    /**
     * 待解析类
     * */
    private Class<?> clazz;
    /**
     * 该类存储的属性名
     * */
    private Set<String> propertiesSet = new HashSet<>();

    /**
     * 该类存储的属性名及属性类
     * */
    private Set<FiledExpand> filedExpandSet = new HashSet<>();

    /**
     * 该类存储的get方法。key为属性名,value为getxxx方法
     * */
    private Map<String,Method> getterMethodMap = new HashMap<>();
    /**
     * 该类存储的set方法。key为属性名,value为setxxx方法
     * */
    private Map<String,Method> setterMethodMap = new HashMap<>();
    /**
     * 缓存,避免对同一个类多次解析
     * */
    private static Map<String,ClazzWrapper> clazzWrapperMap = new ConcurrentHashMap<>();

    public ClazzWrapper(Class clazz){
        this.clazz = clazz;
        // 对类进行解析,如果已经解析了,则不用二次解析
        if(!clazzWrapperMap.containsKey(clazz.getName())){
            // 获取该类的所有属性
            Field[] fields = clazz.getDeclaredFields();
            for(Field field : fields){
                // 获取属性的名称,属性类型
                FiledExpand filedExpand = new FiledExpand(field.getName(),field.getType());
                filedExpandSet.add(filedExpand);
                propertiesSet.add(field.getName());
            }
            // 获取该类的方法
            Method[] methods = clazz.getMethods();
            for(Method method:methods){
                String name = method.getName();
                // 如果是get方法
                if(name.startsWith("get")){
                    name = name.substring(3);
                    if (name.length() == 1 || (name.length() > 1 && !Character.isUpperCase(name.charAt(1)))) {
                        name = name.substring(0, 1).toLowerCase(Locale.ENGLISH) + name.substring(1);
                    }
                    // 放入get方法缓存中
                    if(propertiesSet.contains(name)){
                        getterMethodMap.put(name,method);
                    }
                }else if(name.startsWith("set")){
                    name = name.substring(3);
                    if (name.length() == 1 || (name.length() > 1 && !Character.isUpperCase(name.charAt(1)))) {
                        name = name.substring(0, 1).toLowerCase(Locale.ENGLISH) + name.substring(1);
                    }
                    if(propertiesSet.contains(name)){
                        setterMethodMap.put(name,method);
                    }
                }
                else {
                    continue;
                }
            }
            // 放入缓存,代表该类已经完成解析
            clazzWrapperMap.put(clazz.getName(),this);
        }
    }
    
    // 查找该类是否包含指定属性的get方法
    public boolean hasGetter(String properties){
        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());

        return clazzWrapper.getterMethodMap.containsKey(properties);
    }

    // 查找该类是否包含指定属性的set方法
    public boolean hasSetter(String properties){
        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());

        return clazzWrapper.setterMethodMap.containsKey(properties);
    }

    // 获取该类指定属性的set方法
    public Method getSetterMethod(String properties){
        if(!hasSetter(properties)){
           throw new RuntimeException("properties " + properties + " is not set method") ;
        }
        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());
        return clazzWrapper.setterMethodMap.get(properties);
    }

    // 获取该类指定属性的get方法
    public Method getGetterMethod(String properties){
        if(!hasGetter(properties)){
            throw new RuntimeException("properties " + properties + " is not get method") ;
        }
        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());
        return clazzWrapper.getterMethodMap.get(properties);
    }

    // 获取该类所有属性
    public Set<String> getProperties(){

        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());

        return clazzWrapper.propertiesSet;
    }

    // 获取该类所有属性增强
    public Set<FiledExpand> getFiledExpandSet(){
        ClazzWrapper clazzWrapper = clazzWrapperMap.get(clazz.getName());

        return clazzWrapper.filedExpandSet;
    }

    /**
     *  属性增强类
     * */
    public static class FiledExpand{

        // 属性名称
        String propertiesName;

        // 属性类型
        Class type;

        public FiledExpand() {
        }

        public FiledExpand(String propertiesName, Class type) {
            this.propertiesName = propertiesName;
            this.type = type;
        }

        public String getPropertiesName() {
            return propertiesName;
        }

        public void setPropertiesName(String propertiesName) {
            this.propertiesName = propertiesName;
        }

        public Class getType() {
            return type;
        }

        public void setType(Class type) {
            this.type = type;
        }

        @Override
        public int hashCode() {
            return propertiesName.hashCode();
        }

        @Override
        public boolean equals(Object obj) {
            if(obj instanceof FiledExpand){
                return ((FiledExpand) obj).propertiesName.equals(propertiesName);
            }
            return false;
        }
    }
}

ObjectWrapper

/**
 * @Author  xiabing
 * @Desc    对象增强,封装了get,set方法
 **/
public class ObjectWrapper {
    // 真实对象
    private Object realObject;
    // 该对象的类的增强
    private ClazzWrapper clazzWrapper;

    public ObjectWrapper(Object realObject){
        this.realObject = realObject;
        this.clazzWrapper = new ClazzWrapper(realObject.getClass());
    }

    // 调用对象指定属性的get方法
    public Object getVal(String property) throws Exception{

        return clazzWrapper.getGetterMethod(property).invoke(realObject,null);
    }

    // 调用对象指定属性的set方法
    public void setVal(String property,Object value) throws Exception{

        clazzWrapper.getSetterMethod(property).invoke(realObject,value);
    }

    public Set<String> getProperties(){

        return clazzWrapper.getProperties();
    }

    public Set<ClazzWrapper.FiledExpand> getMapperFiledExpands(){

        return clazzWrapper.getFiledExpandSet();
    }

    public Object getRealObject(){

        return realObject;

    }
}

ObjectWrapperFactory

/**
 * @Author  xiabing5
 * @Desc    对象增强类。使用对象增强是因为我们不知道每个对象的getXXX方法,只能利用反射获取对象的属性和方法,然后
            统一提供getVal或者setVal来获取或者设置指定属性,这就是对象的增强的重要性
 **/
public class ObjectWrapperFactory {

    public static ObjectWrapper getInstance(Object o){
        return new ObjectWrapper(o);
    }

}

总结感悟

上述是我在写这个框架时的一些思路历程,代码演示请见我github。通过自己手写源码,对mybatis也有了个深刻的见解。当然有很多问题,后面也会进一步去完善,如果你对这个开源项目感兴趣,不妨star一下呀!

posted @ 2020-12-17 15:36  超人小冰  阅读(1377)  评论(1编辑  收藏  举报