ThreadLocal源码解析(基于JDK8)

ThreadLocal 是线程的本地变量,也就是不同线程的同一个 ThreadLocal 的 get/set 是独立的。

每个线程 Thread 内部有 ThreadLocalMap,Map 的键是 ThreadLocal<T>,值是泛型 T。

那么,根据线程 Thread 和 ThreadLocal<T> 才能唯一确定一个实际存储的值 T。

1 ThreadLocalMap

ThreadLocalMap 是 ThreadLocal 的静态内部类,是 ThreadLocal 的关键所在。ThreadLocalMap 是一种 HashMap,通过开放寻址法(每次向后移动一位)和环形数组解决哈希冲突,该类没有显式的继承 Map。

ThreadLocalMap 内部存储是一个 Entry[] table,这里的 Entry 是该类的静态内部类,没有向后的指针,也就是每个位置只能放置一个 Entry,而不像 HashMap 桶数组的每个位置可以放一个链表。

static class ThreadLocalMap {

    /**
     * The entries in this hash map extend WeakReference, using
     * its main ref field as the key (which is always a
     * ThreadLocal object).  Note that null keys (i.e. entry.get()
     * == null) mean that the key is no longer referenced, so the
     * entry can be expunged from table.  Such entries are referred to
     * as "stale entries" in the code that follows.
     */
   // 注意到ThreadLocal是弱引用的,而value仍然是强引用
    static class Entry extends WeakReference<ThreadLocal<?>> {
        /** The value associated with this ThreadLocal. */
        Object value;

        Entry(ThreadLocal<?> k, Object v) {
            super(k);
            value = v;
        }
    }

    /**
     * The initial capacity -- MUST be a power of two.
     */
    private static final int INITIAL_CAPACITY = 16;

    /**
     * The table, resized as necessary.
     * table.length MUST always be a power of two.
     */
    private Entry[] table;

    /**
     * table中Entry非空的个数
     */
    private int size = 0;

    /**
     * The next size value at which to resize.
     */
    private int threshold; // Default to 0

    /**
     *阈值为长度的 2/3
     */
    private void setThreshold(int len) {
        threshold = len * 2 / 3;
    }

    /**
     * 增加后取模
     */
    private static int nextIndex(int i, int len) {
        return ((i + 1 < len) ? i + 1 : 0);
    }

    /**
     * 减少后取模
     */
    private static int prevIndex(int i, int len) {
        return ((i - 1 >= 0) ? i - 1 : len - 1);
    }
  ...
}

由于 Entry extends WeakReference<ThreadLocal<?>> ,所以 Entry get() 继承自 Reference,返回相应的泛型引用。在 Entry 内具体指的是返回对应的键 ThreadLocal k。

在后面清理时,称满足e != null && e.get() == null的 Entry e 为 StaleEntry

// WeakReference
public class WeakReference<T> extends Reference<T> {

    /**
     * Creates a new weak reference that refers to the given object.  The new
     * reference is not registered with any queue.
     *
     * @param referent object the new weak reference will refer to
     */
    public WeakReference(T referent) {
        super(referent);
    }
    ...
}

// Reference
public abstract class Reference<T> {
    public T get() {
        return this.referent;
    }
    ...
}

1.1 弱引用问题

以下摘自 https://zhuanlan.zhihu.com/p/158033837


看看官话,为什么要用弱引用。

To help deal with very large and long-lived usages, the hash table entries use WeakReferences for keys.
为了处理非常大和生命周期非常长的线程,哈希表使用弱引用作为 key。

  • 生命周期长:暂时可以想到线程池中的线程

ThreadLocal在没有外部对象强引用时如Thread,发生GC时弱引用Key会被回收,而Value是强引用不会回收,如果创建ThreadLocal的线程一直持续运行如线程池中的线程,那么这个Entry对象中的value就有可能一直得不到回收,发生内存泄露。

  • key 如果使用强引用:引用的ThreadLocal的对象被回收了,但是ThreadLocalMap还持有ThreadLocal的强引用,如果没有手动删除,ThreadLocal不会被回收,导致Entry内存泄漏。
  • key 使用弱引用:引用的ThreadLocal的对象被回收了,由于ThreadLocalMap持有ThreadLocal的弱引用,即使没有手动删除,ThreadLocal也会被回收。value在下一次ThreadLocalMap调用set,get,remove的时候会被清除。

