ThreadLocal

前言

本文基于 JDK1.8,主要介绍 ThreadLocal 的用法和核心原理。

正文

什么是 ThreadLocal ?

ThreadLocal 从字面意思来理解即是 线程本地变量副本,属于每个线程独有的,不同线程使用同一个 ThreadLocal 对象设置值是互相隔离的,即 A 线程向 ThreadLocal 设置值,B 线程用同一个 ThreadLocal 对象是无法获取到的,只有 A 线程自己可以获取的。

ThreadLocal 用法示例

下面有一段 ThreadLocal 代码示例,首先启动2个线程各自使用了 ThreadLocal 设置了一个字符串,首先线程a设置,然后线程b等待线程a执行完尝试获取,主线程等待线程a和线程b都执行完,也尝试获取。

public class ThreadLocalTest {

    private static final ThreadLocal<String> threadLocal = new ThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        // 线程a向 ThreadLocal 设置值,并获取
        Thread a = new Thread(() -> {
            printThreadLocalValue();
            threadLocal.set("A");
            printThreadLocalValue();
        });

        // 线程b等待线程a执行完,再获取 ThreadLocal 中的值
        Thread b = new Thread(() -> {
            // 线程睡眠5秒,等待a线程执行完
            SleepUtils.second(5);
            printThreadLocalValue();
            threadLocal.set("B");
            printThreadLocalValue();
        });
        a.start();
        a.join();
        b.start();
        b.join();

        // 主线程等待线程a和线程b执行完,获取 ThreadLocal 中的值
        printThreadLocalValue();
        threadLocal.set("main");
        printThreadLocalValue();
    }

    private static void printThreadLocalValue() {
        System.out.println(Thread.currentThread().getName() + ": " + threadLocal.get());
    }

}

结果如下所示:

Thread-0: null
Thread-0: A
Thread-1: null
Thread-1: B
main: null
main: main

从上面的代码以及运行结果看出 ThreadLocal 的线程隔离特性,每个线程无法获取到其它线程设置的值,即使是同一个 ThreadLocal 对象。

ThreadLocal#get 方法源码剖析

首先我们看一下 ThreadLocal 中的 get() 方法,代码如下:

public T get() {
  // 获取当前线程
  Thread t = Thread.currentThread();
  // 获取 Thread 的 threadLocals 属性
  ThreadLocalMap map = getMap(t);
  // map 不为空
  if (map != null) {
    // 获取 map 中的节点
    ThreadLocalMap.Entry e = map.getEntry(this);
    if (e != null) {
      @SuppressWarnings("unchecked")
      T result = (T)e.value;
      // 节点不为空直接返回 value
      return result;
    }
  }
  // map 为空 || e 为空,调用初始化方法返回默认值
  return setInitialValue();
}

ThreadLocalMap getMap(Thread t) {
  return t.threadLocals;
}

上面 get() 方法主要流程如下:

  1. 获取当前线程的 ThreadLocalMap
  2. 如果 ThreadLocalMap 为空就调用 setInitialValue() 方法设置默认值并创建 map;不为空就获取 entrykey 对应的 value 值。

下面我们简单看一下 ThreadLocalMapThread 中的定义:

public class Thread implements Runnable {
  
  // 当前线程的 ThreadLocalMap
  ThreadLocal.ThreadLocalMap threadLocals = null;

  // 从父线程继承而来的 ThreadLocalMap,主要用于父子线程 ThreadLocal 变量的传递
  ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
  
  private void init(ThreadGroup g, Runnable target, String name, long stackSize, AccessControlContext acc, boolean inheritThreadLocals) {
    // 省略其它代码...
    
    // 如果 inheritThreadLocals 为true && 父线程 inheritableThreadLocals 不为空
    // 将父线程的 inheritableThreadLocals 值复制到当前的 inheritableThreadLocals 中
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
      this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

    // 省略其它代码...
  }

  // 省略其它代码...
  
}

其中 inheritableThreadLocals 属性会在线程创建时的 init() 方法中判断,如果父线程有值复制到子线程一份;这样就实现了父子进程 ThreadLocal 变量的传递。

这里简单看一下 ThreadLocalMap 的定义:

static class ThreadLocalMap {

