曹工说JDK源码(2)--ConcurrentHashMap的多线程扩容,说白了,就是分段取任务
前言
先预先说明,我这边jdk的代码版本为1.8.0_11,同时,因为我直接在本地jdk源码上进行了部分修改、调试,所以,导致大家看到的我这边贴的代码,和大家的不太一样。
不过,我对源码进行修改、重构时,会保证和原始代码的功能、逻辑严格一致,更多时候,可能只是修改变量名,方便理解。
大家也知道,jdk代码写得实在是比较深奥,变量名经常都是单字符,i,j,k啥的,实在是很难理解,所以,我一般会根据自己的理解,去重命名,为了减轻我们的头脑负担。
至于怎么去修改代码并调试,可以参考我之前的文章:
曹工力荐:调试 jdk 中 rt.jar 包部分的源码(可自由增加注释,修改代码并debug)
文章中,我改过的代码放在:
https://gitee.com/ckl111/jdk-debug
sizeCtl field的初始化
大家知道,concurrentHashMap底层是数组+链表+红黑树,数组的长度假设为n,在hashmap初始化的时候,这个n除了作为数组长度,还会作为另一个关键field的值。
/**
* Table initialization and resizing control. When negative, the
* table is being initialized or resized: -1 for initialization,
* else -(1 + the number of active resizing threads). Otherwise,
* when table is null, holds the initial table size to use upon
* creation, or 0 for default. After initialization, holds the
* next element count value upon which to resize the table.
*/
private transient volatile int sizeCtl;
该字段非常关键,根据取值不同,有不同的功能。
使用默认构造函数时
public ConcurrentHashMap() {
}
此时,sizeCtl被初始化为0.
使用带初始容量的构造函数时
此时,sizeCtl也是32,和容量一致。
使用另一个map来初始化时
public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
this.sizeCtl = DEFAULT_CAPACITY;
putAll(m);
}
此时,sizeCtl,直接使用了默认值,16.
使用初始容量、负载因子来初始化时
public ConcurrentHashMap(int initialCapacity, float loadFactor) {
this(initialCapacity, loadFactor, 1);
}
这里重载了:
这里,我们传入的负载因子为0.75,这也是默认的负载因子,传入的初始容量为14.
这里面会根据: 1 + 14/0.75 = 19,拿到真正的size,然后根据size,获取到第一个大于19的2的n次方,即32,来作为数组容量,然后sizeCtl也被设置为32.
initTable时,对sizeCtl field的修改
实际上,new一个hashmap的时候,我们并没有创建支撑数组,那,什么时候创建数组呢?是在真正往里面放数据的时候,比如put的时候。
/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
if (key == null || value == null) throw new NullPointerException();
int hash = spread(key.hashCode());
int binCount = 0;
ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO();
vo.setBinCount(0);
for (Node<K,V>[] tab = table;;) {
int tableLength;
// 1
if (tab == null) {
tab = initTable();
continue;
}
...
}
1处,即会去初始化table。
/**
* Initializes table, using the size recorded in sizeCtl.
* 初始化hashmap,使用sizeCtl作为容量
*/
private final Node<K,V>[] initTable() {
Node<K,V>[] tab; int sc;
while ((tab = table) == null || tab.length == 0) {
sc = sizeCtl;
if (sc < 0){
Thread.yield(); // lost initialization race; just spin
continue;
}
/**
* 走到这里,说明sizeCtl大于0,大于0,代表什么,可以去看下其构造函数,此时,sizeCtl表示
* capacity的大小。
* {@link #ConcurrentHashMap(int)}
*
* cas修改为-1,如果成功修改为-1,则表示抢到了锁,可以进行初始化
*
*/
// 1
boolean bGotChanceToInit = U.compareAndSwapInt(this, SIZECTL, sc, -1);
if (bGotChanceToInit) {
try {
tab = table;
/**
* 如果当前表为空,尚未初始化,则进行初始化,分配空间
*/
if (tab == null || tab.length == 0) {
/**
* sc大于0,则以sc为准,否则使用默认的容量
*/
int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
Node<K, V>[] nt = (Node<K, V>[]) new Node<?, ?>[n];
table = tab = nt;
/**
* n >>> 2,无符号右移2位,则是n的四分之一。
* n- n/4,结果为3/4 * n
* 则,这里修改sc为 3/4 * n
* 比如,默认容量为16,则修改sc为12
*/
// 2
sc = n - (n >>> 2);
}
} finally {
/**
* 修改sizeCtl到field
*/
// 3
sizeCtl = sc;
}
break;
}
}
return tab;
}
- 1处,cas修改sizeCtl为-1,成功了的,获得初始化table的权利
- 2处,修改局部变量sc为: n - (n >>> 2),也就是修改为 0.75n,假设此时的数组容量为16,那么sc就是16 * 0.75 = 12.
- 3处,将sc赋值到field: sizeCtl
经过上面的分析,initTable时,这个字段可能有两种取值:
- -1,有线程正在对该table进行初始化
- 0.75*数组长度,此时,已经初始化完成
上面说的是,在put的时候去initTable,实际上,这个initTable,也会在以下函数中被调用,其共同点就是,都是往里面放数据的操作:
扩容时机
上面说了很多,目前,我们知道的是,在initTable后,sizeCtl的值,是旧的数组的长度 * 0.75。
接下来,我们看看扩容时机,在put时,会调用putVal,这个函数的大体步骤:
final V putVal(K key, V value, boolean onlyIfAbsent) {
if (key == null || value == null) throw new NullPointerException();
// 1
int hash = spread(key.hashCode());
int binCount = 0;
System.out.println("binCount:" + binCount);
// 2
ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO();
vo.setBinCount(0);
for (Node<K,V>[] tab = table;;) {
int tableLength;
// 3
if (tab == null) {
tab = initTable();
continue;
}
tableLength = tab.length;
if (tableLength == 0) {
tab = initTable();
continue;
}
int entryNodeHashCode;
// 4
int entryNodeIndex = (tableLength - 1) & hash;
Node<K,V> entryNode = tabAt(tab,entryNodeIndex);
/**
* 5 如果我们要放的桶,还是个空的,则直接cas放进去
*/
if (entryNode == null) {
Node<K, V> node = new Node<>(hash, key, value, null);
// no lock when adding to empty bin
boolean bSuccess = casTabAt(tab, entryNodeIndex, null, node);
if (bSuccess) {
break;
} else {
/**
* 如果没成功,则继续下一轮循环
*/
continue;
}
}
entryNodeHashCode = entryNode.hash;
/**
* 6 如果要放的这个桶,正在迁移,则帮助迁移
*/
if (entryNodeHashCode == MOVED){
tab = helpTransfer(tab, entryNode);
continue;
}
/**
* 7 对entryNode加锁
*/
V oldVal = null;
System.out.println("sync");
synchronized (entryNode) {
/**
* 这一行是判断,在我们执行前面的一堆方法的时候,看看entryNodeIndex处的node是否变化
*/
if (tabAt(tab, entryNodeIndex) != entryNode) {
continue;
}
/**
* 8 hashCode大于0,说明不是处于迁移状态
*/
if (entryNodeHashCode >= 0) {
/**
* 9 链表中找到合适的位置并放入
*/
findPositionAndPut(key, value, onlyIfAbsent, hash, vo, entryNode);
binCount = vo.getBinCount();
oldVal = (V) vo.getOldValue();
}
else if (entryNode instanceof TreeBin) {
...
}
}
System.out.println("binCount:" + binCount);
// 10
if (binCount != 0) {
if (binCount >= TREEIFY_THRESHOLD)
treeifyBin(tab, entryNodeIndex);
if (oldVal != null)
return oldVal;
break;
}
}
// 11
addCount(1L, binCount);
return null;
}
-
1处,计算key的hashcode
-
2处,我这边new了一个对象,里面两个字段:
public class ConcurrentHashMapPutResultVO<V> { int binCount; V oldValue; }
其中,oldValue用来存放,如果put进去的key/value,其中key已经存在的话,一般会直接覆盖之前的旧值,这里主要存放之前的旧值,因为我们需要返回旧值。
binCount,则存放:在找到对应的hash桶之后,在链表中,遍历了多少个元素,该值后面会使用,作为一个标志,当该标志大于0的时候,才去进一步检查,看看是否扩容。
-
3处,如果table为null,说明table里没有任何一个键值对,数组也还没创建,则初始化table
-
4处,根据hashcode,和(数组长度 - 1)相与,计算出应该存放的哈希桶在数组中的索引
-
5处,如果要放的哈希桶,还是空的,则直接cas设置进去,成功则跳出循环,否则重试
-
6处,如果要放的这个桶,该节点的hashcode为MOVED(一个常量,值为-1),说明有其他线程正在扩容该hashmap,则帮助扩容
-
7处,对要存放的hash桶的头节点加锁
-
8处,如果头节点的hashcode大于0,说明是拉了一条链表,则调用子方法(我这边自己抽的),去找到合适的位置并插入到链表
-
9处,findPositionAndPut,在链表中,找到合适的位置,并插入
-
10处,在findPositionAndPut函数中,会返回:为了找到合适的位置,遍历了多少个元素,这个值,就是binCount。
如果这个binCount大于8,则说明遍历了8个元素,则需要转红黑树了。
-
11处,因为我们新增了一个元素,总数自然要加1,这里面会去增加总数,和检查是否需要扩容。
其中,第9步,因为是自己抽的函数,所以这里贴出来给大家看下:
/**
* 遍历链表,找到应该放的位置;如果遍历完了还没找到,则放到最后
* @param key
* @param value
* @param onlyIfAbsent
* @param hash
* @param vo
* @param entryNode
*/
private void findPositionAndPut(K key, V value, boolean onlyIfAbsent, int hash, ConcurrentHashMapPutResultVO vo, Node<K, V> entryNode) {
vo.setBinCount(1);
for (Node<K,V> currentIterateNode = entryNode;
;
vo.setBinCount(vo.getBinCount() + 1)) {
/**
* 如果当前遍历指向的节点的hash值,与参数中的key的hash值相等,则,
* 继续判断
*/
K currentIterateNodeKey = currentIterateNode.key;
boolean bKeyEqualOrNot = Objects.equals(currentIterateNodeKey, key);
/**
* key的hash值相等,且equals比较也相等,则就是我们要找的
*/
if (currentIterateNode.hash == hash && bKeyEqualOrNot) {
/**
* 获取旧的值
*/
vo.setOldValue(currentIterateNode.val);
/**
* 覆盖旧的node的val
*/
if (!onlyIfAbsent)
currentIterateNode.val = value;
// 这里直接break跳出循环
break;
}
/**
* 把当前节点保存起来
*/
Node<K,V> pred = currentIterateNode;
/**
* 获取下一个节点
*/
currentIterateNode = currentIterateNode.next;
/**
* 如果下一个节点为null,说明当前已经是链表的最后一个node了
*/
if ( currentIterateNode == null) {
/**
* 则在当前节点后面,挂上新的节点
*/
pred.next = new Node<K,V>(hash, key,
value, null);
break;
}
}
}
第11步,也是我们要看的重点:
private final void addCount(long delta, int check) {
CounterCell[] counterCellsArray = counterCells;
// 1
long b = baseCount;
// 2
long newBaseCount = b + delta;
/**
* 3 直接cas在baseCount上增加
*/
boolean bSuccess = U.compareAndSwapLong(this, BASECOUNT, b, newBaseCount);
if ( counterCellsArray != null || !bSuccess) {
...
newBaseCount = sumCount();
}
// 4
if (check >= 0) {
while (true) {
Node<K,V>[] tab = table;
Node<K,V>[] nt;
int n = 0;
// 5
int sc = sizeCtl;
// 6
boolean bSumExteedSizeControl = newBaseCount >= (long) sc;
// 7
boolean bContinue = bSumExteedSizeControl && tab != null && (n = tab.length) < MAXIMUM_CAPACITY;
if (bContinue) {
int rs = resizeStamp(n);
if (sc < 0) {
if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
transferIndex <= 0)
break;
if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
transfer(tab, nt);
} else if (U.compareAndSwapInt(this, SIZECTL, sc,
(rs << RESIZE_STAMP_SHIFT) + 2))
// 8
transfer(tab, null);
newBaseCount = sumCount();
} else {
break;
}
}
}
}
-
1处,baseCount是一个field,存储当前hashmap中,有多少个键值对,你put一次,就一个;remove一次,就减一个。
-
2处,b + delta,其中,b就是baseCount,是旧的数量;dalta,我们传入的是1,就是要增加的元素数量
所以,b + delta,得到的,就是经过这次put后,预期的数量
-
3处,直接cas,修改baseCount这个field为 新值,也就是第二步拿到的值。
-
4处,这里检查check是否大于0,check,是第二个形参;这个参数,我们外边怎么传的?
addCount(1L, binCount);
不就是bincount吗,也就是说,这里检查:我们在put过程中,在链表中遍历了几个元素,如果遍历了至少1个元素,这里要进入下面的逻辑:检查是否要扩容,因为,你binCount大于0,说明可能已经开始出现哈希冲突了。
-
5处,取field:sizeCtl的值,给局部变量sc
-
6处,判断当前的新的键值对总数,是否大于sc了;比如容量是16,那么sizeCtl是12,如果此时,hashmap中存放的键值对已经大于等于12了,则要检查是否扩容了
-
7处,几个组合条件,查看是否要扩容,其中,主要的条件就是第6步的那个。
-
8处,调用transfer,进行扩容
总结一下,经过前面的第6处,我们知道,如果存放的键值对总数,已经大于等于0.75*哈希桶(也就是底层数组的长度)的数量了,那么,就基本要扩容了。
扩容的大体过程
扩容也是一个相对复杂的过程,这里只说大概,详细的放下讲。
假设,现在底层数组长度,128,也就是128个哈希桶,当存放的键值对数量,大于等于 128 * 0.75的时候,就会开始扩容,扩容的过程,大概是:
- 申请一个256(容量翻倍)的数组
- 现在有128个桶,相当于,需要对128个桶进行遍历,遍历每个桶拉出去的链表或红黑树,查看每个键值对,是需要放到新数组的什么位置
这个过程,昨天的博文,画了个图,这里再贴一下。
扩容后:
可是,如果我们要一个个去遍历所有哈希桶,然后遍历对应的链表/红黑树,会不会太慢了?完全是单线程工作啊。
换个思路,我们能不能加快点呢?比如,线程1可以去处理数组的 0 -15这16个桶,16- 31这16个桶,完全可以让线程2去做啊,这样的话,不就多线程了吗,不是就快了吗?
没错,jdk就是这么干的。
jdk维护了一个field,这个field,专门用来存当前可以获取的任务的索引,举个例子:
大家看上图就懂了,一开始,这里假设我们有128个桶,每次每个线程,去拿16个桶来处理。
刚开始的时候,field:transferIndex就等于127,也就是最后一个桶的位置,然后我们要从后往前取,那么,127 到112,刚好就是16个桶,所以,申请任务的时候,就会用cas去更新field为112,则表示,自己取到了112 到127这一个区间的hash桶迁移任务。
如果自始至终,只有一个线程呢,它处理完了112 - 127这一批hash桶后,会继续取下一波任务,96 - 112;以此类推。
如果多线程的话呢,也是类似的,反正都是去尝试cas更新transferIndex的值为任务区间的开始下标的值,成功了,就算任务认领成功了。
多线程,怎么知道需要去帮助扩容呢? 发起扩容的线程,在处理完bucket[k]时,会把老的table中的对应的bucket[k]的头节点,修改为下面这种类型的节点:
static final class ForwardingNode<K,V> extends Node<K,V> {
final Node<K,V>[] nextTable;
ForwardingNode(Node<K,V>[] tab) {
super(MOVED, null, null, null);
this.nextTable = tab;
}
}
其他线程,在put或者其他操作时,发现头结点变成了这个,就会去协助扩容了。
多线程扩容,和分段取任务的差别?
我个人感觉,差别不大,多线程扩容,就是多线程去获取自己的那一段任务,然后来完成。我这边写了简单的demo,不过感觉还是很有用的,可以帮助我们理解。
import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.concurrent.*;
import java.util.concurrent.locks.LockSupport;
public class ConcurrentTaskFetch {
/**
* 空闲任务索引,获取任务时,从该下标开始,往前获取。
* 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取
*/
// 0
private volatile int freeTaskIndexForFetch;
// 1
private static final int TASK_COUNT_PER_FETCH = 16;
// 2
private String[] tasks = new String[128];
public static void main(String[] args) {
ConcurrentTaskFetch fetch = new ConcurrentTaskFetch();
// 3
fetch.init();
ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 10, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(100));
executor.prestartAllCoreThreads();
CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
// 4
for (int i = 0; i < 10; i++) {
executor.execute(new Runnable() {
@Override
public void run() {
try {
cyclicBarrier.await();
} catch (InterruptedException | BrokenBarrierException e) {
e.printStackTrace();
}
// 5
FetchedTaskInfo fetchedTaskInfo = fetch.fetchTask();
if (fetchedTaskInfo != null) {
System.out.println("thread:" + Thread.currentThread().getName() + ",get task success:" + fetchedTaskInfo);
try {
TimeUnit.SECONDS.sleep(3);
} catch (InterruptedException e) {
e.printStackTrace();
}
System.out.println("thread:" + Thread.currentThread().getName() + ", process task finished");
}
}
});
}
LockSupport.park();
}
public void init() {
for (int i = 0; i < 128; i++) {
tasks[i] = "task" + i;
}
freeTaskIndexForFetch = tasks.length;
}
// 6
public FetchedTaskInfo fetchTask() {
System.out.println("Thread start fetch task:"+Thread.currentThread().getName()+",time: "+System.currentTimeMillis());
while (true){
// 6.1
if (freeTaskIndexForFetch == 0) {
System.out.println("thread:" + Thread.currentThread().getName() + ",get task failed,there is no task");
return null;
}
/**
* 6.2 获取当前任务的集合的上界
*/
int subTaskListEndIndex = this.freeTaskIndexForFetch;
/**
* 6.3 获取当前任务的集合的下界
*/
int subTaskListStartIndex = subTaskListEndIndex > TASK_COUNT_PER_FETCH ?
subTaskListEndIndex - TASK_COUNT_PER_FETCH : 0;
/**
* 6.4
* 现在,我们拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex)
* 该区间为前开后闭,所以,实际的区间为:
* [subTaskListStartIndex,subTaskListEndIndex - 1]
*/
/**
* 6.5 使用cas,尝试更新{@link freeTaskIndexForFetch} 为 subTaskListStartIndex
*/
if (U.compareAndSwapInt(this, FREE_TASK_INDEX_FOR_FETCH, subTaskListEndIndex, subTaskListStartIndex)) {
// 6.6
FetchedTaskInfo info = new FetchedTaskInfo();
info.setStartIndex(subTaskListStartIndex);
info.setEndIndex(subTaskListEndIndex - 1);
return info;
}
}
}
// Unsafe mechanics
private static final sun.misc.Unsafe U;
private static final long FREE_TASK_INDEX_FOR_FETCH;
static {
try {
// U = sun.misc.Unsafe.getUnsafe();
Field f = Unsafe.class.getDeclaredField("theUnsafe");
f.setAccessible(true);
U = (Unsafe) f.get(null);
Class<?> k = ConcurrentTaskFetch.class;
FREE_TASK_INDEX_FOR_FETCH = U.objectFieldOffset
(k.getDeclaredField("freeTaskIndexForFetch"));
} catch (Exception e) {
throw new Error(e);
}
}
static class FetchedTaskInfo{
int startIndex;
int endIndex;
public int getStartIndex() {
return startIndex;
}
public void setStartIndex(int startIndex) {
this.startIndex = startIndex;
}
public int getEndIndex() {
return endIndex;
}
public void setEndIndex(int endIndex) {
this.endIndex = endIndex;
}
@Override
public String toString() {
return "FetchedTaskInfo{" +
"startIndex=" + startIndex +
", endIndex=" + endIndex +
'}';
}
}
}
-
0处,定义了一个field,类似于前面的transferIndex
/** * 空闲任务索引,获取任务时,从该下标开始,往前获取。 * 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取 */ // 0 private volatile int freeTaskIndexForFetch;
-
1,定义了每次取多少个任务,这里也是16个
private static final int TASK_COUNT_PER_FETCH = 16;
-
2,定义任务列表,共128个任务
-
3,main函数中,进行任务初始化
public void init() { for (int i = 0; i < 128; i++) { tasks[i] = "task" + i; } freeTaskIndexForFetch = tasks.length; }
主要初始化任务列表,其次,将freeTaskIndexForFetch 赋值为128,后续取任务,从这个下标开始
-
4处,启动10个线程,每个线程去执行取任务,按理说,我们128个任务,每个线程取16个,只能有8个线程取到任务,2个线程取不到
-
5处,线程逻辑里,去获取任务
-
6处,获取任务的方法定义
-
6.1 ,如果可获取的任务索引为0了,说明没任务了,直接返回
-
6.2,获取当前任务的集合的上界
-
6.3,获取当前任务的集合的下界,减去16就行了
-
6.4,拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex)
-
6.5, 使用cas,更新field为:6.4中的任务下界。
执行效果演示:
可以看到,8个线程取到任务,2个线程没取到。
该思想在内存分配时的应用
其实jvm内存分配时,也是类似的思路,比如,设置堆内存为200m,那这200m是启动时立马从操作系统分配了的。
接下来,就是每次new对象的时候,去这个大内存里,找个小空间,这个过程,也是需要cas去竞争的,比如肯定也有个全局的字段,来表示当前可用内存的索引,比如该索引为100,表示,第100个字节后的空间是可以用的,那我要new个对象,这个对象有3个字段,需要大概30个字节,那我是不是需要把这个索引更新为130。
这中间是多线程的,所以也是要cas操作。
道理都是类似的。
总结
时间仓促,有问题在所难免,欢迎及时指出或加群讨论。