代码改变世界

ThreadLocal 源码分析

2021-07-04 15:20  wang03  阅读(423)  评论(1编辑  收藏  举报

1、ThreadLocal 源码分析

  1. 在多线程开发中,我们经常会使用ThreadLocal来避免共享变量的竞争,提高效率。ThreadLocal底层到底是怎么实现的呢,今天就带大家一起来看看它底层实现。另外也会随便分析下网上讨论比较多的关于ThreadLocal内存泄漏等等究竟是怎么一回事

    我本地的jdk版本是11.0.8,不同版本的jdk,threadLocal源码实现可能有差别,不过大致是一样的。

  2. 首先看下我们一般都是怎么使用ThreadLocal的

​ 这是一段使用threadLocal的demo代码

public class ThreadLocalDemo {
    public final static ThreadLocal<String> threadLocal = new ThreadLocal<>() {
        @Override
        protected String initialValue() {
            return "initValue";
        }
    };

    public static void main(String[] args) throws InterruptedException {
        new Thread(() -> {
            System.out.println("init Value =" + threadLocal.get());
            threadLocal.set("abc");
            System.out.println("执行其他逻辑");
            String str = threadLocal.get();
            System.out.println(str);
            threadLocal.remove();
        }).start();
        Thread.currentThread().join();
    }
}

​ 这段代码比较简单,不用具体说threadLocal的这些方法都是干啥的了。我们直接顺着main方法里面的调用顺序一起去看看这些调用背后都是怎么实现的。

  • 如果直接把类的源码粘上来,做分析,感觉太零散了,看不到方法之间的调用关系,所以,这次就准备按照调用逻辑,一步一步来分析了


  • 首先我们main方法是在第11行调用了threadLocal.get(),这是我们第一次主动调用threadLocal的地方,那我们先从这里进去

    
//这就是ThreadLocal的get方法了,我们泛型参数是String,所以这里的T也就是String了,下面我们一行一行根据我们上面的Demo来分析下这个代码
    public T get() {
        //这一行就不用说了,Thread.currentThread()就是获取当前线程
        Thread t = Thread.currentThread();
        
        //getMap(t)这行是干什么呢?我们这个方法的代码比较简单,我就直接粘贴到下面了 
        
    	//ThreadLocalMap getMap(Thread t) {
        //	return t.threadLocals;
    	//}
        //上面这3行就是getMap(t)调用的代码了,比较简单,获取线程t上的threadLocals属性,这个属性是什么东西呢?我们再看看这个属性
        
        //ThreadLocal.ThreadLocalMap threadLocals = null;
        //上面这一行就是在Thread类中定义的threadLocals属性了,看样子是ThreadLocal类中的内部类ThreadLocalMap的一个实例,具体是啥,我们先不仔细看了,继续看后面代码吧  
        ThreadLocalMap map = getMap(t);
        
        //这里返回的map当前是null,因为Thread的threadLocals,从来没有初始化过
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        //因为map==null,所以会调用到这个setInitialValue这个方法,从这里返回,我们继续看看这个方法吧
        return setInitialValue();
    }

	
    private T setInitialValue() {
        //这里的initialValue方法是protected的,默认返回null,由于我们上面Demo第4行重写了initialValue方法,所以这里的调用就是我们上面的代码,这里的返回应该是上面我们Demo第5行的"initValue"
        T value = initialValue();
        //这几行代码和上面get方法是一致的
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        //map==null
        if (map != null) {
            map.set(this, value);
        } else {
            //会走到这里来,我们去这个createMap方法看看
            createMap(t, value);
        }
        //这里的this是通过匿名继承的ThreadLocal的,不会走到这个instanceof内部去
        if (this instanceof TerminatingThreadLocal) {
            TerminatingThreadLocal.register((TerminatingThreadLocal<?>) this);
        }
        //我们上面的Demo第4行的threadLocal.get()调用最终就会从这里返回,返回值是"initValue"
        return value;
    }
	//这里就是初始化Thread的threadLocals属性了,后面这个属性就不是null了,这里创建了个ThreadLocalMap对象,有两个参数,第一个this就是我们Demo中调用threadLocal.get()方法的对象,也就是Demo第4行的threadLocal对象,firstValue就是上面的字符串"initValue"。
//我们进到ThreadLocalMap构造方法去看看
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }
//
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
        //这里首先会创建一个Entry的数组,INITIAL_CAPACITY = 16;也就是这里创建一个大小是16的Entry数组,Entry是啥,我们下面再看,先看看这个构造方法的其他几行代码
            table = new Entry[INITIAL_CAPACITY];
            //threadLocalHashCode是ThreadLocal的成员变量,
            
            // private final int threadLocalHashCode = nextHashCode();
            //上面是它的定义,从上面可以看到每次创建ThreadLocal对象时就会初始化threadLocalHashCode,它的值是通过静态方法nextHashCode()赋值的,每调用一次nextHashCode返回值就在上次的基础上增加0x61c88647。
            //这里的i就是threadLocalHashCode的值和(INITIAL_CAPACITY - 1)进行与运算,这里的(INITIAL_CAPACITY - 1)值是15,计算结果在0-15之间,作为数组的下标
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            //这里创建个Entry对象,赋值给table中下标为i的
            table[i] = new Entry(firstKey, firstValue);
            //这里的size是table数组中元素的个数
            size = 1;
            //这里是设置threshold = len * 2 / 3;当size的值超过threshold时,table数组就会扩容成原来数组的2倍
            setThreshold(INITIAL_CAPACITY);
        }
