曹工说JDK源码(2)--ConcurrentHashMap的多线程扩容,说白了,就是分段取任务

前言

先预先说明,我这边jdk的代码版本为1.8.0_11,同时,因为我直接在本地jdk源码上进行了部分修改、调试,所以,导致大家看到的我这边贴的代码,和大家的不太一样。

不过,我对源码进行修改、重构时,会保证和原始代码的功能、逻辑严格一致,更多时候,可能只是修改变量名,方便理解。

大家也知道,jdk代码写得实在是比较深奥,变量名经常都是单字符,i,j,k啥的,实在是很难理解,所以,我一般会根据自己的理解,去重命名,为了减轻我们的头脑负担。

至于怎么去修改代码并调试,可以参考我之前的文章:

曹工力荐:调试 jdk 中 rt.jar 包部分的源码(可自由增加注释,修改代码并debug)

文章中,我改过的代码放在:

https://gitee.com/ckl111/jdk-debug

sizeCtl field的初始化

大家知道,concurrentHashMap底层是数组+链表+红黑树,数组的长度假设为n,在hashmap初始化的时候,这个n除了作为数组长度,还会作为另一个关键field的值。

    /**
     * Table initialization and resizing control.  When negative, the
     * table is being initialized or resized: -1 for initialization,
     * else -(1 + the number of active resizing threads).  Otherwise,
     * when table is null, holds the initial table size to use upon
     * creation, or 0 for default. After initialization, holds the
     * next element count value upon which to resize the table.
     */
    private transient volatile int sizeCtl;

该字段非常关键,根据取值不同,有不同的功能。

使用默认构造函数时

    public ConcurrentHashMap() {
    }

此时,sizeCtl被初始化为0.

使用带初始容量的构造函数时

此时,sizeCtl也是32,和容量一致。

使用另一个map来初始化时

    public ConcurrentHashMap(Map<? extends K, ? extends V> m) {
        this.sizeCtl = DEFAULT_CAPACITY;
        putAll(m);
    }

此时,sizeCtl,直接使用了默认值,16.

使用初始容量、负载因子来初始化时

    public ConcurrentHashMap(int initialCapacity, float loadFactor) {
        this(initialCapacity, loadFactor, 1);
    }

这里重载了:

这里,我们传入的负载因子为0.75,这也是默认的负载因子,传入的初始容量为14.

这里面会根据: 1 + 14/0.75 = 19,拿到真正的size,然后根据size,获取到第一个大于19的2的n次方,即32,来作为数组容量,然后sizeCtl也被设置为32.

initTable时,对sizeCtl field的修改

实际上,new一个hashmap的时候,我们并没有创建支撑数组,那,什么时候创建数组呢?是在真正往里面放数据的时候,比如put的时候。