  static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
      super(k);
      value = v;
    }
  }
  
  /**
   * The initial capacity -- MUST be a power of two.
   */
  private static final int INITIAL_CAPACITY = 16;

  /**
   * The table, resized as necessary.
   * table.length MUST always be a power of two.
   */
  private Entry[] table;

  /**
   * The number of entries in the table.
   */
  private int size = 0;

  /**
   * The next size value at which to resize.
   */
  private int threshold; // Default to 0
 	
  // 省略其它代码...
}

上面静态内部类 Entry 中的 value 就是 ThreadLocal#set() 方法设置的值,而 k 是当前的 ThreadLocal 对象,作为 WeakReference (弱引用)的 referent 属性。

在 Java 中,对象引用可以为 强引用软引用弱引用虚引用 四种,是 JVM 回收内存判断的重要标准之一。

  • 强引用 StrongReference,一般声明的一个对象,都是强引用。使用场景,比如 String s = new String(); s 就是一个强引用。gc 如果发现一个对象被强引用指向,如果 JVM 空间不足的时候,就算 OOM 也不会回收它。
  • 软引用 SoftReference,当 JVM 空间不够的时候,gc 会先回收软引用的空间。使用场景:适合用于缓存。
  • 弱引用 WeakReference,只要 gc 发现了弱引用,就会回收掉它的空间。使用场景:ThreadLocalMap WeakHashMap 中的 Entry
  • 虚引用 PhantomReference,这个引用在 gc 垃圾回收线程看来,就是没有引用的意思,它的作用是帮助 JVM 管理直接内存 DirectBuffer。经典的使用场景:NIO。

ThreadLocalMap 数据结构如下所示:(图片来自于https://www.jianshu.com/p/2a540903d696

接下来我们接着看 ThreadLocal#setInitialValue() 方法,代码如下:

private T setInitialValue() {
  // 首先调用初始化 value 方法,默认返回 null
  T value = initialValue();
  // 获取当前线程
  Thread t = Thread.currentThread();
  // 获取 Thread 的 threadLocals 属性
  ThreadLocalMap map = getMap(t);
  if (map != null)
    // 不为空,直接设置 value,key 为当前的 ThreadLocal 对象
    map.set(this, value);
  else
    // map 为空,创建并设置 value,最后赋值给 Thread 的 threadLocals 属性
    createMap(t, value);
  return value;
}

protected T initialValue() {
  return null;
}

上面代码中的 initialValue() 方法我们可以在创建 ThreadLocal 对象时重写,来返回一个我们自定义的默认兜底值,如下所示:

private static final ThreadLocal<String> threadLocal = new ThreadLocal() {
  @Override
  protected Object initialValue() {
    return "";
  }
};

// Java8 写法
private static final ThreadLocal<String> threadLocal2 = ThreadLocal.withInitial(() -> "");

下面我们看一下 createMap() 方法:

void createMap(Thread t, T firstValue) {
  t.threadLocals = new ThreadLocalMap(this, firstValue);
}

ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
  // 创建 Entry 数组,默认容量16
  table = new Entry[INITIAL_CAPACITY];
  // 寻址
  int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
  // 创建 Entry 并给 table[i] 赋值
  table[i] = new Entry(firstKey, firstValue);
  size = 1;
  // 设置扩容阈值,计算方式为 len * 2 / 3,这里也就是 16 * 2 / 3 =  10
  setThreshold(INITIAL_CAPACITY);
}

ThreadLocalMap 如果出现 hash 冲突的话使用的是 开放寻址法,即当前位置有元素了就换个位置;HashMap 出现 hash 冲突使用的是 链表法

ThreadLocal#set 方法源码剖析

public void set(T value) {
  // 获取当前线程
  Thread t = Thread.currentThread();
  // 获取 Thread 的 threadLocals 属性
  ThreadLocalMap map = getMap(t);
  if (map != null)
    // 如果 map 不为空,直接调用 set() 方法设置值
    map.set(this, value);
  else
		// 上文分析过,这里不再赘述
    createMap(t, value);
}

在分析过 ThreadLocal#get() 方法后,看 ThreadLocal#set() 方法就变得很简单了;这里还是总结一下主要流程:

  1. 获取当前线程的 ThreadLocalMap
  2. map 不为空,直接调用 set() 方法设置值,key 为当前的 ThreadLocal 对象本身, value 就是传进来的方法参数;map 为空代表第一次设置值,进行 ThreadLocalMap 的初始化,上面说过初始化的默认大小为16,通过构造函数设置 value 并将 map 赋值给线程的 threadLocals 属性。

