Java线程之ThreadLocal实现原理

一、简述

1.1 什么是ThreadLocal?

ThreadLocal类顾名思义可以理解为线程本地变量也有叫线程本地存储。作用是提供线程内的局部变量,也就是说如果定义了一个ThreadLocal,每个线程往这个ThreadLocal中读写是线程隔离,互相之间不会影响的。它提供了一种将可变数据通过每个线程有自己的独立副本从而实现线程封闭的机制。

1.2 实现思路

Thread类有一个类型为ThreadLocal.ThreadLocalMap的实例变量threadLocals,也就是说每个线程有一个自己的ThreadLocalMapThreadLocalMap有自己的独立实现,可以简单地将它的key视作ThreadLocalvalue为代码中放入的值(实际上key并不是ThreadLocal本身,而是它的一个弱引用)。每个线程在往某个ThreadLocal里塞值的时候,都会往自己的ThreadLocalMap里存,读也是以某个ThreadLocal作为引用,在自己的map里找对应的key,从而实现了线程隔离。

2019061906258.png

一个ThreadLocal只能存储一个Object对象,如果需要存储多个Object对象那么就需要多个ThreadLocal

2019061926179.png

二、案例

public class ThreadLocalExample {

    public static class MyRunnable implements Runnable {

        private ThreadLocal<Integer> threadLocal = new ThreadLocal<>();

        @Override
        public void run() {
            threadLocal.set((int) (Math.random() * 100D));

            try {
                Thread.sleep(2000);
            } catch (InterruptedException e) {
            }
            System.out.println(threadLocal.get());
        }
    }

    public static void main(String[] args) {
        MyRunnable sharedRunnableInstance = new MyRunnable();

        Thread thread1 = new Thread(sharedRunnableInstance);
        Thread thread2 = new Thread(sharedRunnableInstance);

        thread1.start();
        thread2.start();

        thread1.join(); //wait for thread 1 to terminate
        thread2.join(); //wait for thread 2 to terminate
    }
}

本示例创建一个传递给两个不同线程的MyRunnable实例。两个线程都执行run()方法,从而在ThreadLocal实例上设置不同的值。如果对set()调用的访问已经同步,并且它不是ThreadLocal对象,则第二个线程将覆盖第一个线程设置的值。但是由于它是一个ThreadLocal对象,因此两个线程无法看到对方的值。所以他们设定并获得不同的值。

三、源码分析

3.1 内部类ThreadLocalMap

在了解ThreadLocal的之前先来理解一下ThreadLocalMap类。

3.1.1 简述

ThreadLocalMapThreadLocal的静态内部类。虽然类名中带有map字样,但是实际上并不是Map接口的子类。ThreadLocalMap本质上是数组。

ThreadLocalMap提供了一种为ThreadLocal定制的高效实现,并且自带一种基于弱引用的垃圾清理机制。每个Thread实例对象都会维护多个ThreadLocalMap对象:

ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

在默认情况下,线程对象的ThreadLocalMap对象们都是未初始化的,需要使用createMap()方法去初始化:

void createMap(Thread t, T firstValue) {
    //此处ThreadLocal将自身作为key值存入了map中
    t.threadLocals = new ThreadLocalMap(this, firstValue);
}

可以想到的是,此处是为了提高线程的性能,而设计了一个懒加载(Lazy)的调用模式。但是实际上这是理想情况,对于主线程来说,CollectionsStringCoding等的工具类在jdk加载时期就会调用ThreadLocal,所以ThreadLocalMap肯定会被创建好。

3.1.2 数据结构

