ThreadLocal 遇上线程池的问题及解决办法
ThreadLocal 称为线程本地存储,一般作为静态域使用,它为每一个使用它的线程提供一个其值(value)的副本。通常对数据库连接(Connection)和事务(Transaction)使用线程本地存储。
可以简单地将 ThreadLocal<T> 理解成一个容器,它将 value 对象存储在 Map<Thread, T> 域中,即使用当前线程为 key 的一个 Map,ThreadLocal 的 get() 方法从 Map 里取与当前线程相关联的 value 对象。ThreadLocal 的真正实现并不是这样的,但是可以简单地这样理解。
线程池中的线程在任务执行完成后会被复用,所以在线程执行完成时,要对 ThreadLocal 进行清理(清除掉与本线程相关联的 value 对象)。不然,被复用的线程去执行新的任务时会使用被上一个线程操作过的 value 对象,从而产生不符合预期的结果。
下面举一个简单的例子来说明:
1 import java.util.concurrent.ExecutorService; 2 import java.util.concurrent.Executors; 3 4 public class ThreadLocalVariableHolder { 5 private static ThreadLocal<Integer> variableHolder = new ThreadLocal<Integer>() { 6 @Override 7 protected Integer initialValue() { 8 return 0; 9 } 10 }; 11 12 public static int getValue() { 13 return variableHolder.get(); 14 } 15 16 public static void remove() { 17 variableHolder.remove(); 18 } 19 20 public static void increment() { 21 variableHolder.set(variableHolder.get() + 1); 22 } 23 24 public static void main(String[] args) { 25 ExecutorService executor = Executors.newCachedThreadPool(); 26 for (int i = 0; i < 5; i++) { 27 executor.execute(() -> { 28 int before = getValue(); 29 increment(); 30 int after = getValue(); 31 System.out.println("before: " + before + ", after: " + after); 32 }); 33 } 34 35 executor.shutdown(); 36 } 37 }
执行结果如下(如果你的执行结果 before 都是 0,after 都是 1 的话,就增加线程池执行的线程个数):
before: 0, after: 1 before: 0, after: 1 before: 0, after: 1 before: 1, after: 2 before: 1, after: 2
既然是为每个线程都提供一个副本,为什么会出现 before 不为 0 的情况呢?
下面追踪每一个执行的线程,将 main 方法修改为如下:
1 public static void main(String[] args) { 2 ExecutorService executor = Executors.newCachedThreadPool(); 3 for (int i = 0; i < 5; i++) { 4 executor.execute(() -> { 5 long threadId = Thread.currentThread().getId(); 6 int before = getValue(); 7 increment(); 8 int after = getValue(); 9 System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after); 10 }); 11 } 12 13 executor.shutdown(); 14 }
执行结果如下:
threadId: 10, before: 0, after: 1 threadId: 11, before: 0, after: 1 threadId: 12, before: 0, after: 1 threadId: 12, before: 1, after: 2 threadId: 11, before: 1, after: 2
由上面的执行结果可以看出,id 为 11 和 12 的线程被复用。线程池在复用线程执行任务时使用被之前的线程操作过的 value 对象。因此,在每个线程执行完成时,应该清理 ThreadLocal。具体做法如下:
1 public static void main(String[] args) { 2 ExecutorService executor = Executors.newCachedThreadPool(); 3 for (int i = 0; i < 100; i++) { 4 executor.execute(() -> { 5 try { 6 long threadId = Thread.currentThread().getId(); 7 int before = getValue(); 8 increment(); 9 int after = getValue(); 10 System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after); 11 } finally { 12 // 清理线程本地存储 13 remove(); 14 } 15 }); 16 } 17 18 executor.shutdown(); 19 }