父子进程传递 ThreadLocal 变量

想要在父子进程之间传递 ThreadLocal 变量的话需要使用 InheritableThreadLocal,下面我们看一个简单的示例:

public class InheritableThreadLocalTest {

    private static final ThreadLocal<String> threadLocal = new InheritableThreadLocal<>();

    public static void main(String[] args) throws InterruptedException {
        threadLocal.set("main");
        printThreadLocalValue();

        Thread a = new Thread(() -> {
            printThreadLocalValue();
            threadLocal.set("a");
            printThreadLocalValue();
        });
        a.start();
        a.join();

        printThreadLocalValue();
    }

    private static void printThreadLocalValue() {
        System.out.println(Thread.currentThread().getName() + ": " + threadLocal.get());
    }

}

打印结果如下:

main: main
Thread-0: main
Thread-0: a
main: main

可以看出 a 线程获取到了 main 线程中 ThreadLocal 中的值,但是 main 线程无法获取到 a 线程设置后的 ThreadLocal 中的值。 上面介绍过 Threadinit() 方法会拷贝父线程的 inheritableThreadLocals 属性到子线程中,下面我们看看 InheritableThreadLocal 是如何实现的。

public class InheritableThreadLocal<T> extends ThreadLocal<T> {

  // 省略其它代码...
  
  ThreadLocalMap getMap(Thread t) {
    return t.inheritableThreadLocals;
  }

  void createMap(Thread t, T firstValue) {
    t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
  }
}

从源码可以看出 InheritableThreadLocal 实现的非常简单,继承了 ThreadLocal,重写其 getMap()createMap() 方法,将操作 ThreadthreadLocals 属性改为 inheritableThreadLocals 即可。后续的 get()set()remove() 等方法都是在对 inheritableThreadLocals 属性来操作的。

使用不当导致的内存泄露

上文中提到 ThreadLocalMap 内部是一个 Entry 数组,Entry 继承自 WearReference,我们简单看一下 Entry 的构造函数,如下:

static class Entry extends WeakReference<ThreadLocal<?>> {
  /** The value associated with this ThreadLocal. */
  Object value;

  Entry(ThreadLocal<?> k, Object v) {
    super(k);
    value = v;
  }
}

public WeakReference(T referent) {
	super(referent);
}

Reference(T referent) {
  this(referent, null);
}

Reference(T referent, ReferenceQueue<? super T> queue) {
  this.referent = referent;
  this.queue = (queue == null) ? ReferenceQueue.NULL : queue;
}

由上面代码可以看出 k 被传递到 WeakReference 的构造函数里面,也就是说 ThreadLocalMap 里面的 keyThreadLocal 对象的弱引用,具体是 referent 变量引用了 ThreadLocal 对象,value 为具体调用 ThreadLocalset() 方法传递的值。

当一个线程调用 ThreadLocalset() 方法设置变量的时候,当前线程的 ThreadLocalMap 里面就会存放一个记录,这个记录的 keyThreadLocal 的引用, value 则为设置的值。如果当前线程一直存在而没有调用 ThreadLocalremove() 方法,并且这时候其它地方还是有对 ThreadLocal 的引用,则当前线程的 ThreadLocalMap 变量里面会存在 ThreadLocal 变量的引用和 value 对象的引用是不会被释放的,因为还是被引用了;但是当 ThreadLocal 没有强依赖了,由于是弱引用所以会在 gc 时被回收,但是对应的 value 还是存在的,就会形成 ThreadLocalMap 里面有 keynullentry。此时就会导致内存泄露。

ThreadLocalset()get() 以及 remove() 方法里面有一些时机会被这些 keynullentry 进行清理,下面看一下 remove() 方法清理的过程:

private void remove(ThreadLocal<?> key) {
  Entry[] tab = table;
  int len = tab.length;
  // 计算 key 所在的 tab 数组的位置
  int i = key.threadLocalHashCode & (len-1);
  // 由于 ThreadLocalMap 才用开放寻址法,所以在 key 不相等时,索引加1继续判断
  for (Entry e = tab[i];
       e != null;
       e = tab[i = nextIndex(i, len)]) {
    if (e.get() == key) {
      // key 相等找到对应的 Entry 了
      e.clear();
      expungeStaleEntry(i);
      return;
    }
  }
}

// 获取指定下标的下一个下标,如果大于等于 table 的长度就从 0 开始
private static int nextIndex(int i, int len) {
  return ((i + 1 < len) ? i + 1 : 0);
}

