ThreadLocal源码分析

概述

ThreadLocal提供了一种线程安全的数据访问方式,每个线程中都存在一个共享变量副本,从而实现多线程状态下的线程安全。

demo

 public static void main(String[] args) {
        final ThreadLocal<Integer> MAIN = ThreadLocal.withInitial(() -> 100);

        MAIN.set(200);

        new Thread(()->{
            System.out.println(Thread.currentThread().getName() + " MAIN:" + MAIN.get());
        }).start();

        System.out.println("MAIN:" + MAIN.get());
        
        //一定要注意,当ThreadLocal不再使用时,一定要调用remove方法,以免内存泄漏
        MAIN.remove();

        System.out.println("MAIN:" + MAIN.get());

    }

运行之后,打印结果如下:

MAIN:200
Thread-0 MAIN:100

MAIN是在主线程中set的值,可以在主线程中使用get方法获得,但在线程中调用get方法,结果却还是100,这是为什么呢?这个原因其实也是为什么说ThreadLocal能被称之为线程安全的原因。下面我们就通过源码来一探究竟。

关键属性

    
    //表示当前ThreadLocal的hashCode,用于计算当前ThreadLocal在ThreadLocalMap中的索引位置
    private final int threadLocalHashCode = nextHashCode();

    // static+ AtomInteger 保证了在一台机器中每个ThreadLocal的threadLocalHashCode是唯一的
    // 被static修饰十分关键,因为一个线程在处理业务时,ThreadLocalMap会被set多个ThreadLocal,多个     
   // ThreadLocal就依靠着threadLocalHashCode进行区分
    private static AtomicInteger nextHashCode =
        new AtomicInteger();

    // 增量常量
    private static final int HASH_INCREMENT = 0x61c88647;

    //计算 ThreadLocal 的 hashCode 值(就是递增)
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

常用方法

set方法

每个线程的set方法都是串行的,因而不会有线程安全的问题。

public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        // 获取ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        // 当前的threadLocalMap之前有设置值,则直接进行设置,否则就初始化
        if (map != null)
            map.set(this, value);
        else
            //初始化threadLocalMap
            createMap(t, value);
    }

//获取线程的threadLocalMap属性
ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }

//初始化 ThreadLocalMap
 void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

get方法

get方法主要是从ThreadLocalMap中取出当前ThreadLocal存储的值。

public T get() {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取ThreadLocalMap
        ThreadLocalMap map = getMap(t);
        //map不为null时
        if (map != null) {
            //从map中取出Entry,由于ThreadLocalMap在set时解决hash冲突的策略不同,get的逻辑也不同
            ThreadLocalMap.Entry e = map.getEntry(this);
            //entry不为空的话,读取当前ThreadLocal中保存的值
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;、
                    //返回值
                return result;
            }
        }
        //若是map为空 则将当前线程的ThreadLocal初始化 并返回初始值null
        return setInitialValue();
    }


private T setInitialValue() {
        //获取初始值
        T value = initialValue();
       // 获取当前线程
        Thread t = Thread.currentThread();
       // 从当前线程中获取ThreadLocalMap
        ThreadLocalMap map = getMap(t);
       // 如果map不为null的话,直接进行set
        if (map != null)
            map.set(this, value);
        else
            //否则初始化ThreadLocalMap
            createMap(t, value);
     //返回值
        return value;
    }
//直接return null
protected T initialValue() {
        return null;
    }

remove方法

由于ThreadLocal在使用不当时可能存在内存泄漏的场景,因而,在使用完ThreadLocal使用完之后,一定要显示的调用remove方法进行清除。

    public void remove() {
         //获取当前线程绑定的ThreadLocalMap
         ThreadLocalMap m = getMap(Thread.currentThread());
        // map不为null的话  
        if (m != null)
            //从map中移除当前threadLocal对应的K-V
             m.remove(this);
     }

可以看出,无论是set、get还是remove 方法,其底层原理都比较简单,但却都包含一个共性,就是使用到了ThreadLocalMap,那么ThreadLocalMap又是一个什么东东呢?

ThreadLocalMap

ThreadLocalMap是ThreadLocal中的一个静态内部类,其本质上是一个简单的Map结构,key是ThreadLocal类型,value是ThreadLocal保存的值,底层是一个Entry类型数组组成的数据结构。Entry类的结构如下所示:

static class ThreadLocalMap {