static class Entry extends WeakReference<ThreadLocal<?>> {
    //值写入ThreadLocal中
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

EntryThreadLocalMap的节点,继承自WeakReference类表示是弱引用,定义了一个类型为Objectvalue,用于存放ThreadLocal的值。

3.1.3 弱引用与内存泄漏

ThreadLocalMap中使用了弱引用,不过弱引用只是针对key。每个key都弱引用指向threadlocal。当把threadlocal实例置为null以后,没有任何强引用指向threadlocal实例,所以threadlocal将会被gc回收。但是,我们的value却不能回收,因为存在一条从current thread连接过来的强引用。只有当前thread结束以后,current thread就不会存在栈中,强引用断开,Current Thread,Map,value将全部被GC回收。

所以得出一个结论就是只要这个线程对象被gc回收,就不会出现内存泄漏,但在threadLocal设为null和线程结束这段时间不会被回收的,就发生了我们认为的内存泄漏。其实这是一个对概念理解的不一致,也没什么好争论的。最要命的是线程对象不被回收的情况,这就发生了真正意义上的内存泄漏。比如使用线程池的时候,线程结束是不会销毁的,会再次使用的。就可能出现内存泄漏。

下面我们用代码来验证一下,

public class ThreadPoolProblem {

    static ThreadLocal<AtomicInteger> sequencer = new ThreadLocal<AtomicInteger>() {

        @Override
        protected AtomicInteger initialValue() {
            return new AtomicInteger(0);
        }
    };

    static class Task implements Runnable {

        @Override
        public void run() {
            AtomicInteger s = sequencer.get();
            int initial = s.getAndIncrement();
            // 期望初始为0
            System.out.println(initial);
        }
    }

    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        executor.execute(new Task());
        executor.execute(new Task());
        executor.execute(new Task());
        executor.shutdown();
    }
}

对于异步任务Task而言,它期望的初始值应该总是0,但运行程序,结果却为:

0
0
1

第三次执行异步任务,结果就不对了,为什么呢?因为线程池中的线程在执行完一个任务,执行下一个任务时,其中的ThreadLocal对象并不会被清空,修改后的值带到了下一个异步任务。

那怎么办呢?有几种思路:

  • 第一次使用ThreadLocal对象时,总是先调用set设置初始值,或者如果ThreaLocal重写了initialValue方法,先调用remove;
  • 使用完ThreadLocal对象后,总是调用其remove方法;
  • 使用自定义的线程池。

3.1.4 类成员变量与相应方法

// 初始容量,必须为2的幂
private static final int INITIAL_CAPACITY = 16;

// Entry数组,大小必须为2的幂
private Entry[] table;

// entry的大小
private int size = 0;

// 重新分配entry大小的阈值,默认为0
private int threshold;

// 设置resize阈值以维持最坏2/3的装载因子
private void setThreshold(int len) {
    threshold = len * 2 / 3;
}

// 环形意义的下一个索引
private static int nextIndex(int i, int len) {
    return ((i + 1 < len) ? i + 1 : 0);
}

// 环形意义的上一个索引
private static int prevIndex(int i, int len) {
    return ((i - 1 >= 0) ? i - 1 : len - 1);
}

ThreadLocalMap维护了一个Entry数组,并且要求Entry的大小必须为2的幂,同时记录entry的个数以及下一次需要扩容的阈值。

20220705221528852.png

ThreadLocalMap维护了Entry环形数组,数组中元素Entry的逻辑上的key为某个ThreadLocal对象(实际上是指向该ThreadLocal对象的弱引用),value为代码中该线程往该ThreadLoacl变量实际塞入的值。

3.1.5 构造函数

/**
 * 构造一个包含firstKey和firstValue的map。
 * ThreadLocalMap是惰性构造的,所以只有当至少要往里面放一个元素的时候才会构建它。
 */
ThreadLocalMap(java.lang.ThreadLocal<?> firstKey, Object firstValue) {
    // 初始化table数组
    table = new Entry[INITIAL_CAPACITY];
    // 用firstKey的threadLocalHashCode与初始大小16取模得到哈希值
    int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
    // 初始化该节点
    table[i] = new Entry(firstKey, firstValue);
    // 设置节点表大小为1
    size = 1;
    // 设定扩容阈值
    setThreshold(INITIAL_CAPACITY);
}