//这里我们看看Entry对象
//这就是Entry的源码了,更简单。继承了WeakReference对象,这里会把构造方法的ThreadLocal入参包装成弱引用。具体啥是弱引用,下面粘贴上一段《深入理解java虚拟机》上面的描述。放到我们这里来说,就是ThreadLocal变量在没有其他地方引用(只有在Entry这里有引用),当下次垃圾回收的时候,就会被回收掉
        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;
			//
            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

下面是《深入理解java虚拟机》上面的关于引用的描述


到这里上面Demo第4行的源码就全部看完了,下面简单总结下执行逻辑。

  1. 首先进入ThreadLocal的get方法,获取当前线程的threadLocals变量,我们这个变量没有初始化过,所以这个变量为空,继续执行setInitialValue()方法,并从这里返回
  2. 在setInitialValue方法中首先调用我们Demo中重写ThreadLocal的initialValue方法,获取返回值。
  3. 继续调用createMap方法,
    • 创建ThreadLocalMap对象,内部创建Entry数组,将threadLocal对象包装成弱引用及initialValue方法的返回值创建Entry对象,填充到Entry数组数组中
    • 将创建的ThreadLocalMap对象赋值给当前线程的threadLocals变量
  4. 将2中获取的值返回

现在我们看下Demo中的第12行threadLocal.set("abc")