        // Entry继承自WeakReference 因而Entry数组中每个Entry节点也是一个弱引用,当没有引用指向时,
        //会被回收
static class Entry extends WeakReference<ThreadLocal<?>> {
            // 当前ThreadLocal关联的值
            Object value;
    
            // WeakReference的引用 referent就是ThreadLocal
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        //初始容量大小
        private static final int INITIAL_CAPACITY = 16;

        // Entry数组
        private Entry[] table;

        //Entry数组大小
        private int size = 0;

        //阈值
        private int threshold; // Default to 0

        //设置阈值方法  可以看出是容量的2/3
        private void setThreshold(int len) {
            threshold = len * 2 / 3;
        }
    
       //计算索引的下一个位置
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }
    
       //计算索引的上一个位置
        private static int prevIndex(int i, int len) {
            return ((i - 1 >= 0) ? i - 1 : len - 1);
        }
    
    //构造方法
     ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

.............
    
}

ThreadLocal实现线程隔离的原理

ThreadLocal是线程安全的,主要是因为ThreadLocalMap是线程Thread的一个属性,如下所示:

threadLocals和inheritableThreadLocals分别是线程的两个属性,因而每个线程的ThreadLocals都是隔离独享的。在Thread的init方法中,父线程在创建子线程的情况下,会拷贝inheritableThreadLocals的值,但不会拷贝threadLocals的值,如下所示:

// Thread中的init方法
private void init(ThreadGroup g, Runnable target, String name,
                      long stackSize, AccessControlContext acc) {
    ...
        //当父线程的inheritableThreadLocals的值不为空时
        // 会把inheritable里面的值全部传递给子线程
        if (parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals =
               //
                ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
     ...
    }


//ThreadLocal中的createInheritedMap方法
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
        return new ThreadLocalMap(parentMap);
    }

在线程创建时,会把父线程的inheritableThreadLocals属性值进行拷贝。

set方法

  private void set(ThreadLocal<?> key, Object value) {
           //保存entry数组
            Entry[] tab = table;
           // 获取数组长度
            int len = tab.length;
           // 计算索引下标位置
            int i = key.threadLocalHashCode & (len-1);
            
      // 整体策略:查看 i 索引位置有没有值,有值的话,索引位置 + 1,直到找到没有值的位置
    // 这种解决 hash 冲突的策略,也导致了其在 get 时查找策略有所不同,体现在 getEntryAfterMiss 中
            for (Entry e = tab[i];
                 e != null;
                 // nextIndex 就是让在不超过数组长度的基础上,把数组的索引位置 + 1
                 e = tab[i = nextIndex(i, len)]) {
                //获取threadLocal
                ThreadLocal<?> k = e.get();
                
                //如果二者相等 直接替换并返回
                if (k == key) {
                    e.value = value;
                    return;
                }
                
                //如果为空 说明当前的threadLocal被清理了,直接替换并返回
                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
      
            //当前i位置没有值的话,直接生成一个Entry
            tab[i] = new Entry(key, value);
            // 维护size
            int sz = ++size;
            // 当数组大小大于等于扩容阈值(数组大小的三分之二)时,进行扩容
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }

需要注意几点的是:

  • 通过递增的AtomicInteger作为ThreadLocal的hashCode的;
  • 通过计算hashCode计算的索引位置i处,如果已经有值的话,会从i开始,通过+1,不断往后寻找,直到找到索引位置为空的地方,把当前ThreadLocal作为key放进去;

getEntry方法

private Entry getEntry(ThreadLocal<?> key) {
            // 计算索引值
            int i = key.threadLocalHashCode & (table.length - 1);
            // 获取索引处的entry
            Entry e = table[i];
           // e 不为空,并且 e 的 ThreadLocal 的内存地址和 key 相同,直接返回,否则就是没有找到,
           //继续通过 getEntryAfterMiss 方法找
            if (e != null && e.get() == key)
                return e;
            else
                // 这个取数据的逻辑,是因为 set 时数组索引位置冲突造成的
                return getEntryAfterMiss(key, i, e);
        }


  private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
           //暂存entry数组
            Entry[] tab = table;
           // 获取数组长度
            int len = tab.length;
      
           //遍历数组
            while (e != null) {
                // 内存地址一样,表示找到了 直接返回即可
                ThreadLocal<?> k = e.get();
                if (k == key)
                    return e;
                //如果为null的话 删除没用的key
                if (k == null)
                    expungeStaleEntry(i);
                else
                    //否计算出下一个索引的位置
                    i = nextIndex(i, len);
                //继续下一次遍历
                e = tab[i];
            }
          //如果最后entry数组遍历结束都没有找到,直接返回null
            return null;
        }

