JDK源码之ThreadLocal 类分析

一 概述

ThreadLocal类提供了线程局部 (thread-local) 变量。这些变量与普通变量不同,每个线程都可以通过其 get 或 set方法来访问自己的独立初始化的变量副本
ThreadLocal 实例通常是类中的 private static 字段,它们希望将状态与某一个线程(例如,用户 ID 或事务 ID)相关联,类里面定义了一个map,key为ThreadLocal,value为值,存储每一个线程的变量
Thread类里面引用了这个内部类的map实例,从而达到线程隔离

二 源码分析

属性

        // 获取下一个hashCode值
        private final int threadLocalHashCode = nextHashCode();

        // 获取下一个hashCode,ThreadLocal 中使用了斐波那契散列法,来保证哈希表的离散度
        private static int nextHashCode() {
            // 每一次获取值时,加上HASH_INCREMENT为下一次获取的值
            return nextHashCode.getAndAdd(HASH_INCREMENT);
        }
        // 开始为0,每次创建ThreadLocal实例,值都会累加
        private static AtomicInteger nextHashCode = new AtomicInteger();

        // 加数值: 0x61c88647 = 2^32 * 黄金分割比(0.618),
        // 斐波那契数列: 当n趋向于无穷大时,前一项与后一项的比值越来越逼近黄金比,即0.618
        private static final int HASH_INCREMENT = 0x61c88647;
        
        //提供给子类初始化值使用
        protected T initialValue() {
            return null;
        }

核心方法

        // 通过Supplier函数初始化变量值,使用子类构造器进行初始化
        public static <S> ThreadLocal<S> withInitial(Supplier<? extends S> supplier) {
            return new ThreadLocal.SuppliedThreadLocal<>(supplier);
        }

        // ThreadLocal扩展子类,接收一个Supplier函数进行值初始化
        static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {
            private final Supplier<? extends T> supplier;
            SuppliedThreadLocal(Supplier<? extends T> supplier) {
                this.supplier = Objects.requireNonNull(supplier);
            }
            // 覆写父类方法
            @Override
            protected T initialValue() {
                return supplier.get();
            }
        }

        public ThreadLocal() {}

        //  返回当前线程的变量副本值
        public T get() {
            Thread t = Thread.currentThread();
            // 获取线程实例引用的ThreadLocalMap
            ThreadLocal.ThreadLocalMap map = getMap(t);
            if (map != null) {
                ThreadLocal.ThreadLocalMap.Entry e = map.getEntry(this);
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    T result = (T)e.value;
                    return result;
                }
            }
            // 值为空,重新初始化
            return setInitialValue();
        }

        ThreadLocal.ThreadLocalMap getMap(Thread t) { return t.threadLocals; }

        private T setInitialValue() {
            T value = initialValue();
            Thread t = Thread.currentThread();
            ThreadLocal.ThreadLocalMap map = getMap(t);
            // 赋值
            if (map != null)
                map.set(this, value);
            else
                createMap(t, value);
            return value;
        }

        void createMap(Thread t, T firstValue) {
            t.threadLocals = new ThreadLocal.ThreadLocalMap(this, firstValue);
        }

        // 设置ThreadLocal的值,还是设置的thread实例引用的map
        public void set(T value) {
            Thread t = Thread.currentThread();
            ThreadLocal.ThreadLocalMap map = getMap(t);
            if (map != null)
                map.set(this, value);
            else
                createMap(t, value);
        }

        //清除值,key是ThreadLocal,所以直接使用this
        public void remove() {
            ThreadLocal.ThreadLocalMap m = getMap(Thread.currentThread());
            if (m != null)
                // 调用map的remove方法
                m.remove(this);
        }

        // Thread 类里面创建线程时候初始化init方法调用的,主要是复制参数中的table构建一个新map返回
        static ThreadLocal.ThreadLocalMap createInheritedMap(ThreadLocal.ThreadLocalMap parentMap) {
            return new ThreadLocal.ThreadLocalMap(parentMap);
        }

        // 只允许子类调用此方法,ThreadLocal调用直接抛异常
        T childValue(T parentValue) { throw new UnsupportedOperationException(); }

三 静态内部类ThreadLocalMap

