批量插入不用计算的工具类,mysql 批量插入,批量插入跳过主键重复,简化批量插入

使用场景:

  批量导入一大堆的excel文件,插入数据时候有点慢,所以要批量插入。插入中跳过主键重复报错

 

批量插入数据工具类

package com.zlintent.controller;


import cn.hutool.core.util.ReflectUtil;
import cn.hutool.extra.spring.SpringUtil;

import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;
import java.util.function.BiFunction;
import java.util.function.Function;

/**
 * java 8 function 函数 https://www.runoob.com/java/java8-functional-interfaces.html
 *
 * @author liran
 */
public class BatchInsertUtil<T> {

    /**
     * 基于bi function
     *
     * @param action
     * @param tables
     * @param other
     * @return
     */
    public int batchInsert(BiFunction<List<T>, Integer, Integer> action, List<T> tables, Integer other) {
        int insertRow = 5;
        int insert = Math.min(insertRow, tables.size());
        int result = 0;
        int start = 0;
        while (result < tables.size()) {
            int end = start + insert;
            int sum = action.apply(tables.subList(start, Math.min(end, tables.size())), other);
            start = end;
            result = sum + result;
        }
        return result;
    }

    /**
     * 基于function
     *
     * @param action
     * @param tables
     * @return
     */
    public static int batchInsert(Function<List, Integer> action, List tables) {
        int insertRow = 5;
        int insert = Math.min(insertRow, tables.size());
        int result = 0;
        int start = 0;
        while (result < tables.size()) {
            int end = start + insert;
            int sum = action.apply(tables.subList(start, Math.min(end, tables.size())));
            start = end;
            result = sum + result;
        }
        return result;
    }

    /**
     * 基于反射
     *
     * @param tables
     * @param beanClass
     * @param methodName
     * @return
     */
    public static int batchInsert(List<?> tables, Class<?> beanClass, String methodName) {
        int insertRow = 5;
        int insert = Math.min(insertRow, tables.size());
        Object bean = SpringUtil.getBean(beanClass);
        int result = 0;
        int start = 0;
        while (result < tables.size()) {
            int end = start + insert;
            Object invoke = ReflectUtil.invoke(bean, methodName, tables.subList(start, Math.min(end, tables.size())));
            start = end;
            int sum = Integer.parseInt(String.valueOf(invoke));
            result = sum + result;
        }
        return result;
    }

    /**
     * 基于function 批量插入
     *
     * @param tables
     * @param action
     * @return
     */
    public static int batchInsertTask(List<?> tables, Function<List, Integer> action) {
        ForkJoinPool fjp = new ForkJoinPool(8);
        ForkJoinTask<Integer> task = new SaveFunTask(tables, 0, tables.size(), action);
        return fjp.invoke(task);
    }

    public static class SaveFunTask extends RecursiveTask<Integer> {
        static final int INSERT_ROW = 5;
        List<?> array;
        int start;
        int end;
        Function<List, Integer> action;

        SaveFunTask(List<?> array, int start, int end, Function<List, Integer> action) {
            this.array = array;
            this.start = start;
            this.end = end;
            this.action = action;
        }

        @Override
        protected Integer compute() {
            if (end - start <= INSERT_ROW) {
                return action.apply(array.subList(start, end));
            }
            int middle = (end + start) / 2;
            System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
            SaveFunTask task1 = new SaveFunTask(this.array, start, middle, action);
            SaveFunTask task2 = new SaveFunTask(this.array, middle, end, action);
            invokeAll(task1, task2);
            int subresult1 = task1.join();
            int subresult2 = task2.join();
            int result = subresult1 + subresult2;
            System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
            return result;
        }
    }


    /**
     * 反射 批量插入
     *
     * @param tables
     * @param beanClass
     * @param methodName
     * @return
     */
    public static int batchInsertTask(List<?> tables, Class<?> beanClass, String methodName) {
        ForkJoinPool fjp = new ForkJoinPool(8);
        Object bean = SpringUtil.getBean(beanClass);
        ForkJoinTask<Integer> task = new SaveTask(tables, 0, tables.size(), bean, methodName);
        return fjp.invoke(task);
    }

    public static class SaveTask extends RecursiveTask<Integer> {
        static final int INSERT_ROW = 5;
        List<?> array;
        int start;
        int end;
        Object bean;
        String method;

