ThreadLocal 源码分析
2021-07-04 15:20 wang03 阅读(423) 评论(1) 编辑 收藏 举报1、ThreadLocal 源码分析
-
在多线程开发中,我们经常会使用ThreadLocal来避免共享变量的竞争,提高效率。ThreadLocal底层到底是怎么实现的呢,今天就带大家一起来看看它底层实现。另外也会随便分析下网上讨论比较多的关于ThreadLocal内存泄漏等等究竟是怎么一回事
我本地的jdk版本是11.0.8,不同版本的jdk,threadLocal源码实现可能有差别,不过大致是一样的。
-
首先看下我们一般都是怎么使用ThreadLocal的
这是一段使用threadLocal的demo代码
public class ThreadLocalDemo {
public final static ThreadLocal<String> threadLocal = new ThreadLocal<>() {
@Override
protected String initialValue() {
return "initValue";
}
};
public static void main(String[] args) throws InterruptedException {
new Thread(() -> {
System.out.println("init Value =" + threadLocal.get());
threadLocal.set("abc");
System.out.println("执行其他逻辑");
String str = threadLocal.get();
System.out.println(str);
threadLocal.remove();
}).start();
Thread.currentThread().join();
}
}
这段代码比较简单,不用具体说threadLocal的这些方法都是干啥的了。我们直接顺着main方法里面的调用顺序一起去看看这些调用背后都是怎么实现的。
-
如果直接把类的源码粘上来,做分析,感觉太零散了,看不到方法之间的调用关系,所以,这次就准备按照调用逻辑,一步一步来分析了
-
首先我们main方法是在第11行调用了threadLocal.get(),这是我们第一次主动调用threadLocal的地方,那我们先从这里进去
//这就是ThreadLocal的get方法了,我们泛型参数是String,所以这里的T也就是String了,下面我们一行一行根据我们上面的Demo来分析下这个代码
public T get() {
//这一行就不用说了,Thread.currentThread()就是获取当前线程
Thread t = Thread.currentThread();
//getMap(t)这行是干什么呢?我们这个方法的代码比较简单,我就直接粘贴到下面了
//ThreadLocalMap getMap(Thread t) {
// return t.threadLocals;
//}
//上面这3行就是getMap(t)调用的代码了,比较简单,获取线程t上的threadLocals属性,这个属性是什么东西呢?我们再看看这个属性
//ThreadLocal.ThreadLocalMap threadLocals = null;
//上面这一行就是在Thread类中定义的threadLocals属性了,看样子是ThreadLocal类中的内部类ThreadLocalMap的一个实例,具体是啥,我们先不仔细看了,继续看后面代码吧
ThreadLocalMap map = getMap(t);
//这里返回的map当前是null,因为Thread的threadLocals,从来没有初始化过
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//因为map==null,所以会调用到这个setInitialValue这个方法,从这里返回,我们继续看看这个方法吧
return setInitialValue();
}
private T setInitialValue() {
//这里的initialValue方法是protected的,默认返回null,由于我们上面Demo第4行重写了initialValue方法,所以这里的调用就是我们上面的代码,这里的返回应该是上面我们Demo第5行的"initValue"
T value = initialValue();
//这几行代码和上面get方法是一致的
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
//map==null
if (map != null) {
map.set(this, value);
} else {
//会走到这里来,我们去这个createMap方法看看
createMap(t, value);
}
//这里的this是通过匿名继承的ThreadLocal的,不会走到这个instanceof内部去
if (this instanceof TerminatingThreadLocal) {
TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
}
//我们上面的Demo第4行的threadLocal.get()调用最终就会从这里返回,返回值是"initValue"
return value;
}
//这里就是初始化Thread的threadLocals属性了,后面这个属性就不是null了,这里创建了个ThreadLocalMap对象,有两个参数,第一个this就是我们Demo中调用threadLocal.get()方法的对象,也就是Demo第4行的threadLocal对象,firstValue就是上面的字符串"initValue"。
//我们进到ThreadLocalMap构造方法去看看
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
//
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//这里首先会创建一个Entry的数组,INITIAL_CAPACITY = 16;也就是这里创建一个大小是16的Entry数组,Entry是啥,我们下面再看,先看看这个构造方法的其他几行代码
table = new Entry[INITIAL_CAPACITY];
//threadLocalHashCode是ThreadLocal的成员变量,
// private final int threadLocalHashCode = nextHashCode();
//上面是它的定义,从上面可以看到每次创建ThreadLocal对象时就会初始化threadLocalHashCode,它的值是通过静态方法nextHashCode()赋值的,每调用一次nextHashCode返回值就在上次的基础上增加0x61c88647。
//这里的i就是threadLocalHashCode的值和(INITIAL_CAPACITY - 1)进行与运算,这里的(INITIAL_CAPACITY - 1)值是15,计算结果在0-15之间,作为数组的下标
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
//这里创建个Entry对象,赋值给table中下标为i的
table[i] = new Entry(firstKey, firstValue);
//这里的size是table数组中元素的个数
size = 1;
//这里是设置threshold = len * 2 / 3;当size的值超过threshold时,table数组就会扩容成原来数组的2倍
setThreshold(INITIAL_CAPACITY);
}
//这里我们看看Entry对象
//这就是Entry的源码了,更简单。继承了WeakReference对象,这里会把构造方法的ThreadLocal入参包装成弱引用。具体啥是弱引用,下面粘贴上一段《深入理解java虚拟机》上面的描述。放到我们这里来说,就是ThreadLocal变量在没有其他地方引用(只有在Entry这里有引用),当下次垃圾回收的时候,就会被回收掉
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
//
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
下面是《深入理解java虚拟机》上面的关于引用的描述
到这里上面Demo第4行的源码就全部看完了,下面简单总结下执行逻辑。
- 首先进入ThreadLocal的get方法,获取当前线程的threadLocals变量,我们这个变量没有初始化过,所以这个变量为空,继续执行setInitialValue()方法,并从这里返回
- 在setInitialValue方法中首先调用我们Demo中重写ThreadLocal的initialValue方法,获取返回值。
- 继续调用createMap方法,
- 创建ThreadLocalMap对象,内部创建Entry数组,将threadLocal对象包装成弱引用及initialValue方法的返回值创建Entry对象,填充到Entry数组数组中
- 将创建的ThreadLocalMap对象赋值给当前线程的threadLocals变量
- 将2中获取的值返回
现在我们看下Demo中的第12行threadLocal.set("abc")
//这就是ThreadLocal的set方法了,和get方法差不多
public void set(T value) {
//获取当前线程
Thread t = Thread.currentThread();
//获取当前线程的threadLocals的属性,在上面get方法已经设置过这个属性了 ,所以这里不为空了
ThreadLocalMap map = getMap(t);
if (map != null) {
//map不为null,就会走到这里了,这里的this就是我们Demo中继承ThreadLocal的匿名内部类,value就是"abc",下面我们重点去看下这个方法
map.set(this, value);
} else {
createMap(t, value);
}
}
//这个方法是ThreadLocal的内部类ThreadLocalMap的。
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
//这个是ThreadLocalMap构造方法创建的table,是Entry数组
Entry[] tab = table;
int len = tab.length;
//这个是获取根据key计算在数组中的下标,考虑到有可能两个key计算出来的是同一个i,所以数组中下标i不一定就是我们需要的key,会从当前i的下标向后遍历。同样根据key获取值的时候,也会有类似情况
int i = key.threadLocalHashCode & (len-1);
//获取下标i的值Entry,进行遍历,直到entry.key==我们的入参key或者entry==null或者entry.key==null(这个场景就是key其他地方没有引用了,只有Entry有对应的弱引用,在下次垃圾回收后,entry.key就会==null)的情况下,结束循环
//注意:这里不会出现数组中所有下标都满了,且entry.k!=key的场景,因为下面的后面会判断size>=threshold,进行扩容
for (Entry e = tab[i];
e != null;
//考虑到可能会有冲突,也就是上面说的两个key计算出来的是同一个i,所以key有可能不存在对应下标i的位置
//nextIndex(i, len)向后获取下一个位置是循环的,如果达到i==len-1;这时i=0;就会从数组头元素开始,后面几个方法说的向前遍历,向后遍历都是类似的
e = tab[i = nextIndex(i, len)]) {
//获取Entry中的key,这里是个弱引用,通过get方法获取弱引用的实际对象。
ThreadLocal<?> k = e.get();
//找到了对应的key,就设置新值,返回
if (k == key) {
e.value = value;
return;
}
//k==null,这个场景说了,就是外部没有对ThreadLocal对象的其他引用了(强引用),GC释放后,k就是null,jiu就会走到这里
if (k == null) {
//这个方法会从当前下标i开始,删除过期的entry,重新设置新的Entry到i的位置
replaceStaleEntry(key, value, i);
return;
}
}
//走到这里,这时下标i对应的Entry已经是null了
tab[i] = new Entry(key, value);
//数组中元素个数+1
int sz = ++size;
//这里会移除一些过期的entry,判断sz>= threshold,如果成立,就扩容数组
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
//这个方法主要是
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
//这里的for循环都不会陷入死循环,因为上面有 sz >= threshold
//获取Entry数组
Entry[] tab = table;
int len = tab.length;
Entry e;
// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
//从当前入参staleSlot开始向前遍历,直到下标i对应的Entry==null
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
//随着i向前遍历,slotToExpunge的值会逐步更新,直到staleSlot和它之前第一个tab[i]==null之间,e.get()==null的下标,如果threadLocal对象外部一直有引用,那e.get()==null,就不会为null,也不会走到这个if分支
if (e.get() == null)
slotToExpunge = i;
// Find either the key or trailing null slot of run, whichever
// occurs first
//从当前入参staleSlot开始向后遍历,直到下标i对应的Entry==null
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
//获取到对应的Entry的threaLocal变量
ThreadLocal<?> k = e.get();
// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
//如果k==key,这时就是更新value就可以了
if (k == key) {
e.value = value;
//这里会进行元素交换,把找到的Entry换到入参staleSlot的位置
//注意当前staleSlot元素位置是过期的,需要清理的,交换后i下标位置的元素就是需要清理的
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// Start expunge at preceding stale entry if it exists
//这个条件成立的话,需要第一个for循环中if分支没有进入,也就是不存在staleSlot和它之前第一个tab[i]==null之间,e.get()==null
if (slotToExpunge == staleSlot)
//走到这里说明staleSlot之前到tab[i]==null之间没有无效元素,我们上面进行了元素交换,这时i就是第一个过期的元素了
slotToExpunge = i;
//清理无效的元素,slotToExpunge就是staleSlot之前到tab[i]==null到之后到tab[i]==null这段元素之间第一个过期元素的下标位置
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
//k==null说明当前数组已经有entry失效了,slotToExpunge == staleSlot说明staleSlot之前没有失效的,这时就要清理后面的过期元素
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// If key not found, put new entry in stale slot
//入参的staleSlot下标已经个个过期的值,将value设置为null,重新赋值个新的entry
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
//走到这里说明staleSlot向前遍历或者向后遍历中出现了k==null,这时需要清理过期的entry
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
//从staleSlot位置开始,直到数组中entry==null。清理在这过程中数组过期的元素,并调整那些元素有效,但是下标位置不正确的元素位置
private int expungeStaleEntry(int staleSlot) {
Entry[] tab = table;
int len = tab.length;
// expunge entry at staleSlot
//清空对应staleSlot对应的元素,size-1
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
// Rehash until we encounter null
Entry e;
int i;
//在staleSlot开始向后遍历,直到数组中entry==null
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//如果k==null,就清空元素,size-1
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//h是算法计算出来元素应该在数组中的下标位置
int h = k.threadLocalHashCode & (len - 1);
//i是元素的直接存储下标,由于碰撞的原因元组有可能不是存在算法计算出来的下标位置
if (h != i) {
//h!=i说明元素根据算法出来的下标和实际存储下标不一致,这时由于我们清空了一些过期元素,这时就需要重新调整这些有效的,算法计算出来的下标,和实际存储下标不一致的元素位置
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
//从h开始向后遍历,找到数组中的null位置,将当前下标i位置上的元素设置到对应位置
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
//这里的i就是for循环中退出条件,staleSlot开始向后遍历,数组中第一个entry==null的位置
return i;
}
//从i开始清理无效的元素
//注意下标i的位置不是无效元素,要不下标i位置==null,要不是有效元素
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
//获取i后面的位置
i = nextIndex(i, len);
Entry e = tab[i];
//如果i位置的元素是过期的,就执行清理
if (e != null && e.get() == null) {
n = len;
removed = true;
//这里上面说的清理过期元素,调整位置不正确的有效元素的位置
i = expungeStaleEntry(i);
}
//这里的n>>>=1主要是扫描次数,应该是出于效率的考虑
} while ( (n >>>= 1) != 0);
return removed;
}
//这就是扩容了
private void rehash() {
//这里首先清理数组中所有的过期元素,同时会调整不正确元素的下标
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis
//由于我们上面会清理过期元素,所以size有可能变小,有可能就不需要扩容了,所以这里重新判断是否需要扩容,如果需要就进行扩容
if (size >= threshold - threshold / 4)
resize();
}
//这里就是扩容了,数组长度变成原来的2倍,重新设置threshold和size
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
//长度设置成原来的2被
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (Entry e : oldTab) {
if (e != null) {
ThreadLocal<?> k = e.get();
//继续清理扩容过程中过期的元素
if (k == null) {
e.value = null; // Help the GC
} else {
//在新数组中设置元素位置
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
上面就是Demo第12行 threadLocal.set("abc")方法涉及的所有代码了 。
下面总结下:
1.获取当前线程的threadLocals变量,如果不存在就调用createMap(t, value),创建并设置值,我们Demo第11行threadLocal.get()就走的这里
2.如果当前线程的threadLocals变量存在,就调用map.set(this, value)设置值
3.在获取值的过程中首先根据调用者threadLocal对象计算出应该存储在数组中的下标,
- 如果当前下标对应的数组元素是null,就新生成Entry元素放入数组下标i的位置,数组size+1,判断是否需要扩容,如果需要就对数组进行扩容
- 如果当前下标对应的数组元素不是null,就获取对应位置的元素进行判断,如果entry.key == 我们调用的threalocal对象,就更新value返回。如果entry.key==null,说明元素释放了就调用replaceStaleEntry进行处理。如果这两种情况都不是,那就继续从当前下标开始向后遍历,继续判断处理
最后看下Demo中第16行threadLocal.remove();
//获取当前线程的threadLocals变量,如果之前调用过get,set方法,那这时这个变量就不是null了
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null) {
//走到这里看看移除元素的方法,这里的this,是我们方法的调用者,也就是threadLocal实例对象
m.remove(this);
}
}
//这个方法就比较简单了,计算threadLocal下标,从这个位置开始向后遍历,对应元素(这里使用的是key==,所以找到的肯定是同一个对象),将Entry中引用的threadLocal对象清空,再执行清理过期元素的动作
private void remove(ThreadLocal<?> key) {
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
e.clear();
expungeStaleEntry(i);
return;
}
}
}
上面remove方法也讲完了,这个比较简单。就是找到t.threadLocals中对应的元素Entry,调用e.clear()清理掉entry引用的threadLocal变量(这时,通过e.get()获取的threadLocal元素就是null了),然后调用expungeStaleEntry执行过期元素的清理。
2、日常使用注意
上面就是我们日常使用threadLocal中方法的源码了 ,通过上面代码对于调用的内部细节我们也基本看到了,下面说一些日常使用过程中需要注意的地方。
它们的结构简单画个图,就是下面的样子了
文章开始的Demo只是一些日常使用,通过上面的源码阅读,我们其实还是可以看到一些其他的使用方式和注意地方
- 我们一般是在多个线程中是使用同一个threadLocal对象,其实我们也可以在一个线程中使用多个threadLocal对象(同样多个线程也就可以使用多个threadLocal对象),就像下面这样
//创建多个ThreadLocal对象
ThreadLocal<String> threadLocal0 = new ThreadLocal<>();
ThreadLocal<String> threadLocal1 = new ThreadLocal<>();
new Thread(() -> {
//在一个线程中使用多个threadLocal对象
//通过上面源码我们看到threadLocal对象只是用来寻找当前线程中threadLocals变量中数组的位置,并读取或者设置值。由于查找元素下标用的是==,所以无论怎么找到的都是同一个对象实例
threadLocal0.set("th0");
threadLocal1.set("th1");
System.out.println(threadLocal1.get());
System.out.println(threadLocal0.get());
}).start();
Thread.currentThread().join();
-
内存释放相关
上面的源码分析部分,我们看到了存储到threadLocals中Entry的threadLocal变量是弱引用了,而弱引用在下次gc的时候就会被清理,threadLocal被清理了,上面的清理过期元素的方法就会把对应Entry进行清理。但是这里有个前提,threadLocal只有弱引用的时候才会被清理,如果有强引用存在,就不会被清理。
看下面的例子
ThreadLocal<String> threadLocal0 = new ThreadLocal<>();
new Thread(() -> {
//我们第一行定义的地方,threadLocal对象是个强引用,只要这个强引用存在,threadLocals中对应的Entry就不会被清理
threadLocal0.set("th0");
//这个也很好理解,如果被清理掉了,那我们这里的get方法就获取不到值了,我们见过之前set后,对应get获取不到值的情况吗,肯定没有么
System.out.println(threadLocal0.get());
//所以我们想要清理存放的元素Entry就需要调用threadLocal0.remove()进行清理,或者是调用 threadLocal0=null;方法,这样当前 threadLocal0就只有线程threadLocals中对应的弱引用,这时,gc才会清理掉entry.k,对应元素才会在清理元素方法中清理掉
}).start();
Thread.currentThread().join();
另外由于ThreadLocal的操作都是基于当前线程的threadLocals变量的,如果当前线程不存在了,被清理掉了 ,只要threadLocals变量和内部Entry的key和value,如果没有其他地方进行引用,也都会被gc清理掉。
以上就是关于ThreadLocal的全部内容了。
另外父线程和子线程之间传递参数可以通过inheritableThreadLocals,这个变量的设置实在Thread的构造方法中,和threadLocals一样,都是ThreadLocal.ThreadLocalMap的实例。这个用的不多,这里就不说了。