只是不愿随波逐流 ...|

lidongdongdong~

园龄:2年7个月粉丝:14关注:8

43、ThreadLocal

内容来自王争 Java 编程之美

对于无锁编程,我们已经讲解了:CAS、原子类、累加器,本节我们讲解无锁编程中的最后一个知识点:ThreadLocal

我们知道,共享变量是代码存在线程安全的根本原因之一
在某些特殊的业务场景下,我们可以使用 ThreadLocal 线程局部变量替代共享变量,以实现在不需要加锁的情况下达到线程安全

1、基本用法

在 Java 中,我们可以将变量粗略的划分为两类:类的成员变量、函数内局部变量

  • 对于类的成员变量:当多个线程使用同一个对象时,对象中的成员变量就是共享变量,其作用域范围为多个线程均可见,多个线程竞态访问成员变量,就有可能存在线程安全问题
  • 对于函数内局部变量:每个线程在执行函数时,会在自己的栈上创建私有的局部变量,因此函数内局部变量的作用域范围为单线程内可见
    不仅如此,函数内局部变量仅限函数内可见,不同的函数之间不可以共享局部变量,多个函数共享局部变量,需要通过参数传递的方式来实现,我们举例解释一下

1.1、示例

假设我们现在有这样一个需求:在一个标准的 Controller-Service-Repository 三层结构的后端系统中,我们希望实现一个简单的调用链追踪功能
每个接口请求所对应的所有日志都附带同一个 traceId,这样我们通过 traceId 便可以轻松得到一个接口请求的所有日志,方便通过日志查找代码问题
以下是调用链追踪功能的一种实现方式,traceId 定义为函数内的局部变量,需要通过参数传递给被调用函数使用

public class UserController {
private static final Logger logger = LoggerFactory.getLogger(UserController.class);
private UserService userService = new UserService();
public long login(String username, String password) {
// 创建 traceId
String traceId = "[" + System.currentTimeMillis() + "]";
// 所有的日志都带有 traceId
logger.trace(traceId + " username=" + username);
// ... 省略校验逻辑 ...
return userService.login(username, password, traceId); // 传递 traceId
}
}

上述代码存在的问题很明显,我们需要在每个函数中都定义 traceId 参数,导致非业务代码和业务代码耦合在一起

1.2、ThreadLocal 介绍

为了解决这个问题,JUC 便提供了 ThreadLocal,其作用域范围介于类的成员变量和函数内局部变量之间
它既是线程私有的,又可以在函数之间共享,这样既可以避免线程安全问题,又能避免变量在函数之间不停传递

ThreadLocal 的提供的函数如下所示

  • set() 函数用来将变量值存储在当前线程中
  • get() 函数用来从当前线程中取出变量值
  • remove() 函数用来从当前线程删除变量
  • initialValue() 函数是一个 protected 方法,可以在子类中重新实现,用于提供变量的初始值
public class ThreadLocal<T> {
protected T initialValue() {
return null;
}
public T get();
public void set(T value);
public void remove();
}

1.3、解决

我们使用 ThreadLocal 重新实现调用链追踪功能,对应的代码如下所示
我们实现了一个匿名类,继承自 ThreadLocal,重写了 initialValue() 来提供 threadLocalTraceId 的初始值
如果在调用 get() 函数之前没有调用 set() 函数设置 threadLocalTraceId 的值,ThreadLocal 会调用我们提供的 initialValue() 函数,使用其返回值初始化 threadLocalTraceId
对于详细的处理逻辑,我们待会结合源码来讲解

public class Context {
private static final ThreadLocal<String> threadLocalTraceId =
new ThreadLocal<String>() {
@Override
protected String initialValue() {
return "[" + System.currentTimeMillis() + "]";
}
};
public static void setTraceId(String traceId) {
threadLocalTraceId.set(traceId);
}
public static String getTraceId() {
return threadLocalTraceId.get();
}
public static void remove() {
threadLocalTraceId.remove();
}
}
public class UserController {
private static final Logger logger = LoggerFactory.getLogger(UserController.class);
private UserService userService = new UserService();
public long login(String username, String password) {
// 所有的日志都带有 traceId
logger.trace(Context.getTraceId() + " username=" + username);
// ... 省略校验逻辑 ...
return userService.login(username, password); // 通过 Context 传递 traceId
}
}

2、实现原理

很多人对 ThreadLocal 的用法都感到奇怪,往往会有这样的疑问
使用 ThreadLocal 定义的一个变量怎么存储多个线程的数据?或者说在类中的定义的 ThreadLocal 变量如何 "分身" 到多个线程中使用的?
有这样疑问的主要原因是对 ThreadLocal 的底层实现原理没有了解,为了回答以上疑问,接下来我们就结合源码来了解一下 ThreadLocal 的底层实现原理

2.1、ThreadLocal 源码

ThreadLocal 的代码结构如下所示,从如下代码中,我们可以发现,ThreadLocal 类只定义了读写数据的方法,并没有定义任何成员变量来存储数据
那么 set() 函数写入的数据存储在哪里呢?get() 函数又是从哪里读取数据的呢?

public class ThreadLocal<T> {
public ThreadLocal() {
}
protected T initialValue() {
return null;
}
public T get() {
// ...
}
public void set(T value) {
// ...
}
public void remove() {
// ...
}
public static class ThreadLocalMap {
// ...
}
}
public class Thread implements Runnable {
// ...
ThreadLocal.ThreadLocalMap threadLocals = null;
// ...
}