/** Implementation for put and putIfAbsent */
    final V putVal(K key, V value, boolean onlyIfAbsent) {
        if (key == null || value == null) throw new NullPointerException();
        int hash = spread(key.hashCode());

        int binCount = 0;
        ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO();
        vo.setBinCount(0);
        for (Node<K,V>[] tab = table;;) {
            int tableLength;
            // 1
            if (tab == null) {
                tab = initTable();
                continue;
            }
            ...
        }

1处,即会去初始化table。

/**
     * Initializes table, using the size recorded in sizeCtl.
     * 初始化hashmap,使用sizeCtl作为容量
     */
    private final Node<K,V>[] initTable() {
        Node<K,V>[] tab; int sc;
        while ((tab = table) == null || tab.length == 0) {
            sc = sizeCtl;
            if (sc < 0){
                Thread.yield(); // lost initialization race; just spin
                continue;
            }

            /**
             * 走到这里,说明sizeCtl大于0,大于0,代表什么,可以去看下其构造函数,此时,sizeCtl表示
             * capacity的大小。
             * {@link #ConcurrentHashMap(int)}
             *
             * cas修改为-1,如果成功修改为-1,则表示抢到了锁,可以进行初始化
             *
             */
            // 1
            boolean bGotChanceToInit = U.compareAndSwapInt(this, SIZECTL, sc, -1);
            if (bGotChanceToInit) {
                try {
                    tab = table;
                    /**
                     * 如果当前表为空,尚未初始化,则进行初始化,分配空间
                     */
                    if (tab == null || tab.length == 0) {
                        /**
                         * sc大于0,则以sc为准,否则使用默认的容量
                         */
                        int n = (sc > 0) ? sc : DEFAULT_CAPACITY;

                        Node<K, V>[] nt = (Node<K, V>[]) new Node<?, ?>[n];
                        table = tab = nt;
                        /**
                         * n >>> 2,无符号右移2位,则是n的四分之一。
                         * n- n/4,结果为3/4 * n
                         * 则,这里修改sc为 3/4 * n
                         * 比如,默认容量为16,则修改sc为12
                         */
                        // 2
                        sc = n - (n >>> 2);
                    }
                } finally {
                    /**
                     * 修改sizeCtl到field
                     */
                    // 3
                    sizeCtl = sc;
                }
                break;
            }
        }

        return tab;
    }
  • 1处,cas修改sizeCtl为-1,成功了的,获得初始化table的权利
  • 2处,修改局部变量sc为: n - (n >>> 2),也就是修改为 0.75n,假设此时的数组容量为16,那么sc就是16 * 0.75 = 12.
  • 3处,将sc赋值到field: sizeCtl

经过上面的分析,initTable时,这个字段可能有两种取值:

  • -1,有线程正在对该table进行初始化
  • 0.75*数组长度,此时,已经初始化完成

上面说的是,在put的时候去initTable,实际上,这个initTable,也会在以下函数中被调用,其共同点就是,都是往里面放数据的操作:

扩容时机

上面说了很多,目前,我们知道的是,在initTable后,sizeCtl的值,是旧的数组的长度 * 0.75。

接下来,我们看看扩容时机,在put时,会调用putVal,这个函数的大体步骤:

final V putVal(K key, V value, boolean onlyIfAbsent) {
    if (key == null || value == null) throw new NullPointerException();
    // 1
    int hash = spread(key.hashCode());

    int binCount = 0;
    System.out.println("binCount:" + binCount);
    // 2
    ConcurrentHashMapPutResultVO vo = new ConcurrentHashMapPutResultVO();
    vo.setBinCount(0);
    for (Node<K,V>[] tab = table;;) {
        int tableLength;
        // 3
        if (tab == null) {
            tab = initTable();
            continue;
        }
        
        tableLength = tab.length;
        if (tableLength == 0) {
            tab = initTable();
            continue;
        }

        int entryNodeHashCode;
		
        // 4
        int entryNodeIndex = (tableLength - 1) & hash;
        Node<K,V> entryNode = tabAt(tab,entryNodeIndex);

        /**
         * 5 如果我们要放的桶,还是个空的,则直接cas放进去
         */
        if (entryNode == null) {
            Node<K, V> node = new Node<>(hash, key, value, null);

            // no lock when adding to empty bin
            boolean bSuccess = casTabAt(tab, entryNodeIndex, null, node);
            if (bSuccess) {
                break;
            } else {
                /**
                 * 如果没成功,则继续下一轮循环
                 */
                continue;
            }
        }
		
        entryNodeHashCode = entryNode.hash;
        /**
         * 6 如果要放的这个桶,正在迁移,则帮助迁移
         */
        if (entryNodeHashCode == MOVED){
            tab = helpTransfer(tab, entryNode);
            continue;
        }


        /**
         * 7 对entryNode加锁
         */
        V oldVal = null;
        System.out.println("sync");
        synchronized (entryNode) {
            /**
             * 这一行是判断,在我们执行前面的一堆方法的时候,看看entryNodeIndex处的node是否变化
             */
            if (tabAt(tab, entryNodeIndex) != entryNode) {
                continue;
            }

            /**
             * 8 hashCode大于0,说明不是处于迁移状态
             */
            if (entryNodeHashCode >= 0) {
                /**
                 * 9 链表中找到合适的位置并放入
                 */
                findPositionAndPut(key, value, onlyIfAbsent, hash, vo, entryNode);
                binCount = vo.getBinCount();
                oldVal = (V) vo.getOldValue();
            }
            else if (entryNode instanceof TreeBin) {
                ...
            }
        }
		
        System.out.println("binCount:" + binCount);
        // 10
        if (binCount != 0) {
            if (binCount >= TREEIFY_THRESHOLD)
                treeifyBin(tab, entryNodeIndex);
            if (oldVal != null)
                return oldVal;
            break;
        }
    }
    // 11
    addCount(1L, binCount);
    return null;
}
  • 1处,计算key的hashcode

  • 2处,我这边new了一个对象,里面两个字段:

    public class ConcurrentHashMapPutResultVO<V> {
        int binCount;
    
        V oldValue;
    }
    

    其中,oldValue用来存放,如果put进去的key/value,其中key已经存在的话,一般会直接覆盖之前的旧值,这里主要存放之前的旧值,因为我们需要返回旧值。

    binCount,则存放:在找到对应的hash桶之后,在链表中,遍历了多少个元素,该值后面会使用,作为一个标志,当该标志大于0的时候,才去进一步检查,看看是否扩容。

  • 3处,如果table为null,说明table里没有任何一个键值对,数组也还没创建,则初始化table

  • 4处,根据hashcode,和(数组长度 - 1)相与,计算出应该存放的哈希桶在数组中的索引

  • 5处,如果要放的哈希桶,还是空的,则直接cas设置进去,成功则跳出循环,否则重试

  • 6处,如果要放的这个桶,该节点的hashcode为MOVED(一个常量,值为-1),说明有其他线程正在扩容该hashmap,则帮助扩容

  • 7处,对要存放的hash桶的头节点加锁

  • 8处,如果头节点的hashcode大于0,说明是拉了一条链表,则调用子方法(我这边自己抽的),去找到合适的位置并插入到链表

  • 9处,findPositionAndPut,在链表中,找到合适的位置,并插入

  • 10处,在findPositionAndPut函数中,会返回:为了找到合适的位置,遍历了多少个元素,这个值,就是binCount。

    如果这个binCount大于8,则说明遍历了8个元素,则需要转红黑树了。

  • 11处,因为我们新增了一个元素,总数自然要加1,这里面会去增加总数,和检查是否需要扩容。

其中,第9步,因为是自己抽的函数,所以这里贴出来给大家看下:

/**
     * 遍历链表,找到应该放的位置;如果遍历完了还没找到,则放到最后
     * @param key
     * @param value
     * @param onlyIfAbsent
     * @param hash
     * @param vo
     * @param entryNode
     */
    private void findPositionAndPut(K key, V value, boolean onlyIfAbsent, int hash, ConcurrentHashMapPutResultVO vo, Node<K, V> entryNode) {
        vo.setBinCount(1);

        for (Node<K,V> currentIterateNode = entryNode;
                ;
             vo.setBinCount(vo.getBinCount() + 1)) {


            /**
             * 如果当前遍历指向的节点的hash值,与参数中的key的hash值相等,则,
             * 继续判断
             */
            K currentIterateNodeKey = currentIterateNode.key;
            boolean bKeyEqualOrNot = Objects.equals(currentIterateNodeKey, key);
            /**
             * key的hash值相等,且equals比较也相等,则就是我们要找的
             */
            if (currentIterateNode.hash == hash && bKeyEqualOrNot) {
                /**
                 * 获取旧的值
                 */
                vo.setOldValue(currentIterateNode.val);

                /**
                 * 覆盖旧的node的val
                 */
                if (!onlyIfAbsent)
                    currentIterateNode.val = value;
                // 这里直接break跳出循环
                break;
            }

            /**
             * 把当前节点保存起来
             */
            Node<K,V> pred = currentIterateNode;
            /**
             * 获取下一个节点
             */
            currentIterateNode = currentIterateNode.next;
            /**
             * 如果下一个节点为null,说明当前已经是链表的最后一个node了
             */
            if ( currentIterateNode  == null) {
                /**
                 * 则在当前节点后面,挂上新的节点
                 */
                pred.next = new Node<K,V>(hash, key,
                        value, null);
                break;
            }
        }

    }

第11步,也是我们要看的重点:

private final void addCount(long delta, int check) {
        CounterCell[] counterCellsArray = counterCells;
		// 1
        long b = baseCount;
    	// 2
        long newBaseCount = b + delta;

        /**
         * 3 直接cas在baseCount上增加
         */
        boolean bSuccess = U.compareAndSwapLong(this, BASECOUNT, b, newBaseCount);
        if ( counterCellsArray != null ||  !bSuccess) {
			...
            newBaseCount = sumCount();
        }
		
    	// 4
        if (check >= 0) {
            while (true) {

                Node<K,V>[] tab = table;
                Node<K,V>[] nt;
                int n = 0;
                // 5
                int sc =  sizeCtl;
                // 6
                boolean bSumExteedSizeControl = newBaseCount >= (long) sc;
				// 7
                boolean bContinue = bSumExteedSizeControl && tab != null && (n = tab.length) < MAXIMUM_CAPACITY;
                if (bContinue) {
                    int rs = resizeStamp(n);
                    if (sc < 0) {
                        if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
                                sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
                                transferIndex <= 0)
                            break;
                        if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
                            transfer(tab, nt);
                    } else if (U.compareAndSwapInt(this, SIZECTL, sc,
                            (rs << RESIZE_STAMP_SHIFT) + 2))
                        // 8
                        transfer(tab, null);
                    newBaseCount = sumCount();
                } else {
                    break;
                }
            }

        }
    }
  • 1处,baseCount是一个field,存储当前hashmap中,有多少个键值对,你put一次,就一个;remove一次,就减一个。

  • 2处,b + delta,其中,b就是baseCount,是旧的数量;dalta,我们传入的是1,就是要增加的元素数量

    所以,b + delta,得到的,就是经过这次put后,预期的数量

  • 3处,直接cas,修改baseCount这个field为 新值,也就是第二步拿到的值。

  • 4处,这里检查check是否大于0,check,是第二个形参;这个参数,我们外边怎么传的?

    addCount(1L, binCount);

    不就是bincount吗,也就是说,这里检查:我们在put过程中,在链表中遍历了几个元素,如果遍历了至少1个元素,这里要进入下面的逻辑:检查是否要扩容,因为,你binCount大于0,说明可能已经开始出现哈希冲突了。

  • 5处,取field:sizeCtl的值,给局部变量sc

  • 6处,判断当前的新的键值对总数,是否大于sc了;比如容量是16,那么sizeCtl是12,如果此时,hashmap中存放的键值对已经大于等于12了,则要检查是否扩容了

  • 7处,几个组合条件,查看是否要扩容,其中,主要的条件就是第6步的那个。

  • 8处,调用transfer,进行扩容