这个构造函数在setget的时候都可能会被间接调用以初始化线程的ThreadLocalMap

3.1.6 getEntry方法

这个方法会被ThreadLocalget方法直接调用,用于获取map中某个ThreadLocal存放的值。

private Entry getEntry(ThreadLocal<?> key) {
    // 根据key这个ThreadLocal的ID来获取索引,也即哈希值
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    // 对应的entry存在且未失效且弱引用指向的ThreadLocal就是key,则命中返回
    if (e != null && e.get() == key) {
        return e;
    } else {
        // 因为用的是线性探测,所以往后找还是有可能能够找到目标Entry的。
        return getEntryAfterMiss(key, i, e);
    }
}

// 调用getEntry未直接命中的时候调用此方法
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;
   
    // 基于线性探测法不断向后探测直到遇到空entry。
    while (e != null) {
        ThreadLocal<?> k = e.get();
        // 找到目标
        if (k == key) {
            return e;
        }
        if (k == null) {
            // 该entry对应的ThreadLocal已经被回收,调用expungeStaleEntry来清理无效的entry
            expungeStaleEntry(i);
        } else {
            // 环形意义下往后面走
            i = nextIndex(i, len);
        }
        e = tab[i];
    }
    return null;
}

/**
 * 这个函数是ThreadLocal中核心清理函数,它做的事情很简单:
 * 就是从staleSlot开始遍历,将无效(弱引用指向对象被回收)清理,即对应entry中的value置为null,将指向这个entry的table[i]置为null,直到扫到空entry。
 * 另外,在过程中还会对非空的entry作rehash。
 * 可以说这个函数的作用就是从staleSlot开始清理连续段中的slot(断开强引用,rehash slot等)
 */
private int expungeStaleEntry(int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;

    // 因为entry对应的ThreadLocal已经被回收,value设为null,显式断开强引用
    tab[staleSlot].value = null;
    // 显式设置该entry为null,以便垃圾回收
    tab[staleSlot] = null;
    size--;

    Entry e;
    int i;
    for (i = nextIndex(staleSlot, len); (e = tab[i]) != null; i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();
        // 清理对应ThreadLocal已经被回收的entry
        if (k == null) {
            e.value = null;
            tab[i] = null;
            size--;
        } else {
            /*
             * 对于还没有被回收的情况,需要做一次rehash。
             * 
             * 如果对应的ThreadLocal的ID对len取模出来的索引h不为当前位置i,
             * 则从h向后线性探测到第一个空的slot,把当前的entry给挪过去。
             */
            int h = k.threadLocalHashCode & (len - 1);
            if (h != i) {
                tab[i] = null;
                
                /*
                 * 在原代码的这里有句注释值得一提,原注释如下:
                 *
                 * Unlike Knuth 6.4 Algorithm R, we must scan until
                 * null because multiple entries could have been stale.
                 *
                 * 这段话提及了Knuth高德纳的著作TAOCP(《计算机程序设计艺术》)的6.4章节(散列)
                 * 中的R算法。R算法描述了如何从使用线性探测的散列表中删除一个元素。
                 * R算法维护了一个上次删除元素的index,当在非空连续段中扫到某个entry的哈希值取模后的索引
                 * 还没有遍历到时,会将该entry挪到index那个位置,并更新当前位置为新的index,
                 * 继续向后扫描直到遇到空的entry。
                 *
                 * ThreadLocalMap因为使用了弱引用,所以其实每个slot的状态有三种也即
                 * 有效(value未回收),无效(value已回收),空(entry==null)。
                 * 正是因为ThreadLocalMap的entry有三种状态,所以不能完全套高德纳原书的R算法。
                 *
                 * 因为expungeStaleEntry函数在扫描过程中还会对无效slot清理将之转为空slot,
                 * 如果直接套用R算法,可能会出现具有相同哈希值的entry之间断开(中间有空entry)。
                 */
                while (tab[h] != null) {
                    h = nextIndex(h, len);
                }
                tab[h] = e;
            }
        }
    }
    // 返回staleSlot之后第一个空的slot索引
    return i;
}

