基于 @SelectProvider 注解实现无侵入的通用Dao
基于 @SelectProvider 注解实现无侵入的通用Dao
项目框架
基于 SpringBoot 2.x 和 mybatis-spring-boot-starter
代码设计
通用Dao
public interface BaseDao<E,I> {
@SelectProvider(type = BaseSqlProvider.class,method = "getById")
E getById(I id);
@SelectProvider(type = BaseSqlProvider.class,method = "listByEntity")
List<E> listByEntity(E e);
@SelectProvider(type = BaseSqlProvider.class,method = "getByEntity")
E getByEntity(E e);
@SelectProvider(type = BaseSqlProvider.class,method = "listByLambdaQuery")
List<E> listByLambdaQuery(GetterFunction<E,?> lambda, Object val);
@SelectProvider(type = BaseSqlProvider.class,method = "getByLambdaQuery")
List<E> getByLambdaQuery(GetterFunction<E,?> lambda, Object val);
@SelectProvider(type = BaseSqlProvider.class,method = "listByIds")
List<E> listByIds(Collection<I> collection);
@InsertProvider(type = BaseSqlProvider.class,method = "insert")
@Options(keyProperty="id",useGeneratedKeys=true)
int insert(E e);
@InsertProvider(type = BaseSqlProvider.class,method = "insertBatch")
@Options(keyProperty="id",useGeneratedKeys=true)
int insertBatch(Collection<E> list);
@UpdateProvider(type = BaseSqlProvider.class,method = "update")
int update(E e);
@UpdateProvider(type = BaseSqlProvider.class,method = "updateBatch")
int updateBatch(Collection<E> list);
@DeleteProvider(type = BaseSqlProvider.class,method = "deleteById")
int deleteById(I id);
@DeleteProvider(type = BaseSqlProvider.class,method = "deleteByEntity")
int deleteByEntity(E e);
@DeleteProvider(type = BaseSqlProvider.class,method = "deleteByIds")
int deleteByIds(Collection<I> list);
@SelectProvider(type = BaseSqlProvider.class,method = "countAll")
int countAll();
@SelectProvider(type = BaseSqlProvider.class,method = "countByEntity")
int countByEntity(E e);
}
通用SQL Provider
//用于缓存和返回通用SQL语句
public class BaseSqlProvider {
private static final Map<Integer,String> sqlCache = new ConcurrentHashMap<>();
public String getById(ProviderContext context) {
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.getById(context);
sqlCache.put(key,value);
}
return value;
}
public String getByEntity(Object object,ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.getByEntity(object);
sqlCache.put(key,value);
}
return value;
}
public String listByIds(Collection collection, ProviderContext context) throws Exception {
if (collection==null || collection.size()==0){
throw new Exception("id list can not be empty!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.listByIds(context);
sqlCache.put(key,value);
}
return value;
}
public String listByEntity(Object object,ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.listByEntity(object);
sqlCache.put(key,value);
}
return value;
}
public String listByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
Object val = params.get("val");
if (val==null){
throw new Exception("value can not be null!");
}
GetterFunction lambda = (GetterFunction)params.get("lambda");
int key = context.hashCode();
String fieldName = lambda.getFieldName(lambda);
String value = sqlCache.get(key+fieldName);
if (value==null){
value = BaseSqlBuilder.listByField(fieldName,context);
sqlCache.put(key+fieldName,value);
}
return value;
}
public String getByLambdaQuery(Map<String,Object> params,ProviderContext context) throws Exception {
Object val = params.get("val");
if (val==null){
throw new Exception("value can not be null!");
}
GetterFunction lambda = (GetterFunction)params.get("lambda");
int key = context.hashCode();
String fieldName = lambda.getFieldName(lambda);
String value = sqlCache.get(key+fieldName);
if (value==null){
value = BaseSqlBuilder.getByField(fieldName,context);
sqlCache.put(key+fieldName,value);
}
return value;
}
public String insert(Object object, ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.insert(object);
sqlCache.put(key,value);
}
return value;
}
public String insertBatch(Collection collection, ProviderContext context) throws Exception {
if (collection==null || collection.size()==0){
throw new Exception("entity list can not be empty!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.insertBatch(context);
sqlCache.put(key,value);
}
return value;
}
public String update(Object object, ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.update(object);
sqlCache.put(key,value);
}
return value;
}
public String updateBatch(Collection collection, ProviderContext context) throws Exception {
if (collection==null || collection.size()==0){
throw new Exception("entity list can not be empty!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.updateBatch(context);
sqlCache.put(key,value);
}
return value;
}
public String deleteById(ProviderContext context) {
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.deleteById(context);
sqlCache.put(key,value);
}
return value;
}
public String deleteByEntity(Object object,ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.deleteByEntity(object);
sqlCache.put(key,value);
}
return value;
}
public String deleteByIds(Collection collection, ProviderContext context) throws Exception {
if (collection==null || collection.size()==0){
throw new Exception("id list can not be empty!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.deleteByIds(context);
sqlCache.put(key,value);
}
return value;
}
public String countAll(ProviderContext context) {
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.countAll(context);
sqlCache.put(key,value);
}
return value;
}
public String countByEntity(Object object,ProviderContext context) throws Exception {
if (object==null){
throw new Exception("entity can not be null!");
}
int key = context.hashCode();
String value = sqlCache.get(key);
if (value==null){
value = BaseSqlBuilder.countByEntity(object);
sqlCache.put(key,value);
}
return value;
}
}
通用SQL构建类
//生成通用SQL语句
public class BaseSqlBuilder {
public static String getById(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
}
public static String listByEntity(Object object) {
Class eClass = object.getClass();
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> SELECT ");
sql.append(String.join(",",columns));
sql.append(" FROM ").append(tableName);
sql.append(" <where>");
whereByEntity(fields,columns,sql);
sql.append("</where></script>");
return sql.toString();
}
public static String getByEntity(Object object) {
return listByEntity(object)+" LIMIT 1";
}
public static String listByIds(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> SELECT ");
sql.append(String.join(",",columns));
sql.append(" FROM ").append(tableName);
sql.append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
sql.append("<foreach item=\"item\" collection=\"list\" separator=\",\" open=\"(\" close=\")\" index=\"index\">");
sql.append("#{item}</foreach></script>");
return sql.toString();
}
public static String listByField(String fieldName, ProviderContext context) throws Exception {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
if (!fields.contains(fieldName)) {
throw new Exception("not exist column '"+fieldName+"'");
}
List<String> columns = TableEntityMetaData.tableColumns(fields);
return "SELECT "+String.join(",",columns)+" FROM "+tableName+" WHERE "+TableEntityMetaData.toLowerCase(fieldName)+" = #{val}";
}
public static String getByField(String fieldName, ProviderContext context) throws Exception {
return listByField(fieldName,context)+" LIMIT 1";
}
public static String insert(Object object) {
Class eClass = object.getClass();
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder();
sql.append("<script> INSERT INTO ").append(tableName);
sql.append(" <trim prefix=\"(\" suffix=\")\" suffixOverrides=\",\">");
for (int i = 0; i < fields.size(); i++) {
sql.append("<if test=\"").append(fields.get(i)).append(" != null\">");
sql.append(columns.get(i)).append(",").append("</if>");
}
sql.append("</trim><trim prefix=\"values (\" suffix=\")\" suffixOverrides=\",\">");
for (int i = 0; i < fields.size(); i++) {
sql.append("<if test=\"").append(fields.get(i)).append(" != null\">");
sql.append("#{").append(fields.get(i)).append("},").append("</if>");
}
sql.append("</trim></script>");
return sql.toString();
}
public static String insertBatch(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder();
sql.append("<script> INSERT INTO ").append(tableName);
sql.append("(").append(String.join(", ",columns)).append(") values ");
sql.append("<foreach item=\"item\" collection=\"list\" separator=\",\" open=\"\" close=\"\" index=\"index\"> (");
for (int i = 0; i < fields.size(); i++) {
sql.append("#{item.").append(fields.get(i)).append("}");
if (i<fields.size()-1){
sql.append(", ");
}
}
sql.append(")</foreach></script>");
return sql.toString();
}
public static String update(Object object) {
Class eClass = object.getClass();
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> UPDATE ");
sql.append(tableName).append(" <set>");
for (int i = 1; i < fields.size(); i++) {
sql.append("<if test=\"").append(fields.get(i)).append(" != null\">");
sql.append(columns.get(i)).append(" = #{").append(fields.get(i)).append("},</if>");
}
sql.append("</set> WHERE ").append(TableEntityMetaData.getIdColumn(eClass));
sql.append(" = #{").append(TableEntityMetaData.getIdField(eClass)).append("} </script>");
return sql.toString();
}
public static String updateBatch(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> UPDATE ");
sql.append(tableName).append(" <trim prefix=\"set\" suffixOverrides=\",\">");
for (int i = 1; i < fields.size(); i++) {
sql.append("<trim prefix=\"").append(columns.get(i)).append(" = case\" suffix=\"end,\">");
sql.append("<foreach collection=\"list\" item=\"item\" index=\"index\">");
sql.append("when ").append(TableEntityMetaData.getIdColumn(eClass));
sql.append(" = #{item.").append(TableEntityMetaData.getIdField(eClass)).append("} then #{item.").append(fields.get(i)).append("}");
sql.append("</foreach></trim>");
}
sql.append("</trim> WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
sql.append("<foreach collection=\"list\" index=\"index\" item=\"item\" separator=\",\" open=\"(\" close=\")\">");
sql.append("#{item.").append(TableEntityMetaData.getIdField(eClass)).append("} </foreach></script>");
return sql.toString();
}
public static String deleteById(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
return "DELETE FROM "+tableName+" WHERE "+TableEntityMetaData.getIdColumn(eClass)+" = #{id}";
}
public static String deleteByEntity(Object object) {
Class eClass = object.getClass();
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
sql.append(tableName).append(" <where> ");
whereByEntity(fields,columns,sql);
sql.append("</where></script>");
return sql.toString();
}
public static String deleteByIds(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
StringBuilder sql = new StringBuilder("<script> DELETE FROM ");
sql.append(tableName).append(" WHERE ").append(TableEntityMetaData.getIdColumn(eClass)).append(" IN ");
sql.append("<foreach item=\"item\" collection=\"list\" separator=\",\" open=\"(\" close=\")\" index=\"index\">");
sql.append("#{item}</foreach></script>");
return sql.toString();
}
public static String countAll(ProviderContext context) {
Class eClass = TableEntityMetaData.getEntityType(context);
String tableName = TableEntityMetaData.tableName(eClass);
return "SELECT COUNT(*) FROM "+tableName;
}
public static String countByEntity(Object object) {
Class eClass = object.getClass();
String tableName = TableEntityMetaData.tableName(eClass);
List<String> fields = TableEntityMetaData.entityFields(eClass);
List<String> columns = TableEntityMetaData.tableColumns(fields);
StringBuilder sql = new StringBuilder("<script> SELECT COUNT(*) FROM ");
sql.append(tableName).append(" <where> ");
whereByEntity(fields,columns,sql);
sql.append("</where></script>");
return sql.toString();
}
private static void whereByEntity(List<String> fields,List<String> columns,StringBuilder sql){
for (int i = 0; i < fields.size(); i++) {
sql.append("<if test=\"").append(fields.get(i)).append(" != null\">");
sql.append("and ").append(columns.get(i));
sql.append(" = #{").append(fields.get(i)).append("}</if>");
}
}
}
表实体元数据工具类
//通过ProviderContext和entity实体对象获取表和实体元数据信息
//通过实体类型获取表名和列名,但数据库和实体必须遵循下划线转驼峰规则
//即表列名必须全小写,多单词以下划线分割,实体属性必须为驼峰规则
public class TableEntityMetaData {
public static Class getEntityType(ProviderContext context) {
Class mClass = context.getMapperType();
return (Class) ((ParameterizedType) (mClass.getGenericInterfaces()[0])).getActualTypeArguments()[0];
}
public static String getIdColumn(Class eClass){
return "id";
}
public static String getIdField(Class eClass){
return "id";
}
public static String tableName(Class eClass) {
String entityName = eClass.getSimpleName();
return toLowerCase(entityName);
}
public static List<String> entityFields(Class eClass) {
Field[] fields = eClass.getDeclaredFields();
List<String> entityFields = new ArrayList<>(fields.length);
for (int i = 0; i < fields.length; i++) {
String name = fields[i].getName();
if (name.equals(getIdField(eClass))){
entityFields.add(0,name);
}else {
entityFields.add(name);
}
}
return entityFields;
}
public static List<String> tableColumns(List<String> entityFields) {
List<String> tableColumns =new ArrayList<>(entityFields.size());
for (String field : entityFields) {
tableColumns.add(toLowerCase(field));
}
return tableColumns;
}
public static String toLowerCase(String camelStr) {
String lowerCase = camelStr.replaceAll("[A-Z]", "_$0").toLowerCase();
if (lowerCase.startsWith("_")){
lowerCase = lowerCase.substring(1);
}
return lowerCase;
}
}
lambda query function 接口
@FunctionalInterface
public interface GetterFunction<T,R> extends Serializable,Function<T,R> {
default String getFieldName(GetterFunction<T,?> func) {
try {
Method method = func.getClass().getDeclaredMethod("writeReplace");
method.setAccessible(Boolean.TRUE);
SerializedLambda serializedLambda = (SerializedLambda) method.invoke(func);
String getter = serializedLambda.getImplMethodName();
String get = "get";
if (getter.startsWith("is")) {
get = "is";
}
String fieldName = Introspector.decapitalize(getter.replace(get, ""));
return fieldName;
} catch (ReflectiveOperationException e) {
throw new RuntimeException(e);
}
}
}
实体类
@Data//lombok
public class User {
/**
* 主键,自增
*/
private Integer id;
private String username;
private String password;
/**
* 记录生成时间,默认当前时间
*/
private Date gmtCreate;
/**
* 记录修改时间,默认当前时间
*/
private Date gmtModified;
}
具体Dao
public interface UserDao extends BaseDao<User,Integer> {
}
yml mybatis配置
一定要开启下划线转驼峰设置
mybatis:
mapper-locations: classpath*:mapper/**/*.xml
configuration:
map-underscore-to-camel-case: true
使用
- 自己写的SQL可以放在resources/mapper路径里的Mapper.xml中
- 对应dao方法则放在具体的dao中
不积跬步无以至千里