手写一个AQS实现

1.背景

1.AQS简介
AQS全称为AbstractQueuedSynchronizer(抽象队列同步器)。AQS是一个用来构建锁和其他同步组件的基础框架,
使用AQS可以简单且高效地构造出应用广泛的同步器,例如ReentrantLock、Semaphore、ReentrantReadWriteLock和FutureTask等等。

2.原理
AQS核心思想是,如果被请求的共享资源空闲,则将当前请求资源的线程设置为有效的工作线程,并且将共享资源设置为锁定状态。
如果被请求的共享资源被占用,那么就需要一套线程阻塞等待以及被唤醒时锁分配的机制,这个机制AQS是用CLH队列锁实现的,
即将暂时获取不到锁的线程加入到队列中。

3.CLH队列
CLH(Craig,Landin,and Hagersten)队列是一个虚拟的双向队列(虚拟的双向队列即不存在队列实例,仅存在结点之间的关联关系)。
AQS是将每条请求共享资源的线程封装成一个CLH锁队列的一个结点(Node)来实现锁的分配。

通俗的说就是:当共享资源空闲时,线程直接执行,当共享资源被占用时线程进入队列等待;

图解如下:

2.代码-基础准备

2.1.Node节点对象创建

在MyReentrantLock对象内建立一个Node节点对象,后面作为双向链表的节点;

public class MyReentrantLock {
    /**
     * 双向链表队列节点
     */
    class Node {
        /**
         * -1表示下一个获取资源的线程
         */
        static final int SIGNAL = -1;
        /**
         * 等待获取资源的状态
         */
        volatile int waitStatus;
        /**
         * 前一个节点
         */
        volatile Node prev;
        /**
         * 下一个节点
         */
        volatile Node next;
        /**
         * 节点对应的线程
         */
        volatile Thread thread;

        /**
         * 默认构造方法
         */
        Node() {
        }

        /**
         * 传入线程的构造方法
         *
         * @param thread
         */
        Node(Thread thread) {
            this.thread = thread;
        }

        /**
         * 获取当前节点的前一个节点
         *
         * @return
         * @throws NullPointerException
         */
        public Node predecessor() throws NullPointerException {
            Node p = prev;
            if (p == null)
                throw new NullPointerException();
            else
                return p;
        }
    }
}
View Code

2.2.MyReentrantLock对象的属性设计

public class MyReentrantLock {
    /**
     * 资源资源是否可用
     * 0-资源可用
     * 1-资源不可用,已被其他线程占用
     * 默认为资源可用0
     */
    private int state;
    /**
     * 双向链表-头结点
     */
    private Node head;
    /**
     * 双向链表-尾节点
     */
    private Node tail;
    /**
     * 已经获取到锁的线程
     */
    private Thread exclusiveOwnerThread;
    // 注意get set 方法自动生成,这里不在博客中显示出来
}
View Code

2.3.unsafe对象与属性设计

1.获取unsafe实例对象

public class MyUnsafeAccessor {
    static Unsafe unsafe;
    static {
        // 通过反射获取
        try {
            //  Unsafe 对象中的 private static final Unsafe theUnsafe;
            Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe");
            theUnsafe.setAccessible(true);
            unsafe = (Unsafe) theUnsafe.get(null);
        } catch (NoSuchFieldException e) {
            e.printStackTrace();
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
        // 参看 AtomicInteger对象的获取方式,这样可以吗?不行的,为什么?  @CallerSensitive ,大家可以查阅一下这注解就明白了
        //  unsafe = Unsafe.getUnsafe();
    }
    public static Unsafe getUnsafe() {
        return unsafe;
    }
}
View Code

2.用于cas的操作的属性设计

 /**
     * unsafe执行cas对象
     */
    private static final Unsafe unsafe = MyUnsafeAccessor.getUnsafe();
    private static final long stateOffset;
    private static final long headOffset;
    private static final long tailOffset;
    private static final long waitStatusOffset;

