手写一个AQS实现
1.背景
1.AQS简介
AQS全称为AbstractQueuedSynchronizer(抽象队列同步器)。AQS是一个用来构建锁和其他同步组件的基础框架,
使用AQS可以简单且高效地构造出应用广泛的同步器,例如ReentrantLock、Semaphore、ReentrantReadWriteLock和FutureTask等等。
2.原理
AQS核心思想是,如果被请求的共享资源空闲,则将当前请求资源的线程设置为有效的工作线程,并且将共享资源设置为锁定状态。
如果被请求的共享资源被占用,那么就需要一套线程阻塞等待以及被唤醒时锁分配的机制,这个机制AQS是用CLH队列锁实现的,
即将暂时获取不到锁的线程加入到队列中。
3.CLH队列
CLH(Craig,Landin,and Hagersten)队列是一个虚拟的双向队列(虚拟的双向队列即不存在队列实例,仅存在结点之间的关联关系)。
AQS是将每条请求共享资源的线程封装成一个CLH锁队列的一个结点(Node)来实现锁的分配。
通俗的说就是:当共享资源空闲时,线程直接执行,当共享资源被占用时线程进入队列等待;
图解如下:
2.代码-基础准备
2.1.Node节点对象创建
在MyReentrantLock对象内建立一个Node节点对象,后面作为双向链表的节点;
public class MyReentrantLock { /** * 双向链表队列节点 */ class Node { /** * -1表示下一个获取资源的线程 */ static final int SIGNAL = -1; /** * 等待获取资源的状态 */ volatile int waitStatus; /** * 前一个节点 */ volatile Node prev; /** * 下一个节点 */ volatile Node next; /** * 节点对应的线程 */ volatile Thread thread; /** * 默认构造方法 */ Node() { } /** * 传入线程的构造方法 * * @param thread */ Node(Thread thread) { this.thread = thread; } /** * 获取当前节点的前一个节点 * * @return * @throws NullPointerException */ public Node predecessor() throws NullPointerException { Node p = prev; if (p == null) throw new NullPointerException(); else return p; } } }
2.2.MyReentrantLock对象的属性设计
public class MyReentrantLock { /** * 资源资源是否可用 * 0-资源可用 * 1-资源不可用,已被其他线程占用 * 默认为资源可用0 */ private int state; /** * 双向链表-头结点 */ private Node head; /** * 双向链表-尾节点 */ private Node tail; /** * 已经获取到锁的线程 */ private Thread exclusiveOwnerThread; // 注意get set 方法自动生成,这里不在博客中显示出来 }
2.3.unsafe对象与属性设计
1.获取unsafe实例对象
public class MyUnsafeAccessor { static Unsafe unsafe; static { // 通过反射获取 try { // Unsafe 对象中的 private static final Unsafe theUnsafe; Field theUnsafe = Unsafe.class.getDeclaredField("theUnsafe"); theUnsafe.setAccessible(true); unsafe = (Unsafe) theUnsafe.get(null); } catch (NoSuchFieldException e) { e.printStackTrace(); } catch (IllegalAccessException e) { e.printStackTrace(); } // 参看 AtomicInteger对象的获取方式,这样可以吗?不行的,为什么? @CallerSensitive ,大家可以查阅一下这注解就明白了 // unsafe = Unsafe.getUnsafe(); } public static Unsafe getUnsafe() { return unsafe; } }
2.用于cas的操作的属性设计
/** * unsafe执行cas对象 */ private static final Unsafe unsafe = MyUnsafeAccessor.getUnsafe(); private static final long stateOffset; private static final long headOffset; private static final long tailOffset; private static final long waitStatusOffset; /** * 对象启动的时候就加载 */ static { try { stateOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("state")); headOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("head")); tailOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("tail")); waitStatusOffset = unsafe.objectFieldOffset (MyReentrantLock.Node.class.getDeclaredField("waitStatus")); } catch (Exception ex) { throw new Error(ex); } }
2.4cas实现原子更新方法
这里把后面要用到的资源状态、头结点、尾节点、节点等待状态这4个属性的cas原子更新方法写出来
便于后面直接使用
1.cas修改状态值
/** * cas修改状态值 * * @param expect * @param update * @return */ public boolean compareAndSetState(int expect, int update) { return unsafe.compareAndSwapInt(this, stateOffset, expect, update); }
2.cas修改尾节点
/** * cas的方式设置尾节点 * * @param pred * @param node * @return */ public boolean compareAndSetTail(Node pred, Node node) { return unsafe.compareAndSwapObject(this, tailOffset, pred, node); }
3.cas修改节点等待状态
/** * cas修改节点等待状态 * * @param node * @param expect * @param update * @return */ private static final boolean compareAndSetWaitStatus(Node node, int expect, int update) { return unsafe.compareAndSwapInt(node, waitStatusOffset, expect, update); }
4.cas设置头结点
/** * cas设置头结点 * * @param update * @return */ private final boolean compareAndSetHead(Node update) { return unsafe.compareAndSwapObject(this, headOffset, null, update); }
3.代码-核心逻辑
1.加锁
/** * 加锁 */ public void lock() { // compareAndSetState(0, 1) 获取锁 if (compareAndSetState(0, 1)) { // 获取锁成功,设置当前线程为资源拥有者,思考为什么这里不使用cas的方式设置线程拥有者? setExclusiveOwnerThread(Thread.currentThread()); } else { // 获取锁失败,入队列,排队,等待,重新获取锁 acquire(1); } }
2.获取锁
/** * 获取锁 * * @param arg */ public void acquire(Integer arg) { // 将线程加入到队列 Node node = addWaiter(); // 队列中排队等待获取锁 acquireQueued(node, arg); }
3.将线程加入到队列
/** * 将没有获取到锁的线程添加到队列 */ public Node addWaiter() { // 构建一个线程的节点 Node node = new Node(Thread.currentThread()); Node pred = tail; if (pred != null) { node.prev = pred; if (compareAndSetTail(pred, node)) { pred.next = node; return node; } } // 初始化队列,并把新节点作为尾节点 enq(node); return node; }
4.初始化队列,并把新节点作为尾节点
/** * 初始化队列,并把新节点作为尾节点 * * @param node * @return */ private Node enq(Node node) { for (; ; ) { Node t = tail; if (t == null) { // 如果尾节点为空初始化头节点 if (compareAndSetHead(new Node())) tail = head; } else { // 构成双向链表,把节点作为尾节点 node.prev = t; if (compareAndSetTail(t, node)) { t.next = node; return t; } } } }
5.尝试让队列中的节点获取锁
/** * acquireQueued()用于队列中的线程自旋地以独占且不可中断的方式获取同步状态(acquire),直到拿到锁之后再返回。 * 该方法的实现分成两部分:如果当前节点已经成为头结点,尝试获取锁(tryAcquire)成功,然后返回; * 否则检查当前节点是否应该被park, * 然后将该线程park并且检查当前线程是否被可以被中断。 */ public boolean acquireQueued(Node node, Integer arg) { boolean interrupted = false; for (; ; ) { Node p = node.predecessor(); // node 节点的前一个节点 // p == head 表示出节点node外,没有等待的节点 // tryAcquire(arg) 当前线程 尝试获取锁 if (p == head && tryAcquire(arg)) { // 将刚才已经获取了锁的线程设置为头结点 setHead(node); p.next = null; return interrupted; } // shouldParkAfterFailedAcquire(p, node)检查当前节点是否应该被park // parkAndCheckInterrupt() if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) { interrupted = true; } } }
6.尝试获取锁
/** * 尝试获取锁 * * @param arg * @return */ public boolean tryAcquire(Integer arg) { return nonfairTryAcquire(arg); } /** * acquires 传入的值为1 * 执行不公平的tryLock。tryAcquire在中实现 * 子类,但这两个类都需要trylock方法的非空尝试。 */ public boolean nonfairTryAcquire(int acquires) { final Thread current = Thread.currentThread(); // 获取资源状态 0-表示资源可用 , 1-表示资源不可用 int c = getState(); if (c == 0) { // 资源可用,采用cas的方式将0改为1 if (compareAndSetState(0, acquires)) { // 设置当前拥有独占访问权限的线程 setExclusiveOwnerThread(current); // 重试成功 return true; } } // 检查当前线程是否是 已经拥有锁的线程 else if (current == getExclusiveOwnerThread()) { int nextc = c + acquires; if (nextc < 0) // overflow throw new Error("Maximum lock count exceeded"); // 可重入锁,并记录次数 ,如果是第二次进入 state=2 setState(nextc); return true; } // 尝试获取锁失败 return false; }
7.检查节点是否应该被阻塞
/** * 检查并更新无法获取的节点的状态。 * 如果线程应该阻塞,则返回true。这是所有采集回路中的主要信号控制。要求pred==node.prev。 * Node.SIGNAL=-1, waitStatus值,用于指示后续线程需要取消连接 */ private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) { int ws = pred.waitStatus; if (ws == Node.SIGNAL) return true; if (ws > 0) { do { // 当前节点的前一个节点为:前节点的前节点 // node.prev = pred = pred.prev; pred = pred.prev; node.prev = pred; } while (pred.waitStatus > 0); pred.next = node; } else { compareAndSetWaitStatus(pred, ws, Node.SIGNAL); } return false; }
8.当前线程进入阻塞
/** * 该方法让线程去休息,真正进入等待状态。 * park()会让当前线程进入waiting状态。 * 在此状态下,有两种途径可以唤醒该线程:1)被unpark();2)被interrupt()。 * 需要注意的是,Thread.interrupted()会清除当前线程的中断标记位。 */ private final boolean parkAndCheckInterrupt() { LockSupport.park(this); return Thread.interrupted(); }
到此处,加锁的核心代码逻辑已完成!
4.释放锁-代码
1.释放锁
/** * 释放锁 */ public boolean unlock() { int arg = 1; if (tryRelease(arg)) { Node h = head; if (h != null && h.waitStatus != 0) unparkSuccessor(h); return true; } return false; }
2.已执行完成的线程释放资源
/** * 已执行完成的线程,释放资源 * * @param releases * @return */ public boolean tryRelease(int releases) { int c = getState() - releases; if (Thread.currentThread() != getExclusiveOwnerThread()) throw new IllegalMonitorStateException(); boolean free = false; if (c == 0) { free = true; setExclusiveOwnerThread(null); } setState(c); return free; }
3.唤醒下一个等待执行的线程
/** * 唤醒下一个应该执行的线程 * * @param node */ private void unparkSuccessor(Node node) { int ws = node.waitStatus; if (ws < 0) { compareAndSetWaitStatus(node, ws, 0); } Node s = node.next; if (s == null || s.waitStatus > 0) { s = null; for (Node t = tail; t != null && t != node; t = t.prev) if (t.waitStatus <= 0) s = t; } if (s != null) LockSupport.unpark(s.thread); }
5.测试代码
public class Test01 { private int num = 0; /** * 测试自己实现的aqs锁 * * @throws InterruptedException */ @Test public void testMyLock() throws InterruptedException { MyReentrantLock lock = new MyReentrantLock(); List<Thread> list = new ArrayList<>(); for (int i = 0; i < 100; i++) { Thread thread = new Thread(() -> { lock.lock(); // 加锁 addFive(); lock.unlock(); // 释放锁 }); list.add(thread); } // 启动线程 int n = 1; for (Thread thread : list) { thread.setName("t" + (n++)); thread.start(); } // 等待线下执行完成 for (Thread thread : list) { thread.join(); System.out.println(thread.getName() + "已完成"); } System.out.println("num=" + num); } /** * 这是一个将num加5的方法 */ public void addFive() { try { num = num + 1; Thread.sleep(5); num = num + 1; Thread.sleep(5); num = num + 1; Thread.sleep(5); num = num + 1; Thread.sleep(5); num = num + 1; Thread.sleep(5); } catch (InterruptedException e) { e.printStackTrace(); } } }
6.完整的代码
package com.ldp.aqs.v4; import com.ldp.aqs.MyUnsafeAccessor; import sun.misc.Unsafe; import java.util.concurrent.locks.LockSupport; /** * @create 09/11 11:05 * @description <p> * 自定义AQS实现思路: * 如果获取资源失败,将线程放入队列里面,并让线程处于阻塞状态,当资源释放后在唤醒阻塞的线程 * 使用到的技术点 * 1.双向链表作为队列; * 2.unsafe实现cas原子操作; * 3.park,unpark实现线程阻塞与唤醒 * 4.interrupt、isInterrupted、interrupted 实现中断 * </p> */ public class MyReentrantLock { /** * 资源资源是否可用 * 0-资源可用 * 1-资源不可用,已被其他线程占用 * 默认为资源可用0 */ private int state; /** * 双向链表-头结点 */ private Node head; /** * 双向链表-尾节点 */ private Node tail; /** * 已经获取到锁的线程 */ private Thread exclusiveOwnerThread; /** * unsafe执行cas对象 */ private static final Unsafe unsafe = MyUnsafeAccessor.getUnsafe(); private static final long stateOffset; private static final long headOffset; private static final long tailOffset; private static final long waitStatusOffset; /** * 对象启动的时候就加载 */ static { try { stateOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("state")); headOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("head")); tailOffset = unsafe.objectFieldOffset (MyReentrantLock.class.getDeclaredField("tail")); waitStatusOffset = unsafe.objectFieldOffset (MyReentrantLock.Node.class.getDeclaredField("waitStatus")); } catch (Exception ex) { throw new Error(ex); } } /** * 加锁 */ public void lock() { // compareAndSetState(0, 1) 获取锁 if (compareAndSetState(0, 1)) { // 获取锁成功,设置当前线程为资源拥有者,思考为什么这里不使用cas的方式设置线程拥有者? setExclusiveOwnerThread(Thread.currentThread()); } else { // 获取锁失败,入队列,排队,等待,重新获取锁 acquire(1); } } /** * 获取锁 * * @param arg */ public void acquire(Integer arg) { // 将线程加入到队列 Node node = addWaiter(); // 队列中排队等待获取锁 acquireQueued(node, arg); } /** * acquireQueued()用于队列中的线程自旋地以独占且不可中断的方式获取同步状态(acquire),直到拿到锁之后再返回。 * 该方法的实现分成两部分:如果当前节点已经成为头结点,尝试获取锁(tryAcquire)成功,然后返回; * 否则检查当前节点是否应该被park, * 然后将该线程park并且检查当前线程是否被可以被中断。 */ public boolean acquireQueued(Node node, Integer arg) { boolean interrupted = false; for (; ; ) { Node p = node.predecessor(); // node 节点的前一个节点 // p == head 表示出节点node外,没有等待的节点 // tryAcquire(arg) 当前线程 尝试获取锁 if (p == head && tryAcquire(arg)) { // 将刚才已经获取了锁的线程设置为头结点 setHead(node); p.next = null; return interrupted; } // shouldParkAfterFailedAcquire(p, node)检查当前节点是否应该被park // parkAndCheckInterrupt() if (shouldParkAfterFailedAcquire(p, node) && parkAndCheckInterrupt()) { interrupted = true; } } } /** * 该方法让线程去休息,真正进入等待状态。 * park()会让当前线程进入waiting状态。 * 在此状态下,有两种途径可以唤醒该线程:1)被unpark();2)被interrupt()。 * 需要注意的是,Thread.interrupted()会清除当前线程的中断标记位。 */ private final boolean parkAndCheckInterrupt() { LockSupport.park(this); return Thread.interrupted(); } /** * 检查并更新无法获取的节点的状态。 * 如果线程应该阻塞,则返回true。这是所有采集回路中的主要信号控制。要求pred==node.prev。 * Node.SIGNAL=-1, waitStatus值,用于指示后续线程需要取消连接 */ private static boolean shouldParkAfterFailedAcquire(Node pred, Node node) { int ws = pred.waitStatus; if (ws == Node.SIGNAL) return true; if (ws > 0) { do { // 当前节点的前一个节点为:前节点的前节点 // node.prev = pred = pred.prev; pred = pred.prev; node.prev = pred; } while (pred.waitStatus > 0); pred.next = node; } else { compareAndSetWaitStatus(pred, ws, Node.SIGNAL); } return false; } /** * 尝试获取锁 * * @param arg * @return */ public boolean tryAcquire(Integer arg) { return nonfairTryAcquire(arg); } /** * acquires 传入的值为1 * 执行不公平的tryLock。tryAcquire在中实现 * 子类,但这两个类都需要trylock方法的非空尝试。 */ public boolean nonfairTryAcquire(int acquires) { final Thread current = Thread.currentThread(); // 获取资源状态 0-表示资源可用 , 1-表示资源不可用 int c = getState(); if (c == 0) { // 资源可用,采用cas的方式将0改为1 if (compareAndSetState(0, acquires)) { // 设置当前拥有独占访问权限的线程 setExclusiveOwnerThread(current); // 重试成功 return true; } } // 检查当前线程是否是 已经拥有锁的线程 else if (current == getExclusiveOwnerThread()) { int nextc = c + acquires; if (nextc < 0) // overflow throw new Error("Maximum lock count exceeded"); // 可重入锁,并记录次数 ,如果是第二次进入 state=2 setState(nextc); return true; } // 尝试获取锁失败 return false; } /** * 将没有获取到锁的线程添加到队列 */ public Node addWaiter() { // 构建一个线程的节点 Node node = new Node(Thread.currentThread()); Node pred = tail; if (pred != null) { node.prev = pred; if (compareAndSetTail(pred, node)) { pred.next = node; return node; } } // 初始化队列,并把新节点作为尾节点 enq(node); return node; } /** * 初始化队列,并把新节点作为尾节点 * * @param node * @return */ private Node enq(Node node) { for (; ; ) { Node t = tail; if (t == null) { // 如果尾节点为空初始化头节点 if (compareAndSetHead(new Node())) tail = head; } else { // 构成双向链表,把节点作为尾节点 node.prev = t; if (compareAndSetTail(t, node)) { t.next = node; return t; } } } } /** * 释放锁 */ public boolean unlock() { int arg = 1; if (tryRelease(arg)) { Node h = head; if (h != null && h.waitStatus != 0) unparkSuccessor(h); return true; } return false; } /** * 唤醒下一个应该执行的线程 * * @param node */ private void unparkSuccessor(Node node) { int ws = node.waitStatus; if (ws < 0) { compareAndSetWaitStatus(node, ws, 0); } Node s = node.next; if (s == null || s.waitStatus > 0) { s = null; for (Node t = tail; t != null && t != node; t = t.prev) if (t.waitStatus <= 0) s = t; } if (s != null) LockSupport.unpark(s.thread); } /** * 已执行完成的线程,释放资源 * * @param releases * @return */ public boolean tryRelease(int releases) { int c = getState() - releases; if (Thread.currentThread() != getExclusiveOwnerThread()) throw new IllegalMonitorStateException(); boolean free = false; if (c == 0) { free = true; setExclusiveOwnerThread(null); } setState(c); return free; } /** * cas修改状态值 * * @param expect * @param update * @return */ public boolean compareAndSetState(int expect, int update) { // 注意这个你更新的是int基本类型只能使用:compareAndSwapInt 不能使用compareAndSwapObject return unsafe.compareAndSwapInt(this, stateOffset, expect, update); } /** * cas的方式设置尾节点 * * @param pred * @param node * @return */ public boolean compareAndSetTail(Node pred, Node node) { return unsafe.compareAndSwapObject(this, tailOffset, pred, node); } /** * cas修改节点等待状态 * * @param node * @param expect * @param update * @return */ private static final boolean compareAndSetWaitStatus(Node node, int expect, int update) { return unsafe.compareAndSwapInt(node, waitStatusOffset, expect, update); } /** * cas设置头结点 * * @param update * @return */ private final boolean compareAndSetHead(Node update) { return unsafe.compareAndSwapObject(this, headOffset, null, update); } public int getState() { return state; } public void setState(int state) { this.state = state; } public Node getHead() { return head; } public void setHead(Node head) { this.head = head; } public Node getTail() { return tail; } public void setTail(Node tail) { this.tail = tail; } public Thread getExclusiveOwnerThread() { return exclusiveOwnerThread; } public void setExclusiveOwnerThread(Thread exclusiveOwnerThread) { this.exclusiveOwnerThread = exclusiveOwnerThread; } /** * 双向链表队列节点 */ class Node { /** * -1表示下一个获取资源的线程 */ static final int SIGNAL = -1; /** * 等待获取资源的状态 */ volatile int waitStatus; /** * 前一个节点 */ volatile Node prev; /** * 下一个节点 */ volatile Node next; /** * 节点对应的线程 */ volatile Thread thread; /** * 默认构造方法 */ Node() { } /** * 传入线程的构造方法 * * @param thread */ Node(Thread thread) { this.thread = thread; } /** * 获取当前节点的前一个节点 * * @return * @throws NullPointerException */ public Node predecessor() throws NullPointerException { Node p = prev; if (p == null) throw new NullPointerException(); else return p; } } }