聊聊ThreadLocal

      首先说说与ThreadLocal相识的背景。在项目中,service有一些逻辑处理{如对主从延迟敏感的下单逻辑}需要强制走主库的。就查了查公司数据库框架zebra强制走主库的方式,发现其主要就是在业务线程的context中写入了一个ThreadLocal<Boolean>变量,当要进行SQL路由时,根据此变量来判断是否需要强制走主库。

首先,如上图所示,Java内存模型JMM如上图所示,每个Java通过自己的缓存与主内存相连接。当需要读写某对象或变量时,需要到主内存copy到线程内,操作完再写回主内存。那么,这些非原子性的操作就很容易引起数据不一致的问题。ThreadLocal就是用解决此问题的,顾名思义,ThreadLocal就是线程局部变量的意思,其使用方式与调用普通POJO数据对象无异,即使用get/set来进行赋值或读取。然而其在存储上,却隐含了线程私有存储之意,线程之间相互隔离,互不影响,从而实现线程安全。先来个使用样例,直接给出zebra的强制走主库的工具类实现:

public class LocalContextReadWriteStrategy implements ReadWriteStrategy {

	private static InheritableThreadLocal<Boolean> forceMaster = new InheritableThreadLocal<Boolean>();

	@Override
	public boolean shouldReadFromMaster() {
		Boolean shouldReadFromMaster = forceMaster.get();

		if (shouldReadFromMaster == null || !shouldReadFromMaster) {
			return false;
		} else {
			return true;
		}
	}

	protected static void setReadFromMaster() {
		forceMaster.set(true);
	}

	protected static void clearContext() {
		forceMaster.remove();
	}

	@Override
	public void setGroupDataSourceConfig(GroupDataSourceConfig config) {
	}
}

如上,通过一个ThreadLocal<Boolean>型的线程局部变量forceMaster,通过set()来为其赋值,get()来取值,remove()来实现clear。接下来就来看看这三个方法的实现。ThreadLocal.set()方法实现如下:

    public void set(T value) {
        Thread t = Thread.currentThread();  // 获取当前线程
        ThreadLocalMap map = getMap(t);     // 获取线程内部藏着的map
        if (map != null)
            map.set(this, value);           // 以ThreadLocal对象自身为key,将value存放到map中
        else
            createMap(t, value);
    }
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }
    void createMap(Thread t, T firstValue) {
        t.threadLocals = new ThreadLocalMap(this, firstValue);
    }

 从上面的set方法可以看出:

1. Thread内部有了隐藏的ThreadLocalMap负责存放该线程的ThreadLocal变量,这个是ThreadLocal自定义的map,与HashMap有较大的区别,后续再讨论

2. 第一次调用set的时候,才会创建并初始化这个ThreadLocalMap

3. ThreadLocalMap是以ThreadLocal对象自身作为key的

 ThreadLocal的get方法实现如下:

    public T get() {
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null)
                return (T)e.value;
        }
        return setInitialValue();
    }

ThreadLocal的remove方法如下:

     public void remove() {
         ThreadLocalMap m = getMap(Thread.currentThread());
         if (m != null)
             m.remove(this);
     }

其实,通过上面的几个方法,已经可以看出来了,对于ThreadLocal变量的存取,关键就在于这个ThreadLocalMap了。set时无非是以ThreadLocal对象自身为key将value存放到map中,而get时,就是以ThreadLocal作为key去map中查找。接下来,就重点瞧瞧这个Thread内部偷偷藏着的自定义map--ThreadLocalMap。

ThreadLocalMap是定义在ThreadLocal中的一个静态类,主要作为装载线程局部变量的容器。先来看看其map定义的节点Entry:

        static class Entry extends WeakReference<ThreadLocal> {
            /** The value associated with this ThreadLocal. */
            Object value;

            Entry(ThreadLocal k, Object v) {
                super(k);
                value = v;
            }
        }

很有意思,Entry中只定义了一个value,没再定义key,因为key就是ThreadLocal自身。entry.get()拿到key,即threadLocal,entry.value拿到value。此外,通过super(key),将entry对key的引用定义为了弱引用。这里就要回忆下Java中的强、软、弱、虚四种引用了。仔细思量下就可以理解了,Entry对于Key[即ThreadLocal对象背身]的引用被定义为虚引用,主要是为了能让ThreadLocal在除了Entry之外再无引用时,被顺利的回收掉。因为Thread对于map的引用,map经底层Entry数据对entry的引用都是强引用,这时候,要是entry对ThreadLocal的引用也是强引用的话,那这个ThreadLocal就无法被回收了,也就造成内存泄漏了。

再来看看第一次调用ThreadLocal.set()时创建map的constructor:

        ThreadLocalMap(ThreadLocal firstKey, Object firstValue) {
            table = new Entry[INITIAL_CAPACITY];
            int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
            table[i] = new Entry(firstKey, firstValue);
            size = 1;
            setThreshold(INITIAL_CAPACITY);
        }

这里面就涉及到ThreadLocal的hashcode的计算了:

private static final int HASH_INCREMENT = 0x61c88647;

private final int threadLocalHashCode = nextHashCode();

private static AtomicInteger nextHashCode = new AtomicInteger();

private static int nextHashCode() {
    return nextHashCode.getAndAdd(HASH_INCREMENT);
}

上面的几行代码就已经完成了一个ThreadLocal对象的哈希值的计算了:

1. hash值在thread对象创建时就已经计算好了,后续不再发生变化

2. 以神奇的魔数 0x61c88647 作为递增量,来计算新ThreadLocal对象的hash值(这个魔数在各类SDK中经常被用到,原因是据此递增出来的hash值在2的N次方的长度的链表中分布非常均匀)

再来看看ThreadLocalMap的set方法:

private void set(ThreadLocal key, Object value) {
            Entry[] tab = table;
            int len = tab.length;
            int i = key.threadLocalHashCode & (len-1);

            for (Entry e = tab[i];e != null;e = tab[i = nextIndex(i, len)]) {
                ThreadLocal k = e.get();

                if (k == key) {
                    e.value = value;
                    return;
                }

                if (k == null) {
                    replaceStaleEntry(key, value, i);
                    return;
                }
            }

            tab[i] = new Entry(key, value);
            int sz = ++size;
            if (!cleanSomeSlots(i, sz) && sz >= threshold)
                rehash();
        }
        private static int nextIndex(int i, int len) {
            return ((i + 1 < len) ? i + 1 : 0);
        }  

上面的set方法有几个有意思的地方:

1. hash冲突,其使用了线性探测法而非HashMap中的链表法来解决hash冲突

2. 判空,前面说了,Entry对于key的引用是个弱引用,因此当拿出来Entry之后发现key=null了,说明作为key的ThreadLocal已经被回收了,这个节点也就是个过时的节点了,需要被处理掉了

3. 在Java1.3中,存放ThreadLocal的那个map,使用了HashMap来实现,后来为了性能,变更为自定义map 

至此,已然有所得了~

posted @ 2017-03-05 16:43  Mr.do  阅读(341)  评论(0编辑  收藏  举报