总结一下,经过前面的第6处,我们知道,如果存放的键值对总数,已经大于等于0.75*哈希桶(也就是底层数组的长度)的数量了,那么,就基本要扩容了。

扩容的大体过程

扩容也是一个相对复杂的过程,这里只说大概,详细的放下讲。

假设,现在底层数组长度,128,也就是128个哈希桶,当存放的键值对数量,大于等于 128 * 0.75的时候,就会开始扩容,扩容的过程,大概是:

  • 申请一个256(容量翻倍)的数组
  • 现在有128个桶,相当于,需要对128个桶进行遍历,遍历每个桶拉出去的链表或红黑树,查看每个键值对,是需要放到新数组的什么位置

这个过程,昨天的博文,画了个图,这里再贴一下。

扩容后:

可是,如果我们要一个个去遍历所有哈希桶,然后遍历对应的链表/红黑树,会不会太慢了?完全是单线程工作啊。

换个思路,我们能不能加快点呢?比如,线程1可以去处理数组的 0 -15这16个桶,16- 31这16个桶,完全可以让线程2去做啊,这样的话,不就多线程了吗,不是就快了吗?

没错,jdk就是这么干的。

jdk维护了一个field,这个field,专门用来存当前可以获取的任务的索引,举个例子:

大家看上图就懂了,一开始,这里假设我们有128个桶,每次每个线程,去拿16个桶来处理。

