JDK源码之ThreadLocal 类分析
一 概述
ThreadLocal类提供了线程局部 (thread-local) 变量。这些变量与普通变量不同,每个线程都可以通过其 get 或 set方法来访问自己的独立初始化的变量副本
ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联,类里面定义了一个map,key为ThreadLocal,value为值,存储每一个线程的变量
Thread类里面引用了这个内部类的map实例,从而达到线程隔离
二 源码分析
属性
// 获取下一个hashCode值
private final int threadLocalHashCode = nextHashCode();
// 获取下一个hashCode,ThreadLocal 中使用了斐波那契散列法,来保证哈希表的离散度
private static int nextHashCode() {
// 每一次获取值时,加上HASH_INCREMENT为下一次获取的值
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
// 开始为0,每次创建ThreadLocal实例,值都会累加
private static AtomicInteger nextHashCode = new AtomicInteger();
// 加数值: 0x61c88647 = 2^32 * 黄金分割比(0.618),
// 斐波那契数列: 当n趋向于无穷大时,前一项与后一项的比值越来越逼近黄金比,即0.618
private static final int HASH_INCREMENT = 0x61c88647;
//提供给子类初始化值使用
protected T initialValue() {
return null;
}
核心方法
// 通过Supplier函数初始化变量值,使用子类构造器进行初始化
public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
return new ThreadLocal.SuppliedThreadLocal<>(supplier);
}
// ThreadLocal扩展子类,接收一个Supplier函数进行值初始化
static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
private final Supplier<? extends T> supplier;
SuppliedThreadLocal(Supplier<? extends T> supplier) {
this.supplier = Objects.requireNonNull(supplier);
}
// 覆写父类方法
@Override
protected T initialValue() {
return supplier.get();
}
}
public ThreadLocal() {}
// 返回当前线程的变量副本值
public T get() {
Thread t = Thread.currentThread();
// 获取线程实例引用的ThreadLocalMap
ThreadLocal.ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
// 值为空,重新初始化
return setInitialValue();
}
ThreadLocal.ThreadLocalMap getMap(Thread t) { return t.threadLocals; }
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocal.ThreadLocalMap map = getMap(t);
// 赋值
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocal.ThreadLocalMap(this, firstValue);
}
// 设置ThreadLocal的值,还是设置的thread实例引用的map
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocal.ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
//清除值,key是ThreadLocal,所以直接使用this
public void remove() {
ThreadLocal.ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
// 调用map的remove方法
m.remove(this);
}
// Thread 类里面创建线程时候初始化init方法调用的,主要是复制参数中的table构建一个新map返回
static ThreadLocal.ThreadLocalMap createInheritedMap(ThreadLocal.ThreadLocalMap parentMap) {
return new ThreadLocal.ThreadLocalMap(parentMap);
}
// 只允许子类调用此方法,ThreadLocal调用直接抛异常
T childValue(T parentValue) { throw new UnsupportedOperationException(); }
三 静态内部类ThreadLocalMap
ThreadLocal中的静态内部类ThreadLocalMap,这个类本质上是一个map,和HashMap之类的实现相似,依然是key-value的形式,其中有一个内部类Entry,其中key可以看做是ThreadLocal实例,
在ThreadLocal中并没有对于ThreadLocalMap的引用,ThreadLocalMap的引用在Thread类中
每个线程在向ThreadLocal里塞值的时候,其实都是向自己所持有的ThreadLocalMap里塞入数据;
读的时候同理,首先从自己线程中取出自己持有的ThreadLocalMap,然后再根据ThreadLocal引用作为key取出value,
基于以上描述,ThreadLocal实现了变量的线程隔离
Entry
/**
* Entry继承WeakReference,(弱引用:当一个对象仅仅被weak reference指向, 而没有任何其他strong reference指向的时候, 如果GC运行, 那么这个对象就会被回收)
* 并且用ThreadLocal作为key.如果key为null
* (entry.get() == null)表示key不再被引用,表示ThreadLocal对象被回收
* 因此这时候entry也可以从table从清除。
*/
static class Entry extends WeakReference<ThreadLocal<?>> {
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
// 初始容量
private static final int INITIAL_CAPACITY = 16;
// 存放数据的数组
private Entry[] table;
// 数组里面entrys的个数,可以用于判断table当前使用量是否超过负因子
private int size = 0;
// 进行扩容的阈值,表使用量大于它的时候进行扩容
private int threshold; // Default to 0
// 设置阈值为参数的三分之二
private void setThreshold(int len) { threshold = len * 2 / 3;}
核心方法
/**
* ThreadLocalMap使用线性探测法来解决哈希冲突,线性探测法的地址增量di = 1, 2, ... , m-1,其中,i为探测次数。
* 该方法一次探测下一个地址,直到有空的地址后插入,若整个空间都找不到空余的地址,则产生溢出。
* 假设当前table长度为16,也就是说如果计算出来key的hash值为14,如果table[14]上已经有值,并且其key与当前key不一致,那么就发生了hash冲突,
* 这个时候将14加1得到15,取table[15]进行判断,这个时候如果还是冲突会继续下一个,如果是最后的则重新回到0,取table[0],以此类推,直到可以插入。
* 可以把table看成一个环形数组
*/
// 获取环形数组的下一个索引
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
// 获取环形数组的上一个索引
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
// 构造器初始化
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
// 初始化table
table = new Entry[INITIAL_CAPACITY];
// 计算索引, & (INITIAL_CAPACITY - 1),这是取模的一种方式,对于2的幂作为模数取模,用此代替%(2^n),目的是均匀分布在数组中
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
// 设置值
table[i] = new Entry(firstKey, firstValue);
size = 1;
// 设置阈值
setThreshold(INITIAL_CAPACITY);
}
/**
* 构建一个包含所有parentMap中Inheritable ThreadLocals的ThreadLocalMap返回
* 该函数只被 createInheritedMap() 调用.即只在Thread类的init方法里面调用(init方法调用了createInheritedMap()),目的是将父线程的变量值复制到当前线程中
*/
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];
// 逐一复制 parentMap 的记录到当前线程中
for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
//此处获取的都是InheritableThreadLocal类,ThreadLocal的子类,用于父线程变量传递的类
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
// 如果是ThreadLocal,就会抛异常
Object value = key.childValue(e.value);
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}
private Entry getEntry(ThreadLocal<?> key) {
// 计算数组下标
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
// e不为空,但是key不相等时候再找
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
// 线性探测法,循环查找
while (e != null) {
ThreadLocal<?> k = e.get();
if (k == key)
// 找到就返回
return e;
if (k == null)
// 如果key为null,则清除掉无效的entry
expungeStaleEntry(i);
else
// 环形向后扫描
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
private void set(ThreadLocal<?> key, Object value) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
/**
* 根据获取到的索引进行循环,如果当前索引上的table[i]不为空,在没有return的情况下,
* 就使用nextIndex()获取下一个地址,即线性探测法。
*/
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
// key相同则更新value,set成功
e.value = value;
return;
}
/**
* table[i]上的key为空,说明被回收了
* 这个时候说明改table[i]可以重新使用,用新的key-value将其替换,并删除其他无效的entry
*/
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
/**
* Remove the entry for key.
*/
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
//key-value替换 staleSlot位置的值
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
/**
* 根据传入的无效entry的位置(staleSlot),向前扫描,当前已经是无效的,可能前面也有无效的,找到最开始无效的一个进行替换,维护线性探测法
* 一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
* 直到找到一个无效entry,或者扫描完也没找到
*/
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
/**
* 向后扫描一段连续的entry
*/
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
/**
* 如果找到了key,直接替换,也就是与table[staleSlot]进行替换,此时staleSlot位置已经替换了对应key的值
*/
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
//如果向前查找没有找到无效entry,则更新slotToExpunge为当前值i
if (slotToExpunge == staleSlot)
slotToExpunge = i;
// 此时,staleSlot位置已经设置值了,应该从 slotToExpunge 位置开始往后清除
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
/**
* 如果向前查找没有找到无效entry,并且当前向后扫描的entry无效,则更新slotToExpunge为当前值i
*/
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
/**
* 如果没有找到key,也就是说key之前不存在table中
* 就直接最开始的无效entry——tab[staleSlot]上直接新增即可
*/
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
/**
* slotToExpunge != staleSlot,说明存在其他的无效entry需要进行清理。
*/
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
/**
* 连续段清除
* 根据传入的staleSlot,清理对应的无效entry——table[staleSlot],
* 并且根据当前传入的staleSlot,向后扫描一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
* 对可能存在hash冲突的entry进行rehash,并且清理遇到的无效entry.
* @param staleSlot key为null,需要无效entry所在的table中的索引
* @return 返回下一个为空的solt的索引。
*/
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// key为null,直接清理
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
/**
* 计算出来的索引h,与其现在所在位置的索引——i不一致,置空当前的table[i]
* 从h开始向后线性探测到第一个空的slot,把当前的entry挪过去。
*/
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
/**
* 启发式的扫描清除,扫描次数由传入的参数n决定
* @param i 从i向后开始扫描(不包括i,因为索引为i的Slot肯定为null)
* @param n 控制扫描次数,正常情况下为 log2(n) ,
* 如果找到了无效entry,会将n重置为table的长度len,进行段清除。
* map.set()点用的时候传入的是元素个数,replaceStaleEntry()调用的时候传入的是table的长度len
*/
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
//无符号的右移动,可以用于控制扫描次数在log2(n)
} while ( (n >>>= 1) != 0);
return removed;
}
private void rehash() {
expungeStaleEntries();
/**
* threshold = 2/3 * len
* 所以threshold - threshold / 4 = 1en/2
* 这里主要是因为上面做了一次全清理所以size减小,需要进行判断。
* 判断的时候把阈值调低了。
*/
if (size >= threshold - threshold / 4)
resize();
}
/**
* 扩容,扩大为原来的2倍(这样保证了长度为2的冥)
*/
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
// 清空全部Entry
private void expungeStaleEntries() {
Entry[] tab = table;
int len = tab.length;
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}