ThreadLocal读一个值可能遇到的情况:根据入参threadLocalthreadLocalHashCode对表容量取模得到index,如果index对应的slot就是要读的threadLocal,则直接返回结果。调用getEntryAfterMiss线性探测,过程中每碰到无效slot,调用expungeStaleEntry进行段清理;如果找到了key,则返回结果entry。没有找到key,返回null

3.1.7 set方法

private void set(ThreadLocal<?> key, Object value) {

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len - 1);
    // 线性探测
    for (Entry e = tab[i]; e != null; e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();
        // 找到对应的entry
        if (k == key) {
            e.value = value;
            return;
        }
        // 替换失效的entry
        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold) {
        rehash();
    }
}

private void replaceStaleEntry(ThreadLocal<?> key, Object value,
                               int staleSlot) {
    Entry[] tab = table;
    int len = tab.length;
    Entry e;

    // 向前扫描,查找最前的一个无效slot
    int slotToExpunge = staleSlot;
    for (int i = prevIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = prevIndex(i, len)) {
        if (e.get() == null) {
            slotToExpunge = i;
        }
    }

    // 向后遍历table
    for (int i = nextIndex(staleSlot, len);
         (e = tab[i]) != null;
         i = nextIndex(i, len)) {
        ThreadLocal<?> k = e.get();

        // 找到了key,将其与无效的slot交换
        if (k == key) {
            // 更新对应slot的value值
            e.value = value;

            tab[i] = tab[staleSlot];
            tab[staleSlot] = e;

            /*
             * 如果在整个扫描过程中(包括函数一开始的向前扫描与i之前的向后扫描)
             * 找到了之前的无效slot则以那个位置作为清理的起点,
             * 否则则以当前的i作为清理起点
             */
            if (slotToExpunge == staleSlot) {
                slotToExpunge = i;
            }
            // 从slotToExpunge开始做一次连续段的清理,再做一次启发式清理
            cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
            return;
        }

        // 如果当前的slot已经无效,并且向前扫描过程中没有无效slot,则更新slotToExpunge为当前位置
        if (k == null && slotToExpunge == staleSlot) {
            slotToExpunge = i;
        }
    }

    // 如果key在table中不存在,则在原地放一个即可
    tab[staleSlot].value = null;
    tab[staleSlot] = new Entry(key, value);

    // 在探测过程中如果发现任何无效slot,则做一次清理(连续段清理+启发式清理)
    if (slotToExpunge != staleSlot) {
        cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
    }
}

/**
 * 启发式地清理slot,
 * i对应entry是非无效(指向的ThreadLocal没被回收,或者entry本身为空)
 * n是用于控制控制扫描次数的
 * 正常情况下如果log n次扫描没有发现无效slot,函数就结束了
 * 但是如果发现了无效的slot,将n置为table的长度len,做一次连续段的清理
 * 再从下一个空的slot开始继续扫描
 * 
 * 这个函数有两处地方会被调用,一处是插入的时候可能会被调用,另外个是在替换无效slot的时候可能会被调用,
 * 区别是前者传入的n为元素个数,后者为table的容量
 */
private boolean cleanSomeSlots(int i, int n) {
    boolean removed = false;
    Entry[] tab = table;
    int len = tab.length;
    do {
        // i在任何情况下自己都不会是一个无效slot,所以从下一个开始判断
        i = nextIndex(i, len);
        Entry e = tab[i];
        if (e != null && e.get() == null) {
            // 扩大扫描控制因子
            n = len;
            removed = true;
            // 清理一个连续段
            i = expungeStaleEntry(i);
        }
    } while ((n >>>= 1) != 0);
    return removed;
}