刚开始的时候,field:transferIndex就等于127,也就是最后一个桶的位置,然后我们要从后往前取,那么,127 到112,刚好就是16个桶,所以,申请任务的时候,就会用cas去更新field为112,则表示,自己取到了112 到127这一个区间的hash桶迁移任务。

如果自始至终,只有一个线程呢,它处理完了112 - 127这一批hash桶后,会继续取下一波任务,96 - 112;以此类推。

如果多线程的话呢,也是类似的,反正都是去尝试cas更新transferIndex的值为任务区间的开始下标的值,成功了,就算任务认领成功了。

多线程,怎么知道需要去帮助扩容呢? 发起扩容的线程,在处理完bucket[k]时,会把老的table中的对应的bucket[k]的头节点,修改为下面这种类型的节点:

    static final class ForwardingNode<K,V> extends Node<K,V> {
        final Node<K,V>[] nextTable;
        ForwardingNode(Node<K,V>[] tab) {
            super(MOVED, null, null, null);
            this.nextTable = tab;
        }
    }

其他线程,在put或者其他操作时,发现头结点变成了这个,就会去协助扩容了。

多线程扩容,和分段取任务的差别?

我个人感觉,差别不大,多线程扩容,就是多线程去获取自己的那一段任务,然后来完成。我这边写了简单的demo,不过感觉还是很有用的,可以帮助我们理解。