//这就是ThreadLocal的set方法了,和get方法差不多
    public void set(T value) {
        //获取当前线程
        Thread t = Thread.currentThread();
        //获取当前线程的threadLocals的属性,在上面get方法已经设置过这个属性了 ,所以这里不为空了
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            //map不为null,就会走到这里了,这里的this就是我们Demo中继承ThreadLocal的匿名内部类,value就是"abc",下面我们重点去看下这个方法
            map.set(this, value);
        } else {
            createMap(t, value);
        }
    }
   
   //这个方法是ThreadLocal的内部类ThreadLocalMap的。
   private void set(ThreadLocal<?> key, Object value) {

            // We don't use a fast path as with get() because it is at
            // least as common to use set() to create new entries as
            // it is to replace existing ones, in which case, a fast
            // path would fail more often than not.
			//这个是ThreadLocalMap构造方法创建的table,是Entry数组
            Entry[] tab = table;
            int len = tab.length;
       		//这个是获取根据key计算在数组中的下标,考虑到有可能两个key计算出来的是同一个i,所以数组中下标i不一定就是我们需要的key,会从当前i的下标向后遍历。同样根据key获取值的时候,也会有类似情况
            int i = key.threadLocalHashCode & (len-1);
			//获取下标i的值Entry,进行遍历,直到entry.key==我们的入参key或者entry==null或者entry.key==null(这个场景就是key其他地方没有引用了,只有Entry有对应的弱引用,在下次垃圾回收后,entry.key就会==null)的情况下,结束循环
       
       //注意:这里不会出现数组中所有下标都满了,且entry.k!=key的场景,因为下面的后面会判断size>=threshold,进行扩容
            for (Entry e = tab[i];
                 e != null;
                 //考虑到可能会有冲突,也就是上面说的两个key计算出来的是同一个i,所以key有可能不存在对应下标i的位置
                 //nextIndex(i, len)向后获取下一个位置是循环的,如果达到i==len-1;这时i=0;就会从数组头元素开始,后面几个方法说的向前遍历,向后遍历都是类似的
                 e = tab[i = nextIndex(i, len)]) {
                //获取Entry中的key,这里是个弱引用,通过get方法获取弱引用的实际对象。
                ThreadLocal<?> k = e.get();
				//找到了对应的key,就设置新值,返回
                if (k == key) {
                    e.value = value;
                    return;
                }
				//k==null,这个场景说了,就是外部没有对ThreadLocal对象的其他引用了(强引用),GC释放后,k就是null,jiu就会走到这里
                if (k == null) {
                    //这个方法会从当前下标i开始,删除过期的entry,重新设置新的Entry到i的位置
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }
			//走到这里,这时下标i对应的Entry已经是null了
            tab[i] = new Entry(key, value);
       		//数组中元素个数+1
            int sz = ++size;
       		//这里会移除一些过期的entry,判断sz>= threshold,如果成立,就扩容数组
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }


//这个方法主要是

 private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                       int staleSlot) {
     //这里的for循环都不会陷入死循环,因为上面有 sz >= threshold
     	//获取Entry数组
            Entry[] tab = table;
            int len = tab.length;
            Entry e;

            // Back up to check for prior stale entry in current run.
            // We clean out whole runs at a time to avoid continual
            // incremental rehashing due to garbage collector freeing
            // up refs in bunches (i.e., whenever the collector runs).
            int slotToExpunge = staleSlot;
     //从当前入参staleSlot开始向前遍历,直到下标i对应的Entry==null
            for (int i = prevIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = prevIndex(i, len))
                //随着i向前遍历,slotToExpunge的值会逐步更新,直到staleSlot和它之前第一个tab[i]==null之间,e.get()==null的下标,如果threadLocal对象外部一直有引用,那e.get()==null,就不会为null,也不会走到这个if分支
                if (e.get() == null)
                    slotToExpunge = i;

            // Find either the key or trailing null slot of run, whichever
            // occurs first
     //从当前入参staleSlot开始向后遍历,直到下标i对应的Entry==null
            for (int i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                //获取到对应的Entry的threaLocal变量
                ThreadLocal<?> k = e.get();

                // If we find key, then we need to swap it
                // with the stale entry to maintain hash table order.
                // The newly stale slot, or any other stale slot
                // encountered above it, can then be sent to expungeStaleEntry
                // to remove or rehash all of the other entries in run.
                //如果k==key,这时就是更新value就可以了
                if (k == key) {
                    e.value = value;
					//这里会进行元素交换,把找到的Entry换到入参staleSlot的位置
                    //注意当前staleSlot元素位置是过期的,需要清理的,交换后i下标位置的元素就是需要清理的
                    tab[i] = tab[staleSlot];
                    tab[staleSlot] = e;

                    // Start expunge at preceding stale entry if it exists
                    //这个条件成立的话,需要第一个for循环中if分支没有进入,也就是不存在staleSlot和它之前第一个tab[i]==null之间,e.get()==null
                    if (slotToExpunge == staleSlot)
                        //走到这里说明staleSlot之前到tab[i]==null之间没有无效元素,我们上面进行了元素交换,这时i就是第一个过期的元素了
                        slotToExpunge = i;
                    //清理无效的元素,slotToExpunge就是staleSlot之前到tab[i]==null到之后到tab[i]==null这段元素之间第一个过期元素的下标位置
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                    return;
                }

                // If we didn't find stale entry on backward scan, the
                // first stale entry seen while scanning for key is the
                // first still present in the run.
                //k==null说明当前数组已经有entry失效了,slotToExpunge == staleSlot说明staleSlot之前没有失效的,这时就要清理后面的过期元素
                if (k == null && slotToExpunge == staleSlot)
                    slotToExpunge = i;
            }

            // If key not found, put new entry in stale slot
     		//入参的staleSlot下标已经个个过期的值,将value设置为null,重新赋值个新的entry
            tab[staleSlot].value = null;
            tab[staleSlot] = new Entry(key, value);

            // If there are any other stale entries in run, expunge them
            if (slotToExpunge != staleSlot)
                //走到这里说明staleSlot向前遍历或者向后遍历中出现了k==null,这时需要清理过期的entry
                cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
        }