private void rehash() {
    // 做一次全量清理
    expungeStaleEntries();

    /*
     * 因为做了一次清理,所以size很可能会变小。
     * ThreadLocalMap这里的实现是调低阈值来判断是否需要扩容,
     * threshold默认为len*2/3,所以这里的threshold - threshold / 4相当于len/2
     */
    if (size >= threshold - threshold / 4) {
        resize();
    }
}

/*
 * 做一次全量清理
 */
private void expungeStaleEntries() {
    Entry[] tab = table;
    int len = tab.length;
    for (int j = 0; j < len; j++) {
        Entry e = tab[j];
        if (e != null && e.get() == null) {
            /*
             * 个人觉得这里可以取返回值,如果大于j的话取了用,这样也是可行的。
             * 因为expungeStaleEntry执行过程中是把连续段内所有无效slot都清理了一遍了。
             */
            expungeStaleEntry(j);
        }
    }
}

/**
 * 扩容,因为需要保证table的容量len为2的幂,所以扩容即扩大2倍
 */
private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; 
            } else {
                // 线性探测来存放Entry
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null) {
                    h = nextIndex(h, newLen);
                }
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

ThreadLocalset方法可能会有的情况:

探测过程中slot都不无效,并且顺利找到key所在的slot,直接替换即可。探测过程中发现有无效slot,调用replaceStaleEntry,效果是最终一定会把keyvalue放在这个slot,并且会尽可能清理无效slot
replaceStaleEntry过程中,如果找到了key,则做一个swap把它放到那个无效slot中,value置为新值。
replaceStaleEntry过程中,没有找到key,直接在无效slot原地放entry
探测没有发现key,则在连续段末尾的后一个空位置放上entry,这也是线性探测法的一部分。放完后,做一次启发式清理,如果没清理出去key,并且当前table大小已经超过阈值了,则做一次rehashrehash函数会调用一次全量清理slot方法也即expungeStaleEntries,如果完了之后table大小超过了threshold - threshold / 4,则进行扩容2倍。

3.1.8 remove方法

/**
 * 从map中删除ThreadLocal
 */
private void remove(ThreadLocal<?> key) {
    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len - 1);
    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        if (e.get() == key) {
            // 显式断开弱引用
            e.clear();
            // 进行段清理
            expungeStaleEntry(i);
            return;
        }
    }
}

remove方法相对于getEntryset方法比较简单,直接在table中找key,如果找到了,把弱引用断了做一次段清理。

3.2 set方法

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

ThreadLocalset流程:

  1. 获取当前线程t
  2. 返回当前线程t的成员变量ThreadLocalMap(以下简写map);
  3. map不为null,则更新以当前线程为keyThreadLocalMap,否则创建一个ThreadLocalMap,其中当前线程tkey

3.3 get方法

