sfunction
/* * Copyright (c) 2011-2022, baomidou (jobob@qq.com). * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.baomidou.mybatisplus.advance.injector; import com.baomidou.mybatisplus.core.conditions.AbstractJoinWrapper; import com.baomidou.mybatisplus.core.conditions.AbstractWrapper; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.core.enums.SqlMethod; import com.baomidou.mybatisplus.core.toolkit.*; import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache; import com.baomidou.mybatisplus.core.toolkit.support.SFunction; import com.baomidou.mybatisplus.extension.conditions.query.BasicJoinQueryWrapper; import com.baomidou.mybatisplus.extension.toolkit.SqlHelper; import lombok.AllArgsConstructor; import lombok.Data; import org.apache.ibatis.session.SqlSession; import org.mybatis.spring.SqlSessionUtils; import java.lang.invoke.CallSite; import java.lang.invoke.LambdaMetafactory; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; import java.lang.reflect.Field; import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 通过此类去给wrapper添加参数和执行sql返回结果 * * @author wanglei * @since 2022-03-14 */ public class FuntionTools { /** * 可序列化 */ private static final int FLAG_SERIALIZABLE = 1; private static Map<String, SFunction> functionMap = new HashMap<>(); /** * 把动作 * * @param actions 动作 * @param operator 操作符 比如like * @param po po对象 * @param property 属性名 * @param value 值 */ public static void addAction(List<Action> actions, String operator, Object po, String property, Object value) { actions.add(new Action(po.getClass(), property, operator, value, null, null, null)); } /** * 添加动作 * * @param actions actions * @param operator 操作符 比如like * @param po po对象 * @param property 属性名 * @param minValue 最小值 * @param maxValue 最大值 */ public static void addAction(List<Action> actions, String operator, Object po, String property, Object minValue, Object maxValue) { actions.add(new Action(po.getClass(), property, operator, null, minValue, maxValue, null)); } /** * 构造一个查询条件 - 单表查询 * * @param actions 动作集合 * @param po po * @return QueryWrapper */ public static LambdaQueryWrapper buildQueryWrapper(List<Action> actions, Object po) { LambdaQueryWrapper<?> queryWrapper = new LambdaQueryWrapper(); for (Action action : actions) { // 普通的查询只保留他自己的where条件和select,就算他调用了join也忽略掉 if (action.getModelClass().equals(po.getClass()) && !"select".equals(action.getAction())) { SFunction column = getSFunction(action.getModelClass(), action.getProperty()); buildWhere(queryWrapper, action, column); } else if ("select".equals(action.getAction()) && action.getModelClass().equals(po.getClass())) { SFunction[] columns = new SFunction[action.getProperties().length]; int i = 0; for (String property : action.getProperties()) { columns[i++] = getSFunction(action.getModelClass(), property); } queryWrapper.select(columns); } } return queryWrapper; } /** * 构造where条件 * * @param queryWrapper wapper * @param action 动作 * @param column 列名 */ public static void buildWhere(AbstractWrapper queryWrapper, Action action, Object column) { switch (action.getAction()) { case OperatorConstant.EQ: queryWrapper.eq(column, action.getValue()); break; case OperatorConstant.LT: queryWrapper.lt(column, action.getValue()); break; case OperatorConstant.GT: queryWrapper.gt(column, action.getValue()); break; case OperatorConstant.LE: queryWrapper.le(column, action.getValue()); break; case OperatorConstant.GE: queryWrapper.ge(column, action.getValue()); break; case OperatorConstant.NE: queryWrapper.ne(column, action.getValue()); break; case OperatorConstant.LIKE: queryWrapper.like(column, action.getValue()); break; case OperatorConstant.LIKE_LEFT: queryWrapper.likeLeft(column, action.getValue()); break; case OperatorConstant.LIKE_RIGHT: queryWrapper.likeRight(column, action.getValue()); break; case OperatorConstant.NOT_LIKE: queryWrapper.notLike(column, action.getValue()); break; case OperatorConstant.IS_NULL: queryWrapper.isNull(column); break; case OperatorConstant.NOT_NULL: queryWrapper.isNotNull(column); break; case OperatorConstant.IN: queryWrapper.in(column, (Collection) (action.getValue())); break; case OperatorConstant.NOT_IN: queryWrapper.notIn(column, (Collection) (action.getValue())); break; case OperatorConstant.ORDER_BY_ASC: queryWrapper.orderByAsc(column); break; case OperatorConstant.ORDER_BY_DESC: queryWrapper.orderByDesc(column); break; case OperatorConstant.BETWEEN: queryWrapper.between(column, action.getMin(), action.getMax()); break; case OperatorConstant.NOT_BETWEEN: queryWrapper.notBetween(column, action.getMin(), action.getMax()); break; } } /** * 构造一个查询条件 - 多表查询 * * @param actions 动作集合 * @param po po * @return */ public static BasicJoinQueryWrapper buildJoinWrapper(List<Action> actions, Object po) { BasicJoinQueryWrapper queryWrapper = new BasicJoinQueryWrapper(po.getClass()); // 先搞定join,不然添加条件和select会报错 for (Action action : actions) { if (action.getAction().contains(OperatorConstant.JOIN)) { switch (action.getAction()) { case OperatorConstant.JOIN: queryWrapper.innerJoin(action.getModelClass()); break; case OperatorConstant.LEFT_JOIN: queryWrapper.leftJoin(action.getModelClass()); break; } } } for (Action action : actions) { // 不是select和join动作的话,就是where条件 if (!OperatorConstant.SELECT.equals(action.getAction()) && !action.getAction().contains(OperatorConstant.JOIN)) { BasicJoinQueryWrapper.ModelProperty column = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), action.getProperty()); buildWhere(queryWrapper, action, column); // select处理 } else if (OperatorConstant.SELECT.equals(action.getAction())) { BasicJoinQueryWrapper.ModelProperty[] columns = new BasicJoinQueryWrapper.ModelProperty[action.getProperties().length]; int i = 0; for (String property : action.getProperties()) { columns[i++] = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), property); } queryWrapper.select(columns); } } return queryWrapper; } /** * 判断是否是join * * @param actions 动作集合 * @return true 包含join动作,false 不包含join动作 */ public static boolean isJoin(List<Action> actions) { return actions.stream().filter(action -> action.getAction().contains(OperatorConstant.JOIN)).count() > 0; } /** * join其他的表 * * @param target 目标标的实体类 * @param actions 动作集合 * @param joinType join的类型 * @return 目标类的对象 */ public static <T> T join(Class<T> target, List<Action> actions, String joinType) { try { T result = target.newInstance(); // 拿到关联对象的actions,并且重新赋值 Field field = ReflectionKit.getFieldMap(target).get("actions"); field.setAccessible(true); field.set(result, actions); actions.add(new Action(target, null, joinType, null, null, null, null)); return result; } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } return null; } /** * 回主表 * @param target 目标标的实体类 * @param actions 动作集合 * @return 目标类的对象 */ public static <T> T end(Class<T> target, List<Action> actions) { try { T result = target.newInstance(); // 拿到关联对象的actions,并且重新赋值 Field field = ReflectionKit.getFieldMap(target).get("actions"); field.setAccessible(true); field.set(result, actions); return result; } catch (InstantiationException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } return null; } /** * 查询列表 * * @param actions 动作集合 * @param po po * @return 列表 */ public static List list(List<Action> actions, Object po) { boolean isJoin = isJoin(actions); AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po); SqlSession sqlSession = sqlSession(po.getClass()); Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1); map.put(Constants.WRAPPER, wrapper); try { return sqlSession.selectList(sqlStatement(SqlMethod.SELECT_LIST, po.getClass()), map); } finally { closeSqlSession(sqlSession, po.getClass()); } } /** * 手动指定查询字段 * * @param actions 动作集合 * @param po po * @param fields 字段 */ public static void addSelect(List<Action> actions, Object po, String... fields) { actions.add(new Action(po.getClass(), null, OperatorConstant.SELECT, null, null, null, fields)); } /** * 获取数据库字段名 * * @param entityClass 实体类 * @param fieldName 属性 * @return 字段名 */ public static String getDBField(Class entityClass, String fieldName) { ColumnCache columnCache = AbstractJoinWrapper.getCache(entityClass, fieldName); if (columnCache != null) { return columnCache.getColumn(); } throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName); } /** * 使用wrapper进行查询 * * @param actions 动作集合 * @param po po * @return 单个对象 */ public static Object one(List<Action> actions, Object po) { List list = list(actions, po); if (list.size() > 0) { return list.get(0); } return null; } /** * 使用wrapper进行查询 * * @param actions 动作集合 * @param po po * @return 总数 */ public static Long count(List<Action> actions, Object po) { boolean isJoin = isJoin(actions); AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po); SqlSession sqlSession = sqlSession(po.getClass()); Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1); map.put(Constants.WRAPPER, wrapper); try { return sqlSession.selectOne(sqlStatement(SqlMethod.SELECT_COUNT, po.getClass()), map); } finally { closeSqlSession(sqlSession, po.getClass()); } } /** * actions进行修改返回受影响行数 * * @param actions 动作集合 * @param po po * @return 受影响行数 */ public static Integer update(List<Action> actions, Object po) { AbstractWrapper wrapper = buildQueryWrapper(actions, po); SqlSession sqlSession = sqlSession(po.getClass()); Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1); map.put(Constants.WRAPPER, wrapper); map.put(Constants.ENTITY, po); try { return sqlSession.update(sqlStatement(SqlMethod.UPDATE, po.getClass()), map); } finally { closeSqlSession(sqlSession, po.getClass()); } } /** * 使用actions进行删除返回受影响行数 * * @param actions 动作集合 * @param po po * @return 受影响行数 */ public static Integer delete(List<Action> actions, Object po) { AbstractWrapper wrapper = buildQueryWrapper(actions, po); SqlSession sqlSession = sqlSession(po.getClass()); Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1); map.put(Constants.WRAPPER, wrapper); try { return sqlSession.delete(sqlStatement(SqlMethod.DELETE, po.getClass()), map); } finally { closeSqlSession(sqlSession, po.getClass()); } } protected static SqlSession sqlSession(Class poClass) { return SqlHelper.sqlSession(poClass); } protected static String sqlStatement(SqlMethod sqlMethod, Class poClass) { return sqlStatement(sqlMethod.getMethod(), poClass); } protected static String sqlStatement(String sqlMethod, Class poClass) { return SqlHelper.table(poClass).getSqlStatement(sqlMethod); } protected static void closeSqlSession(SqlSession sqlSession, Class poClass) { SqlSessionUtils.closeSqlSession(sqlSession, GlobalConfigUtils.currentSessionFactory(poClass)); } /** * 获取方法的sfunction * @param entityClass 实体类 * @param fieldName 字段名 * @return sfunction */ public static SFunction getSFunction(Class<?> entityClass, String fieldName) { if (functionMap.containsKey(entityClass.getName() + fieldName)) { return functionMap.get(entityClass.getName() + fieldName); } Field field = getDeclaredField(entityClass, fieldName); if(field == null){ throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName); } SFunction func = null; final MethodHandles.Lookup lookup = MethodHandles.lookup(); MethodType methodType = MethodType.methodType(field.getType(), entityClass); final CallSite site; String getFunName = "get" + fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1); try { site = LambdaMetafactory.altMetafactory(lookup, "invoke", MethodType.methodType(SFunction.class), methodType, lookup.findVirtual(entityClass, getFunName, MethodType.methodType(field.getType())), methodType, FLAG_SERIALIZABLE); func = (SFunction) site.getTarget().invokeExact(); functionMap.put(entityClass.getName() + field, func); return func; } catch (Throwable e) { throw ExceptionUtils.mpe("This class %s is not have method %s ", entityClass.getName(), getFunName); } } /** * 获取字段 * @param clazz 类 * @param fieldName 字段名 * @return 字段 */ public static Field getDeclaredField(Class<?> clazz, String fieldName) { Field field = null; for (; clazz != Object.class; clazz = clazz.getSuperclass()) { try { field = clazz.getDeclaredField(fieldName); return field; } catch (Exception e) { // 这里甚么都不要做!并且这里的异常必须这样写,不能抛出去。 // 如果这里的异常打印或者往外抛,则就不会执行clazz = clazz.getSuperclass(),最后就不会进入到父类中了 } } return null; } /** * 动作 * * @author wanglei * @since 2022-03-18 */ @Data @AllArgsConstructor public static class Action { /** * po的class */ private Class<?> modelClass; /** * 属性 */ private String property; /** * 动作 */ private String action; /** * 值 */ private Object value; /** * 最小值 */ private Object min; /** * 最大值 */ private Object max; /** * 属性 -用于select */ private String[] properties; } }