由于上述原因,下面不断使用 expungeStaleEntry 完成 value/entry 的清除。

1.2 构造器

有两种构造器,第一种是当前线程的 ThreadLocalMap threadLocals 为空,向里面放第一个 ThreadLocal 和值的情况;第二种是使用父线程的 ThreadLocalMap inheritableThreadLocals 来构造当前子线程的 inheritableThreadLocals。具体解释可以见 2.1 和 2.2。

/**
 * Construct a new map initially containing (firstKey, firstValue).
 * ThreadLocalMaps are constructed lazily, so we only create
 * one when we have at least one entry to put in it.
 */
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    table = new Entry[INITIAL_CAPACITY];
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    table[i] = new Entry(firstKey, firstValue);
    size = 1;
    setThreshold(INITIAL_CAPACITY);
}

/**
 * Construct a new map including all Inheritable ThreadLocals
 * from given parent map. Called only by createInheritedMap.
 *
 * @param parentMap the map associated with parent thread.
 */
private ThreadLocalMap(ThreadLocalMap parentMap) {
    Entry[] parentTable = parentMap.table;
    int len = parentTable.length;
    setThreshold(len);
    table = new Entry[len];

    for (int j = 0; j < len; j++) {
        Entry e = parentTable[j];
        if (e != null) {
            @SuppressWarnings("unchecked")
            ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
            if (key != null) {
                // 放入的值是key.childValue(e.value)
                Object value = key.childValue(e.value);
                Entry c = new Entry(key, value);
                int h = key.threadLocalHashCode & (len - 1);
                // 出现冲突向后找,直到找到一个可以插入的位置
                while (table[h] != null)
                    h = nextIndex(h, len);
                table[h] = c;
                size++;
            }
        }
    }
}

1.3 set

该方法相当于 HashMap 中 set 和 put,如果 map 中有对应的 key,替换 value;没有则插入。

在 for 循环中,从开始计算出的 i 向后查找直到 tab 在当前位置的 Entry 为 null 才退出。后续方法中的 for 退出逻辑均一致。

结果有三种情况

① 在 for 中,当前位置 key 已存在,替换相应的值即可。

② 在 for 中,当前位置 key 为 null,使用 replaceStaleEntry 实现删除 i 处的 Entry 和插入(key,value)的操作。

③ 退出 for,从 i 处检查 \(log_{2}sz\) 个元素,尝试清除其中的 StaleEntry,若未实现清除则检查 sz 是否超过阈值,是则执行 rehash。

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
  	// 找到key初始位置
    int i = key.threadLocalHashCode & (len-1);
    // 执行循环向后查找,直到当前检查的e为null,退出
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
				// 找到对应的k,替换value即可
        if (k == key) {
            e.value = value;
            return;
        }
				// 如果对应位置k为null,通过replaceStaleEntry 实现清除和插入的操作
      	// 注意,在该位置之后,可能存在和当前 key相同的 k,这个也要考虑到。
      	// 插入位置为i
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
  	//由于set只能放入一个,如果前一个条件为true,则至少清除了一个,
  	//加一减一相当于没变,不需要检查是否超过阈值以及 rehash
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

1.3.1 replaceStaleEntry

对于 staleSlot 处的 staleEntry,执行的是放入 (key,value);对于其他的位置(如果存在)即 slotToExpunge 处的 staleEntry,执行清除。

该方法首先要做的就是将 (key,value) 正确放入 staleSlot,如果后续没有出现 key,直接在 staleSlot 处新建一个 Entry即可;如果出现了,进行相应的 Entry 交换和修改。另外, 作者还希望顺带检查一下除了 staleSlot,是否还有其他位置的 staleEntry,有的话执行清除操作。


slotToExpunge 表示找到的 staleEntry 的位置,初始为 staleSlot,如果找到了其他位置也是 staleEntry,就修改slotToExpunge。最后会判断 slotToExpunge 是否发生修改,如果修改,说明存在其他的 staleEntry,执行 expungeStaleEntrycleanSomeSlots 完成相应的清除操作。可以用三元表达式来描述: slotToExpunge = 是否存在其他的staleEntry ? 最靠前的位置 : staleSlot

