ThreadLocal与ThreadLocalMap源码分析
ThreadLocal类
该类主要用于不同线程存储自己的线程本地变量。本文先通过一个示例简单介绍该类的使用方法,然后从ThreadLocal类的初始化、存储结构、增删数据和hash值计算等几个方面,分析对应源码。采用的版本为jdk1.8。
ThreadLocal-使用方法
ThreadLocal对象可以在多个线程中被使用,通过set()方法设置线程本地变量,通过get()方法获取设置的线程本地变量。我们先通过一个示例简单了解下使用方法:
public static void main(String[] args){
ThreadLocal<String> threadLocal = new ThreadLocal<>();
// 线程1
new Thread(()->{
// 查看是否有初始值
System.out.println("线程1的初始值:"+threadLocal.get());
// 设置线程1的值
threadLocal.set("V1");
// 输出
System.out.println("线程1的值:"+threadLocal.get());
// 等待一段时间,等线程2设置值后再查看线程1的值
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("线程1的值:"+threadLocal.get());
}).start();
// 线程2
new Thread(()->{
// 等待线程1设置初始值
try {
Thread.sleep(500);
} catch (InterruptedException e) {
e.printStackTrace();
}
// 查看线程2的初始值
System.out.println("线程2的值:"+threadLocal.get());
// 设置线程2的值
threadLocal.set("V2");
// 查看线程2的值
System.out.println("线程2的值:"+threadLocal.get());
}).start();
}
由于threadlocal设置的值是在每个线程中都有一个副本的,线程之间不会互相影响。代码运行的结果如下所示:
线程1的初始值:null
线程1的值:V1
线程2的值:null
线程2的值:V2
线程1的值:V1
ThreadLocal-初始化
ThreadLocal类只有一个无参的构造方法,如下所示:
/**
* Creates a thread local variable.
* @see #withInitial(java.util.function.Supplier)
*/
public ThreadLocal() {
}
但其实还有一个带参数的构造方法,不过是它的子类。ThreadLocal中定义了一个内部类SuppliedThreadLocal,为继承自ThreadLocal类的子类。可以通过该类进行给定初始值的初始化,其定义如下:
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
@Override
protected T initialValue() {
return supplier.get();
}
}
通过TheadLocal threadLocal = Thread.withInitial(supplier);这样的语句可以进行给定初始值的初始化。在某个线程第一次调用get()方法时,会执行initialValue()方法设置线程变量为传入supplier中的值。
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new SuppliedThreadLocal<>(supplier);
}
ThreadLocal-存储结构
在jdk1.8版本中,使用的是TheadLocalMap这一个容器存储线程本地变量。
该容器的设计思想和HashMap有很多共同之处。比如:内部定义了Entry节点存储键值对(使用ThreadLocal对象作为键);使用一个数组存储entry节点;设定一个阈值,超过阈值时进行扩容;通过键的hash值与数组长度进行&操作确定下标索引等。但也有很多不同之处,具体我们在后续介绍ThreadLocalMap类时再详细分析。
static class ThreadLocalMap {
// Entry节点定义
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
// 存储元素的数组
private Entry[] table;
// 容器内元素数量
private int size = 0;
// 阈值
private int threshold; // Default to 0
// 修改和添加元素
private void set(ThreadLocal<?> key, Object value){
...
}
// 移除元素
private void remove(ThreadLocal<?> key) {
...
}
...
}
ThreadLocal-增删数据
ThreadLocal类提供了get(),set()和remove()方法来操作当前线程的threadlocal变量副本。底层则是基于ThreadLocalMap容器来实现数据操作。
不过要注意的是:ThreadLocal中并没有ThreadLocalMap的成员变量,ThreadLocalMap对象是Thread类中的一个成员,所以需要通过通过当前线程的Thread对象去获取该容器。
每一个线程Thread对象都会有一个map容器,该容器会随着线程的终结而回收。
设置线程本地变量的方法。
public void set(T value) {
// 获取当前线程对应的Thread对象,其是map键值对中的健
Thread t = Thread.currentThread();
// 获取当前线程对象的容器map
ThreadLocalMap map = getMap(t);
// 如果容器不为null,则直接设置元素。否则用线程对象t和value去初始化容器对象
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
// 通过当前线程的线程对象获取容器
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
// 创建map容器,本质是初始化Thread对象的成员变量threadLocals
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
获取线程本地变量的方法。
public T get() {
// 获取当前线程对象
Thread t = Thread.currentThread();
// 获取当前线程对象的容器map
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
// 如果容器不为null且容器内有当前threadlocal对象对应的值,则返回该值
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 如果容器为null或者容器内没有当前threadlocal对象绑定的值,则先设置初始值并返回该初始值
return setInitialValue();
}
// 设置初始值。主要分为两步:1.加载和获取初始值;2.在容器中设置该初始值。
// 第二步其实和set(value)方法实现一模一样。
private T setInitialValue() {
// 加载并获取初始值,默认是null。如果是带参初始化的子类SuppliedThreadLocal,会有一个输入初始值。
// 当然也可以继承ThreadLocal类重写该方法设置初始值
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
// 如果容器不为null,则直接设置元素。否则用线程对象t和value去初始化容器对象
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
移除线程本地变量的方法
public void remove() {
// 如果容器不为null就调用容器的移除方法,移除和该threadlocal绑定的变量
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}
ThreadLocal-hash值计算
ThreadLocal的hash值用于ThreadLocalMap容器计算数组下标。类中定义threadLocalHashCode表示其hash值。类中定义了静态方法和静态原子变量计算hash值,也就是说所有的threadLocal对象共用一个增长器。
// 当前ThreadLocal对象的hash值
private final int threadLocalHashCode = nextHashCode();
// 用来计算hash值的原子变量,所有的threadlocal对象共用一个增长器
private static AtomicInteger nextHashCode = new AtomicInteger();
// 魔法数字,使hash散列均匀
private static final int HASH_INCREMENT = 0x61c88647;
// 计算hash值的静态方法
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
我们使用同样的方法定义一个测试类,定义多个不同测试类对象,看看hash值的生成情况。如下所示,可以看到hash值都不同,是共用的一个增长器。
public class Test{
private static final int HASH_INCREMENT = 0x61c88647;
public static AtomicInteger nextHashCode = new AtomicInteger();
public final int nextHashCode = nextHashCode();
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
public static void main(String[] args){
for (int i = 0; i < 5; i++) {
Test test = new Test();
System.out.println(test.nextHashCode);
}
}
// 输出的hash值
0
1640531527
-1013904242
626627285
-2027808484
}
ThreadLocalMap类
ThreadLocalMap类是ThreadLocal的内部类。其作为一个容器,为ThreadLocal提供操作线程本地变量的功能。每一个Thread对象中都会有一个ThreadLocalMap对象实例(成员变量threadLocals,初始值为null)。因为map是Thread对象的非公共成员,不会被并发调用,所以不用考虑并发风险。
后文将从数据存储设计、初始化、增删数据等方面分析对应源码。
ThreadLocalMap-数据存储设计
该map和hashmap类似,使用一个Entry数组来存储节点元素,定义size变量表示当前容器中元素的数量,定义threshold变量用于计算扩容的阈值。
// Entry数组
private Entry[] table;
// 容器内元素个数
private int size = 0;
// 扩容计算用阈值
private int threshold;
不同的是Entry节点为WeakReference类的子类,使用引用字段作为键,将弱引用字段(通常是ThreadLocal对象)和值绑定在一起。使用弱引用是为了使得threadLocal对象可以被回收,(如果将key作为entry的一个成员变量,那线程销毁前,threadLocal对象不会被回收掉,即使该threadLocal对象不再使用)。
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
ThreadLocalMap-初始化
提供了带初始键和初始值的map构造方法,还有一个基于已有map的构造方法(用于ThreadLocal的子类InheritableThreadLocal初始化map容器,目的是将父线程的map传入子线程,会在创建子线程的过程中自动执行)。如下所示:
// 基于初始键值的构造函数
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 基于输入键值构建节点
table = new Entry[INITIAL_CAPACITY];
// 根据键的hash值计算所在数组下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 采用懒加载的方式,只创建一个必要的节点
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置阈值为初始长度的2/3,初始长度默认为12,那么阈值为为8
setThreshold(INITIAL_CAPACITY);
}
// 基于已有map的构造函数
private ThreadLocalMap(ThreadLocalMap parentMap) {
// 获取传入map的节点数组
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
// 构造相同长度的数组
table = new Entry[len];
// 深拷贝传入数组中各个节点到当前容器数组
// 注意这里因为采用开放地址解决hash冲突,拷贝后的元素在数组中的位置与原数组不一定相同
for (int j = 0; j < len; j++) {
// 获取数组各个位置上的节点
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
//确保key为InheritableThreadLocal类型,否则抛出UnsupportedOperationException
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
// 根据hash值和数组长度,计算下标
int h = key.threadLocalHashCode & (len - 1);
// 这里采用开放地址的方法解决hash冲突
// 当发生冲突时,就顺延到数组下一位,直到该位置没有元素
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
ThreadLocalMap-移除元素
这里将移除元素的方法放在前面,是因为其它部分会频繁使用过时节点的移除方法。先理解这部分内容有助于后续理解其他部分。
根据key移除容器元素的方法:
private void remove(ThreadLocal<?> key) {
// 计算索引下标
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 从下标i处开始向后寻找是否有key对应节点,直到遇到Null节点
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// 如果遇到key对应节点,执行移除操作
if (e.get() == key) {
// 移除节点的键(弱引用)
e.clear();
// 移除该过时节点
expungeStaleEntry(i);
return;
}
}
}
移除过时节点的执行方法:
移除过时节点除了将该节点置为null之外,还要对该节点之后的节点进行移动,看看能不能往前找合适的空格转移。
这种方法有点类似jvm垃圾回收算法的标记-整理方法。都是将垃圾清除之后,将剩余元素进行整理,变得更紧凑。这里的整理是需要强制执行的,目的是为了保证开放地址法一定能在连续的非null节点块中找到已有节点。(试想,如果把过时节点移除而不整理,该节点为null,将前后节点分开了。而如果后面有某个节点hash计算的下标在前面的节点块,在查找节点时通过开放地址会找不到该节点)。示意图如下:
private int expungeStaleEntry(int staleSlot) {
// 获取entyy数组和长度
Entry[] tab = table;
int len = tab.length;
// 清除staltSlot节点的值的引用,清除节点的引用
tab[staleSlot].value = null;
tab[staleSlot] = null;
// 容器元素个数-1
size--;
// 清除staleSlot节点后的整理工作
// 将staleSlot索引后的节点计算下标往前插空移动
Entry e;
int i;
// 遍历连续的非null节点,直到遇到null节点
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// case1:如果遍历到的节点是过时节点,将该节点清除,容器元素数量-1
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
// case2:如果遍历到的节点不是过时节点,重新计算下标
int h = k.threadLocalHashCode & (len - 1);
// 当下标不是当前位置时,从hash值计算的下标h处,开放地址往后顺延插空
if (h != i) {
// 先将该节点置为null
tab[i] = null;
// 找到为null的节点,插入节点
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
移除所有过时节点的方法:很简单,全局遍历,移除所有过时节点。
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
// 遇到过时节点,清除整理该节点所在的连续块
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
尝试去扫描一些过时节点并清除节点,如果有节点被清除会返回true。这里只执行了logn次扫描判断,是为了在不扫描和全局扫描之间找到一种平衡,是上面的方法的一个平衡。
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
// 遇到过时节点,清除整理该节点所在连续块
if (e != null && e.get() == null) {
n = len;
removed = true;
// 从该连续块后第一个null节点开始
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
ThreadLocalMap-获取元素
获取容器元素的方法:
// 根据key快速查找entry节点
private Entry getEntry(ThreadLocal<?> key) {
// 通过threadLocal对象(key)的hash值计算数组下标
int i = key.threadLocalHashCode & (table.length - 1);
// 取对应下标元素
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// 查找不到有两种情况:
// 1.对应下标桶位为空
// 2对应下标桶位元素不是key关联的entry(开放地址解决hash冲突导致的)
return getEntryAfterMiss(key, i, e);
}
// 初次查找失败后再次查找entry节点
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
// 获取entry数组及长度
Entry[] tab = table;
int len = tab.length;
// 如果e为null,说明对应下标桶位为空,找不到key对应的entry
// 如果e不为null,则用解决hash冲突时的方法(顺延数组下一位)一直找下去,直到找到或e为null
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
return e;
// 在寻找的过程中如果节点的key,即ThreadLocal已经被回收(被弱引用的对象可能会被回收)
// 则移除过时的节点,移除过时节点的方法分析见移除元素部分
if (k == null)
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
// 没有找到,返回null
return null;
}
ThreadLocalMap-增加和修改元素
增加和修改容器元素的方法:
这里在根据hash值计算出下标后,由于是开放地址解决hash冲突,会顺序向后遍历直到遇到null或遇到key对应的节点。
这里会出现三种情况:
case1:遍历时找到了key对应节点,这时直接修改节点的值即可;
case2:遍历中遇到了有过时的节点(key被回收的节点);
case3:遍历没有遇到过时的节点,也没有找到key对应节点,说明此时应该插入新节点(用输入键值构造新节点)。因为是增加新元素,所以可以容量会超过阈值。在删除节点后容量如果超过阈值,则要进行扩容操作。
private void set(ThreadLocal<?> key, Object value) {
// 获取数组,计算key对应的数组下标
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
// 从下标i开始,顺序遍历数组(顺着hash冲突开放地址的路径),直到节点为null
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
// 获取遍历到的节点的key
ThreadLocal<?> k = e.get();
// case1:命中key,说明已存在key对应节点,修改value值即可
if (k == key) {
e.value = value;
return;
}
// case2:如果遍历到的节点的key为null,说明该threadLocal对象已经被回收
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
// case3:遍历节点直到null都没有找到对应key,说明map中没有key对应entry
// 则在该位置用输入键和值新建一个entry节点
tab[i] = new Entry(key, value);
int sz = ++size;
// 判断是否清理过时节点后,在判断是否需要扩容
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
case2:增加和修改过程中遇到已经过时的节点的处理。这里的参数staleSlot表示key计算的下标开始往后遇到的第一个过时节点,不管map中有无key对应的节点,该位置之后一定会存入key的节点。这里定义了一个变量slotToExpunge,其含义是左右连续非null的entry块中第一个过时节点(记录该位置是为了后续清除过时节点可以从slotToExpunge处开始)。示意如下:
这步操作有两种情况:
casse2.1:从过时节点staleSlot往后查找遇到key对应节点,则将staleSlot处节点与key对应节点交换。然后清除整理连续块。
casse2.2:没遇到key对应节点,说明map中不存在key对应节点,则新建一个节点填入staleSlot处。然后清除整理连续块。
private void replaceStaleEntry(ThreadLocal<?> key, Object value,int staleSlot) {
// 获取entry数组和长度
Entry[] tab = table;
int len = tab.length;
Entry e;
// 往前移动寻找第一个过时节点(直到遇到null),如果没找到的话说明第一个过时节点为staleslot处节点
// slotToExpunge表示连续块中第一个过时节点
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len); (e = tab[i]) != null; i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
// 从输入下标staleSlot向后找到第一个出现的key对应的节点或过时的节点(key被回收的节点)
for (int i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// case2.1:如果找到key对应的节点,则用staleSlot处节点和该节点交换,以保持hash表的顺序(hash冲突时顺序向后寻找)
// 交换后的staleSlot节点及其之前的过时节点会被清除
if (k == key) {
// 交换staleSlot处节点和key对应节点
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// 更新slotToExpunge的值,使其保持连续块中第一个过时节点的特性,方便后续清理过时节点。
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 从slotToExpunge开始清除整理连续块
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// 如果遇到过时节点,更新slotToExpunge的值
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// case2.2:没有找到key对应节点,增加新节点并填入staleSlot处
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 这里如果slotToExpunge=staleSlot,说明连续块中只有一个过时节点,且已经被新建节点填入,就不需要再整理。
// 如果除了原staleSlot处,还有其它过时节点,从slotToExpunge开始清除整理连续块
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
case3:增加元素后可能超过阈值导致的扩容处理
private void rehash() {
// 清除所有过时节点
expungeStaleEntries();
// 在清除所有过时节点后,如果数量超过3/4的阈值,则进行扩容处理
// setThreshold()方法非公有,threshold值一直为数组长度的2/3,所以这里是超过数组长度一半就进行扩容
if (size >= threshold - threshold / 4)
resize();
}
/**
* 双倍扩容
*/
private void resize() {
// 获取旧数组和长度
Entry[] oldTab = table;
int oldLen = oldTab.length;
// 新数组长度为原来的两倍
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
// 遍历原数组元素
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
// 如果为非null节点
if (e != null) {
ThreadLocal<?> k = e.get();
// 如果是过时节点,则将value置为null,可以使得value的实体尽快被回收
if (k == null) {
e.value = null; // Help the GC
} else {
// 如果是正常节点,计算下标,重新填入新数组(开放地址解决hash冲突)
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
// 新数组元素个数+1
count++;
}
}
}
// 重新设置阈值
setThreshold(newLen);
size = count;
// 将变量table指向新数组
table = newTab;
}
ThreadLocalMap-内存泄露问题以及对设计的一些思考
先来聊一聊内存泄漏这个概念。我的理解是有一块内存空间,如果不再被使用但又不能被垃圾回收器回收掉,那么就相当于这块内存少了这块空间,即出现了内存泄露问题。如果内存泄露的空间一直在积累,那么最终会导致可用空间一直减少,最终可能导致程序无法运行。
ThreadLocalMap中也是有可能会出现该问题的,map中entry节点的key为弱引用,如果key没有其它强引用,是会被垃圾收集器回收的。回收之后,map中该节点的value就不会再被使用,但value又被entry节点强引用,不会被回收。这就相当于value这块内存空间发生了泄露。所以能看到在源码中很多方法都进行了清除过时节点的操作,为的就是尽量避免内存泄漏。
在看源码时,一直在思考为什么entry节点的键要采用弱引用的方式。不妨反过来思考,如果entry节点将threadLocal对象作为一个成员变量,而不是采用弱引用的方式,那么entry节点一直对key和value保持着强引用关系,即使threadlocal对象在其它地方都不再使用,该对象也不会被回收。这就会导致entry节点永远不会被回收(只要线程不终结),而且也不能主动去判断是否切断map中threadlocal对象的引用(不知道是否还有其它地方引用到了)。
因为map是Thread对象的一个成员变量,线程不终结,map是不会被回收的,如果发生了内存泄露的问题,可能会一直积累下去,最终导致程序发生异常。而key采用弱引用加之主动的判断过时节点(判断是否过时很简单,看key是否为null即可)并进行清除处理可以最大限度的减少内存泄露的发生。