sfunction

/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.advance.injector;

import com.baomidou.mybatisplus.core.conditions.AbstractJoinWrapper;
import com.baomidou.mybatisplus.core.conditions.AbstractWrapper;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.toolkit.*;
import com.baomidou.mybatisplus.core.toolkit.support.ColumnCache;
import com.baomidou.mybatisplus.core.toolkit.support.SFunction;
import com.baomidou.mybatisplus.extension.conditions.query.BasicJoinQueryWrapper;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.ibatis.session.SqlSession;
import org.mybatis.spring.SqlSessionUtils;

import java.lang.invoke.CallSite;
import java.lang.invoke.LambdaMetafactory;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.Field;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * 通过此类去给wrapper添加参数和执行sql返回结果
 *
 * @author wanglei
 * @since 2022-03-14
 */
public class FuntionTools {

    /**
     * 可序列化
     */
    private static final int FLAG_SERIALIZABLE = 1;

    private static Map<String, SFunction> functionMap = new HashMap<>();

    /**
     * 把动作
     *
     * @param actions  动作
     * @param operator 操作符 比如like
     * @param po       po对象
     * @param property 属性名
     * @param value    值
     */
    public static void addAction(List<Action> actions, String operator, Object po, String property, Object value) {
        actions.add(new Action(po.getClass(), property, operator, value, null, null, null));
    }

    /**
     * 添加动作
     *
     * @param actions  actions
     * @param operator 操作符 比如like
     * @param po       po对象
     * @param property 属性名
     * @param minValue 最小值
     * @param maxValue 最大值
     */
    public static void addAction(List<Action> actions, String operator, Object po, String property, Object minValue, Object maxValue) {
        actions.add(new Action(po.getClass(), property, operator, null, minValue, maxValue, null));
    }

    /**
     * 构造一个查询条件 - 单表查询
     *
     * @param actions 动作集合
     * @param po      po
     * @return QueryWrapper
     */
    public static LambdaQueryWrapper buildQueryWrapper(List<Action> actions, Object po) {
        LambdaQueryWrapper<?> queryWrapper = new LambdaQueryWrapper();
        for (Action action : actions) {
            // 普通的查询只保留他自己的where条件和select,就算他调用了join也忽略掉
            if (action.getModelClass().equals(po.getClass()) && !"select".equals(action.getAction())) {
                SFunction column = getSFunction(action.getModelClass(), action.getProperty());
                buildWhere(queryWrapper, action, column);
            } else if ("select".equals(action.getAction()) && action.getModelClass().equals(po.getClass())) {
                SFunction[] columns = new SFunction[action.getProperties().length];
                int i = 0;
                for (String property : action.getProperties()) {
                    columns[i++] = getSFunction(action.getModelClass(), property);
                }
                queryWrapper.select(columns);
            }
        }
        return queryWrapper;
    }

    /**
     * 构造where条件
     *
     * @param queryWrapper wapper
     * @param action       动作
     * @param column       列名
     */
    public static void buildWhere(AbstractWrapper queryWrapper, Action action, Object column) {
        switch (action.getAction()) {
            case OperatorConstant.EQ:
                queryWrapper.eq(column, action.getValue());
                break;
            case OperatorConstant.LT:
                queryWrapper.lt(column, action.getValue());
                break;
            case OperatorConstant.GT:
                queryWrapper.gt(column, action.getValue());
                break;
            case OperatorConstant.LE:
                queryWrapper.le(column, action.getValue());
                break;
            case OperatorConstant.GE:
                queryWrapper.ge(column, action.getValue());
                break;
            case OperatorConstant.NE:
                queryWrapper.ne(column, action.getValue());
                break;
            case OperatorConstant.LIKE:
                queryWrapper.like(column, action.getValue());
                break;
            case OperatorConstant.LIKE_LEFT:
                queryWrapper.likeLeft(column, action.getValue());
                break;
            case OperatorConstant.LIKE_RIGHT:
                queryWrapper.likeRight(column, action.getValue());
                break;
            case OperatorConstant.NOT_LIKE:
                queryWrapper.notLike(column, action.getValue());
                break;
            case OperatorConstant.IS_NULL:
                queryWrapper.isNull(column);
                break;
            case OperatorConstant.NOT_NULL:
                queryWrapper.isNotNull(column);
                break;
            case OperatorConstant.IN:
                queryWrapper.in(column, (Collection) (action.getValue()));
                break;
            case OperatorConstant.NOT_IN:
                queryWrapper.notIn(column, (Collection) (action.getValue()));
                break;
            case OperatorConstant.ORDER_BY_ASC:
                queryWrapper.orderByAsc(column);
                break;
            case OperatorConstant.ORDER_BY_DESC:
                queryWrapper.orderByDesc(column);
                break;
            case OperatorConstant.BETWEEN:
                queryWrapper.between(column, action.getMin(), action.getMax());
                break;
            case OperatorConstant.NOT_BETWEEN:
                queryWrapper.notBetween(column, action.getMin(), action.getMax());
                break;
        }
    }


