如何开发一个ORM数据库框架

如何开发一个ORM框架

ORM(Object Relational Mapping)对象关系映射,ORM的数据库框架有hibernate,mybatis。我该如何开发一个类似这样的框架呢?

为什么会有这种疑问?

首先我想快速(注意快速一词)对数据库CURD发现,用hibernate要一堆配置、mybatis更多配置,于是我gitee上随便搜发现一堆,顺手拿一个推荐 300左右start的来用,用了几个小时看了看底层源码,连基本查询都有bug,selectOne竟然是全表扫描的list.get(0),无语了,多数据库兼容也不知道说了啥。 其实性能不性能无所谓,能快速CURD就行啦,但是你有bug我就受不了了,你又不是萌妹子,我不想深入了解你。((小声:)你的架构很糟糕)

于是 final-sql 诞生了:https://gitee.com/lingkang_top/final-sql
性能对比:mybatis、hibernate、final-sql性能对比https://gitee.com/lingkang_top/final-sql/wikis/nature%20%E6%80%A7%E8%83%BD%E5%AF%B9%E6%AF%94hibernate%E3%80%81mybatis

扯远了,回到标题,如何开发一个ORM,首先要有以下认知。

  • 连接池原理
  • java的jdbc接口基本调用
  • 实体对应SQL生成
  • 事务处理
  • 不同数据库兼容(可以参考方言处理)
    基于以上思路,开始开发。

1、连接池+jdbc基本调用

首先创建一个普通springboot项目<spring-boot.version>2.5.12</spring-boot.version>
连接池就自己去查资料吧,我们直接拿一个连接池来用 druid 引入Maven

 <!--使用阿里的连接池监控 http://druid.apache.org/-->
<dependency>
    <groupId>com.alibaba</groupId>
    <artifactId>druid-spring-boot-starter</artifactId>
    <version>1.2.8</version>
</dependency>
<dependency>
    <groupId>mysql</groupId>
    <artifactId>mysql-connector-java</artifactId>
</dependency>
<dependency>
    <groupId>org.springframework</groupId>
    <artifactId>spring-jdbc</artifactId>
</dependency>

简单配置一下

spring:
  datasource:
    type: com.alibaba.druid.pool.DruidDataSource #使用阿里数据源druid
    driver-class-name: com.mysql.cj.jdbc.Driver
    url: jdbc:mysql://127.0.0.1:3306/test?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=Asia/Shanghai
    username: root
    password: 123456

druid 已经帮我们在springboot中自动装配了DruidDataSourceAutoConfigure
直接使用

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;

/**
 * @author lingkang
 * Created by 2022/4/20
 */
@RestController
public class DemoController {

    @Autowired // DruidDataSourceAutoConfigure
    private DataSource dataSource;

    @GetMapping("demo")
    public Object demo() throws Exception {
        Connection connection = dataSource.getConnection();
        PreparedStatement preparedStatement = connection.prepareStatement("select * from user");
        ResultSet resultSet = preparedStatement.executeQuery();
        String res = "";
        while (resultSet.next()) {
            for (int i = 1; i < resultSet.getMetaData().getColumnCount(); i++) {
                res += resultSet.getMetaData().getColumnName(i) + "=" + resultSet.getObject(i) + ", ";
            }
            res += "\n";
        }
        System.out.println(res);

        // DataSource是 druid 的实现,不会真正释放连接,而是放回到 druid 中
        // 开发框架时一定要释放连接,否则会一直占用不放回连接池,导致数据库连接耗尽
        connection.close();
        return res;
    }
}

在这里插入图片描述
上面就实现了基本的jdbc调用,接下来就是实体映射了

2、实体映射

实体映射我们得定义表注解、列注解

@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface Table {
    String value() default "";// 用于声明表名
}
import java.lang.annotation.*;

/**
 * @author lingkang
 * Created by 2022/4/11
 */
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Column {
    String value() default "";// 列名称
}

再用上面的注解创建我们的user表映射

/**
 * @author lingkang
 * Created by 2022/4/20
 */
