Java 手写AQS同步器(一)
使用自旋 + LockSupport + CAS
具体实现:
package thread.lock;
import sun.misc.Unsafe;
import java.lang.reflect.Field;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.locks.LockSupport;
public class TestLock {
// 状态量
private volatile int state = 0;
// 队列
private final ConcurrentLinkedQueue<Thread> linkedQueue = new ConcurrentLinkedQueue<>();
// 当前占有锁的线程
private Thread lockHolder;
private boolean acquire() {
if (state == 0) { // state == 0 表示当前线程可以加锁
Thread currentThread = Thread.currentThread();
if ((linkedQueue.size() == 0 || currentThread == linkedQueue.peek())&& compareAndSwapState(0, 1)) {
lockHolder = currentThread;
return true;
}
}
return false;
}
public void lock() {
if (acquire()) {
return;
}
// 获取当前执行线程
Thread current = Thread.currentThread();
linkedQueue.offer(current);
// spin 自旋
for (;;) {
if(linkedQueue.peek() == current && acquire()) { // 判断是不是队首线程,是则尝试获取锁
linkedQueue.poll(); // 加锁成功,出队
return;
}
// 阻塞
LockSupport.park();
}
}
public void unlock() {
Thread current = Thread.currentThread();
if (current != lockHolder) {
throw new RuntimeException("current thread couldn't release lock");
}
if (compareAndSwapState(1, 0)){
lockHolder = null;
Thread thread = linkedQueue.peek();
if (thread != null) {
LockSupport.unpark(thread);
}
}
}
// 类的private方法会隐式地被指定为final方法
// 在早期的Java实现版本中,会将final方法转为内嵌调用
private final boolean compareAndSwapState(int except, int update) {
return unsafe.compareAndSwapInt(this, stateOffset, except, update);
}
private static final Unsafe unsafe;
// 偏移量:相对于当前对象起始位偏移多少内存位
private static final long stateOffset;
static {
try {
// 反射获取私有变量 theUnsafe
Field field = Unsafe.class.getDeclaredField("theUnsafe");
field.setAccessible(true);
unsafe = (Unsafe)field.get(null);
stateOffset = unsafe.objectFieldOffset(TestLock.class.getDeclaredField("state"));
} catch (Exception e) {
throw new Error("reflect error");
}
}
}
测试与使用
package thread;
import thread.lock.TestLock;
public class CountTest {
private static int redisCount = 20;
private static TestLock lock = new TestLock();
// private static ReentrantLock lock = new ReentrantLock();
private static void tryCreateOrder() {
lock.lock();
Thread thread = Thread.currentThread();
if (redisCount <= 0) {
System.out.println(thread.getName() + "库存不足:" + redisCount);
lock.unlock();
return;
}
try {
Thread.sleep(2); // 模拟网络延时, 发送消息给redis需要时间。
} catch (InterruptedException e) {
e.printStackTrace();
}
redisCount--;
System.out.println(thread.getName() + "预减库存成功,当前剩余: " + redisCount);
lock.unlock();
}
public static void main(String[] args) throws InterruptedException {
for(int i = 0; i < 100; i++) {
new Thread(new Runnable() {
@Override
public void run() {
tryCreateOrder();
}
}).start();
}
Thread.sleep(500);
System.out.println(redisCount);
}
}