【ThreadLocal】ThreadLocal的实现机制和原理

1  前言

这节我们看下 ThreadLocal ,这个东西大家应该不陌生,经常在一些同步优化中会使用到它。很多地方叫线程本地变量,ThreadLocal为变量在每个线程中都创建了一个副本,那么每个线程可以访问自己内部的副本变量。也就是对于同一个ThreadLocal,每个线程通过get、set、remove接口操作只会影响自身线程的数据,不会干扰其他线程中的数据。常见的比如我们的登录信息是不是用到了,AOP里的 AOPContext等,那么这节我们就来看看它的实现机制。

2  类图

说到 ThreadLocal ,涉及的相关类我们先简单介绍下:

  • ThreadLocal 核心类,本地线程类,他里边有两个子类:SuppliedThreadLocal这个主要是传的表达式相当于延迟初始化,当调用get的时候才会获取值;ThreadLocalMap这个就是用来存放数据管理数据的,会依附在 Thread 里;
  • Thread 线程类  内部拥有ThreadLocalMap变量,有两个,threadLocals 负责正常存放某个本地线程变量,inheritableThreadLocals负责父子线程传递的;
  • InheritableThreadLocal 用于父子线程的传递

 我们看下类图:

3  源码分析

3.1  测试代码

public class TestThreadLocal {

    // Person类
    static class Person {
        private String name;
        Person(String name) {
            this.name = name;
        }
        @Override
        public String toString() {
            return "Person{" +
                    "name='" + name + '\'' +
                    '}';
        }
    }

    public static void main(String[] args) {
        Person person = new Person("狗子");
        ThreadLocal<Person> personThreadLocal = new ThreadLocal<>();
        personThreadLocal.set(person);
        System.out.println(personThreadLocal.get());
    }
}

可以看到我们测试的代码,创建 ThreadLocal 对象,设置值以及获取,那么接下来我们就来看看 set、get方法。

3.2  set 方法

// TheradLocal
public void set(T value) {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 获取当前线程中的 map 集合
    ThreadLocalMap map = getMap(t);
    if (map != null)
        /**
         * 如果 map 不为空,就设置值
         * 可以看到 key 就是我们的 ThreadLocal本身 说明什么?
         * 也就是说我们的每个线程中只会存放一个 ThreadLocal对象对应的值
         */
        map.set(this, value);
    else
        /**
         * map 为空的话 就要创建并赋予 value 初值
         */
        createMap(t, value);
}
// 其实就是 Thread 中的 threadLocals 变量
ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}
// Therad
public class Thread implements Runnable {
    /* ThreadLocal values pertaining to this thread. This map is maintained
     * by the ThreadLocal class. */
    ThreadLocal.ThreadLocalMap threadLocals = null;

    /*
     * InheritableThreadLocal values pertaining to this thread. This map is
     * maintained by the InheritableThreadLocal class.
     */
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;
}

可以看到我们的 set 方法,首先会获取当前线程的 threadLocals,如果为空的话就会创建并赋予初值,不为空的话就会把当前值设置进去,那我们先看一下 createMap 方法看看是如何初始化并赋值的。

3.2.1  createMap 方法

// ThreadLocal
void createMap(Thread t, T firstValue) {
    // 就是创建 ThreadLocalMap 对象,并赋值给线程 t
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

那我们继续看下 ThreaLocalMap 的实例化:

// ThreadLocalMap
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
    /**
     * 实例化 table 数组,初始大小为:16
     * private static final int INITIAL_CAPACITY = 16;
     */
    table = new Entry[INITIAL_CAPACITY];
    /**
     * 就是根据 ThreadLocal的哈希值和数组的长度取余
     * 那你有没有疑问 取余的话,会不会不同的ThreadLocal取余一样不就冲突了?
     * 其实 threadLocalHashCode 是递增的,这个你们可以看下我就不在这里看了哈,源码就在上边 用了atomic的增长的
     */
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 设置值 key->我们的ThreadLocal对象 value->就是我们要存放的值
    table[i] = new Entry(firstKey, firstValue);
    // 因为创建并放置了一个键值对所以长度 1
    size = 1;
    // 扩容标志 也就是到达三分之二的时候要进行重新计算扩容
    setThreshold(INITIAL_CAPACITY);
}

