理解ThreadLocal
记得去年学习Spring MVC的时候自己学着写了一个小小的框架,用了一个AppContext来表示应用上下文,每个请求都应该有各自独立的AppContext,里面可以存储一些数据,比如数据库连接Connection等,此时考虑数据库的事务问题,即在一个线程内,一个事务的多个操作拿到的是一个Connection,该如何实现呢?此时就需要使用ThreadLocal来解决。
ThreadLocal介绍
ThreadLocal能干啥?
ThreadLocal是基于线程的一个本地变量的支持类,用户可以将对象与线程绑定,每一个线程都拥有一个自己的对象,例如对于上面的需求来说,可以将AppContext存入到ThreadLocal,代码如下:
public class AppContext {
private static ThreadLocal<AppContext> appContextMap = new ThreadLocal<AppContext>();
private Map<String, Object> objects = new HashMap<String, Object>();
private AppContext() {};
// 部分代码省略
public void clear() {
AppContext context = appContextMap.get();
if (context != null) {
context.objects.clear();
}
context = null;
}
public static AppContext getAppContext() {
AppContext appContext = appContextMap.get();
if (appContext == null) {
appContext = new AppContext();
appContextMap.set(appContext);
}
return appContextMap.get();
}
}
对于数据库的Connection,可以有以下实现
public Class ConnectionManager {
// 创建一个私有静态的并且是与事务相关联的局部线程变量
private static ThreadLocal<Connection> connectionHolder = new ThreadLocal<Connection>;
public static Connection getConnection() {
// 获得线程变量connectionHolder的值conn
Connection conn = connectionHolder.get();
if (conn == null){
// 如果连接为空,则创建连接,另一个工具类,创建连接
conn = DbUtil.getConnection();
// 将局部变量connectionHolder的值设置为conn
connectionHolder.set(conn);
}
return conn;
}
}
ThreadLocal原理分析
ThreadLocal有如下成员变量和方法,如下图所示
其中经常用到的是以下几个方法:
public T get() { }
public void set(T value) { }
public void remove() { }
protected T initialValue() { }
由于ThreadLocal里面需要存值和取值,又需要与线程相关,那么数据存在哪里,用哪种数据结构呢?由于Map可以存储很多类型,这里又不需要对外提供服务,所以这里就用了静态内部类的Map来搞存储,来存储真实的变量实例。
get()流程
那么, ThreadLocal是如何工作的呢?我们先从get方法看起,下面是get方法的源码
public T get() {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return setInitialValue();
}
首先获得当前线程,然后通过getMap(t)方法获取到一个map,map的类型为ThreadLocalMap。接下来根据<key,value>从map中获取Entry,注意这里获取键值对传进去的是this,而不是当前线程t。如果获取成功,则返回value值。如果map为空,则调用setInitialValue方法返回value。
getMap()方法是如何获取到ThreadLocalMap的呢?来看看源码
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
发现是直接获取当前线程的threadLocals成员变量,那么接下来就到Thread类里面去看一下
/* ThreadLocal values pertaining to this thread. This map is maintained
* by the ThreadLocal class. */
ThreadLocal.ThreadLocalMap threadLocals = null;
实际上就是ThreadLocalMap,这个类型是ThreadLocal类的一个内部类,我们来看看ThreadLocalMap内部的Entry类,源码如下:
static class Entry extends WeakReference<ThreadLocal<?>> {
/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
Entry继承自WeakReference,这里弱引用为Map的key,也就是ThreadLocal,弱引用就是只要JVM垃圾回收器发现了它,就会将之回收。
回到get()方法, 如果通过getMap()方法获取的map为空,就会调用setInitialValue() 方法,下面是该方法的源码
private T setInitialValue() {
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
首先调用initialValue() 方法进行初始化value,默认为null,接下来获取当前线程,获取map,判断map是否为空,不为空将ThreadLocal类的对象为key,设定value,为空则创建map,调用createMap(t, value)方法,createMap代码如下:
void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
set(T value)流程
接下来看看set方法如何实现的,下面是源码:
public void set(T value) {
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
首先获取当前线程,然后获取map,判断map是否为空,不为空将ThreadLocal类的对象为key,设定value,为空则创建map,调用createMap(t, value)方法。
至此,我们就可以知道大致知道ThreadLocal的工作流程:
-
Thread类中有一个成员变量属于ThreadLocalMap类(一个定义在ThreadLocal类中的内部类),它是一个Map,它的key是ThreadLocal实例对象。
-
当为ThreadLocal类的对象set值时,首先获得当前线程的ThreadLocalMap类属性,然后以ThreadLocal类的对象为key,设定value。get值时则类似。
一个线程多个ThreadLocal,如何区分?
既然ThreadLocal内部用map存储数据,一个线程可以对应多个ThreadLocal对象,那么这些ThreadLocal对象是如何区分的呢?上面只是大致分析了ThreadLocal的工作原理,并未涉及ThreadLocalMap的存值和取值,接下来我们继续来看源码
/**
* ThreadLocals rely on per-thread linear-probe hash maps attached
* to each thread (Thread.threadLocals and
* inheritableThreadLocals). The ThreadLocal objects act as keys,
* searched via threadLocalHashCode. This is a custom hash code
* (useful only within ThreadLocalMaps) that eliminates collisions
* in the common case where consecutively constructed ThreadLocals
* are used by the same threads, while remaining well-behaved in
* less common cases.
*/
private final int threadLocalHashCode = nextHashCode();
/**
* The next hash code to be given out. Updated atomically. Starts at
* zero.
*/
private static AtomicInteger nextHashCode =
new AtomicInteger();
/**
* The difference between successively generated hash codes - turns
* implicit sequential thread-local IDs into near-optimally spread
* multiplicative hash values for power-of-two-sized tables.
*/
private static final int HASH_INCREMENT = 0x61c88647;
/**
* Returns the next hash code.
*/
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
在ThreadLocal类内部定义了一个final的变量threadLocalHashCode,这个变量是干什么的?看注释,在ThreadLocalMap存储数据时,ThreadLocal对象作为key,通过threadLocalHashCode进行搜索,threadLocalHashCode通过原子类AtomicInteger,提供原子操作,由于nextHashCode为类变量,保证每次生成的hashCode都不一致,每次生成hashCode都会有HASH_INCREMENT的差值。threadLocalHashCode会在ThreadLocalMap中用到,下面继续分析。
前面分析get()流程,对于如何从ThreadLocalMap取数据并未提及,现在看看源码如何实现的:
private Entry getEntry(ThreadLocal<?> key) {
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
if (e != null && e.get() == key)
return e;
else
return getEntryAfterMiss(key, i, e);
}
通过调用ThreadLocalMap的getEntry方法,传入当前ThreadLocal对象,然后获取ThreadLocal的threadLocalHashCode, 然后通过位运算与(&) 将 threadLocalHashCode和ThreadLocal内部存储数据的table的长度减一进行位运算得到i,利用i在table中直接进行搜索。
在ThreadLocalMap如何存值?下面看ThreadLocalMap.set()源码
private void set(ThreadLocal<?> key, Object value) {
// We don't use a fast path as with get() because it is at
// least as common to use set() to create new entries as
// it is to replace existing ones, in which case, a fast
// path would fail more often than not.
Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
if (k == key) {
e.value = value;
return;
}
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
rehash();
}
在ThreadLocalMap.set()方法中,传入当前ThreadLocal对象和要存的值,然后通过位运算与(&) 将 threadLocalHashCode和ThreadLocal内部存储数据的table的长度减一进行位运算得到i,这个i在get()方法已经见过了,完全一样(不一样就出问题啦),接下来开始遍历table,判断有没有相同的key等处理,其实最核心的就是你得去new 一个entry然后设置到table数组中,就是下面这句:
tab[i] = new Entry(key, value);
ThreadLocal会有内存泄露?
看了好多博客,里面提到ThreadLocal会有内存泄露问题,因为从ThreadLocalMap的设计来看,如下图,key被设计成弱引用,一旦JVM进行GC时,这个key就没了,那么与key对应的value还存在ThreadLocalMap,ThreadLocalMap与Entry存在着强引用,GC无法回收,造成内存泄露。
当然,这些都是分析出来的,既然我们考虑到了,那么Josh Bloch 和 Doug Lea肯定也为我们考虑过了,所以这个问题在源码中已经解决了,下面来看看相关源码
在ThreadLocalMap.set()方法中,如果key为null,此时会调用 replaceStaleEntry()方法,在这个方法中进行处理
private void replaceStaleEntry(ThreadLocal<?> key, Object value,
int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
// Back up to check for prior stale entry in current run.
// We clean out whole runs at a time to avoid continual
// incremental rehashing due to garbage collector freeing
// up refs in bunches (i.e., whenever the collector runs).
int slotToExpunge = staleSlot;
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
// Find either the key or trailing null slot of run, whichever
// occurs first
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
// If we find key, then we need to swap it
// with the stale entry to maintain hash table order.
// The newly stale slot, or any other stale slot
// encountered above it, can then be sent to expungeStaleEntry
// to remove or rehash all of the other entries in run.
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
// Start expunge at preceding stale entry if it exists
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
// If we didn't find stale entry on backward scan, the
// first stale entry seen while scanning for key is the
// first still present in the run.
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
其中我们可以看到这段代码:
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
如果key找不到,那么就将value置为null,help GC。这样问题解决。
当然在resize()方法中也有同样的操作,总之都会进行处理的。
最后,我们可以调用remove()方法将相关数据移除,这个肯定就不会有内存泄露啦。
参考:
https://www.cnblogs.com/xzwblog/p/7227509.html#_label0
https://www.jianshu.com/p/ee8c9dccc953
title: 理解ThreadLocal
tags: [ThreadLocal,java]
author: Mingshan
categories: Java
date: 2018-7-1
本文来自博客园,作者:mingshan,转载请注明原文链接:https://www.cnblogs.com/mingshan/p/17793614.html