public T get() {
    // 获取当前线程
    Thread t = Thread.currentThread();
    // 返回当前线程t的成员变量ThreadLocalMap
	ThreadLocalMap map = getMap(t);
    if (map != null) {
        // 获取以当前线程为key的ThreadLocalMap的Entry    
		ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
		    // 返回该Entry的value
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

ThreadLocalget流程:

  1. 获取当前线程t
  2. 返回当前线程t的成员变量ThreadLocalMap(以下简写map);
  3. map不为null,则获取以当前线程为keyThreadLocalMapEntry(以下简写e),如果e不为null,则直接返回该Entryvalue
  4. 如果mapnull或者enull,返回setInitialValue()的值。setInitialValue()调用重写的initialValue()返回新值(如果没有重写initialValue将返回默认值null),并将新值存入当前线程的ThreadLocalMap(如果当前线程没有ThreadLocalMap,会先创建一个)。

3.4 remove方法

public void remove() {
    ThreadLocalMap m = getMap(Thread.currentThread());
    if (m != null)
        m.remove(this);
}

直接调用ThreadLocalMapremove方法。

3.5 主要的代码实现

下面代码是ThreadLocal中比较重要的,还是比较容易看懂的,就不在一一细说

public class ThreadLocal<T> {
    
    /**
     * ThreadLocals依赖于附加的每线程线性探测哈希映射到每个线程(Thread.threadLocals和inheritableThreadLocals)。
     * ThreadLocal对象充当键,通过threadLocalHashCode搜索。
     * 这是一个自定义哈希码(仅在ThreadLocalMaps内有用),可以消除哈希冲突
     * 在连续构造ThreadLocals的常见情况下由相同的线程使用,同时保持良好的行为和异常情况的发生。
     */
    private final int threadLocalHashCode = nextHashCode();

    /**
     * 下一个哈希码将被发出,原子更新,从零开始。
     */
    private static AtomicInteger nextHashCode = new AtomicInteger();

    /**
     * 连续生成的散列码之间的区别,将隐式顺序线程本地ID转换为近乎最佳的散布两倍大小的表的乘法散列值。
     */
    private static final int HASH_INCREMENT = 0x61c88647;

    /**
     * Returns the next hash code.
     */
    private static int nextHashCode() {
        return nextHashCode.getAndAdd(HASH_INCREMENT);
    }

    /**
     * 返回此线程局部变量的当前线程的“初始值”。在线程首次访问带有{@link #get}方法的变量时,将调用此方法,
     * 除非线程先前调用了{@link #set}方法,在这种情况下,initialValue方法不会为该线程调用。
     * 通常,每个线程最多调用一次此方法,但在后续调用{@link #remove}后跟{@link #get}时可能会再次调用此方法。
     * <p>这个实现只是返回{@code null};如果程序员希望线程局部变量的初始值不是{@code null},
     * 则必须对子代码{@CodeLocal}进行子类化,并重写此方法。通常,将使用匿名内部类。
     */
    protected T initialValue() {
        return null;
    }

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        return setInitialValue();
    }

    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    public void set(T value) {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
    }

    public void remove() {
        ThreadLocalMap m = getMap(Thread.currentThread());
        if (m != null)
            m.remove(this);
    }
    
    /**
     * SuppliedThreadLocal是JDK8新增的内部类,只是扩展了ThreadLocal的初始化值的方法而已,
     * 允许使用JDK8新增的Lambda表达式赋值。需要注意的是,函数式接口Supplier不允许为null。
     */
    static final class SuppliedThreadLocal<T> extends ThreadLocal<T> {

        private final Supplier<? extends T> supplier;

        SuppliedThreadLocal(Supplier<? extends T> supplier) {
            this.supplier = Objects.requireNonNull(supplier);
        }

        @Override
        protected T initialValue() {
            return supplier.get();
        }
    }

    /**
     * ThreadLocalMap是定制的hashMap,仅用于维护当前线程的本地变量值。
     * 仅ThreadLocal类对其有操作权限,是Thread的私有属性。
     * 为避免占用空间较大或生命周期较长的数据常驻于内存引发一系列问题,
     * hash table的key是弱引用WeakReferences。当空间不足时,会清理未被引用的entry。
     */
    static class ThreadLocalMap {

        static class Entry extends WeakReference<ThreadLocal<?>> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal<?> k, Object v) {
                super(k);
                value = v;
            }
        }

        /**
         * 构造一个新的包含初始映射,ThreadLocal映射的新映射,因此我们只在创建至少一个条目时创建一个。
         */
        ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

        /**
         * 从给定的parentMap构造一个包含所有map的新ThreadLocal。仅由createInheritedMap调用。
         *
         * @param parentMap the map associated with parent thread.
         */
        private ThreadLocalMap(ThreadLocalMap parentMap) {
            Entry[] parentTable = parentMap.table;
            int len = parentTable.length;
            setThreshold(len);
            table = new Entry[len];

            for (int j = 0; j < len; j++) {
                Entry e = parentTable[j];
                if (e != null) {
                    @SuppressWarnings("unchecked")
                    ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
                    if (key != null) {
                        Object value = key.childValue(e.value);
                        Entry c = new Entry(key, value);
                        int h = key.threadLocalHashCode & (len - 1);
                        while (table[h] != null)
                            h = nextIndex(h, len);
                        table[h] = c;
                        size++;
                    }
                }
            }
        }
        //...
    }
}