public void clear() {
  // 将 WeakReference 中的 referent 设置为 null
  this.referent = null;
}

private int expungeStaleEntry(int staleSlot) {
  Entry[] tab = table;
  int len = tab.length;

  // 将指定下标的元素数据设置为 null
  tab[staleSlot].value = null;
  tab[staleSlot] = null;
  size--;

  // Rehash until we encounter null
  Entry e;
  int i;
  // 从指定位置的下一个开始遍历
  for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
    ThreadLocal<?> k = e.get();
    // 如果 key 等于 null,代表已被回收,直接将 value 设置为空,避免内存泄露
    if (k == null) {
      e.value = null;
      tab[i] = null;
      size--;
    } else {
      // 调整元素位置
      int h = k.threadLocalHashCode & (len - 1);
      if (h != i) {
        tab[i] = null;

        // Unlike Knuth 6.4 Algorithm R, we must scan until
        // null because multiple entries could have been stale.
        while (tab[h] != null)
          h = nextIndex(h, len);
        tab[h] = e;
      }
    }
  }
  return i;
}

上面代码便是 ThreadLocal#remove() 方法的全过程,并不复杂,总结起来流程如下:

  1. 首先计算指定 key 所在 table 的位置,判断 key 是否相等,如果不相等由于解决 hash 冲突使用的 开放寻址法,所以增加索引下标线性的去遍历判断直到找到 key 相等的。
  2. 如果找到 key 相等的,将 keyvalue 和指定下标的元素都设置为 null
  3. 然后对 Entry 数组进行一次整理和回收,重新计算当前 key 的位置,如果和当前下标不等说明有元素被清除了,进行 rehash 从重新计算的下标位置 h 作为起始位置直到找到一个没有元素的位置,将元素放入进去。

下面我们用一个例子看一下,手动调用 remove() 方法和不调用有什么区别:

public class ThreadLocalMemoryGiveWayTest {

    public static void main(String[] args) {
        test();
        System.gc();
    }

    private static void test() {
        Content content = new Content();
        ThreadLocal<Content> threadLocal = new ThreadLocal<>();
        threadLocal.set(content);
        threadLocal = null;
    }

    @Data
    static class Content {
        private byte[] data = new byte[5 * 1024 * 1024];
    }

}

在运行前先设置一下 JVM 参数,如下:

-verbose:gc -Xms20M -Xmx20M -Xmn10M -XX:SurvivorRatio=8 -XX:+PrintGCDetails

参数意思不做过多解释了,不知道的同学自行百度。

上面的测试代码大概意思就是,创建一个 ThreadLocal 里面放一个大概有5MB大小的对象,然后将 threadLocal 设置为 null,即没有强引用了,然后再手动触发 gc,看是否会回收这5MB的空间。

首先测试的是没有手动调用 remove() 方法,gc 日志如下:

4BEEE28F-F44F-41C5-BD2F-861AFEB73146.png

可以看到 ThreadLocal 中的 value 被分配在老年代中,gc 后并没有被回收。接下来我们将上面代码中的 test() 方法改成如下所示:

private static void test() {
  Content content = new Content();
  ThreadLocal<Content> threadLocal = new ThreadLocal<>();
  threadLocal.set(content);
  threadLocal.remove();
}

再次运行,gc 日志如下:

33ED8359-6631-45B7-9DCB-86E496863315.png

可以看到对象已经被回收了,所以在实际使用中,当线程还在运行时,并且 ThreadLocal 已经使用完,最好手动调用 remove() 方法,来防止内存泄露,特别是在使用线程池时。

最佳实践

  • 在类中定义 ThreadLocal,用 private static final 修饰。

  • 根据源码分析,由于 ThreadLocal 对象本身会作为 Entrykey 去获取数据,所以最好也要用 final 去修饰。key 通常都是不可变对象,否则当作为 key 的对象发生了变化之后,之前存储的数据将无法 get,造成资源泄露。

  • 在使用完 ThreadLocal 后,及时手动 remove() 无用的 ThreadLocal,防止资源泄露。

  • 在使用线程池时,由于线程池中的线程会被复用,所以也要记得用完后 remove(),防止线程下一次使用时还存储着上一次的信息。

参考

posted @ 2020-09-03 16:16  leisurexi  阅读(203)  评论(0编辑  收藏  举报