【JAVA】ThreadLocal源码分析
ThreadLocal内部是用一张哈希表来存储:
1 static class ThreadLocalMap { 2 static class Entry extends WeakReference<ThreadLocal<?>> { 3 /** The value associated with this ThreadLocal. */ 4 Object value; 5 6 Entry(ThreadLocal<?> k, Object v) { 7 super(k); 8 value = v; 9 } 10 } 11 private static final int INITIAL_CAPACITY = 16; 12 private Entry[] table; 13 private int size = 0; 14 private int threshold; 15 ......
看过HashMap的话就很容易理解上述内容【Java】HashMap源码分析
而在Thread类中有一个ThreadLocalMap 的成员:
1 ThreadLocal.ThreadLocalMap threadLocals = null;
所以不难得出如下关系:
每一个线程都有一张线程私有的Map,存放多个线程本地变量
set()方法:
1 public void set(T value) { 2 Thread t = Thread.currentThread(); 3 ThreadLocalMap map = getMap(t); 4 if (map != null) 5 map.set(this, value); 6 else 7 createMap(t, value); 8 } 9 10 ThreadLocalMap getMap(Thread t) { 11 return t.threadLocals; 12 }
不难看出,先获取当前线程的Thread对象,再得到该Thread对象的ThreadLocalMap 成员map,若map为空,需要先createMap()方法,若不为空,则需要调用map的set()方法
1 void createMap(Thread t, T firstValue) { 2 t.threadLocals = new ThreadLocalMap(this, firstValue); 3 } 4 ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) { 5 table = new Entry[INITIAL_CAPACITY]; 6 int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1); 7 table[i] = new Entry(firstKey, firstValue); 8 size = 1; 9 setThreshold(INITIAL_CAPACITY); 10 } 11 private void setThreshold(int len) { 12 threshold = len * 2 / 3; 13 }
createMap方法会创建一个ThreadLocalMap对象,在ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue)构造方法中,可以看出和HashMap很相似,通过firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1)取模,计算出哈希表的下标,将创建好的Entry对象放入该位置,再根据表长计算阈值,可以看出负载因子是2/3,初始哈希表的大小是16。
1 private void set(ThreadLocal<?> key, Object value) { 2 Entry[] tab = table; 3 int len = tab.length; 4 int i = key.threadLocalHashCode & (len-1); 5 6 for (Entry e = tab[i]; 7 e != null; 8 e = tab[i = nextIndex(i, len)]) { 9 ThreadLocal<?> k = e.get(); 10 11 if (k == key) { 12 e.value = value; 13 return; 14 } 15 16 if (k == null) { 17 replaceStaleEntry(key, value, i); 18 return; 19 } 20 } 21 22 tab[i] = new Entry(key, value); 23 int sz = ++size; 24 if (!cleanSomeSlots(i, sz) && sz >= threshold) 25 rehash(); 26 }
不难看出,通过key.threadLocalHashCode & (len-1)计算出哈希表的下标,判断该位置的Entry是否为null,若为null,则创建Entry对象,将其放入该下标位置;若Entry已存在,则需要解决哈希冲突,重新计算下标。最后size自增,再根据!cleanSomeSlots(i, sz) && sz >= threshold进行判断是否需要进行哈希表的调整。
在解决哈希冲突的上,常用的有开链法、线性探测法和再散列法,HashMap中使用的是开链法,而ThreadLocal使用的是线性探测法,即发生哈希冲突,往后移动到合适位置。
1 private static int nextIndex(int i, int len) { 2 return ((i + 1 < len) ? i + 1 : 0); 3 } 4 private static int prevIndex(int i, int len) { 5 return ((i - 1 >= 0) ? i - 1 : len - 1); 6 }
从这两个操作看出,ThreadLocal中的哈希表是利用了循环数组的方式,进行环形的线性探测
在上述for循环中,会取出该Entry上的ThreadLocal对象(键)进行判断,若相同则直接覆盖,若为null,说明该Entry空间存在但其ThreadLocal对象的指向为null,需要进行调整;若都不成立,则继续循环,重复以上操作。
Entry空间指向存在但ThreadLocal对象的指向为null是因为Entry继承自WeakReference<ThreadLocal<?>>,是弱引用,存在被GC的情况,所以会存在这种情况,视为脏Entry,接下来的操作就是通过replaceStaleEntry进行处理。
1 private void replaceStaleEntry(ThreadLocal<?> key, Object value, 2 int staleSlot) { 3 Entry[] tab = table; 4 int len = tab.length; 5 Entry e; 6 7 int slotToExpunge = staleSlot; 8 for (int i = prevIndex(staleSlot, len); 9 (e = tab[i]) != null; 10 i = prevIndex(i, len)) 11 if (e.get() == null) 12 slotToExpunge = i; 13 14 for (int i = nextIndex(staleSlot, len); 15 (e = tab[i]) != null; 16 i = nextIndex(i, len)) { 17 ThreadLocal<?> k = e.get(); 18 19 if (k == key) { 20 e.value = value; 21 22 tab[i] = tab[staleSlot]; 23 tab[staleSlot] = e; 24 25 if (slotToExpunge == staleSlot) 26 slotToExpunge = i; 27 cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); 28 return; 29 } 30 31 if (k == null && slotToExpunge == staleSlot) 32 slotToExpunge = i; 33 } 34 35 tab[staleSlot].value = null; 36 tab[staleSlot] = new Entry(key, value); 37 38 if (slotToExpunge != staleSlot) 39 cleanSomeSlots(expungeStaleEntry(slotToExpunge), len); 40 }
可以清楚看到第一个for循环前向遍历查找脏Entry,用slotToExpunge保存脏Entry下标;
第二个for循环后向遍历,若遇到ThreadLocal向同,更新value,然后与下标为staleSlot(传入进来的脏Entry)进行交换,接着判断前向查找脏Entry是否存在,slotToExpunge == staleSlot说明的就是前向查找没找到,就更改slotToExpunge的值,然后进行清理操作,结束掉;若后向遍历遇到脏Entry,并且前向没找到,更改slotToExpunge的值,为清理时用,继续循环。
若不存在和ThreadLocal引用相同的Entry,则需要将staleSlot的位置的Entry替换为一个新的Entry对象,tab[staleSlot].value = null是为了GC;
最后根据slotToExpunge来判断前向后向遍历中是否存在脏Entry,若存在还需要进行清理。
其中的expungeStaleEntry方法如下:
1 private int expungeStaleEntry(int staleSlot) { 2 Entry[] tab = table; 3 int len = tab.length; 4 5 // expunge entry at staleSlot 6 tab[staleSlot].value = null; 7 tab[staleSlot] = null; 8 size--; 9 10 // Rehash until we encounter null 11 Entry e; 12 int i; 13 for (i = nextIndex(staleSlot, len); 14 (e = tab[i]) != null; 15 i = nextIndex(i, len)) { 16 ThreadLocal<?> k = e.get(); 17 if (k == null) { 18 e.value = null; 19 tab[i] = null; 20 size--; 21 } else { 22 int h = k.threadLocalHashCode & (len - 1); 23 if (h != i) { 24 tab[i] = null; 25 26 // Unlike Knuth 6.4 Algorithm R, we must scan until 27 // null because multiple entries could have been stale. 28 while (tab[h] != null) 29 h = nextIndex(h, len); 30 tab[h] = e; 31 } 32 } 33 } 34 return i; 35 }
可以看到,先把当前位置的脏Entry清除掉(置为null),size自减。然后从当前位置后向遍历,若遇到脏Entry直接清除,size自减;若不是脏Entry,则需要判断它是否经过哈希冲突的调整的,若调整过,需要将其重新调整,最后返回当前位置为null的table下标;综上,该方法就是后向清除脏Entry,再把调整需要调整的Entry。
在replaceStaleEntry方法中,调用expungeStaleEntry清除掉脏Entry后,还要用cleanSomeSlots方法清除掉返回回来的下标后的脏Entry;
cleanSomeSlots方法:
1 private boolean cleanSomeSlots(int i, int n) { 2 boolean removed = false; 3 Entry[] tab = table; 4 int len = tab.length; 5 do { 6 i = nextIndex(i, len); 7 Entry e = tab[i]; 8 if (e != null && e.get() == null) { 9 n = len; 10 removed = true; 11 i = expungeStaleEntry(i); 12 } 13 } while ( (n >>>= 1) != 0); 14 return removed; 15 }
从下标为i后面的开始后向遍历,遇到脏Entry调用expungeStaleEntry清除掉,令removed为true,i会变为下标为null的位置,继续循环;其中n的用途是控制循环次数,当遇到脏Entry时,会令n等于表长,扩大搜索范围。
在set方法中,最后根据!cleanSomeSlots(i, sz) && sz >= threshold,判断是否清理掉了脏Entry,若清理了什么都不做;若没有清理,还会判断是否达到阈值,进而是否需要rehash操作;
rehash方法:
1 private void rehash() { 2 expungeStaleEntries(); 3 4 // Use lower threshold for doubling to avoid hysteresis 5 if (size >= threshold - threshold / 4) 6 resize(); 7 }
首先调用expungeStaleEntries方法:
1 private void expungeStaleEntries() { 2 Entry[] tab = table; 3 int len = tab.length; 4 for (int j = 0; j < len; j++) { 5 Entry e = tab[j]; 6 if (e != null && e.get() == null) 7 expungeStaleEntry(j); 8 } 9 }
可以看到expungeStaleEntries方法是遍历整个哈希表,通过调用expungeStaleEntry方法清除掉所有脏Entry。
由于清除掉了脏Entry,还需要对size进行判断,看是否达到了阈值的3/4(提前触发resize),来判断是否真的需要resize;
resize方法:
1 private void resize() { 2 Entry[] oldTab = table; 3 int oldLen = oldTab.length; 4 int newLen = oldLen * 2; 5 Entry[] newTab = new Entry[newLen]; 6 int count = 0; 7 8 for (int j = 0; j < oldLen; ++j) { 9 Entry e = oldTab[j]; 10 if (e != null) { 11 ThreadLocal<?> k = e.get(); 12 if (k == null) { 13 e.value = null; // Help the GC 14 } else { 15 int h = k.threadLocalHashCode & (newLen - 1); 16 while (newTab[h] != null) 17 h = nextIndex(h, newLen); 18 newTab[h] = e; 19 count++; 20 } 21 } 22 } 23 24 setThreshold(newLen); 25 size = count; 26 table = newTab; 27 }
刚开始的操作可以清楚的明白,每次扩容的大小都是原来的两倍;然后遍历原表的所有Entry,遇到脏Entry直接赋值null引起帮助GC;遇到有效Entry则需要根据新的表长重新计算下标,再通过线性探测完成新表的填充;填充完毕,计算新的阈值,给size和table赋值,结束操作。
至此,有关set的操作就结束了,还剩下get和remove:
get方法:
1 public T get() { 2 Thread t = Thread.currentThread(); 3 ThreadLocalMap map = getMap(t); 4 if (map != null) { 5 ThreadLocalMap.Entry e = map.getEntry(this); 6 if (e != null) { 7 @SuppressWarnings("unchecked") 8 T result = (T)e.value; 9 return result; 10 } 11 } 12 return setInitialValue(); 13 }
和set一样,先获取当前线程,再根据当前线程获取其ThreadLocalMap成员map;
若map不为null,通过map的getEntry方法得到Entry对象,若Entry不为null则直接返回Entry的value;
若map为null,或者map不为null,但是Entry是null,则都需要调用setInitialValue方法。
getEntry方法:
1 private Entry getEntry(ThreadLocal<?> key) { 2 int i = key.threadLocalHashCode & (table.length - 1); 3 Entry e = table[i]; 4 if (e != null && e.get() == key) 5 return e; 6 else 7 return getEntryAfterMiss(key, i, e); 8 }
根据ThreadLocal定位哈希表的下标,若满足则直接返回,若不是,调用getEntryAfterMiss继续找。
getEntryAfterMiss方法:
1 private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) { 2 Entry[] tab = table; 3 int len = tab.length; 4 5 while (e != null) { 6 ThreadLocal<?> k = e.get(); 7 if (k == key) 8 return e; 9 if (k == null) 10 expungeStaleEntry(i); 11 else 12 i = nextIndex(i, len); 13 e = tab[i]; 14 } 15 return null; 16 }
看以看到这还是一个后向遍历的查找,若是找到则直接返回;若遇到脏Entry需要调用expungeStaleEntry方法清理掉;最后还没找到返回null。
setInitialValue方法:
1 private T setInitialValue() { 2 T value = initialValue(); 3 Thread t = Thread.currentThread(); 4 ThreadLocalMap map = getMap(t); 5 if (map != null) 6 map.set(this, value); 7 else 8 createMap(t, value); 9 return value; 10 }
先调用initialValue方法,该方法需要使用者进行覆盖,否则返回的是null。所以当没有使用set方法时覆盖initialValue方法时还是会调用set方法的,效果是一样的。
1 protected T initialValue() { 2 return null; 3 }
后面的操作就和set方法一样。get方法至此结束。
remove方法:
1 public void remove() { 2 ThreadLocalMap m = getMap(Thread.currentThread()); 3 if (m != null) 4 m.remove(this); 5 }
以当前线程为参数调用getMap方法:
1 ThreadLocalMap getMap(Thread t) { 2 return t.threadLocals; 3 }
若是当前线程的ThreadLocalMap对象不存在,什么都不做,若存在,调用内部的remove方法:
1 private void remove(ThreadLocal<?> key) { 2 Entry[] tab = table; 3 int len = tab.length; 4 int i = key.threadLocalHashCode & (len-1); 5 for (Entry e = tab[i]; 6 e != null; 7 e = tab[i = nextIndex(i, len)]) { 8 if (e.get() == key) { 9 e.clear(); 10 expungeStaleEntry(i); 11 return; 12 } 13 } 14 }
首先根据ThreadLocal找到其对应的的哈希表的下标(不一定是它的下标,会有哈希冲突的可能性),然后开始后向遍历,找到真正的位置,调用clear方法删除掉,顺便还进行脏Entry的清理。
clear方法是Reference类的方法:
1 public void clear() { 2 this.referent = null; 3 }
可以看到仅仅只是令指向变为null,因为Reference是WeakReference的父类,ThreadLocalMap继承自WeakReference<ThreadLocal<?>>,弱引用变为null,就会变成脏Entry,所以就需要expungeStaleEntry对其清理。为什么不令tab[i]直接为null,就是因为在expungeStaleEntry执行时还会清理遇到的脏Entry,这样可以尽可能多的删除掉脏Entry。
ThreadLocal源码分析到此结束。