@Data
@Table("user")
public class MyUser {
    @Column
    private Integer id;
    @Column
    private Integer num;
    @Column
    private String username;
    @Column
    private String password;

    private Date createTime;
}

接下来将它生成对应的sql,这时有许多数据库sql协议,我们按熟悉的mysql来。

 // 生成执行的sql
    private <T> Map<String,Object> entityToSelectSql(T entity) throws Exception{
        String sql = "select ";

        List<String> where=new ArrayList<>();
        List<Object> param=new ArrayList<>();
        // 获取列名
        Field[] declaredFields = entity.getClass().getDeclaredFields();
        for (Field field : declaredFields) {
            Column column = field.getAnnotation(Column.class);
            if (column != null) {
                String value = column.value();
                if ("".equals(value)) {// 说明没有自定义列名称
                    // 直接拿对象属性当做列,,,,,,例如做一些驼峰命名处理
                    value = field.getName();
                }
                sql += value + ", ";
                // 参数条件
                field.setAccessible(true);
                Object o = field.get(entity);
                if (o!=null){
                    param.add(o);
                    where.add(value);
                }
            }
        }

        sql = sql.substring(0, sql.length() - 2);// 删除后面的 ", "

        // 获取表名
        Table tableAnn = entity.getClass().getAnnotation(Table.class);
        String table = tableAnn.value();
        sql+=" from "+table;

        // 整理条件
        sql+=" where 1=1 "; // 别问为什么 1=1 因为偷懒
        for (String w:where){
            sql+="and "+w+"=? ";
        }
        Map<String,Object> map=new HashMap<>();// 偷懒,直接用map
        map.put("sql",sql);
        map.put("param",param);
        return map;
    }

    // 处理结果
    private <T> List<T> handlerResult(ResultSet resultSet,T entity)throws Exception{
        List<T> list=new ArrayList<>();
        // 获取列名
        Class<?> clazz = entity.getClass();
        while (resultSet.next()){
            T en =(T) clazz.newInstance();// 实例化
            for (Field field : clazz.getDeclaredFields()) {
                Column column = field.getAnnotation(Column.class);
                if (column != null) {
                    String value = column.value();
                    if ("".equals(value)) {// 说明没有自定义列名称
                        // 直接拿对象属性当做列
                        value = field.getName();
                    }
                    Object object = resultSet.getObject(value, field.getType());
                    Field declaredField = en.getClass().getDeclaredField(value);
                    declaredField.setAccessible(true);
                    declaredField.set(en,object);// 设置值
                }
            }
            list.add((T) en);// 结果处理完
        }
        return list;
    }

调用查询

    @GetMapping("demo1")
    public Object demo1()throws Exception{
        MyUser user=new MyUser();
        user.setUsername("lingkang");// 条件
        // 生成sql
        Map<String, Object> map = entityToSelectSql(user);
        System.out.println(map.get("sql"));
        PreparedStatement statement = dataSource.getConnection().prepareStatement(map.get("sql").toString());
        // 添加条件
        List<Object> param = (List<Object>) map.get("param");
        for (int i=1;i<=param.size();i++){
            statement.setObject(i,param.get(i-1));
        }
        // 执行查询
        ResultSet resultSet = statement.executeQuery();
        // 处理结果
        List<MyUser> myUsers = handlerResult(resultSet, user);
        System.out.println(myUsers);// 打印输出
        return myUsers;
    }

结果如下
在这里插入图片描述

3、事务处理

事务处理需要保证每次获取的连接都是同一个,我们用线程变量来实现 ThreadLocal

// 统一获取连接入口
    private Connection getConnection() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null) {
            connection = dataSource.getConnection();
            threadLocal.set(connection);
        }
        return connection;
    }

    private void begin() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null) {
            connection=getConnection();
            connection.setAutoCommit(false);// 开始事务
            threadLocal.set(connection);
        }
    }

    private void commit() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null)
            throw new Exception("事务未开启!");
        connection.commit(); // 提交事务
        connection.close();// 几等关闭连接
    }

    private void rollback() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null)
            throw new Exception("事务未开启!");
        connection.rollback(); // 回滚事务
        connection.close();// 几等关闭连接
    }