//从staleSlot位置开始,直到数组中entry==null。清理在这过程中数组过期的元素,并调整那些元素有效,但是下标位置不正确的元素位置
        private int expungeStaleEntry(int staleSlot) {
            Entry[] tab = table;
            int len = tab.length;

            // expunge entry at staleSlot
            //清空对应staleSlot对应的元素,size-1
            tab[staleSlot].value = null;
            tab[staleSlot] = null;
            size--;

            // Rehash until we encounter null
            Entry e;
            int i;
            //在staleSlot开始向后遍历,直到数组中entry==null
            for (i = nextIndex(staleSlot, len);
                 (e = tab[i]) != null;
                 i = nextIndex(i, len)) {
                ThreadLocal<?> k = e.get();
                //如果k==null,就清空元素,size-1
                if (k == null) {
                    e.value = null;
                    tab[i] = null;
                    size--;
                } else {
                    //h是算法计算出来元素应该在数组中的下标位置
                    int h = k.threadLocalHashCode & (len - 1);
                    //i是元素的直接存储下标,由于碰撞的原因元组有可能不是存在算法计算出来的下标位置
                    if (h != i) {
                        //h!=i说明元素根据算法出来的下标和实际存储下标不一致,这时由于我们清空了一些过期元素,这时就需要重新调整这些有效的,算法计算出来的下标,和实际存储下标不一致的元素位置
                        tab[i] = null;
						
                        // Unlike Knuth 6.4 Algorithm R, we must scan until
                        // null because multiple entries could have been stale.
                        //从h开始向后遍历,找到数组中的null位置,将当前下标i位置上的元素设置到对应位置
                        while (tab[h] != null)
                            h = nextIndex(h, len);
                        tab[h] = e;
                    }
                }
            }
            //这里的i就是for循环中退出条件,staleSlot开始向后遍历,数组中第一个entry==null的位置
            return i;
        }
//从i开始清理无效的元素     
//注意下标i的位置不是无效元素,要不下标i位置==null,要不是有效元素
		private boolean cleanSomeSlots(int i, int n) {
            boolean removed = false;
            Entry[] tab = table;
            int len = tab.length;
            do {
                //获取i后面的位置
                i = nextIndex(i, len);
                Entry e = tab[i];
                //如果i位置的元素是过期的,就执行清理
                if (e != null && e.get() == null) {
                    n = len;
                    removed = true;
                    //这里上面说的清理过期元素,调整位置不正确的有效元素的位置
                    i = expungeStaleEntry(i);
                }
                //这里的n>>>=1主要是扫描次数,应该是出于效率的考虑
            } while ( (n >>>= 1) != 0);
            return removed;
        }
//这就是扩容了      
		private void rehash() {
            //这里首先清理数组中所有的过期元素,同时会调整不正确元素的下标
            expungeStaleEntries();

            // Use lower threshold for doubling to avoid hysteresis
            //由于我们上面会清理过期元素,所以size有可能变小,有可能就不需要扩容了,所以这里重新判断是否需要扩容,如果需要就进行扩容
            if (size >= threshold - threshold / 4)
                resize();
        }

//这里就是扩容了,数组长度变成原来的2倍,重新设置threshold和size
        private void resize() {
            Entry[] oldTab = table;
            int oldLen = oldTab.length;
            //长度设置成原来的2被
            int newLen = oldLen * 2;
            Entry[] newTab = new Entry[newLen];
            int count = 0;

            for (Entry e : oldTab) {
                if (e != null) {
                    ThreadLocal<?> k = e.get();
                    //继续清理扩容过程中过期的元素
                    if (k == null) {
                        e.value = null; // Help the GC
                    } else {
                        //在新数组中设置元素位置
                        int h = k.threadLocalHashCode & (newLen - 1);
                        while (newTab[h] != null)
                            h = nextIndex(h, newLen);
                        newTab[h] = e;
                        count++;
                    }
                }
            }

            setThreshold(newLen);
            size = count;
            table = newTab;
        }

上面就是Demo第12行 threadLocal.set("abc")方法涉及的所有代码了 。

下面总结下:

1.获取当前线程的threadLocals变量,如果不存在就调用createMap(t, value),创建并设置值,我们Demo第11行threadLocal.get()就走的这里

2.如果当前线程的threadLocals变量存在,就调用map.set(this, value)设置值

3.在获取值的过程中首先根据调用者threadLocal对象计算出应该存储在数组中的下标,

  • 如果当前下标对应的数组元素是null,就新生成Entry元素放入数组下标i的位置,数组size+1,判断是否需要扩容,如果需要就对数组进行扩容
  • 如果当前下标对应的数组元素不是null,就获取对应位置的元素进行判断,如果entry.key == 我们调用的threalocal对象,就更新value返回。如果entry.key==null,说明元素释放了就调用replaceStaleEntry进行处理。如果这两种情况都不是,那就继续从当前下标开始向后遍历,继续判断处理