整体上就是实例化我们的ThreaLocalMap ,里边实际上就是有一个 Entry数组来维护,初始化长度为 16,扩容标志是三分之二的水平线。那我们再来简单看下我们的Entry:

// ThreadLocalMap的内部类
// 可以看到我们的 Entry 是继承的弱引用,也就是线程中的 ThreadLocalMap 不会影响 ThreadLocal的回收
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    // 我们存放的值
    Object value;
    // k 就是我们的 ThreadLocal对象,弱引用  v就是我们要存放的值
    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

可以看到 ThreaLocalMap 弱引用于我们的 ThreadLocal 对象,也就是线程中的 ThreaLocalMap 不会影响我 ThreadLocal 对象的回收。

3.2.2  ThreaLocalMap 中的 set 方法

我们继续看下当线程中的 ThreadLocalMap不为空的话,是如何存放的,看之前首先我们猜猜应该有哪些,我们看到实例化的时候有扩容标志,那么是不是会在 set 的时候来进行扩容的判断呢?是不是我们来看看:

/**
 * @param key 我们的 ThreadLocal 对象
 * @param value 就是我们要存放的值
 */
private void set(ThreadLocal<?> key, Object value) {
    // 我们的 Entry 数组
    Entry[] tab = table;
    // 数组的长度
    int len = tab.length;
    // 取余,判断当前 ThreadLocal 应该存放的索引位置
    int i = key.threadLocalHashCode & (len-1);
    /**
     * 获取索引下的 Entry 键值对 这里为什么用循环呢?
     * threadLocalHashCode 递增的值是 nextHashCode.getAndAdd(HASH_INCREMENT) 0x61c88647
     * 并不是按1增长的 所以可能还是会冲突吧 
     * 所以冲突了就会加1 nextIndex => ((i + 1 < len) ? i + 1 : 0)
     * 循环的作用:
     * 1、处理索引冲突
     * 2、处理脏的ThreadLocal
     */
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        // 获取到我们的 ThreadLocal 对象
        ThreadLocal<?> k = e.get();
        // 当发现对应的 ThreadLocal 已经存在某个值,则进行替换并返回
        if (k == key) {
            e.value = value;
            return;
        }
        // 因为 ThreadLocal 是弱引用的,
        // 那么当 ThreadLocal 被回收时,这里就会为空 
        // 那么说明我们当前 Entry中的ThreadLocal是脏的了,进行脏处理
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }
    // 索引位置的 Entry为空,那么创建 Entry 并设置 value
    tab[i] = new Entry(key, value);
    // 长度++
    int sz = ++size;
    /**
     * cleanSomeSlots 清理脏的 ThreadLocal 发现好严谨奥 
     * 返回为true 说明有脏的 ThreadLocal
     * 为false的情况下,并且大小到达了我们的增长标志进行重新hash并扩容
     */
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

可以看到里边除了常规的替换值设置值,有两个很重要的操作就是清理脏 ThreadLocal - replaceStaleEntry 以及扩容- reHash 判断,其实 cleanSomeSlots 也是清理脏 ThreadLocal,调用的 expungeStaleEntry,我们先来看下 cleanSomeSlots :

3.2.2.1  cleanSomeSlots 

// ThreadLocalMap
private boolean cleanSomeSlots(int i, int n) {
    // 有效清理标志
    boolean removed = false;
    // 我们的 Entry 数组
    Entry[] tab = table;
    // 数组长度
    int len = tab.length;
    do {
        // 下一个索引
        i = nextIndex(i, len);
        // 获取索引位置的 Entry
        Entry e = tab[i];
        // 判断当前的 Entry 的 ThreadLocal 是否为空
        if (e != null && e.get() == null) {
            // TheradLocal为空 说明是脏的
            n = len;
            // 更新有效清理标志
            removed = true;
            // 清理脏 TheradLocal
            i = expungeStaleEntry(i);
        }
        // 无符号右移 相当于除以2 相当于折半判断并不是全都遍历一遍
    } while ( (n >>>= 1) != 0);
    return removed;
}

那我们继续跟进去看下 expungeStaleEntry,是如何清理的:

3.2.2.2  expungeStaleEntry