import sun.misc.Unsafe;

import java.lang.reflect.Field;
import java.util.concurrent.*;
import java.util.concurrent.locks.LockSupport;

public class ConcurrentTaskFetch {

    /**
     * 空闲任务索引,获取任务时,从该下标开始,往前获取。
     * 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取
     */
    // 0
    private  volatile int freeTaskIndexForFetch;
	
    // 1
    private static final int TASK_COUNT_PER_FETCH = 16;
	
    // 2
    private String[] tasks = new String[128];

    public static void main(String[] args) {
        ConcurrentTaskFetch fetch = new ConcurrentTaskFetch();
        // 3
        fetch.init();

        ThreadPoolExecutor executor = new ThreadPoolExecutor(10, 10, 60, TimeUnit.SECONDS, new ArrayBlockingQueue<>(100));
        executor.prestartAllCoreThreads();

        CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
		
        // 4
        for (int i = 0; i < 10; i++) {
            executor.execute(new Runnable() {
                @Override
                public void run() {
                    try {
                        cyclicBarrier.await();
                    } catch (InterruptedException | BrokenBarrierException e) {
                        e.printStackTrace();
                    }
					
                    // 5
                    FetchedTaskInfo fetchedTaskInfo = fetch.fetchTask();
                    if (fetchedTaskInfo != null) {
                        System.out.println("thread:" + Thread.currentThread().getName() + ",get task success:" + fetchedTaskInfo);
                        try {
                            TimeUnit.SECONDS.sleep(3);
                        } catch (InterruptedException e) {
                            e.printStackTrace();
                        }

                        System.out.println("thread:" + Thread.currentThread().getName()  +  ", process task finished");
                    }
                }
            });
        }


        LockSupport.park();
    }

    public void init() {
        for (int i = 0; i < 128; i++) {
            tasks[i] = "task" + i;
        }
        freeTaskIndexForFetch = tasks.length;
    }

	// 6
    public FetchedTaskInfo fetchTask() {
        System.out.println("Thread start fetch task:"+Thread.currentThread().getName()+",time: "+System.currentTimeMillis());

        while (true){
			// 6.1
            if (freeTaskIndexForFetch == 0) {
                System.out.println("thread:" + Thread.currentThread().getName() + ",get task failed,there is no task");
                return null;
            }

            /**
             * 6.2 获取当前任务的集合的上界
             */
            int subTaskListEndIndex = this.freeTaskIndexForFetch;

            /**
             * 6.3 获取当前任务的集合的下界
             */
            int subTaskListStartIndex = subTaskListEndIndex > TASK_COUNT_PER_FETCH ?
                    subTaskListEndIndex - TASK_COUNT_PER_FETCH : 0;

            /**
             * 6.4
             * 现在,我们拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex)
             * 该区间为前开后闭,所以,实际的区间为:
             * [subTaskListStartIndex,subTaskListEndIndex - 1]
             */

            /**
             * 6.5 使用cas,尝试更新{@link freeTaskIndexForFetch} 为 subTaskListStartIndex
             */
            if (U.compareAndSwapInt(this, FREE_TASK_INDEX_FOR_FETCH, subTaskListEndIndex, subTaskListStartIndex)) {
                // 6.6 
                FetchedTaskInfo info = new FetchedTaskInfo();
                info.setStartIndex(subTaskListStartIndex);
                info.setEndIndex(subTaskListEndIndex - 1);


                return info;
            }
        }

    }