四、应用场景

ThreadLocal适用于如下两种场景:

  1. 每个线程需要有自己单独的实例
  2. 实例需要在多个方法中共享,但不希望被多线程共享

对于第一点,每个线程拥有自己实例,实现它的方式很多。例如可以在线程内部构建一个单独的实例。ThreadLocal可以以非常方便的形式满足该需求。
对于第二点,可以在满足第一点(每个线程有自己的实例)的条件下,通过方法间引用传递的形式实现。ThreadLocal使得代码耦合度更低,且实现更优雅。

场景

场景1:实现单个线程单例以及单个线程上下文信息存储,比如交易id
场景2:实现线程安全,非线程安全的对象使用ThreadLocal之后就会变得线程安全,因为每个线程都会有一个对应的实例,如:数据库连接、Session管理和Spring的事务等

private static final ThreadLocal threadSession = new ThreadLocal();
 
public static Session getSession() throws InfrastructureException {
    Session s = (Session) threadSession.get();
    try {
        if (s == null) {
            s = getSessionFactory().openSession();
            threadSession.set(s);
        }
    } catch (HibernateException ex) {
        throw new InfrastructureException(ex);
    }
    return s;
}

场景3:承载一些线程相关的数据,避免在方法中来回传递参数(数据跨层传递)

每个线程内需要保存类似于全局变量的信息(例如在拦截器中获取的用户信息),可以让不同方法直接使用,避免参数传递的麻烦却不想被多线程共享(因为不同线程获取到的用户信息不一样)。

例如,用ThreadLocal保存一些业务内容(用户权限信息、从用户系统获取到的用户名、用户ID等),这些信息在同一个线程内相同,但是不同的线程使用的业务内容是不相同的。

在线程生命周期内,都通过这个静态ThreadLocal实例的get()方法取得自己set过的那个对象,避免了将这个对象(如user对象)作为参数传递的麻烦。

比如说我们是一个用户系统,那么当一个请求进来的时候,一个线程会负责执行这个请求,然后这个请求就会依次调用service1()service2()service3(),这几个方法可能是分布在不同的类中的。这个例子和存储session有些像。

public class ThreadLocalTest {
    public static void main(String[] args) {
        User user = new User("jack");
        new Service1().service1(user);
    }
}
 
class Service1 {
    public void service1(User user) {
        //给ThreadLocal赋值,后续的服务直接通过ThreadLocal获取就行了。
        UserContextHolder.holder.set(user);
        new Service2().service2();
    }
}
 
class Service2 {
    public void service2() {
        User user = UserContextHolder.holder.get();
        System.out.println("service2拿到的用户:" + user.name);
        new Service3().service3();
    }
}
 
class Service3 {
    public void service3() {
        User user = UserContextHolder.holder.get();
        System.out.println("service3拿到的用户:" + user.name);
        //在整个流程执行完毕后,一定要执行remove
        UserContextHolder.holder.remove();
    }
}
 
class UserContextHolder {
    //创建ThreadLocal保存User对象
    public static ThreadLocal<User> holder = new ThreadLocal<>();
}
 
class User {
    String name;
    public User(String name) {
        this.name = name;
    }
}
 
//执行的结果:
 
//service2拿到的用户:jack
//service3拿到的用户:jack

场景4Spring使用ThreadLocal解决线程安全问题

我们知道在一般情况下,只有无状态的Bean才可以在多线程环境下共享,在Spring中,绝大部分Bean都可以声明为singleton作用域。就是因为Spring对一些Bean(如RequestContextHolderTransactionSynchronizationManagerLocaleContextHolder等)中非线程安全的“状态性对象”采用ThreadLocal进行封装,让它们也成为线程安全的“状态性对象”,因此有状态的Bean就能够以singleton的方式在多线程中正常工作了。