最后看下Demo中第16行threadLocal.remove();

//获取当前线程的threadLocals变量,如果之前调用过get,set方法,那这时这个变量就不是null了
     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null) {
             //走到这里看看移除元素的方法,这里的this,是我们方法的调用者,也就是threadLocal实例对象
             m.remove(this);
         }
     }
        //这个方法就比较简单了,计算threadLocal下标,从这个位置开始向后遍历,对应元素(这里使用的是key==,所以找到的肯定是同一个对象),将Entry中引用的threadLocal对象清空,再执行清理过期元素的动作
        
        private void remove(ThreadLocal<?> key) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);
            for (Entry e = tab[i];
                 e != null;
                 e = tab[i = nextIndex(i, len)]) {
                if (e.get() == key) {
                    e.clear();
                    expungeStaleEntry(i);
                    return;
                }
            }
        }


上面remove方法也讲完了,这个比较简单。就是找到t.threadLocals中对应的元素Entry,调用e.clear()清理掉entry引用的threadLocal变量(这时,通过e.get()获取的threadLocal元素就是null了),然后调用expungeStaleEntry执行过期元素的清理。


2、日常使用注意

上面就是我们日常使用threadLocal中方法的源码了 ,通过上面代码对于调用的内部细节我们也基本看到了,下面说一些日常使用过程中需要注意的地方。

它们的结构简单画个图,就是下面的样子了

文章开始的Demo只是一些日常使用,通过上面的源码阅读,我们其实还是可以看到一些其他的使用方式和注意地方

  • 我们一般是在多个线程中是使用同一个threadLocal对象,其实我们也可以在一个线程中使用多个threadLocal对象(同样多个线程也就可以使用多个threadLocal对象),就像下面这样
       //创建多个ThreadLocal对象
		ThreadLocal<String> threadLocal0 = new ThreadLocal<>();
        ThreadLocal<String> threadLocal1 = new ThreadLocal<>();
        new Thread(() -> {
            //在一个线程中使用多个threadLocal对象
            //通过上面源码我们看到threadLocal对象只是用来寻找当前线程中threadLocals变量中数组的位置,并读取或者设置值。由于查找元素下标用的是==,所以无论怎么找到的都是同一个对象实例
            threadLocal0.set("th0");
            threadLocal1.set("th1");
            System.out.println(threadLocal1.get());
            System.out.println(threadLocal0.get());
        }).start();
        Thread.currentThread().join();
  • 内存释放相关

    上面的源码分析部分,我们看到了存储到threadLocals中Entry的threadLocal变量是弱引用了,而弱引用在下次gc的时候就会被清理,threadLocal被清理了,上面的清理过期元素的方法就会把对应Entry进行清理。但是这里有个前提,threadLocal只有弱引用的时候才会被清理,如果有强引用存在,就不会被清理。

    看下面的例子

        ThreadLocal<String> threadLocal0 = new ThreadLocal<>();
        new Thread(() -> {
         	//我们第一行定义的地方,threadLocal对象是个强引用,只要这个强引用存在,threadLocals中对应的Entry就不会被清理
            threadLocal0.set("th0");
            //这个也很好理解,如果被清理掉了,那我们这里的get方法就获取不到值了,我们见过之前set后,对应get获取不到值的情况吗,肯定没有么
            System.out.println(threadLocal0.get());
            //所以我们想要清理存放的元素Entry就需要调用threadLocal0.remove()进行清理,或者是调用 threadLocal0=null;方法,这样当前 threadLocal0就只有线程threadLocals中对应的弱引用,这时,gc才会清理掉entry.k,对应元素才会在清理元素方法中清理掉
        }).start();
        Thread.currentThread().join();

另外由于ThreadLocal的操作都是基于当前线程的threadLocals变量的,如果当前线程不存在了,被清理掉了 ,只要threadLocals变量和内部Entry的key和value,如果没有其他地方进行引用,也都会被gc清理掉。

以上就是关于ThreadLocal的全部内容了。


另外父线程和子线程之间传递参数可以通过inheritableThreadLocals,这个变量的设置实在Thread的构造方法中,和threadLocals一样,都是ThreadLocal.ThreadLocalMap的实例。这个用的不多,这里就不说了。