    /**
     * 对象启动的时候就加载
     */
    static {
        try {
            stateOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("state"));
            headOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("head"));
            tailOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("tail"));
            waitStatusOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.Node.class.getDeclaredField("waitStatus"));
        } catch (Exception ex) {
            throw new Error(ex);
        }
    }
View Code

2.4cas实现原子更新方法

这里把后面要用到的资源状态、头结点、尾节点、节点等待状态这4个属性的cas原子更新方法写出来

便于后面直接使用

1.cas修改状态值

    /**
     * cas修改状态值
     *
     * @param expect
     * @param update
     * @return
     */
    public boolean compareAndSetState(int expect, int update) {
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

2.cas修改尾节点

    /**
     * cas的方式设置尾节点
     *
     * @param pred
     * @param node
     * @return
     */
    public boolean compareAndSetTail(Node pred, Node node) {
        return unsafe.compareAndSwapObject(this, tailOffset, pred, node);
    }

3.cas修改节点等待状态

    /**
     * cas修改节点等待状态
     *
     * @param node
     * @param expect
     * @param update
     * @return
     */
    private static final boolean compareAndSetWaitStatus(Node node, int expect, int update) {
        return unsafe.compareAndSwapInt(node, waitStatusOffset, expect, update);
    }

4.cas设置头结点

    /**
     * cas设置头结点
     *
     * @param update
     * @return
     */
    private final boolean compareAndSetHead(Node update) {
        return unsafe.compareAndSwapObject(this, headOffset, null, update);
    }

3.代码-核心逻辑

1.加锁

    /**
     * 加锁
     */
    public void lock() {
        // compareAndSetState(0, 1) 获取锁
        if (compareAndSetState(0, 1)) {
            // 获取锁成功,设置当前线程为资源拥有者,思考为什么这里不使用cas的方式设置线程拥有者?
            setExclusiveOwnerThread(Thread.currentThread());
        } else {
            // 获取锁失败,入队列,排队,等待,重新获取锁
            acquire(1);
        }
    }

2.获取锁

    /**
     * 获取锁
     *
     * @param arg
     */
    public void acquire(Integer arg) {
        // 将线程加入到队列
        Node node = addWaiter();
        // 队列中排队等待获取锁
        acquireQueued(node, arg);
    }

3.将线程加入到队列

    /**
     * 将没有获取到锁的线程添加到队列
     */
    public Node addWaiter() {
        // 构建一个线程的节点
        Node node = new Node(Thread.currentThread());
        Node pred = tail;
        if (pred != null) {
            node.prev = pred;
            if (compareAndSetTail(pred, node)) {
                pred.next = node;
                return node;
            }
        }
        // 初始化队列,并把新节点作为尾节点
        enq(node);
        return node;
    }

4.初始化队列,并把新节点作为尾节点

    /**
     * 初始化队列,并把新节点作为尾节点
     *
     * @param node
     * @return
     */
    private Node enq(Node node) {
        for (; ; ) {
            Node t = tail;
            if (t == null) {
                // 如果尾节点为空初始化头节点
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else {
                // 构成双向链表,把节点作为尾节点
                node.prev = t;
                if (compareAndSetTail(t, node)) {
                    t.next = node;
                    return t;
                }
            }
        }
    }

5.尝试让队列中的节点获取锁

 /**
     * acquireQueued()用于队列中的线程自旋地以独占且不可中断的方式获取同步状态(acquire),直到拿到锁之后再返回。
     * 该方法的实现分成两部分:如果当前节点已经成为头结点,尝试获取锁(tryAcquire)成功,然后返回;
     * 否则检查当前节点是否应该被park,
     * 然后将该线程park并且检查当前线程是否被可以被中断。
     */
    public boolean acquireQueued(Node node, Integer arg) {
        boolean interrupted = false;
        for (; ; ) {
            Node p = node.predecessor(); // node 节点的前一个节点
            // p == head 表示出节点node外,没有等待的节点
            // tryAcquire(arg) 当前线程  尝试获取锁
            if (p == head && tryAcquire(arg)) {
                // 将刚才已经获取了锁的线程设置为头结点
                setHead(node);
                p.next = null;
                return interrupted;
            }
            // shouldParkAfterFailedAcquire(p, node)检查当前节点是否应该被park
            // parkAndCheckInterrupt()
            if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt()) {
                interrupted = true;
            }
        }
    }

6.尝试获取锁

   /**
     * 尝试获取锁
     *
     * @param arg
     * @return
     */
    public boolean tryAcquire(Integer arg) {
        return nonfairTryAcquire(arg);
    }

    /**
     * acquires  传入的值为1
     * 执行不公平的tryLock。tryAcquire在中实现
     * 子类,但这两个类都需要trylock方法的非空尝试。
     */
    public boolean nonfairTryAcquire(int acquires) {
        final Thread current = Thread.currentThread();
        // 获取资源状态  0-表示资源可用 , 1-表示资源不可用
        int c = getState();
        if (c == 0) {
            // 资源可用,采用cas的方式将0改为1
            if (compareAndSetState(0, acquires)) {
                // 设置当前拥有独占访问权限的线程
                setExclusiveOwnerThread(current);
                // 重试成功
                return true;
            }
        }
        // 检查当前线程是否是 已经拥有锁的线程
        else if (current == getExclusiveOwnerThread()) {
            int nextc = c + acquires;
            if (nextc < 0) // overflow
                throw new Error("Maximum lock count exceeded");
            // 可重入锁,并记录次数 ,如果是第二次进入 state=2
            setState(nextc);
            return true;
        }
        // 尝试获取锁失败
        return false;
    }

7.检查节点是否应该被阻塞

   /**
     * 检查并更新无法获取的节点的状态。
     * 如果线程应该阻塞,则返回true。这是所有采集回路中的主要信号控制。要求pred==node.prev。
     * Node.SIGNAL=-1, waitStatus值,用于指示后续线程需要取消连接
     */
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            return true;
        if (ws > 0) {
            do {
                // 当前节点的前一个节点为:前节点的前节点
                // node.prev = pred = pred.prev;
                pred = pred.prev;
                node.prev = pred;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

8.当前线程进入阻塞

    /**
     * 该方法让线程去休息,真正进入等待状态。
     * park()会让当前线程进入waiting状态。
     * 在此状态下,有两种途径可以唤醒该线程:1)被unpark();2)被interrupt()。
     * 需要注意的是,Thread.interrupted()会清除当前线程的中断标记位。
     */
    private final boolean parkAndCheckInterrupt() {
        LockSupport.park(this);
        return Thread.interrupted();
    }

 到此处,加锁的核心代码逻辑已完成!

4.释放锁-代码

1.释放锁

    /**
     * 释放锁
     */
    public boolean unlock() {
        int arg = 1;
        if (tryRelease(arg)) {
            Node h = head;
            if (h != null && h.waitStatus != 0)
                unparkSuccessor(h);
            return true;
        }
        return false;
    }

2.已执行完成的线程释放资源 

    /**
     * 已执行完成的线程,释放资源
     *
     * @param releases
     * @return
     */
    public boolean tryRelease(int releases) {
        int c = getState() - releases;
        if (Thread.currentThread() != getExclusiveOwnerThread())
            throw new IllegalMonitorStateException();
        boolean free = false;
        if (c == 0) {
            free = true;
            setExclusiveOwnerThread(null);
        }
        setState(c);
        return free;
    }

3.唤醒下一个等待执行的线程

 /**
     * 唤醒下一个应该执行的线程
     *
     * @param node
     */
    private void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0) {
            compareAndSetWaitStatus(node, ws, 0);
        }
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread);
    }