resize方法

        //扩容
        private void resize() {
            //暂存旧的entry数组
            Entry[] oldTab = table;
            //旧数组的容量
            int oldLen = oldTab.length;
            //新数组的容量
            int newLen = oldLen * 2;
            //初始化一个空的新数组
            Entry[] newTab = new Entry[newLen];
            //记录数量
            int count = 0;
            
            //开始复制
            for (int j = 0; j < oldLen; ++j) {
                //旧数组的一个节点
                Entry e = oldTab[j];
                //不为null的话开始进行复制
                if (e != null) {
                    //获取threadLocal
                    ThreadLocal<?> k = e.get();
                    //为null的话直接进行清除
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        //计算threadLocal在新entry中的索引位置
                        int h = k.threadLocalHashCode & (newLen - 1);
                        //如果该位置以及有值了,那么就寻找下一个索引位置,直到为空
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        //将值拷贝到新数组的位置
                        newTab[h] = e;
                        //更新数量
                        count++;
                    }
                }
            }
            
            //重新设置扩容时的阈值,新数组长度的2/3
            setThreshold(newLen);
            //维护size大小
            size = count;
            //将entry数组的引用指向新数组
            table = newTab;
        }

扩容时的逻辑也比较清晰:

  • 扩容时新数组大小为原来的两倍;
  • 扩容时没有线程安全问题,因为ThreadLocalMap是线程本身的一个属性,一个线程同一时刻只能对ThreadLocalMap进行操作,因为同一个线程执行业务逻辑时必然是串行的,那么操作ThreadLocalMap必然也是串行的;

ThreadLocal内存泄漏原因探究

在demo演示中我们提到过,在使用ThreadLocal的过程中,如果使用不当,最后可能会导致内存泄漏的问题,那么原因是什么呢?

先来明确一下内存泄漏的概念:

内存泄漏主要有两种情况,一是堆中申请的空间没有被释放;二是对象已不再被使用,但仍在内存中保留着。

先来看一下ThreadLocal和当前Thread在堆栈中的布局图吧。

注:上面的连接线中,实线代表强引用,虚线代表弱引用。

上面我们已经说过了,ThreadLocalMap中的Entry静态类继承了WeakReference类,它的Key是ThreadLocal对象的弱引用。那什么是弱引用呢?根据《深入理解Java虚拟机 第二版》中的定义:

弱引用是用来描述非必需对象的,被弱引用关联的对象只能生存到下一次垃圾回收发生之前。当垃圾回收器工作时,无论当前内存是否足够,都会回收掉只被弱引用关联的对象。

一般情况下,当我们不再使用threadLocal变量时,会手动将该变量置为null,这样堆中的threadLocal实例对象将不会再被任何强引用所指向,这样垃圾回收器就可以对其进行回收。此时,根据上图我们可知,在垃圾回收之后,ThreadLocalMap中的Entry的key已经变成了null。但是,如果此时线程还存活着继续运行,则key为null,但value指向Object(存在着强引用关系)的Entry对象仍然不会被回收,此时就会发生内存泄漏。当然,如果线程在完成任务之后就结束了生命周期,那么随后ThreadLocalMap和Entry也会随之消亡。但如果使用的是线程池,线程在完成任务之后会回放到线程池中从而继续被复用,那么此时value就会一直存在,导致内存泄漏。

那么问题来了,既然弱引用可能导致内存泄漏,那么改为强引用呢?

还是跟上面一样,一起来分析一下。当手动将threadLocal置为null时,虽然threadLocal Ref—ThreadLocal实例之间没有引用关系,但Entry中key与ThreadLocal之间仍然存在着强引用关系,就会产生key和value都不为null的Entry对象,但是ThreadLocal我们明明已经不需要了,且只有线程一直运行下去,那么threadLocal实例还是无法被回收,这样还是会发生内存泄漏

因而,虽然弱引用同样也会导致内存泄漏的问题,但ThreadLocal的set、get以及remove操作都会清除ThreadLocalMap中Entry数组中key为null的Entry,从而降低出现内存泄漏的风险。

posted @ 2020-07-19 21:29  Reecelin  阅读(236)  评论(0编辑  收藏  举报