// ThreadLocalMap
// staleSlot 脏ThreadLocal索引
private int expungeStaleEntry(int staleSlot) {
    // 我们的 Entry 数组
    Entry[] tab = table;
    // 数组的长度
    int len = tab.length;
    // expunge entry at staleSlot
    // 因为我们的 table 是强引用 value的 这里先把 value的强引用释放
    tab[staleSlot].value = null;
    // 把自身的占位释放
    tab[staleSlot] = null;
    // 长度--
    size--;
    // Rehash until we encounter null
    Entry e;
    int i;
    // 遍历下一个索引位置的 Entry 是否为空
    for (i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        // 不为空的话判断 ThreadLocal 是否是脏的
        ThreadLocal<?> k = e.get();
        // 是脏的话 释放引用 长度--
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            // ThreadLocal 不脏的情况下 因为前边释放了一个位置 这里再根据长度取余
            int h = k.threadLocalHashCode & (len - 1);
            // 新的索引位置 不等于原来旧的索引位置的话
            if (h != i) {
                // 旧位置置为空
                tab[i] = null;
                // Unlike Knuth 6.4 Algorithm R, we must scan until
                // null because multiple entries could have been stale.
                // 判断新位置是否为空,不为空的找一个空的位置出来
                while (tab[h] != null)
                    h = nextIndex(h, len);
                // 设置到新位置去
                tab[h] = e;
            }
            /**
             * 我在想为什么要这么做呢? 为什么还要遍历其它的不为空呃 给他们重新变换位置呢?
             * 是为了收缩?还是就是单纯的要释放其它脏的ThreadLocal呢
             */
        }
    }
    return i;
}

释放的主要操作就是 tab[i].value = null 释放强引用,并把 tab[i] = null, 其实释放就完事了,但是它还会接着遍历判断下一个索引位置的 Entry 是否也是脏的,是的话也会继续清理,不是的话会调整新位置,直到索引位置的 Entry 是空的话,结束循环完事。这么做的道理可能就是更好的去为 GC做准备释放一些无效的引用,防止内存泄漏嘛。

3.2.2.3  replaceStaleEntry

/**
 * ThreadLocalMap
 * @param key 脏的ThreadLocal
 * @param value 要存放的值
 * @param staleSlot 脏索引
 */
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    // 我们的数组
    Entry[] tab = table;
    // 数组的长度
    int len = tab.length;
    Entry e;
    // Back up to check for prior stale entry in current run.
    // We clean out whole runs at a time to avoid continual
    // incremental rehashing due to garbage collector freeing
    // up refs in bunches (i.e., whenever the collector runs).
    // slotToExpunge 要去释放的索引位置
    int slotToExpunge = staleSlot;
    // 往前找一个Entry不空的
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len))
        // 发现也是脏的
        if (e.get() == null)
            // 更新要释放的索引位置
            slotToExpunge = i;
    // Find either the key or trailing null slot of run, whichever
    // occurs first
    // 又往后找一个 Entry不为空的
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        // 获取当前不为空的 Entry 的 ThreadLocal
        ThreadLocal<?> k = e.get();
        // If we find key, then we need to swap it
        // with the stale entry to maintain hash table order.
        // The newly stale slot, or any other stale slot
        // encountered above it, can then be sent to expungeStaleEn
        // to remove or rehash all of the other entries in run.
        // 发现就是我们参数里的脏ThreadLocal
        if (k == key) {
            e.value = value;
            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;
            // Start expunge at preceding stale entry if it exists
            if (slotToExpunge == staleSlot)
                slotToExpunge = i;
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }
        // If we didn't find stale entry on backward scan, the
        // first stale entry seen while scanning for key is the
        // first still present in the run.
        if (k == null && slotToExpunge == staleSlot)
            slotToExpunge = i;
    }
    // If key not found, put new entry in stale slot
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);
    // If there are any other stale entries in run, expunge them
    if (slotToExpunge != staleSlot)
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}

replaceStaleEntry 这个方法说实话有点邪门,没怎么看懂,它这是先往前找一找脏ThreadLocal,然后又往后遍历,我搜了搜别人的理解,也是众说纷纭,暂时保留,总之我们知道这个也是清理脏ThreadLocal的。那我们看见有三个方法来清理脏ThreadLocal的我们看下三者的区别:

  • expungeStaleEntry 这个肯定是会清理掉至少一个的,返回值返回的是下一个 空Entry的索引位置;
  • replaceStaleEntry 这个方法从我的理解上我觉得时空清理至少一个的;
  • cleanSomeSlots 这个是会尝试去清理脏ThreadLocal,有可能没清一个,有可能清了。

 set 方法除了清理还有一项重要的是判断是否需要扩容,我们来看下:

