多线程处理百万数据
package org.example; import com.alibaba.druid.pool.DruidDataSource; import java.sql.*; import java.util.ArrayList; import java.util.List; import java.util.concurrent.*; public class test4 { public static void main(String[] args) throws ClassNotFoundException, SQLException, InterruptedException { DruidDataSource dataSource = new DruidDataSource(); dataSource.setUrl("jdbc:mysql://localhost:3306/mall"); dataSource.setUsername("root"); dataSource.setPassword("123456"); long start = System.currentTimeMillis(); // 创建连接 Connection conn = dataSource.getConnection(); // 定义每批处理的记录数 int batchSize = 100000; // 获取总记录数 int totalRows = getTotalRows(conn); // 计算批数 int batches = totalRows / batchSize; // 创建线程池 // 创建线程池 ThreadPoolExecutor executor = new ThreadPoolExecutor( 10, // 核心线程数 15, // 最大线程数 60, TimeUnit.SECONDS, // 空闲线程存活时间 new ArrayBlockingQueue<>(50), // 阻塞队列 new ThreadPoolExecutor.CallerRunsPolicy() // 拒绝策略 ); CountDownLatch latch = new CountDownLatch(batches); // 分批处理 for (int i = 0; i < batches; i++) { int offset = i * batchSize; System.out.println("第"+offset+"条"); List<Student> students = null; try { students = queryBatch(offset, batchSize, conn); } catch (SQLException e) { throw new RuntimeException(e); } // 查询一批记录 List<Student> finalStudents = students; executor.submit(() -> { // 处理name finalStudents.forEach(s -> s.setName(s.getName() + "*#")); // 更新这批记录 try { updateBatch(finalStudents,conn); } catch (SQLException e) { throw new RuntimeException(e); } latch.countDown(); }); } latch.await(); long end = System.currentTimeMillis(); long timeTaken = end - start; // 将毫秒转换为秒 double timeInSeconds = timeTaken / 1000.0; System.out.println("Total time taken: " + timeInSeconds/60 + " mins"); } static List<Student> queryBatch(int offset, int size, Connection conn) throws SQLException { // System.out.println("第"+offset+"批"); List<Student> students = new ArrayList<>(); String sql = "SELECT id, name, age, gender, major, gpa FROM students LIMIT ?, ?"; PreparedStatement stmt = conn.prepareStatement(sql); stmt.setInt(1, offset); stmt.setInt(2, size); ResultSet rs = stmt.executeQuery(); while (rs.next()) { Student s = new Student(); s.setId(rs.getInt("id")); s.setName(rs.getString("name")); s.setAge(rs.getInt("age")); s.setGender(rs.getString("gender")); s.setMajor(rs.getString("major")); s.setGpa(rs.getDouble("gpa")); students.add(s); } rs.close(); stmt.close(); return students; } // 更新一批记录 static void updateBatch(List<Student> students, Connection conn) throws SQLException { String sql = "UPDATE students SET name = ? WHERE id = ?"; PreparedStatement stmt = conn.prepareStatement(sql); for (Student s : students) { stmt.setString(1, s.getName()); stmt.setInt(2, s.getId()); stmt.addBatch(); } stmt.executeBatch(); stmt.close(); } // 获取总记录数 static int getTotalRows(Connection conn) throws SQLException { String sql = "SELECT COUNT(*) FROM students"; PreparedStatement stmt = conn.prepareStatement(sql); ResultSet rs = stmt.executeQuery(); int totalRows = 0; if(rs.next()) { totalRows = rs.getInt(1); } rs.close(); stmt.close(); return totalRows; } }