如何开发一个ORM数据库框架
如何开发一个ORM框架
ORM(Object Relational Mapping)对象关系映射,ORM的数据库框架有hibernate,mybatis。我该如何开发一个类似这样的框架呢?
为什么会有这种疑问?
首先我想快速(注意快速一词)对数据库CURD发现,用hibernate要一堆配置、mybatis更多配置,于是我gitee上随便搜发现一堆,顺手拿一个推荐 300左右start的来用,用了几个小时看了看底层源码,连基本查询都有bug,selectOne竟然是全表扫描的list.get(0),无语了,多数据库兼容也不知道说了啥。 其实性能不性能无所谓,能快速CURD就行啦,但是你有bug我就受不了了,你又不是萌妹子,我不想深入了解你。((小声:)你的架构很糟糕)
于是 final-sql
诞生了:https://gitee.com/lingkang_top/final-sql
性能对比:mybatis、hibernate、final-sql性能对比https://gitee.com/lingkang_top/final-sql/wikis/nature%20%E6%80%A7%E8%83%BD%E5%AF%B9%E6%AF%94hibernate%E3%80%81mybatis
扯远了,回到标题,如何开发一个ORM,首先要有以下认知。
- 连接池原理
- java的jdbc接口基本调用
- 实体对应SQL生成
- 事务处理
- 不同数据库兼容(可以参考方言处理)
基于以上思路,开始开发。
1、连接池+jdbc基本调用
首先创建一个普通springboot项目<spring-boot.version>2.5.12</spring-boot.version>
连接池就自己去查资料吧,我们直接拿一个连接池来用 druid
引入Maven
<!--使用阿里的连接池监控 http://druid.apache.org/-->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid-spring-boot-starter</artifactId>
<version>1.2.8</version>
</dependency>
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
</dependency>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId>
</dependency>
简单配置一下
spring:
datasource:
type: com.alibaba.druid.pool.DruidDataSource #使用阿里数据源druid
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://127.0.0.1:3306/test?useUnicode=true&characterEncoding=utf-8&useSSL=false&serverTimezone=Asia/Shanghai
username: root
password: 123456
druid 已经帮我们在springboot中自动装配了DruidDataSourceAutoConfigure
直接使用
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
/**
* @author lingkang
* Created by 2022/4/20
*/
@RestController
public class DemoController {
@Autowired // DruidDataSourceAutoConfigure
private DataSource dataSource;
@GetMapping("demo")
public Object demo() throws Exception {
Connection connection = dataSource.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement("select * from user");
ResultSet resultSet = preparedStatement.executeQuery();
String res = "";
while (resultSet.next()) {
for (int i = 1; i < resultSet.getMetaData().getColumnCount(); i++) {
res += resultSet.getMetaData().getColumnName(i) + "=" + resultSet.getObject(i) + ", ";
}
res += "\n";
}
System.out.println(res);
// DataSource是 druid 的实现,不会真正释放连接,而是放回到 druid 中
// 开发框架时一定要释放连接,否则会一直占用不放回连接池,导致数据库连接耗尽
connection.close();
return res;
}
}
上面就实现了基本的jdbc调用,接下来就是实体映射了
2、实体映射
实体映射我们得定义表注解、列注解
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Inherited
public @interface Table {
String value() default "";// 用于声明表名
}
import java.lang.annotation.*;
/**
* @author lingkang
* Created by 2022/4/11
*/
@Target({ElementType.FIELD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Column {
String value() default "";// 列名称
}
再用上面的注解创建我们的user表映射
/**
* @author lingkang
* Created by 2022/4/20
*/
@Data
@Table("user")
public class MyUser {
@Column
private Integer id;
@Column
private Integer num;
@Column
private String username;
@Column
private String password;
private Date createTime;
}
接下来将它生成对应的sql,这时有许多数据库sql协议,我们按熟悉的mysql来。
// 生成执行的sql
private <T> Map<String,Object> entityToSelectSql(T entity) throws Exception{
String sql = "select ";
List<String> where=new ArrayList<>();
List<Object> param=new ArrayList<>();
// 获取列名
Field[] declaredFields = entity.getClass().getDeclaredFields();
for (Field field : declaredFields) {
Column column = field.getAnnotation(Column.class);
if (column != null) {
String value = column.value();
if ("".equals(value)) {// 说明没有自定义列名称
// 直接拿对象属性当做列,,,,,,例如做一些驼峰命名处理
value = field.getName();
}
sql += value + ", ";
// 参数条件
field.setAccessible(true);
Object o = field.get(entity);
if (o!=null){
param.add(o);
where.add(value);
}
}
}
sql = sql.substring(0, sql.length() - 2);// 删除后面的 ", "
// 获取表名
Table tableAnn = entity.getClass().getAnnotation(Table.class);
String table = tableAnn.value();
sql+=" from "+table;
// 整理条件
sql+=" where 1=1 "; // 别问为什么 1=1 因为偷懒
for (String w:where){
sql+="and "+w+"=? ";
}
Map<String,Object> map=new HashMap<>();// 偷懒,直接用map
map.put("sql",sql);
map.put("param",param);
return map;
}
// 处理结果
private <T> List<T> handlerResult(ResultSet resultSet,T entity)throws Exception{
List<T> list=new ArrayList<>();
// 获取列名
Class<?> clazz = entity.getClass();
while (resultSet.next()){
T en =(T) clazz.newInstance();// 实例化
for (Field field : clazz.getDeclaredFields()) {
Column column = field.getAnnotation(Column.class);
if (column != null) {
String value = column.value();
if ("".equals(value)) {// 说明没有自定义列名称
// 直接拿对象属性当做列
value = field.getName();
}
Object object = resultSet.getObject(value, field.getType());
Field declaredField = en.getClass().getDeclaredField(value);
declaredField.setAccessible(true);
declaredField.set(en,object);// 设置值
}
}
list.add((T) en);// 结果处理完
}
return list;
}
调用查询
@GetMapping("demo1")
public Object demo1()throws Exception{
MyUser user=new MyUser();
user.setUsername("lingkang");// 条件
// 生成sql
Map<String, Object> map = entityToSelectSql(user);
System.out.println(map.get("sql"));
PreparedStatement statement = dataSource.getConnection().prepareStatement(map.get("sql").toString());
// 添加条件
List<Object> param = (List<Object>) map.get("param");
for (int i=1;i<=param.size();i++){
statement.setObject(i,param.get(i-1));
}
// 执行查询
ResultSet resultSet = statement.executeQuery();
// 处理结果
List<MyUser> myUsers = handlerResult(resultSet, user);
System.out.println(myUsers);// 打印输出
return myUsers;
}
结果如下
3、事务处理
事务处理需要保证每次获取的连接都是同一个,我们用线程变量来实现 ThreadLocal
// 统一获取连接入口
private Connection getConnection() throws Exception {
Connection connection = threadLocal.get();
if (connection == null) {
connection = dataSource.getConnection();
threadLocal.set(connection);
}
return connection;
}
private void begin() throws Exception {
Connection connection = threadLocal.get();
if (connection == null) {
connection=getConnection();
connection.setAutoCommit(false);// 开始事务
threadLocal.set(connection);
}
}
private void commit() throws Exception {
Connection connection = threadLocal.get();
if (connection == null)
throw new Exception("事务未开启!");
connection.commit(); // 提交事务
connection.close();// 几等关闭连接
}
private void rollback() throws Exception {
Connection connection = threadLocal.get();
if (connection == null)
throw new Exception("事务未开启!");
connection.rollback(); // 回滚事务
connection.close();// 几等关闭连接
}
调用
@GetMapping("demo2")
public Object demo2() throws Exception {
try {
begin();// 开启事务
// getConnection() 获取连接
getConnection().prepareStatement("update user set username='lingkang123' where id=6").executeUpdate();
getConnection().prepareStatement("delete from user where id=6").executeUpdate();
if (1 == 1)
throw new Exception("手动抛出一个异常让事务回滚");
commit();// 提交事务
} catch (Exception e) {
rollback();// 异常回滚事务
}
return "ok";
}
执行以上请求后,数据库数据未变动!
若整合spring,可根据spring-tx的TransactionSynchronizationManager
来做到@Transactional
控制!
4、方言
方言是用来处理不同数据库sql协议有所差别的,例如mysql中查询一行是limit 1
informix数据库却是 select first 1
这时我们就可以在上面返回SQL时添加一个处理,根据不同数据库方言对最终SQL进行调整替换SQL语句。
private int dialect=1;// mysql
private String selectOne(String sql){
if (dialect==1){
return sql+" limit 1";
}else if (dialect==2){
return sql.substring(7)+"select first 1 ";
}
throw new RuntimeException("解析数据库类型错误,请自行扩展方言");
}
总结
以上就是如何开发一个ORM基本底层原理了,架构与封装应该根据你的开发经验来。觉得不错,别忘了给我点个start
https://gitee.com/lingkang_top/final-sql
完整代码
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import top.lingkang.finalsql.annotation.Column;
import top.lingkang.finalsql.annotation.Table;
import top.lingkang.yuecommunity.entity.MyUser;
import javax.sql.DataSource;
import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* @author lingkang
* Created by 2022/4/20
*/
@RestController
public class DemoController {
@Autowired // DruidDataSourceAutoConfigure
private DataSource dataSource;
@GetMapping("demo")
public Object demo() throws Exception {
Connection connection = dataSource.getConnection();
PreparedStatement preparedStatement = connection.prepareStatement("select * from user");
ResultSet resultSet = preparedStatement.executeQuery();
String res = "";
while (resultSet.next()) {
for (int i = 1; i < resultSet.getMetaData().getColumnCount(); i++) {
res += resultSet.getMetaData().getColumnName(i) + "=" + resultSet.getObject(i) + ", ";
}
res += "\n";
}
System.out.println(res);
// DataSource是 druid 的实现,不会真正释放连接,而是放回到 druid 中
// 开发框架时一定要释放连接,否则会一直占用不放回连接池,导致数据库连接耗尽
connection.close();
return res;
}
@GetMapping("demo1")
public Object demo1() throws Exception {
MyUser user = new MyUser();
user.setUsername("lingkang");// 条件
// 生成sql
Map<String, Object> map = entityToSelectSql(user);
System.out.println(map.get("sql"));
PreparedStatement statement = dataSource.getConnection().prepareStatement(map.get("sql").toString());
// 添加条件
List<Object> param = (List<Object>) map.get("param");
for (int i = 1; i <= param.size(); i++) {
statement.setObject(i, param.get(i - 1));
}
// 执行查询
ResultSet resultSet = statement.executeQuery();
// 处理结果
List<MyUser> myUsers = handlerResult(resultSet, user);
System.out.println(myUsers);// 打印输出
return myUsers;
}
ThreadLocal<Connection> threadLocal = new ThreadLocal<>();
@GetMapping("demo2")
public Object demo2() throws Exception {
try {
begin();// 开启事务
// getConnection() 获取连接
getConnection().prepareStatement("update user set username='lingkang123' where id=6").executeUpdate();
getConnection().prepareStatement("delete from user where id=6").executeUpdate();
if (1 == 1)
throw new Exception("手动抛出一个异常让事务回滚");
commit();// 提交事务
} catch (Exception e) {
rollback();// 异常回滚事务
}
return "ok";
}
// 统一获取连接入口
private Connection getConnection() throws Exception {
Connection connection = threadLocal.get();
if (connection == null) {
connection = dataSource.getConnection();
threadLocal.set(connection);
}
return connection;
}
private void begin() throws Exception {
Connection connection = threadLocal.get();
if (connection == null) {
connection=getConnection();
connection.setAutoCommit(false);// 开始事务
threadLocal.set(connection);
}
}
private void commit() throws Exception {
Connection connection = threadLocal.get();
if (connection == null)
throw new Exception("事务未开启!");
connection.commit(); // 提交事务
connection.close();// 几等关闭连接
}
private void rollback() throws Exception {
Connection connection = threadLocal.get();
if (connection == null)
throw new Exception("事务未开启!");
connection.rollback(); // 回滚事务
connection.close();// 几等关闭连接
}
private int dialect=1;// mysql
private String selectOne(String sql){
if (dialect==1){
return sql+" limit 1";
}else if (dialect==2){
return sql.substring(7)+"select first 1 ";
}
throw new RuntimeException("解析数据库类型错误,请自行扩展方言");
}
// 生成执行的sql
private <T> Map<String, Object> entityToSelectSql(T entity) throws Exception {
String sql = "select ";
List<String> where = new ArrayList<>();
List<Object> param = new ArrayList<>();
// 获取列名
Field[] declaredFields = entity.getClass().getDeclaredFields();
for (Field field : declaredFields) {
Column column = field.getAnnotation(Column.class);
if (column != null) {
String value = column.value();
if ("".equals(value)) {// 说明没有自定义列名称
// 直接拿对象属性当做列,,,,,,例如做一些驼峰命名处理
value = field.getName();
}
sql += value + ", ";
// 参数条件
field.setAccessible(true);
Object o = field.get(entity);
if (o != null) {
param.add(o);
where.add(value);
}
}
}
sql = sql.substring(0, sql.length() - 2);// 删除后面的 ", "
// 获取表名
Table tableAnn = entity.getClass().getAnnotation(Table.class);
String table = tableAnn.value();
sql += " from " + table;
// 整理条件
sql += " where 1=1 "; // 别问为什么 1=1 因为偷懒
for (String w : where) {
sql += "and " + w + "=? ";
}
Map<String, Object> map = new HashMap<>();// 偷懒,直接用map
map.put("sql", sql);
map.put("param", param);
return map;
}
// 处理结果
private <T> List<T> handlerResult(ResultSet resultSet, T entity) throws Exception {
List<T> list = new ArrayList<>();
// 获取列名
Class<?> clazz = entity.getClass();
while (resultSet.next()) {
T en = (T) clazz.newInstance();// 实例化
for (Field field : clazz.getDeclaredFields()) {
Column column = field.getAnnotation(Column.class);
if (column != null) {
String value = column.value();
if ("".equals(value)) {// 说明没有自定义列名称
// 直接拿对象属性当做列
value = field.getName();
}
Object object = resultSet.getObject(value, field.getType());
Field declaredField = en.getClass().getDeclaredField(value);
declaredField.setAccessible(true);
declaredField.set(en, object);// 设置值
}
}
list.add((T) en);// 结果处理完
}
return list;
}
}
CREATE TABLE `user` (
`id` int(11) NOT NULL AUTO_INCREMENT,
`num` int(11) DEFAULT NULL,
`username` varchar(255) DEFAULT NULL,
`password` varchar(255) DEFAULT NULL,
`create_time` datetime DEFAULT NULL,
PRIMARY KEY (`id`)
) ENGINE=InnoDB AUTO_INCREMENT=52 DEFAULT CHARSET=utf8;