3.2.2.4  reHash

// ThreadLocalMap
private void rehash() {
    // 还是会调用 expungeStaleEntry 进行清理脏ThreadLocal
    expungeStaleEntries();
    // Use lower threshold for doubling to avoid hysteresis
    // 判断是否达到 四分之三 超过了就重新扩容
    if (size >= threshold - threshold / 4)
        resize();
}
// ThreadLocalMap
private void resize() {
    // 我们的数组
    Entry[] oldTab = table;
    // 数组的长度
    int oldLen = oldTab.length;
    // 新数组的长度,可以看到是2倍扩容
    int newLen = oldLen * 2;
    // 创建新的数组
    Entry[] newTab = new Entry[newLen];
    int count = 0;
    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            // 脏的ThreadLocal就被丢弃了
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                // 重新计算位置
                int h = k.threadLocalHashCode & (newLen - 1);
                // 冲突的话重新获取就是往后放
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                // 设置到新数组中
                newTab[h] = e;
                // 计数器++
                count++;
            }
        }
    }
    // 设置扩容标志
    setThreshold(newLen);
    // 设置回去
    size = count;
    table = newTab;
}

其实我们默认的扩容水平线是:threshold = len * 2 / 3,也就是到达三分之二的时候会进入 reHash 方法,而 reHash会进行脏ThreadLocal的清理,清理后发现长度达到了 四分之三的话,就会扩容,扩容后的长度是原长度的2倍。

3.3  get 方法

看完 set 的方法,我们再来看下 get 的方法:

// ThreadLocal
public T get() {
    // 当前线程
    Thread t = Thread.currentThread();
    // 当前线程中的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    // 不为空的话,就从 ThreadLocalMap 取
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        // 取出的值不为空的话,进行强制转换
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    // ThreadLocalMap 为空或者值为空的话,就从 setInitialValue 中获取
    return setInitialValue();
}
// ThreadLocal
private T setInitialValue() {
    // 从 initialValue 方法中获取默认的值 默认的实现是返回 null
    T value = initialValue();
    // 当前线程
    Thread t = Thread.currentThread();
    // 获取到线程的 ThreadLocalMap
    ThreadLocalMap map = getMap(t);
    // 不为空的话 set 设置进去
    if (map != null)
        map.set(this, value);
    else
        // 创建 ThreadLocalMap 并初始化值
        createMap(t, value);
    return value;
}

3.3.1  ThreadLocalMap 的 getEntry 方法

当线程中的ThreadLocalMap 不为空的情况下,会调用 ThreadLocalMap 中的 getEntry来获取,我们看下:

// ThreadLocalMap
private Entry getEntry(ThreadLocal<?> key) {
    // 获取当前 ThreadLocal 的索引位置
    int i = key.threadLocalHashCode & (table.length - 1);
    // 获取到当前位置的 Entry
    Entry e = table[i];
    // 不等于空的话,并且等于当前的 ThreadLocal 对象
    if (e != null && e.get() == key)
        return e;
    else
        // 什么清空走这个? 1、脏ThreadLocal 2、冲突了
        return getEntryAfterMiss(key, i, e);
}
// ThreadLocalMap
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    // 当前的数组
    Entry[] tab = table;
    // 数组的长度
    int len = tab.length;
    // 循环直到 Entry 为空了 退出循环
    while (e != null) {
        // 获取 Entry 的 key
        ThreadLocal<?> k = e.get();
        // 等于当前的 ThreadLocal 直接返回
        if (k == key)
            return e;
        // 脏了就清除掉继续循环
        if (k == null)
            expungeStaleEntry(i);
        else
            // 下一个索引位置继续
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

3.4  remove 方法

我们最后再来看一个remove方法:

// ThreadLocal
public void remove() {
    // 获取到当前线程的 ThreadLocalMap
    ThreadLocalMap m = getMap(Thread.currentThread());
    // 不为空的话 调用ThreadLocalMap 的 remove
    if (m != null)
        m.remove(this);
}

可以看到调用了 ThreadLocalMap 的 remove 方法,那么我们进去看看:

3.4.1  ThreadLocalMap 的 remove 方法

// ThreadLocalMap
private void remove(ThreadLocal<?> key) {
    // 当前的数组
    Entry[] tab = table;
    // 数组长度
    int len = tab.length;
    // 获取当前 ThreadLocal 对应的索引位置
    int i = key.threadLocalHashCode & (len-1);
    /**
     * 这里为什么要循环呢?
     * 因为取余会得到的可能不是当前 ThreadLocal 的
     * 因为会冲突 冲突后会往后存放,所以循环知道找到和当前的 ThreadLocal相等的
     */
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // reference 的 弱引用置空
            e.clear();
            // 清理掉当前的 ThreadLocal 
            expungeStaleEntry(i);
            return;
        }
    }
}