        SaveTask(List<?> array, int start, int end, Object bean, String method) {
            this.array = array;
            this.start = start;
            this.end = end;
            this.bean = bean;
            this.method = method;
        }

        @Override
        protected Integer compute() {
            if (end - start <= INSERT_ROW) {
                Object invoke = ReflectUtil.invoke(bean, method, array.subList(start, end));
                return Integer.parseInt(String.valueOf(invoke));
            }
            int middle = (end + start) / 2;
            System.out.println(String.format("split %d~%d ==> %d~%d, %d~%d", start, end, start, middle, middle, end));
            SaveTask task1 = new SaveTask(this.array, start, middle, bean, method);
            SaveTask task2 = new SaveTask(this.array, middle, end, bean, method);
            invokeAll(task1, task2);
            int subresult1 = task1.join();
            int subresult2 = task2.join();
            int result = subresult1 + subresult2;
            System.out.println("result = " + subresult1 + " + " + subresult2 + " ==> " + result);
            return result;
        }
    }


}

上面工具类用到了springUtil需要注入

import cn.hutool.extra.spring.SpringUtil;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

@SpringBootApplication
public class Application {

    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }

    @Bean
    public SpringUtil springUtil(){
        return new SpringUtil();
    }
}

 

service测试方法

import org.springframework.stereotype.Service;

import java.util.List;

@Service
public class InsertService {

    public Integer applyIn(List<Integer> list,Object other) {
        System.out.println("线程id"+Thread.currentThread().getId());
        System.out.println("多个参数"+ other);
        System.out.println("插入数据"+list.size());
        return list.size();
    }

    public Integer applyIn(List list) {
        System.out.println("线程id"+Thread.currentThread().getId());
        System.out.println("插入数据"+list.size());
        return list.size();
    }
}

 

测试方法

       // 组装数据
        List<Integer> tables = new ArrayList<>();
        for (int i = 0; i < 11; i++) {
            tables.add(i);
        }
        // 单线程插入 做类型校验,多个参数
        System.out.println("------------------------基于 function----------------------------------");
        int test1 = new BatchInsertUtil<Integer>().batchInsert((list, d) -> insertService.applyIn(list, d), tables, 1);

        System.out.println("************************基于 function**********************************");
        // 单线程插入 不做类型校验,一个参数
        int test2 = BatchInsertUtil.batchInsert((list) -> insertService.applyIn(list), tables);

        System.out.println("-------------------------基于 function---------------------------------");
        // 多线程插入 不做类型校验  一个参数
        int test3 = BatchInsertUtil.batchInsertTask(tables,(list) -> insertService.applyIn(list));


        System.out.println("*************************基于 反射*********************************");
        // 单线程插入 不做类型校验,基于反射
        int test4 = BatchInsertUtil.batchInsert(tables,InsertService.class,"applyIn");


        Assert.isTrue(test1==11 && test2==11 && test3==11 && test4==11 );

 

输出 日志

------------------------基于 function----------------------------------
线程id28
多个参数1
插入数据5
线程id28
多个参数1
插入数据5
线程id28
多个参数1
插入数据1
************************基于 function**********************************
线程id28
插入数据5
线程id28
插入数据5
线程id28
插入数据1
-------------------------基于 function---------------------------------
split 0~11 ==> 0~5, 5~11
线程id47
插入数据5
split 5~11 ==> 5~8, 8~11
线程id48
插入数据3
线程id47
插入数据3
result = 3 + 3 ==> 6
result = 5 + 6 ==> 11
*************************基于 反射*********************************
线程id28
插入数据5
线程id28
插入数据5
线程id28
插入数据1

 

说明:SaveTask 是用了fork join 这里一般要根据cpu 核数来确定   “ForkJoinPool fjp = new ForkJoinPool(8)”

 

实际插入数据可以工具类用对应的 mapper

 

插入时候如果要跳过主键重复或者唯一索引的校验, insert ignore

  <insert id="insertList">
        insert ignore into reference_table (
        id, tid, task_year_month, table_type
        )
        VALUES
        <foreach collection="subList" item="item" separator=",">
            (#{item.id,jdbcType=VARCHAR},
            #{item.tid,jdbcType=VARCHAR},
            #{item.taskYearMonth,jdbcType=TIMESTAMP},
            #{item.tableType,jdbcType=TINYINT})
        </foreach>
    </insert>

  



posted @ 2020-12-29 10:59  _Phoenix  阅读(1227)  评论(0编辑  收藏  举报