    /**
     * 构造一个查询条件 - 多表查询
     *
     * @param actions 动作集合
     * @param po      po
     * @return
     */
    public static BasicJoinQueryWrapper buildJoinWrapper(List<Action> actions, Object po) {
        BasicJoinQueryWrapper queryWrapper = new BasicJoinQueryWrapper(po.getClass());
        // 先搞定join,不然添加条件和select会报错
        for (Action action : actions) {
            if (action.getAction().contains(OperatorConstant.JOIN)) {
                switch (action.getAction()) {
                    case OperatorConstant.JOIN:
                        queryWrapper.innerJoin(action.getModelClass());
                        break;
                    case OperatorConstant.LEFT_JOIN:
                        queryWrapper.leftJoin(action.getModelClass());
                        break;
                }
            }
        }
        for (Action action : actions) {
            // 不是select和join动作的话,就是where条件
            if (!OperatorConstant.SELECT.equals(action.getAction()) && !action.getAction().contains(OperatorConstant.JOIN)) {
                BasicJoinQueryWrapper.ModelProperty column = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), action.getProperty());
                buildWhere(queryWrapper, action, column);
                // select处理
            } else if (OperatorConstant.SELECT.equals(action.getAction())) {
                BasicJoinQueryWrapper.ModelProperty[] columns = new BasicJoinQueryWrapper.ModelProperty[action.getProperties().length];
                int i = 0;
                for (String property : action.getProperties()) {
                    columns[i++] = new BasicJoinQueryWrapper.ModelProperty(action.getModelClass(), property);
                }
                queryWrapper.select(columns);
            }
        }
        return queryWrapper;
    }


    /**
     * 判断是否是join
     *
     * @param actions 动作集合
     * @return true 包含join动作,false 不包含join动作
     */
    public static boolean isJoin(List<Action> actions) {
        return actions.stream().filter(action -> action.getAction().contains(OperatorConstant.JOIN)).count() > 0;
    }


    /**
     * join其他的表
     *
     * @param target   目标标的实体类
     * @param actions  动作集合
     * @param joinType join的类型
     * @return 目标类的对象
     */
    public static <T> T join(Class<T> target, List<Action> actions, String joinType) {
        try {
            T result = target.newInstance();
            // 拿到关联对象的actions,并且重新赋值
            Field field = ReflectionKit.getFieldMap(target).get("actions");
            field.setAccessible(true);
            field.set(result, actions);
            actions.add(new Action(target, null, joinType, null, null, null, null));
            return result;
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return null;
    }

    /**
     * 回主表
     * @param target   目标标的实体类
     * @param actions  动作集合
     * @return 目标类的对象
     */
    public static <T> T end(Class<T> target, List<Action> actions) {
        try {
            T result = target.newInstance();
            // 拿到关联对象的actions,并且重新赋值
            Field field = ReflectionKit.getFieldMap(target).get("actions");
            field.setAccessible(true);
            field.set(result, actions);
            return result;
        } catch (InstantiationException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        return null;
    }





    /**
     * 查询列表
     *
     * @param actions 动作集合
     * @param po      po
     * @return 列表
     */
    public static List list(List<Action> actions, Object po) {
        boolean isJoin = isJoin(actions);
        AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po);
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put(Constants.WRAPPER, wrapper);
        try {
            return sqlSession.selectList(sqlStatement(SqlMethod.SELECT_LIST, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }


    /**
     * 手动指定查询字段
     *
     * @param actions 动作集合
     * @param po      po
     * @param fields  字段
     */
    public static void addSelect(List<Action> actions, Object po, String... fields) {
        actions.add(new Action(po.getClass(), null, OperatorConstant.SELECT, null, null, null, fields));
    }


    /**
     * 获取数据库字段名
     *
     * @param entityClass 实体类
     * @param fieldName   属性
     * @return 字段名
     */
    public static String getDBField(Class entityClass, String fieldName) {
        ColumnCache columnCache = AbstractJoinWrapper.getCache(entityClass, fieldName);
        if (columnCache != null) {
            return columnCache.getColumn();
        }
        throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName);
    }


    /**
     * 使用wrapper进行查询
     *
     * @param actions 动作集合
     * @param po      po
     * @return 单个对象
     */
    public static Object one(List<Action> actions, Object po) {
        List list = list(actions, po);
        if (list.size() > 0) {
            return list.get(0);
        }
        return null;
    }


    /**
     * 使用wrapper进行查询
     *
     * @param actions 动作集合
     * @param po      po
     * @return 总数
     */
    public static Long count(List<Action> actions, Object po) {
        boolean isJoin = isJoin(actions);
        AbstractWrapper wrapper = isJoin ? buildJoinWrapper(actions, po) : buildQueryWrapper(actions, po);
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put(Constants.WRAPPER, wrapper);
        try {
            return sqlSession.selectOne(sqlStatement(SqlMethod.SELECT_COUNT, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }


    /**
     * actions进行修改返回受影响行数
     *
     * @param actions 动作集合
     * @param po      po
     * @return 受影响行数
     */
    public static Integer update(List<Action> actions, Object po) {
        AbstractWrapper wrapper = buildQueryWrapper(actions, po);
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put(Constants.WRAPPER, wrapper);
        map.put(Constants.ENTITY, po);
        try {
            return sqlSession.update(sqlStatement(SqlMethod.UPDATE, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }

    /**
     * 使用actions进行删除返回受影响行数
     *
     * @param actions 动作集合
     * @param po      po
     * @return 受影响行数
     */
    public static Integer delete(List<Action> actions, Object po) {
        AbstractWrapper wrapper = buildQueryWrapper(actions, po);
        SqlSession sqlSession = sqlSession(po.getClass());
        Map<String, Object> map = CollectionUtils.newHashMapWithExpectedSize(1);
        map.put(Constants.WRAPPER, wrapper);
        try {
            return sqlSession.delete(sqlStatement(SqlMethod.DELETE, po.getClass()), map);
        } finally {
            closeSqlSession(sqlSession, po.getClass());
        }
    }


    protected static SqlSession sqlSession(Class poClass) {
        return SqlHelper.sqlSession(poClass);
    }

    protected static String sqlStatement(SqlMethod sqlMethod, Class poClass) {
        return sqlStatement(sqlMethod.getMethod(), poClass);
    }

    protected static String sqlStatement(String sqlMethod, Class poClass) {
        return SqlHelper.table(poClass).getSqlStatement(sqlMethod);
    }

    protected static void closeSqlSession(SqlSession sqlSession, Class poClass) {
        SqlSessionUtils.closeSqlSession(sqlSession, GlobalConfigUtils.currentSessionFactory(poClass));
    }




    /**
     * 获取方法的sfunction
     * @param entityClass 实体类
     * @param fieldName 字段名
     * @return sfunction
     */
    public static SFunction getSFunction(Class<?> entityClass, String fieldName) {
        if (functionMap.containsKey(entityClass.getName() + fieldName)) {
            return functionMap.get(entityClass.getName() + fieldName);
        }
        Field field = getDeclaredField(entityClass, fieldName);
        if(field == null){
            throw ExceptionUtils.mpe("This class %s is not have field %s ", entityClass.getName(), fieldName);
        }
        SFunction func = null;
        final MethodHandles.Lookup lookup = MethodHandles.lookup();
        MethodType methodType = MethodType.methodType(field.getType(), entityClass);
        final CallSite site;
        String getFunName = "get" + fieldName.substring(0, 1).toUpperCase() + fieldName.substring(1);
        try {
            site = LambdaMetafactory.altMetafactory(lookup,
                    "invoke",
                    MethodType.methodType(SFunction.class),
                    methodType,
                    lookup.findVirtual(entityClass, getFunName, MethodType.methodType(field.getType())),
                    methodType, FLAG_SERIALIZABLE);
            func = (SFunction) site.getTarget().invokeExact();
            functionMap.put(entityClass.getName() + field, func);
            return func;
        } catch (Throwable e) {
            throw ExceptionUtils.mpe("This class %s is not have method %s ", entityClass.getName(), getFunName);
        }
    }

    /**
     * 获取字段
     * @param clazz 类
     * @param fieldName 字段名
     * @return 字段
     */
    public static Field getDeclaredField(Class<?> clazz, String fieldName) {
        Field field = null;

        for (; clazz != Object.class; clazz = clazz.getSuperclass()) {
            try {
                field = clazz.getDeclaredField(fieldName);
                return field;
            } catch (Exception e) {
                // 这里甚么都不要做!并且这里的异常必须这样写,不能抛出去。
                // 如果这里的异常打印或者往外抛,则就不会执行clazz = clazz.getSuperclass(),最后就不会进入到父类中了

            }
        }

        return null;
    }

    /**
     * 动作
     *
     * @author wanglei
     * @since 2022-03-18
     */
    @Data
    @AllArgsConstructor
    public static class Action {
        /**
         * po的class
         */
        private Class<?> modelClass;
        /**
         * 属性
         */
        private String property;

        /**
         * 动作
         */
        private String action;

        /**
         * 值
         */
        private Object value;
        /**
         * 最小值
         */
        private Object min;
        /**
         * 最大值
         */
        private Object max;
        /**
         * 属性 -用于select
         */
        private String[] properties;

    }
}

 

posted @ 2024-10-27 21:42  findlisa  阅读(5)  评论(0编辑  收藏  举报