调用

    @GetMapping("demo2")
    public Object demo2() throws Exception {
        try {
            begin();// 开启事务
            // getConnection() 获取连接
            getConnection().prepareStatement("update user set username='lingkang123' where id=6").executeUpdate();
            getConnection().prepareStatement("delete from user where id=6").executeUpdate();
            if (1 == 1)
                throw new Exception("手动抛出一个异常让事务回滚");
            commit();// 提交事务
        } catch (Exception e) {
            rollback();// 异常回滚事务
        }
        return "ok";
    }

执行以上请求后,数据库数据未变动!
若整合spring,可根据spring-tx的TransactionSynchronizationManager来做到@Transactional 控制!

4、方言

方言是用来处理不同数据库sql协议有所差别的,例如mysql中查询一行是limit 1 informix数据库却是 select first 1
这时我们就可以在上面返回SQL时添加一个处理,根据不同数据库方言对最终SQL进行调整替换SQL语句。

    private int dialect=1;// mysql
    private String selectOne(String sql){
        if (dialect==1){
            return sql+" limit 1";
        }else if (dialect==2){
            return sql.substring(7)+"select first 1 ";
        }
        throw new RuntimeException("解析数据库类型错误,请自行扩展方言");
    }

总结

以上就是如何开发一个ORM基本底层原理了,架构与封装应该根据你的开发经验来。觉得不错,别忘了给我点个start
https://gitee.com/lingkang_top/final-sql

完整代码

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import top.lingkang.finalsql.annotation.Column;
import top.lingkang.finalsql.annotation.Table;
import top.lingkang.yuecommunity.entity.MyUser;

import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author lingkang
 * Created by 2022/4/20
 */
@RestController
public class DemoController {

    @Autowired // DruidDataSourceAutoConfigure
    private DataSource dataSource;

    @GetMapping("demo")
    public Object demo() throws Exception {
        Connection connection = dataSource.getConnection();
        PreparedStatement preparedStatement = connection.prepareStatement("select * from user");
        ResultSet resultSet = preparedStatement.executeQuery();
        String res = "";
        while (resultSet.next()) {
            for (int i = 1; i < resultSet.getMetaData().getColumnCount(); i++) {
                res += resultSet.getMetaData().getColumnName(i) + "=" + resultSet.getObject(i) + ", ";
            }
            res += "\n";
        }
        System.out.println(res);

        // DataSource是 druid 的实现,不会真正释放连接,而是放回到 druid 中
        // 开发框架时一定要释放连接,否则会一直占用不放回连接池,导致数据库连接耗尽
        connection.close();
        return res;
    }

    @GetMapping("demo1")
    public Object demo1() throws Exception {
        MyUser user = new MyUser();
        user.setUsername("lingkang");// 条件
        // 生成sql
        Map<String, Object> map = entityToSelectSql(user);
        System.out.println(map.get("sql"));
        PreparedStatement statement = dataSource.getConnection().prepareStatement(map.get("sql").toString());
        // 添加条件
        List<Object> param = (List<Object>) map.get("param");
        for (int i = 1; i <= param.size(); i++) {
            statement.setObject(i, param.get(i - 1));
        }
        // 执行查询
        ResultSet resultSet = statement.executeQuery();
        // 处理结果
        List<MyUser> myUsers = handlerResult(resultSet, user);
        System.out.println(myUsers);// 打印输出
        return myUsers;
    }

    ThreadLocal<Connection> threadLocal = new ThreadLocal<>();

    @GetMapping("demo2")
    public Object demo2() throws Exception {
        try {
            begin();// 开启事务
            // getConnection() 获取连接
            getConnection().prepareStatement("update user set username='lingkang123' where id=6").executeUpdate();
            getConnection().prepareStatement("delete from user where id=6").executeUpdate();
            if (1 == 1)
                throw new Exception("手动抛出一个异常让事务回滚");
            commit();// 提交事务
        } catch (Exception e) {
            rollback();// 异常回滚事务
        }
        return "ok";
    }