ThreadLocal中的静态内部类ThreadLocalMap,这个类本质上是一个map,和HashMap之类的实现相似,依然是key-value的形式,其中有一个内部类Entry,其中key可以看做是ThreadLocal实例,
在ThreadLocal中并没有对于ThreadLocalMap的引用,ThreadLocalMap的引用在Thread类中
每个线程在向ThreadLocal里塞值的时候,其实都是向自己所持有的ThreadLocalMap里塞入数据;
读的时候同理,首先从自己线程中取出自己持有的ThreadLocalMap,然后再根据ThreadLocal引用作为key取出value,
基于以上描述,ThreadLocal实现了变量的线程隔离

Entry

            /**
             * Entry继承WeakReference,(弱引用:当一个对象仅仅被weak reference指向, 而没有任何其他strong reference指向的时候, 如果GC运行, 那么这个对象就会被回收)
             * 并且用ThreadLocal作为key.如果key为null
             * (entry.get() == null)表示key不再被引用,表示ThreadLocal对象被回收
             * 因此这时候entry也可以从table从清除。
             */
            static class Entry extends WeakReference<ThreadLocal<?>> {
                Object value;
                Entry(ThreadLocal<?> k, Object v) {
                    super(k);
                    value = v;
                }
            }

            // 初始容量
            private static final int INITIAL_CAPACITY = 16;

            // 存放数据的数组
            private Entry[] table;

            //  数组里面entrys的个数,可以用于判断table当前使用量是否超过负因子
            private int size = 0;

            // 进行扩容的阈值,表使用量大于它的时候进行扩容
            private int threshold; // Default to 0

            // 设置阈值为参数的三分之二
            private void setThreshold(int len) { threshold = len * 2 / 3;}

