43、ThreadLocal
对于无锁编程,我们已经讲解了:CAS、原子类、累加器,本节我们讲解无锁编程中的最后一个知识点:ThreadLocal
我们知道,共享变量是代码存在线程安全的根本原因之一
在某些特殊的业务场景下,我们可以使用 ThreadLocal 线程局部变量替代共享变量,以实现在不需要加锁的情况下达到线程安全
1、基本用法
在 Java 中,我们可以将变量粗略的划分为两类:类的成员变量、函数内局部变量
- 对于类的成员变量:当多个线程使用同一个对象时,对象中的成员变量就是共享变量,其作用域范围为多个线程均可见,多个线程竞态访问成员变量,就有可能存在线程安全问题
- 对于函数内局部变量:每个线程在执行函数时,会在自己的栈上创建私有的局部变量,因此函数内局部变量的作用域范围为单线程内可见
不仅如此,函数内局部变量仅限函数内可见,不同的函数之间不可以共享局部变量,多个函数共享局部变量,需要通过参数传递的方式来实现,我们举例解释一下
1.1、示例
假设我们现在有这样一个需求:在一个标准的 Controller-Service-Repository 三层结构的后端系统中,我们希望实现一个简单的调用链追踪功能
每个接口请求所对应的所有日志都附带同一个 traceId,这样我们通过 traceId 便可以轻松得到一个接口请求的所有日志,方便通过日志查找代码问题
以下是调用链追踪功能的一种实现方式,traceId 定义为函数内的局部变量,需要通过参数传递给被调用函数使用
public class UserController { private static final Logger logger = LoggerFactory.getLogger(UserController.class); private UserService userService = new UserService(); public long login(String username, String password) { // 创建 traceId String traceId = "[" + System.currentTimeMillis() + "]"; // 所有的日志都带有 traceId logger.trace(traceId + " username=" + username); // ... 省略校验逻辑 ... return userService.login(username, password, traceId); // 传递 traceId } }
上述代码存在的问题很明显,我们需要在每个函数中都定义 traceId 参数,导致非业务代码和业务代码耦合在一起
1.2、ThreadLocal 介绍
为了解决这个问题,JUC 便提供了 ThreadLocal,其作用域范围介于类的成员变量和函数内局部变量之间
它既是线程私有的,又可以在函数之间共享,这样既可以避免线程安全问题,又能避免变量在函数之间不停传递
ThreadLocal 的提供的函数如下所示
- set() 函数用来将变量值存储在当前线程中
- get() 函数用来从当前线程中取出变量值
- remove() 函数用来从当前线程删除变量
- initialValue() 函数是一个 protected 方法,可以在子类中重新实现,用于提供变量的初始值
public class ThreadLocal<T> { protected T initialValue() { return null; } public T get(); public void set(T value); public void remove(); }
1.3、解决
我们使用 ThreadLocal 重新实现调用链追踪功能,对应的代码如下所示
我们实现了一个匿名类,继承自 ThreadLocal,重写了 initialValue() 来提供 threadLocalTraceId 的初始值
如果在调用 get() 函数之前没有调用 set() 函数设置 threadLocalTraceId 的值,ThreadLocal 会调用我们提供的 initialValue() 函数,使用其返回值初始化 threadLocalTraceId
对于详细的处理逻辑,我们待会结合源码来讲解
public class Context { private static final ThreadLocal<String> threadLocalTraceId = new ThreadLocal<String>() { @Override protected String initialValue() { return "[" + System.currentTimeMillis() + "]"; } }; public static void setTraceId(String traceId) { threadLocalTraceId.set(traceId); } public static String getTraceId() { return threadLocalTraceId.get(); } public static void remove() { threadLocalTraceId.remove(); } }
public class UserController { private static final Logger logger = LoggerFactory.getLogger(UserController.class); private UserService userService = new UserService(); public long login(String username, String password) { // 所有的日志都带有 traceId logger.trace(Context.getTraceId() + " username=" + username); // ... 省略校验逻辑 ... return userService.login(username, password); // 通过 Context 传递 traceId } }
2、实现原理
很多人对 ThreadLocal 的用法都感到奇怪,往往会有这样的疑问
使用 ThreadLocal 定义的一个变量怎么存储多个线程的数据?或者说在类中的定义的 ThreadLocal 变量如何 "分身" 到多个线程中使用的?
有这样疑问的主要原因是对 ThreadLocal 的底层实现原理没有了解,为了回答以上疑问,接下来我们就结合源码来了解一下 ThreadLocal 的底层实现原理
2.1、ThreadLocal 源码
ThreadLocal 的代码结构如下所示,从如下代码中,我们可以发现,ThreadLocal 类只定义了读写数据的方法,并没有定义任何成员变量来存储数据
那么 set() 函数写入的数据存储在哪里呢?get() 函数又是从哪里读取数据的呢?
public class ThreadLocal<T> { public ThreadLocal() { } protected T initialValue() { return null; } public T get() { // ... } public void set(T value) { // ... } public void remove() { // ... } public static class ThreadLocalMap { // ... } } public class Thread implements Runnable { // ... ThreadLocal.ThreadLocalMap threadLocals = null; // ... }
2.2、图示
要回答以上问题,我们需要了解 ThreadLocal 的数据存储结构,如下图所示
从图中我们可以发现,实际上数据是存储在线程对应的 Thread 对象的 threadLocals 成员变量中的
threadLocals 成员变量的类型为 ThreadLocalMap,ThreadLocalMap 是 ThreadLocal 的内部类
ThreadLocalMap 类似 HashMap,也是用来存储键值对,其中键为 ThreadLocal 对象,值为 Object 对象
接下来,我们依次讲解一下 ThreadLocal 中的 set()、get()、remove() 三个函数的底层实现原理
2.3、set()
set() 函数的源码如下所示,对 ThreadLocal 的存储结构有了了解之后,我们便很容易看懂 set() 函数的代码逻辑
对于 set() 函数的代码逻辑,我在代码中添加了详细的注释,这里就不再赘述了
public void set(T value) { Thread t = Thread.currentThread(); // 获取当前线程对应的 Thread 对象 ThreadLocalMap map = getMap(t); // 获取 Thread 对象的 threadLocals 成员变量 if (map != null) map.set(this, value); // threadLocals 不为空, 则添加键值对 else createMap(t, value); // threadLocals 为空, 则先创建再添加 } ThreadLocalMap getMap(Thread t) { return t.threadLocals; } void createMap(Thread t, T firstValue) { t.threadLocals = new ThreadLocalMap(this, firstValue); }
2.4、get()
get() 函数的源码如下所示,get() 函数的逻辑也比较简单,我们先获取当前线程的 threadLocals 成员变量
- 如果 threadLocals 不为 null
那么我们在 threadLocals 中查找当前操作的 ThreadLocal 变量对应的数据值,如果查找到对应的数据值,则直接返回 - 如果 threadLocals 为 null,或者没有查找 threadLocal 变量对应的数据值
则调用 initialValue() 方法获取到 threadLocal 变量的初始值,创建 threadLocals 并添加键值对(threadLocal 变量和初始值)
public T get() { Thread t = Thread.currentThread(); // 获取当前线程对应的 Thread 对象 ThreadLocalMap map = getMap(t); // 获取 Thread 对象的 threadLocals 成员变量 if (map != null) { ThreadLocalMap.Entry e = map.getEntry(this); // this 是 ThreadLocal 变量 if (e != null) { @SuppressWarnings("unchecked") T result = (T) e.value; return result; // 获取到对应的数据值 } } // 要么 map 为 null, 要么没有获取到对应的数据值, 执行初始化操作 return setInitialValue(); } private T setInitialValue() { T value = initialValue(); // 默认返回 null, 但可重新实现此函数, 见上述 Context 示例 Thread t = Thread.currentThread(); ThreadLocalMap map = getMap(t); if (map != null) map.set(this, value); // 添加键值对 else createMap(t, value); // 创建 threadLocals return value; }
2.5、remove()
remove() 函数的源码如下所示,相对来说更加简单,对其代码逻辑,我们就不再赘述了
public void remove() { ThreadLocalMap m = getMap(Thread.currentThread()); if (m != null) m.remove(this); }
3、应用场景
前面我们讲到的 ReentrantReadWriteLock,其底层实现也用到了 ThreadLocal
我们一块回忆一下,对于 ReentrantReadWriteLock,AQS 中的 state 同时存储写锁和读锁的加锁情况
- state 的低 16 位存储写锁的加锁情况:值为 0 表示没有加写锁,值为 1 表示已加写锁,值大于 1 表示写锁的重入次数
- state 的高 16 位存储读锁的加锁情况:值为 0 表示没有加读锁,值为 1 表示已加读锁
不过值大于 1 并不表示读锁的重入次数,而是表示读锁总共被获取了多少次(每个线程对读锁重入的次数相加),此值用来最终解锁读锁 - 而每个线程对读锁的重入次数是有用信息,只有重入次数大于 0 时,线程才可以继续重入,那么重入次数在哪里记录呢?
因为重入次数是跟每个线程相关的数据,所以我们就可以使用 ThreadLocal 变量来存储它,对应的源码如下所示
// 以下代码位于 ReentrantReadWriteLock.java 中 static final class HoldCounter { int count = 0; final long tid = getThreadId(Thread.currentThread()); } static final class ThreadLocalHoldCounter extends ThreadLocal<HoldCounter> { public HoldCounter initialValue() { return new HoldCounter(); } } private transient ThreadLocalHoldCounter readHolds;
4、课后思考题
在本节的示例中,我们使用时间戳来生成 traceId,如果我们更改 traceId 的生成方式,使用自增的方式生成 traceId,如下代码所示
那么请问下面的代码是否是线程安全的?如果不是,应该如何修改才能保证其线程安全性?
public class Context { private static int id = 0; private static final ThreadLocal<String> threadLocalTraceId = new ThreadLocal<String>() { @Override protected String initialValue() { id++; return String.valueOf(id); } }; public static void setTraceId(String traceId) { threadLocalTraceId.set(traceId); } public static String getTraceId() { return threadLocalTraceId.get(); } public static void remove() { threadLocalTraceId.remove(); } }
代码非线程安全,因为当两个线程同时调用 Context 上的 getTraceId() 函数时,两个线程会调用 initialValue() 方法初始化 ThreadLocal 的 value 值
而两个线程竞争对共享资源 id 进行非原子的 id++ 自增操作,存在线程安全问题
本文来自博客园,作者:lidongdongdong~,转载请注明原文链接:https://www.cnblogs.com/lidong422339/p/17490839.html
【推荐】国内首个AI IDE,深度理解中文开发场景,立即下载体验Trae
【推荐】编程新体验,更懂你的AI,立即体验豆包MarsCode编程助手
【推荐】抖音旗下AI助手豆包,你的智能百科全书,全免费不限次数
【推荐】轻量又高性能的 SSH 工具 IShell:AI 加持,快人一步