具体来说:

查找是否其他的存在 staleEntry,分为两步,第一步是向前查找,找到则修改;第二步是向后查找, 如果找到了新的 staleEntry,不要急着修改,会判断前面是否修改过了,修改过了则不修改,保证存在其他的 staleEntry 时一定存储最靠前的位置,对于后续清除更有帮助。

在第二步向后查找的时候,如果找到了相同的 key,则需要进行交换,保证 tab[staleSlot] 修改为 (key,value) 的 entry,且 tab[i] 变成了 staleEntry。对于不存在相同键的情况,放在退出循环后处理。

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        if (e.get() == null)
            slotToExpunge = i;

    // Find either the key or trailing null slot of run, whichever
    // occurs first
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEntry
        // to remove or rehash all of the other entries in run.
      	// 这个if和后一个if可以交换顺序,交换后和我的描述一致,便于理解
        if (k == key) {
          	// 以下几步实现了staleSlot和i处Entry的互换,
          	// 最后结果是staleSlot存储的e的键和值是传入的参数,
          	// i处的Entry键为null,变成了 staleEntry。
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }

    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

1.3.2 expungeStaleEntry

该方法将 staleSlot 处清空,然后顺带着向后检查,如果后面是 staleEntry,清空;如果后面是非 staleEntry 的,重新计算位置,进行相应的移动。返回满足 tab[i] 为 null 的 i。

在重新计算的代码中,旧位置为 i,新位置为 h。① len 改变,则 h 会和 i 不同; ② len 不变,但是 i 并不是上一次计算出的位置,而是计算出的位置又向后移动了一些,这样 h 和 i 也可以不一样。

private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // expunge entry at staleSlot
  	// 清空 staleSlot
    tab[staleSlot].value = null;
    tab[staleSlot] = null;
    size--;

    // Rehash until we encounter null
    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
      	// 清除其他的 staleEntry
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
          	// 重新计算,如果不同,则进行相应的移动。
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;

                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                while (tab[h] != null)
                    h = nextIndex(h, len);
                tab[h] = e;
            }
        }
    }
    return i;
}

1.3.3 cleanSomeSlots

从 i 向后依次检查\(log_{2}n\)个元素,如果是 staleEntry 则清除。清除至少一个返回 true,否则是 false。

private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            n = len;
            removed = true;
            i = expungeStaleEntry(i);
        }
    } while ( (n >>>= 1) != 0);
    return removed;
}

1.3.4 rehash

执行 expungeStaleEntries 来检查并清除全部 staleEntry,而后判断 size 是否达到数组一半长度,是执行扩容 resize。

计算可得 \(3/4 * threshold = 3/4*2/3*len=1/2*len\)

private void rehash() {
    expungeStaleEntries();

    // Use lower threshold for doubling to avoid hysteresis
  	// 达到一半长度执行扩容
    if (size >= threshold - threshold / 4)
        resize();
}

1.3.5 expungeStaleEntries

检查并清除全部的 StaleEntry。

private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null)
            expungeStaleEntry(j);
    }
}

1.3.6 resize

Map 中的 table 扩容为原来的两倍。

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

1.4 get/remove

getEntry 只比较计算出的位置 i 是否满足条件,其他情况交给 getEntryAfterMiss

getEntryAfterMiss 在循环中处理三种情况,一个是 key 相等,直接返回;第二是 k 为 null,说明是 staleEntry,执行删除;第三种则移动位置。 如果最后没有找到,则返回 null。

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

在 remove 中,不断查找,找到同样的 key 后,删除,并执行 expungeStaleEntry 来清除该位置。

clear 是 Reference 内部的方法,将 referent 置为 null,在这里指的是 e 的键 ThreadLocal。

private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            e.clear();
            expungeStaleEntry(i);
            return;
        }
    }
}
// Reference
public void clear() {
    this.referent = null;
}

2 ThreadLocal

2.1 childValue

其中 childValue 方法在 ThreadLocal 内抛出异常,在其子类 InheritableThreadLocal 直接返回输入。

该方法只用于线程的 inheritableThreadLocals,用于构建子线程的 inheritableThreadLocals