核心方法

            /**
             * ThreadLocalMap使用线性探测法来解决哈希冲突,线性探测法的地址增量di = 1, 2, ... , m-1,其中,i为探测次数。
             * 该方法一次探测下一个地址,直到有空的地址后插入,若整个空间都找不到空余的地址,则产生溢出。
             * 假设当前table长度为16,也就是说如果计算出来key的hash值为14,如果table[14]上已经有值,并且其key与当前key不一致,那么就发生了hash冲突,
             * 这个时候将14加1得到15,取table[15]进行判断,这个时候如果还是冲突会继续下一个,如果是最后的则重新回到0,取table[0],以此类推,直到可以插入。
             * 可以把table看成一个环形数组
             */
            // 获取环形数组的下一个索引
            private static int nextIndex(int i, int len) {
                return ((i + 1 < len) ? i + 1 : 0);
            }

            // 获取环形数组的上一个索引
            private static int prevIndex(int i, int len) {
                return ((i - 1 >= 0) ? i - 1 : len - 1);
            }

            // 构造器初始化
            ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
                // 初始化table
                table = new Entry[INITIAL_CAPACITY];
                // 计算索引, & (INITIAL_CAPACITY - 1),这是取模的一种方式,对于2的幂作为模数取模,用此代替%(2^n),目的是均匀分布在数组中
                int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
                // 设置值
                table[i] = new Entry(firstKey, firstValue);
                size = 1;
                // 设置阈值
                setThreshold(INITIAL_CAPACITY);
            }

            /**
             * 构建一个包含所有parentMap中Inheritable ThreadLocals的ThreadLocalMap返回
             * 该函数只被 createInheritedMap() 调用.即只在Thread类的init方法里面调用(init方法调用了createInheritedMap()),目的是将父线程的变量值复制到当前线程中
             */
            private ThreadLocalMap(ThreadLocalMap parentMap) {
                Entry[] parentTable = parentMap.table;
                int len = parentTable.length;
                setThreshold(len);
                table = new Entry[len];
                // 逐一复制 parentMap 的记录到当前线程中
                for (int j = 0; j < len; j++) {
                    Entry e = parentTable[j];
                    if (e != null) {
                        //此处获取的都是InheritableThreadLocal类,ThreadLocal的子类,用于父线程变量传递的类
                        ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                        if (key != null) {
                            // 如果是ThreadLocal,就会抛异常
                            Object value = key.childValue(e.value);
                            Entry c = new Entry(key, value);
                            int h = key.threadLocalHashCode & (len - 1);
                            while (table[h] != null)
                                h = nextIndex(h, len);
                            table[h] = c;
                            size++;
                        }
                    }
                }
            }


            private Entry getEntry(ThreadLocal<?> key) {
                // 计算数组下标
                int i = key.threadLocalHashCode & (table.length - 1);
                Entry e = table[i];
                if (e != null && e.get() == key)
                    return e;
                else
                    // e不为空,但是key不相等时候再找
                    return getEntryAfterMiss(key, i, e);
            }

            private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
                Entry[] tab = table;
                int len = tab.length;
                // 线性探测法,循环查找
                while (e != null) {
                    ThreadLocal<?> k = e.get();
                    if (k == key)
                        // 找到就返回
                        return e;
                    if (k == null)
                        // 如果key为null,则清除掉无效的entry
                        expungeStaleEntry(i);
                    else
                        // 环形向后扫描
                        i = nextIndex(i, len);
                    e = tab[i];
                }
                return null;
            }


            private void set(ThreadLocal<?> key, Object value) {
                Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
                /**
                 * 根据获取到的索引进行循环,如果当前索引上的table[i]不为空,在没有return的情况下,
                 * 就使用nextIndex()获取下一个地址,即线性探测法。
                 */
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    ThreadLocal<?> k = e.get();
                    if (k == key) {
                        // key相同则更新value,set成功
                        e.value = value;
                        return;
                    }
                    /**
                     * table[i]上的key为空,说明被回收了
                     * 这个时候说明改table[i]可以重新使用,用新的key-value将其替换,并删除其他无效的entry
                     */
                    if (k == null) {
                        replaceStaleEntry(key, value, i);
                        return;
                    }
                }

                tab[i] = new Entry(key, value);
                int sz = ++size;
                if (!cleanSomeSlots(i, sz) && sz >= threshold)
                    rehash();
            }

            /**
             * Remove the entry for key.
             */
            private void remove(ThreadLocal<?> key) {
               Entry[] tab = table;
                int len = tab.length;
                int i = key.threadLocalHashCode & (len-1);
                for (Entry e = tab[i];
                     e != null;
                     e = tab[i = nextIndex(i, len)]) {
                    if (e.get() == key) {
                        e.clear();
                        expungeStaleEntry(i);
                        return;
                    }
                }
            }

            //key-value替换 staleSlot位置的值
            private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                                           int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;
                Entry e;
                /**
                 * 根据传入的无效entry的位置(staleSlot),向前扫描,当前已经是无效的,可能前面也有无效的,找到最开始无效的一个进行替换,维护线性探测法
                 * 一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
                 * 直到找到一个无效entry,或者扫描完也没找到
                 */
                int slotToExpunge = staleSlot;
                for (int i = prevIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = prevIndex(i, len))
                    if (e.get() == null)
                        slotToExpunge = i;

                /**
                 * 向后扫描一段连续的entry
                 */
                for (int i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();

                    /**
                     * 如果找到了key,直接替换,也就是与table[staleSlot]进行替换,此时staleSlot位置已经替换了对应key的值
                     */
                    if (k == key) {
                        e.value = value;
                        tab[i] = tab[staleSlot];
                        tab[staleSlot] = e;

                        //如果向前查找没有找到无效entry,则更新slotToExpunge为当前值i
                        if (slotToExpunge == staleSlot)
                            slotToExpunge = i;
                        // 此时,staleSlot位置已经设置值了,应该从 slotToExpunge 位置开始往后清除
                        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
                        return;
                    }
                    /**
                     * 如果向前查找没有找到无效entry,并且当前向后扫描的entry无效,则更新slotToExpunge为当前值i
                     */
                    if (k == null && slotToExpunge == staleSlot)
                        slotToExpunge = i;
                }

                /**
                 * 如果没有找到key,也就是说key之前不存在table中
                 * 就直接最开始的无效entry——tab[staleSlot]上直接新增即可
                 */
                tab[staleSlot].value = null;
                tab[staleSlot] = new Entry(key, value);
                /**
                 * slotToExpunge != staleSlot,说明存在其他的无效entry需要进行清理。
                 */
                if (slotToExpunge != staleSlot)
                    cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            }

            /**
             * 连续段清除
             * 根据传入的staleSlot,清理对应的无效entry——table[staleSlot],
             * 并且根据当前传入的staleSlot,向后扫描一段连续的entry(这里的连续是指一段相邻的entry并且table[i] != null),
             * 对可能存在hash冲突的entry进行rehash,并且清理遇到的无效entry.
             * @param staleSlot key为null,需要无效entry所在的table中的索引
             * @return 返回下一个为空的solt的索引。
             */
            private int expungeStaleEntry(int staleSlot) {
                Entry[] tab = table;
                int len = tab.length;

                tab[staleSlot].value = null;
                tab[staleSlot] = null;
                size--;

               Entry e;
                int i;
                for (i = nextIndex(staleSlot, len);
                     (e = tab[i]) != null;
                     i = nextIndex(i, len)) {
                    ThreadLocal<?> k = e.get();
                    // key为null,直接清理
                    if (k == null) {
                        e.value = null;
                        tab[i] = null;
                        size--;
                    } else {
                        /**
                         * 计算出来的索引h,与其现在所在位置的索引——i不一致,置空当前的table[i]
                         * 从h开始向后线性探测到第一个空的slot,把当前的entry挪过去。
                         */
                        int h = k.threadLocalHashCode & (len - 1);
                        if (h != i) {
                            tab[i] = null;
                            while (tab[h] != null)
                                h = nextIndex(h, len);
                            tab[h] = e;
                        }
                    }
                }
                return i;
            }

            /**
             * 启发式的扫描清除,扫描次数由传入的参数n决定
             * @param i 从i向后开始扫描(不包括i,因为索引为i的Slot肯定为null)
             * @param n 控制扫描次数,正常情况下为 log2(n) ,
             * 如果找到了无效entry,会将n重置为table的长度len,进行段清除。
             * map.set()点用的时候传入的是元素个数,replaceStaleEntry()调用的时候传入的是table的长度len
             */
            private boolean cleanSomeSlots(int i, int n) {
                boolean removed = false;
                Entry[] tab = table;
                int len = tab.length;
                do {
                    i = nextIndex(i, len);
                   Entry e = tab[i];
                    if (e != null && e.get() == null) {
                        n = len;
                        removed = true;
                        i = expungeStaleEntry(i);
                    }
                //无符号的右移动,可以用于控制扫描次数在log2(n)
                } while ( (n >>>= 1) != 0);
                return removed;
            }

            private void rehash() {
                expungeStaleEntries();
                /**
                 * threshold = 2/3 * len
                 * 所以threshold - threshold / 4 = 1en/2
                 * 这里主要是因为上面做了一次全清理所以size减小,需要进行判断。
                 * 判断的时候把阈值调低了。
                 */
                if (size >= threshold - threshold / 4)
                    resize();
            }

            /**
             * 扩容,扩大为原来的2倍(这样保证了长度为2的冥)
             */
            private void resize() {
                Entry[] oldTab = table;
                int oldLen = oldTab.length;
                int newLen = oldLen * 2;
                Entry[] newTab = new Entry[newLen];
                int count = 0;

                for (int j = 0; j < oldLen; ++j) {
                    Entry e = oldTab[j];
                    if (e != null) {
                        ThreadLocal<?> k = e.get();
                        if (k == null) {
                            e.value = null; // Help the GC
                        } else {
                            int h = k.threadLocalHashCode & (newLen - 1);
                            while (newTab[h] != null)
                                h = nextIndex(h, newLen);
                            newTab[h] = e;
                            count++;
                        }
                    }
                }

                setThreshold(newLen);
                size = count;
                table = newTab;
            }

            // 清空全部Entry
            private void expungeStaleEntries() {
                Entry[] tab = table;
                int len = tab.length;
                for (int j = 0; j < len; j++) {
                    Entry e = tab[j];
                    if (e != null && e.get() == null)
                        expungeStaleEntry(j);
                }
            }
posted @ 2020-02-07 16:18  侯小厨  阅读(273)  评论(0编辑  收藏  举报
Fork me on Gitee