一般的Web应用划分为展现层、服务层和持久层三个层次,在不同的层中编写对应的逻辑,下层通过接口向上层开放功能调用。在一般情况下,从接收请求到返回响应所经过的所有程序调用都同属于一个线程。

这样用户就可以根据需要,将一些非线程安全的变量以ThreadLocal存放,在同一次请求响应的调用线程中,所有对象所访问的同一ThreadLocal变量都是当前线程所绑定的。

五、总结

  1. 每个Thread维护着⼀个ThreadLocalMap的引⽤;
  2. ThreadLocalMapThreadLocal的内部类,⽤Entry来进⾏存储;
  3. 调⽤ThreadLocalset()⽅法时,实际上就是往ThreadLocalMap设置值,keyThreadLocal对象,value是传递进来的对象;
  4. 调⽤ThreadLocalget()⽅法时,实际上就是往ThreadLocalMap获取值,keyThreadLocal对象;
  5. ThreadLocal本身并不存储值,它只是作为⼀个key来让线程从ThreadLocalMap获取value

六、拓展

6.1 局限性

ThreadLocal为解决多线程程序的并发问题提供了一种新的思路。但是ThreadLocal也有局限性,我们来看看阿里规范:

【参考】ThreadLocal无法解决共享对象的更新问题,ThreadLocal对象建议使用static修饰。这个变量是针对一个线程内所有操作共有的,所以设置为静态变量,所有此类实例共享此静态变量,也就是说在类第一次被使用时装载,只分配一块存储空间,所有此类的对象(只要是这个线程内定义的)都可以操控这个变量。

每个线程往ThreadLocal中读写数据是线程隔离,互相之间不会影响的,所以ThreadLocal无法解决共享对象的更新问题!

由于不需要共享信息,自然就不存在竞争问题了,从而保证了某些情况下线程的安全,以及避免了某些情况需要考虑线程安全必须同步带来的性能损失!

这类场景阿里规范里面也提到了:

【强制】SimpleDateFormat是线程不安全的类,一般不要定义为static变量,如果定义为static,必须加锁,或者使用DateUtils工具类。
正例:注意线程安全,使用DateUtils。亦推荐如下处理:

private static finalThreadLocal<DateFormat> df = new ThreadLocal<DateFormat>() {
    @Override
    protected DateFormatinitialValue() {
        return new SimpleDateFormat("yyyy-MM-dd");
    }
};

说明:如果是JDK8的应用,可以使用instant代替DateLocaldatetime代替CalendarDatetimeformatter代替Simpledateformatter,官方给出的解释:simple beautiful strong immutable thread-safe。

6.2 InheritableThreadLocal

InheritableThreadLocal提供了一种父子线程之间的数据共享机制。它的具体实现是在Thread类中除了threadLocals外还有一个inheritableThreadLocals对象。

public class InheritableThreadLocalTest {
    public static void main(String[] args) {
        Thread parentThread = new Thread(() -> {
            ThreadLocal<Integer> threadLocal = new ThreadLocal<>();
            threadLocal.set(1);
            InheritableThreadLocal<Integer> inheritableThreadLocal = new InheritableThreadLocal<>();
            inheritableThreadLocal.set(2);
            new Thread(() -> {
                System.out.println("threadLocal = " + threadLocal.get());
                System.out.println("inheritableThreadLocal = " + inheritableThreadLocal.get());
            }).start();
        }, "父线程");
        parentThread.start();
    }
}

输出:

threadLocal = null
inheritableThreadLocal = 2

6.3 FastThreadLocal

FastThreadLocal的构造方法中,会为当前FastThreadLocal分配一个index,这个index是由一个全局唯一的static类型的AtomInteger产生的,可以保证每个FastThreadLocalindex都不同。

参考文章

posted @ 2022-04-25 15:32  夏尔_717  阅读(191)  评论(0编辑  收藏  举报