remove方法比较简单,回收掉当前的 ThreadLocal占用的 Entry,其实也会进行其它 脏ThreadLocal的清理。

4  思考

4.1  大致的引用关系图

这里简单画了一下各个类型之间的引用关系,方便大家理解哈:

4.2  为什么重新扩容或者放置新的 ThreadLocal的时候都不加锁呢?也就是ThreadLocal为什么是无锁的呢?

你想从始至终我们的 ThreadLocalMap都是依附在当前线程中的,这些东西都是你自己的东西,没别的线程跟你抢,相当于 ThreadLocal对象都是对每个线程的一个副本,自己用自己的。

4.3  常说的 ThreadLocal的内存泄漏是怎么回事?

内存泄漏是什么意思:就是我们的垃圾回收不掉已经不用的空间,不再用到的内存,没有及时释放,就叫做内存泄漏。就会一直滞留在jvm中,导致可能会OOM,但是我们可以看到 set、get其实都会对脏的 ThreadLocal 进行 value强引用的解除,那么怎么还会有泄漏呢?难道是如果数据初始化好之后,一直不调用get、set等方法,这样Entry就一直不能回收,导致内存泄漏的么?

另一个理解Threadlocal本身的存活时间就比较长,当我们的线程池线程复用起来的话,每个线程往里边set完不主动remove的话是不是就内存泄漏了,一直得不到释放。

所以使用ThreadLocal会发生 内存泄漏的前提条件:

(1)线程长时间运行而没有被销毁。 线程池中的Thread实例很容易满足此条件。

(2)ThreadLocal引用被设置为null,且后续在同一Thread实例的执行期间,没有发生对其他 ThreadLocal实例的get、set或remove操作。

4.4  ThreadLocalMap中的 key 为什么是弱引用?

首先弱引用是当仅有弱引用(WeakReference)指向的对象,只能生存到下一次垃圾回收之前。换句话说,当GC发生时,不管内存够不够,仅有弱引用所指向的对象都会被回收。而拥有强引用 指向的对象,则不会被直接回收。

由于ThreadLocalMap中Entry的 Key 使用了弱引用,在下次GC发生时,就可以使那些没有被其他强引用指向、仅被Entry的Key 所指向的ThreadLocal实例能被顺利回收。并且,在Entry的Key引用被回收 之后,其Entry的Key值变为null。后续当ThreadLocal的 get 、 set 或 remove 被调用时, ThreadLocalMap的内部代码会清除这些Key为null的Entry,从而完成相应的内存释放。

举个例子:

public void funcA() {
    //创建一个线程本地变量
    ThreadLocal local = new ThreadLocal<Integer>(); 
    //设置值
    local.set(100);   
    //获取值
    local.get();  
    //函数末尾
}

当线程执行完funcA方法后,funcA的方法栈帧将被销毁,强引用 local 的值也就没有了,但此时线程 的ThreadLocalMap里的对应的Entry的 Key 引用还指向了 ThreadLocal 实例。若Entry的 Key 引用是 强引用, 就会导致Key引用指向的ThreadLocal实例、及其Value值都不能被GC回收,这将造成严重的内存泄露.

5  小结

本节我们浏览了一下 ThreadLocal 的机制和原理,看了其核心的具体实现,有理解不对的地方欢迎指正哈。

posted @ 2023-03-17 07:23  酷酷-  阅读(98)  评论(0编辑  收藏  举报