    // 统一获取连接入口
    private Connection getConnection() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null) {
            connection = dataSource.getConnection();
            threadLocal.set(connection);
        }
        return connection;
    }

    private void begin() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null) {
            connection=getConnection();
            connection.setAutoCommit(false);// 开始事务
            threadLocal.set(connection);
        }
    }

    private void commit() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null)
            throw new Exception("事务未开启!");
        connection.commit(); // 提交事务
        connection.close();// 几等关闭连接
    }

    private void rollback() throws Exception {
        Connection connection = threadLocal.get();
        if (connection == null)
            throw new Exception("事务未开启!");
        connection.rollback(); // 回滚事务
        connection.close();// 几等关闭连接
    }

    private int dialect=1;// mysql
    private String selectOne(String sql){
        if (dialect==1){
            return sql+" limit 1";
        }else if (dialect==2){
            return sql.substring(7)+"select first 1 ";
        }
        throw new RuntimeException("解析数据库类型错误,请自行扩展方言");
    }

    // 生成执行的sql
    private <T> Map<String, Object> entityToSelectSql(T entity) throws Exception {
        String sql = "select ";

        List<String> where = new ArrayList<>();
        List<Object> param = new ArrayList<>();
        // 获取列名
        Field[] declaredFields = entity.getClass().getDeclaredFields();
        for (Field field : declaredFields) {
            Column column = field.getAnnotation(Column.class);
            if (column != null) {
                String value = column.value();
                if ("".equals(value)) {// 说明没有自定义列名称
                    // 直接拿对象属性当做列,,,,,,例如做一些驼峰命名处理
                    value = field.getName();
                }
                sql += value + ", ";
                // 参数条件
                field.setAccessible(true);
                Object o = field.get(entity);
                if (o != null) {
                    param.add(o);
                    where.add(value);
                }
            }
        }

        sql = sql.substring(0, sql.length() - 2);// 删除后面的 ", "

        // 获取表名
        Table tableAnn = entity.getClass().getAnnotation(Table.class);
        String table = tableAnn.value();
        sql += " from " + table;

        // 整理条件
        sql += " where 1=1 "; // 别问为什么 1=1 因为偷懒
        for (String w : where) {
            sql += "and " + w + "=? ";
        }
        Map<String, Object> map = new HashMap<>();// 偷懒,直接用map
        map.put("sql", sql);
        map.put("param", param);
        return map;
    }

    // 处理结果
    private <T> List<T> handlerResult(ResultSet resultSet, T entity) throws Exception {
        List<T> list = new ArrayList<>();
        // 获取列名
        Class<?> clazz = entity.getClass();
        while (resultSet.next()) {
            T en = (T) clazz.newInstance();// 实例化
            for (Field field : clazz.getDeclaredFields()) {
                Column column = field.getAnnotation(Column.class);
                if (column != null) {
                    String value = column.value();
                    if ("".equals(value)) {// 说明没有自定义列名称
                        // 直接拿对象属性当做列
                        value = field.getName();
                    }
                    Object object = resultSet.getObject(value, field.getType());
                    Field declaredField = en.getClass().getDeclaredField(value);
                    declaredField.setAccessible(true);
                    declaredField.set(en, object);// 设置值
                }
            }
            list.add((T) en);// 结果处理完
        }
        return list;
    }

}
CREATE TABLE `user` (
  `id` int(11) NOT NULL AUTO_INCREMENT,
  `num` int(11) DEFAULT NULL,
  `username` varchar(255) DEFAULT NULL,
  `password` varchar(255) DEFAULT NULL,
  `create_time` datetime DEFAULT NULL,
  PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=52 DEFAULT CHARSET=utf8;
posted @ 2022-09-16 00:08  凌康  阅读(111)  评论(0编辑  收藏  举报