5.测试代码

public class Test01 {
    private int num = 0;

    /**
     * 测试自己实现的aqs锁
     *
     * @throws InterruptedException
     */
    @Test
    public void testMyLock() throws InterruptedException {
        MyReentrantLock lock = new MyReentrantLock();
        List<Thread> list = new ArrayList<>();
        for (int i = 0; i < 100; i++) {
            Thread thread = new Thread(() -> {
                lock.lock(); // 加锁
                addFive();
                lock.unlock(); // 释放锁
            });
            list.add(thread);
        }
        // 启动线程
        int n = 1;
        for (Thread thread : list) {
            thread.setName("t" + (n++));
            thread.start();
        }
        // 等待线下执行完成
        for (Thread thread : list) {
            thread.join();
            System.out.println(thread.getName() + "已完成");
        }
        System.out.println("num=" + num);
    }

    /**
     * 这是一个将num加5的方法
     */
    public void addFive() {
        try {
            num = num + 1;
            Thread.sleep(5);
            num = num + 1;
            Thread.sleep(5);
            num = num + 1;
            Thread.sleep(5);
            num = num + 1;
            Thread.sleep(5);
            num = num + 1;
            Thread.sleep(5);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}
View Code

 6.完整的代码

package com.ldp.aqs.v4;

import com.ldp.aqs.MyUnsafeAccessor;
import sun.misc.Unsafe;

import java.util.concurrent.locks.LockSupport;

/**
 * @create 09/11 11:05
 * @description <p>
 * 自定义AQS实现思路:
 * 如果获取资源失败,将线程放入队列里面,并让线程处于阻塞状态,当资源释放后在唤醒阻塞的线程
 * 使用到的技术点
 * 1.双向链表作为队列;
 * 2.unsafe实现cas原子操作;
 * 3.park,unpark实现线程阻塞与唤醒
 * 4.interrupt、isInterrupted、interrupted 实现中断
 * </p>
 */
public class MyReentrantLock {
    /**
     * 资源资源是否可用
     * 0-资源可用
     * 1-资源不可用,已被其他线程占用
     * 默认为资源可用0
     */
    private int state;
    /**
     * 双向链表-头结点
     */
    private Node head;
    /**
     * 双向链表-尾节点
     */
    private Node tail;
    /**
     * 已经获取到锁的线程
     */
    private Thread exclusiveOwnerThread;

    /**
     * unsafe执行cas对象
     */
    private static final Unsafe unsafe = MyUnsafeAccessor.getUnsafe();
    private static final long stateOffset;
    private static final long headOffset;
    private static final long tailOffset;
    private static final long waitStatusOffset;

    /**
     * 对象启动的时候就加载
     */
    static {
        try {
            stateOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("state"));
            headOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("head"));
            tailOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.class.getDeclaredField("tail"));
            waitStatusOffset = unsafe.objectFieldOffset
                    (MyReentrantLock.Node.class.getDeclaredField("waitStatus"));
        } catch (Exception ex) {
            throw new Error(ex);
        }
    }

    /**
     * 加锁
     */
    public void lock() {
        // compareAndSetState(0, 1) 获取锁
        if (compareAndSetState(0, 1)) {
            // 获取锁成功,设置当前线程为资源拥有者,思考为什么这里不使用cas的方式设置线程拥有者?
            setExclusiveOwnerThread(Thread.currentThread());
        } else {
            // 获取锁失败,入队列,排队,等待,重新获取锁
            acquire(1);
        }
    }

    /**
     * 获取锁
     *
     * @param arg
     */
    public void acquire(Integer arg) {
        // 将线程加入到队列
        Node node = addWaiter();
        // 队列中排队等待获取锁
        acquireQueued(node, arg);
    }

    /**
     * acquireQueued()用于队列中的线程自旋地以独占且不可中断的方式获取同步状态(acquire),直到拿到锁之后再返回。
     * 该方法的实现分成两部分:如果当前节点已经成为头结点,尝试获取锁(tryAcquire)成功,然后返回;
     * 否则检查当前节点是否应该被park,
     * 然后将该线程park并且检查当前线程是否被可以被中断。
     */
    public boolean acquireQueued(Node node, Integer arg) {
        boolean interrupted = false;
        for (; ; ) {
            Node p = node.predecessor(); // node 节点的前一个节点
            // p == head 表示出节点node外,没有等待的节点
            // tryAcquire(arg) 当前线程  尝试获取锁
            if (p == head && tryAcquire(arg)) {
                // 将刚才已经获取了锁的线程设置为头结点
                setHead(node);
                p.next = null;
                return interrupted;
            }
            // shouldParkAfterFailedAcquire(p, node)检查当前节点是否应该被park
            // parkAndCheckInterrupt()
            if (shouldParkAfterFailedAcquire(p, node) &&
                    parkAndCheckInterrupt()) {
                interrupted = true;
            }
        }
    }

    /**
     * 该方法让线程去休息,真正进入等待状态。
     * park()会让当前线程进入waiting状态。
     * 在此状态下,有两种途径可以唤醒该线程:1)被unpark();2)被interrupt()。
     * 需要注意的是,Thread.interrupted()会清除当前线程的中断标记位。
     */
    private final boolean parkAndCheckInterrupt() {
        LockSupport.park(this);
        return Thread.interrupted();
    }

    /**
     * 检查并更新无法获取的节点的状态。
     * 如果线程应该阻塞,则返回true。这是所有采集回路中的主要信号控制。要求pred==node.prev。
     * Node.SIGNAL=-1, waitStatus值,用于指示后续线程需要取消连接
     */
    private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) {
        int ws = pred.waitStatus;
        if (ws == Node.SIGNAL)
            return true;
        if (ws > 0) {
            do {
                // 当前节点的前一个节点为:前节点的前节点
                // node.prev = pred = pred.prev;
                pred = pred.prev;
                node.prev = pred;
            } while (pred.waitStatus > 0);
            pred.next = node;
        } else {
            compareAndSetWaitStatus(pred, ws, Node.SIGNAL);
        }
        return false;
    }

    /**
     * 尝试获取锁
     *
     * @param arg
     * @return
     */
    public boolean tryAcquire(Integer arg) {
        return nonfairTryAcquire(arg);
    }

    /**
     * acquires  传入的值为1
     * 执行不公平的tryLock。tryAcquire在中实现
     * 子类,但这两个类都需要trylock方法的非空尝试。
     */
    public boolean nonfairTryAcquire(int acquires) {
        final Thread current = Thread.currentThread();
        // 获取资源状态  0-表示资源可用 , 1-表示资源不可用
        int c = getState();
        if (c == 0) {
            // 资源可用,采用cas的方式将0改为1
            if (compareAndSetState(0, acquires)) {
                // 设置当前拥有独占访问权限的线程
                setExclusiveOwnerThread(current);
                // 重试成功
                return true;
            }
        }
        // 检查当前线程是否是 已经拥有锁的线程
        else if (current == getExclusiveOwnerThread()) {
            int nextc = c + acquires;
            if (nextc < 0) // overflow
                throw new Error("Maximum lock count exceeded");
            // 可重入锁,并记录次数 ,如果是第二次进入 state=2
            setState(nextc);
            return true;
        }
        // 尝试获取锁失败
        return false;
    }

    /**
     * 将没有获取到锁的线程添加到队列
     */
    public Node addWaiter() {
        // 构建一个线程的节点
        Node node = new Node(Thread.currentThread());
        Node pred = tail;
        if (pred != null) {
            node.prev = pred;
            if (compareAndSetTail(pred, node)) {
                pred.next = node;
                return node;
            }
        }
        // 初始化队列,并把新节点作为尾节点
        enq(node);
        return node;
    }

    /**
     * 初始化队列,并把新节点作为尾节点
     *
     * @param node
     * @return
     */
    private Node enq(Node node) {
        for (; ; ) {
            Node t = tail;
            if (t == null) {
                // 如果尾节点为空初始化头节点
                if (compareAndSetHead(new Node()))
                    tail = head;
            } else {
                // 构成双向链表,把节点作为尾节点
                node.prev = t;
                if (compareAndSetTail(t, node)) {
                    t.next = node;
                    return t;
                }
            }
        }
    }

    /**
     * 释放锁
     */
    public boolean unlock() {
        int arg = 1;
        if (tryRelease(arg)) {
            Node h = head;
            if (h != null && h.waitStatus != 0)
                unparkSuccessor(h);
            return true;
        }
        return false;
    }

    /**
     * 唤醒下一个应该执行的线程
     *
     * @param node
     */
    private void unparkSuccessor(Node node) {
        int ws = node.waitStatus;
        if (ws < 0) {
            compareAndSetWaitStatus(node, ws, 0);
        }
        Node s = node.next;
        if (s == null || s.waitStatus > 0) {
            s = null;
            for (Node t = tail; t != null && t != node; t = t.prev)
                if (t.waitStatus <= 0)
                    s = t;
        }
        if (s != null)
            LockSupport.unpark(s.thread);
    }

    /**
     * 已执行完成的线程,释放资源
     *
     * @param releases
     * @return
     */
    public boolean tryRelease(int releases) {
        int c = getState() - releases;
        if (Thread.currentThread() != getExclusiveOwnerThread())
            throw new IllegalMonitorStateException();
        boolean free = false;
        if (c == 0) {
            free = true;
            setExclusiveOwnerThread(null);
        }
        setState(c);
        return free;
    }

    /**
     * cas修改状态值
     *
     * @param expect
     * @param update
     * @return
     */
    public boolean compareAndSetState(int expect, int update) {
        // 注意这个你更新的是int基本类型只能使用:compareAndSwapInt 不能使用compareAndSwapObject
        return unsafe.compareAndSwapInt(this, stateOffset, expect, update);
    }

    /**
     * cas的方式设置尾节点
     *
     * @param pred
     * @param node
     * @return
     */
    public boolean compareAndSetTail(Node pred, Node node) {
        return unsafe.compareAndSwapObject(this, tailOffset, pred, node);
    }

    /**
     * cas修改节点等待状态
     *
     * @param node
     * @param expect
     * @param update
     * @return
     */
    private static final boolean compareAndSetWaitStatus(Node node, int expect, int update) {
        return unsafe.compareAndSwapInt(node, waitStatusOffset, expect, update);
    }

    /**
     * cas设置头结点
     *
     * @param update
     * @return
     */
    private final boolean compareAndSetHead(Node update) {
        return unsafe.compareAndSwapObject(this, headOffset, null, update);
    }

    public int getState() {
        return state;
    }

    public void setState(int state) {
        this.state = state;
    }

    public Node getHead() {
        return head;
    }

    public void setHead(Node head) {
        this.head = head;
    }

    public Node getTail() {
        return tail;
    }

    public void setTail(Node tail) {
        this.tail = tail;
    }

    public Thread getExclusiveOwnerThread() {
        return exclusiveOwnerThread;
    }

    public void setExclusiveOwnerThread(Thread exclusiveOwnerThread) {
        this.exclusiveOwnerThread = exclusiveOwnerThread;
    }

    /**
     * 双向链表队列节点
     */
    class Node {
        /**
         * -1表示下一个获取资源的线程
         */
        static final int SIGNAL = -1;
        /**
         * 等待获取资源的状态
         */
        volatile int waitStatus;
        /**
         * 前一个节点
         */
        volatile Node prev;
        /**
         * 下一个节点
         */
        volatile Node next;
        /**
         * 节点对应的线程
         */
        volatile Thread thread;

        /**
         * 默认构造方法
         */
        Node() {
        }

        /**
         * 传入线程的构造方法
         *
         * @param thread
         */
        Node(Thread thread) {
            this.thread = thread;
        }

        /**
         * 获取当前节点的前一个节点
         *
         * @return
         * @throws NullPointerException
         */
        public Node predecessor() throws NullPointerException {
            Node p = prev;
            if (p == null)
                throw new NullPointerException();
            else
                return p;
        }
    }
}
View Code

完美!

posted @ 2022-10-08 15:34  李东平|一线码农  阅读(200)  评论(0编辑  收藏  举报