2.2、图示

要回答以上问题,我们需要了解 ThreadLocal 的数据存储结构,如下图所示
从图中我们可以发现,实际上数据是存储在线程对应的 Thread 对象的 threadLocals 成员变量中的
threadLocals 成员变量的类型为 ThreadLocalMap,ThreadLocalMap 是 ThreadLocal 的内部类
ThreadLocalMap 类似 HashMap,也是用来存储键值对,其中键为 ThreadLocal 对象,值为 Object 对象
image

接下来,我们依次讲解一下 ThreadLocal 中的 set()、get()、remove() 三个函数的底层实现原理

2.3、set()

set() 函数的源码如下所示,对 ThreadLocal 的存储结构有了了解之后,我们便很容易看懂 set() 函数的代码逻辑
对于 set() 函数的代码逻辑,我在代码中添加了详细的注释,这里就不再赘述了

public void set(T value) {
Thread t = Thread.currentThread(); // 获取当前线程对应的 Thread 对象
ThreadLocalMap map = getMap(t); // 获取 Thread 对象的 threadLocals 成员变量
if (map != null) map.set(this, value); // threadLocals 不为空, 则添加键值对
else createMap(t, value); // threadLocals 为空, 则先创建再添加
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

2.4、get()

get() 函数的源码如下所示,get() 函数的逻辑也比较简单,我们先获取当前线程的 threadLocals 成员变量

  • 如果 threadLocals 不为 null
    那么我们在 threadLocals 中查找当前操作的 ThreadLocal 变量对应的数据值,如果查找到对应的数据值,则直接返回
  • 如果 threadLocals 为 null,或者没有查找 threadLocal 变量对应的数据值
    则调用 initialValue() 方法获取到 threadLocal 变量的初始值,创建 threadLocals 并添加键值对(threadLocal 变量和初始值)
public T get() {
Thread t = Thread.currentThread(); // 获取当前线程对应的 Thread 对象
ThreadLocalMap map = getMap(t); // 获取 Thread 对象的 threadLocals 成员变量
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this); // this 是 ThreadLocal 变量
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T) e.value;
return result; // 获取到对应的数据值
}
}
// 要么 map 为 null, 要么没有获取到对应的数据值, 执行初始化操作
return setInitialValue();
}
private T setInitialValue() {
T value = initialValue(); // 默认返回 null, 但可重新实现此函数, 见上述 Context 示例
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) map.set(this, value); // 添加键值对
else createMap(t, value); // 创建 threadLocals
return value;
}

2.5、remove()

remove() 函数的源码如下所示,相对来说更加简单,对其代码逻辑,我们就不再赘述了

public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null) m.remove(this);
}

3、应用场景

前面我们讲到的 ReentrantReadWriteLock,其底层实现也用到了 ThreadLocal
我们一块回忆一下,对于 ReentrantReadWriteLock,AQS 中的 state 同时存储写锁和读锁的加锁情况

  • state 的低 16 位存储写锁的加锁情况:值为 0 表示没有加写锁,值为 1 表示已加写锁,值大于 1 表示写锁的重入次数
  • state 的高 16 位存储读锁的加锁情况:值为 0 表示没有加读锁,值为 1 表示已加读锁
    不过值大于 1 并不表示读锁的重入次数,而是表示读锁总共被获取了多少次(每个线程对读锁重入的次数相加),此值用来最终解锁读锁
  • 而每个线程对读锁的重入次数是有用信息,只有重入次数大于 0 时,线程才可以继续重入,那么重入次数在哪里记录呢?
    因为重入次数是跟每个线程相关的数据,所以我们就可以使用 ThreadLocal 变量来存储它,对应的源码如下所示
// 以下代码位于 ReentrantReadWriteLock.java 中
static final class HoldCounter {
int count = 0;
final long tid = getThreadId(Thread.currentThread());
}
static final class ThreadLocalHoldCounter extends ThreadLocal<HoldCounter> {
public HoldCounter initialValue() {
return new HoldCounter();
}
}
private transient ThreadLocalHoldCounter readHolds;

4、课后思考题

在本节的示例中,我们使用时间戳来生成 traceId,如果我们更改 traceId 的生成方式,使用自增的方式生成 traceId,如下代码所示
那么请问下面的代码是否是线程安全的?如果不是,应该如何修改才能保证其线程安全性?

public class Context {
private static int id = 0;
private static final ThreadLocal<String> threadLocalTraceId =
new ThreadLocal<String>() {
@Override
protected String initialValue() {
id++;
return String.valueOf(id);
}
};
public static void setTraceId(String traceId) {
threadLocalTraceId.set(traceId);
}
public static String getTraceId() {
return threadLocalTraceId.get();
}
public static void remove() {
threadLocalTraceId.remove();
}
}

代码非线程安全,因为当两个线程同时调用 Context 上的 getTraceId() 函数时,两个线程会调用 initialValue() 方法初始化 ThreadLocal 的 value 值
而两个线程竞争对共享资源 id 进行非原子的 id++ 自增操作,存在线程安全问题

posted @   lidongdongdong~  阅读(34)  评论(0编辑  收藏  举报
点击右上角即可分享
微信分享提示
评论
收藏
关注
推荐
深色
回顶
展开