    // Unsafe mechanics
    private static final sun.misc.Unsafe U;

    private static final long FREE_TASK_INDEX_FOR_FETCH;

    static {
        try {
//            U = sun.misc.Unsafe.getUnsafe();
            Field f = Unsafe.class.getDeclaredField("theUnsafe");
            f.setAccessible(true);
            U = (Unsafe) f.get(null);
            Class<?> k = ConcurrentTaskFetch.class;
            FREE_TASK_INDEX_FOR_FETCH = U.objectFieldOffset
                    (k.getDeclaredField("freeTaskIndexForFetch"));
        } catch (Exception e) {
            throw new Error(e);
        }
    }


    static class FetchedTaskInfo{
        int startIndex;
        int endIndex;

        public int getStartIndex() {
            return startIndex;
        }

        public void setStartIndex(int startIndex) {
            this.startIndex = startIndex;
        }

        public int getEndIndex() {
            return endIndex;
        }

        public void setEndIndex(int endIndex) {
            this.endIndex = endIndex;
        }

        @Override
        public String toString() {
            return "FetchedTaskInfo{" +
                    "startIndex=" + startIndex +
                    ", endIndex=" + endIndex +
                    '}';
        }
    }
}

  • 0处,定义了一个field,类似于前面的transferIndex

        /**
         * 空闲任务索引,获取任务时,从该下标开始,往前获取。
         * 比如当前下标为10,表示tasks数组中,0-10这个区间的任务,没人领取
         */
        // 0
        private  volatile int freeTaskIndexForFetch;
    
  • 1,定义了每次取多少个任务,这里也是16个

    private static final int TASK_COUNT_PER_FETCH = 16;
    
  • 2,定义任务列表,共128个任务

  • 3,main函数中,进行任务初始化

    public void init() {
        for (int i = 0; i < 128; i++) {
            tasks[i] = "task" + i;
        }
        freeTaskIndexForFetch = tasks.length;
    }
    

    主要初始化任务列表,其次,将freeTaskIndexForFetch 赋值为128,后续取任务,从这个下标开始

  • 4处,启动10个线程,每个线程去执行取任务,按理说,我们128个任务,每个线程取16个,只能有8个线程取到任务,2个线程取不到

  • 5处,线程逻辑里,去获取任务

  • 6处,获取任务的方法定义

  • 6.1 ,如果可获取的任务索引为0了,说明没任务了,直接返回

  • 6.2,获取当前任务的集合的上界

  • 6.3,获取当前任务的集合的下界,减去16就行了

  • 6.4,拿到了集合的上下界,即[subTaskListStartIndex,subTaskListEndIndex)

  • 6.5, 使用cas,更新field为:6.4中的任务下界。

执行效果演示:

可以看到,8个线程取到任务,2个线程没取到。

该思想在内存分配时的应用

其实jvm内存分配时,也是类似的思路,比如,设置堆内存为200m,那这200m是启动时立马从操作系统分配了的。

接下来,就是每次new对象的时候,去这个大内存里,找个小空间,这个过程,也是需要cas去竞争的,比如肯定也有个全局的字段,来表示当前可用内存的索引,比如该索引为100,表示,第100个字节后的空间是可以用的,那我要new个对象,这个对象有3个字段,需要大概30个字节,那我是不是需要把这个索引更新为130。

这中间是多线程的,所以也是要cas操作。

道理都是类似的。

总结

时间仓促,有问题在所难免,欢迎及时指出或加群讨论。

posted @ 2020-06-07 22:45  三国梦回  阅读(1209)  评论(0编辑  收藏  举报