父线程的 ThreadLocalMap 来构造当前子线程的 inheritableThreadLocals。

本文不介绍 InheritableThreadLocal,需要的可自行查看 https://www.cnblogs.com/hama1993/p/10400265.html

// ThreadLocal
T childValue(T parentValue) {
        throw new UnsupportedOperationException();
    }
// InheritableThreadLocal
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * <p>
     * This method merely returns its input argument, and should be overridden
     * if a different behavior is desired.
     *
     * @param parentValue the parent thread's value
     * @return the child thread's initial value
     */
    // 返回输入
    protected T childValue(T parentValue) {
        return parentValue;
    }
}

2.2 threadLocalHashCode

每次一个新的 ThreadLocal 产生,threadLocalHashCode = 上一个 ThreadLocal 的 threadLocalHashCode + HASH_INCREMENT。这样做可以在 len 为2的幂时,key.threadLocalHashCode & (len - 1)取得尽可能均匀。

也就是说,这里的 hashCode 与 key 的值无关,与 key 是第几个 ThreadLocal 有关。

// ThreadLocal 类
/**
 * ThreadLocals rely on per-thread linear-probe hash maps attached
 * to each thread (Thread.threadLocals and
 * inheritableThreadLocals).  The ThreadLocal objects act as keys,
 * searched via threadLocalHashCode.  This is a custom hash code
 * (useful only within ThreadLocalMaps) that eliminates collisions
 * in the common case where consecutively constructed ThreadLocals
 * are used by the same threads, while remaining well-behaved in
 * less common cases.
 */
private final int threadLocalHashCode = nextHashCode();

/**
 * The next hash code to be given out. Updated atomically. Starts at
 * zero.
 */
// 以下三个均为static
private static AtomicInteger nextHashCode =
    new AtomicInteger();

/**
 * The difference between successively generated hash codes - turns
 * implicit sequential thread-local IDs into near-optimally spread
 * multiplicative hash values for power-of-two-sized tables.
 */
private static final int HASH_INCREMENT = 0x61c88647;

/**
 * Returns the next hash code.
 */
private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

2.3 其他

由于其他的 ThreadLocal 的方法通常是当前线程 Thread.currentThread() 内部的 ThreadLocal.ThreadLocalMap threadLocals 执行相应方法所做的操作,这里不再细述,仅举几例供参考。

//初始值为null,通常使用时需要重写
protected T initialValue() {
    return null;
}

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

private T setInitialValue() {
    T value = initialValue();
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
    return value;
}  

// 调用ThreadLocalMap的getEntry
// 如果不存在则设置初值
public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}
void createMap(Thread t, T firstValue) {
    t.threadLocals = new ThreadLocalMap(this, firstValue);
} 
// 调用ThreadLocalMap的set
public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}
// 调用ThreadLocalMap的remove
public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

3 ThreadLocal 的简单使用

ThreadLocal 在使用的时候,要求设置为 static 的,防止在多个线程内初始化导致算出的 hashCode 不同。

在线程内可以显式的调用 set 将对应的 ThreadLocal 插入;如果没有显式调用 set ,等到调用 get 时将 initialValue 给插入,该种情况下需要重写 initialValue,否则插入的值为 null。

public class ThreadLocalTest {
  	// 设置为 static
    private static ThreadLocal<String> sThreadLocal = new ThreadLocal<String>(){
        @Override
        protected String initialValue() {
            return "该线程使用 initialValue";
        }
    };
    public static void main(String args[]) {

        // 调用set,以set为准
        sThreadLocal.set("这是在主线程中");
        System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get());

        //线程a
        new Thread(()->{
            // 调用set,以set为准
            sThreadLocal.set("这是在线程a中");
            System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get());
        }).start();

        //线程b
        new Thread(()->{
          	// 直接get,则调用 initialValue。
            System.out.println("线程名字:" + Thread.currentThread().getName() + "---" + sThreadLocal.get())
            ;},"线程b").start();
    }
}

结果

线程名字:main---这是在主线程中
线程名字:Thread-0---这是在线程a中
线程名字:线程b---该线程使用 initialValue
posted @ 2021-04-24 21:26  Java与大数据进阶  阅读(132)  评论(0编辑  收藏  举报