Java 手写一个可重入锁(带注释)
手写一个可重入锁
package com.example.test.juc;
import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;
/**
* @Author humorchen
* @Date 2021/10/29
*/
public class MyReentrantLock implements Lock {
/**
* 没有锁定
*/
private static final int UNLOCK = -1;
/**
* 初次锁定
*/
private static final int FIRST = 1;
/**
* 是否为公平锁
*/
private boolean fair;
/**
* 当前占有锁的线程
*/
private Thread current;
/**
* 状态(由于我们没办法使用内部使用类Unsafe,只好用封装的AtomicInteger)
*/
private AtomicInteger state = new AtomicInteger(UNLOCK);
/**
* 给自己使用
*/
private Condition condition ;
abstract class MyAbstractCondition implements Condition{
/**
* 对象锁
*/
protected Object lockObj = new Object();
}
/**
* 非公平的条件
*/
class MyNonFairCondition extends MyAbstractCondition{
@Override
public void await() throws InterruptedException {
synchronized (lockObj){
//挂起当前线程,等待其他线程调用该对象的notify唤醒
lockObj.wait();
}
}
@Override
public void awaitUninterruptibly() {
}
@Override
public long awaitNanos(long nanosTimeout) throws InterruptedException {
return 0;
}
@Override
public boolean await(long time, TimeUnit unit) throws InterruptedException {
return false;
}
@Override
public boolean awaitUntil(Date deadline) throws InterruptedException {
return false;
}
@Override
public void signal() {
synchronized (lockObj){
//唤醒一个等待该对象的线程
lockObj.notify();
}
}
@Override
public void signalAll() {
synchronized (lockObj){
//唤醒所有等待该对象的线程
lockObj.notifyAll();
}
}
}
/**
* 公平条件类
*/
class MyFairCondition extends MyAbstractCondition{
/**
* 哈希,判断是否有这个线程在排队了
*/
private HashSet<Thread> threadHashSet = new HashSet<>();
/**
* 排队的线程
*/
private Queue<Thread> threadQueue = new LinkedList<>();
/**
* 让当前线程等待
* @throws InterruptedException
*/
@Override
public void await() throws InterruptedException {
Thread thread = Thread.currentThread();
//没有正在排队就放进队列排队
if (!threadHashSet.contains(thread)){
threadQueue.add(thread);
threadHashSet.add(thread);
}
//把当前线程挂起
LockSupport.park();
}
/**
* 让线程等待并允许被中断
*/
@Override
public void awaitUninterruptibly() {
}
/**
* 至多等待nanosTimeout纳秒
* @param nanosTimeout
* @return
* @throws InterruptedException
*/
@Override
public long awaitNanos(long nanosTimeout) throws InterruptedException {
Thread thread = Thread.currentThread();
//没有正在排队就放进队列排队
if (!threadHashSet.contains(thread)){
threadQueue.add(thread);
threadHashSet.add(thread);
}
//挂起,最多持续一段时间
LockSupport.parkNanos(nanosTimeout);
return 0;
}
/**
* 至多等待多久,转换为纳秒再调用纳秒
* @param time
* @param unit
* @return
* @throws InterruptedException
*/
@Override
public boolean await(long time, TimeUnit unit) throws InterruptedException {
awaitNanos(TimeUnit.NANOSECONDS.convert(time,unit));
return true;
}
/**
* 等待直到死亡时间
* @param deadline
* @return
* @throws InterruptedException
*/
@Override
public boolean awaitUntil(Date deadline) throws InterruptedException {
awaitNanos(TimeUnit.NANOSECONDS.convert(deadline.getTime() - System.currentTimeMillis(),TimeUnit.MILLISECONDS));
return false;
}
/**
* 唤起一个正在等待该条件的线程
*/
@Override
public void signal() {
synchronized (lockObj){
if (threadQueue.size() > 0){
//让线程起来
Thread thread = threadQueue.poll();
threadHashSet.remove(thread);
LockSupport.unpark(thread);
}
}
}
/**
* 唤起所有正在等待该条件的线程
*/
@Override
public void signalAll() {
synchronized (lockObj){
//唤醒所有的等待线程
while (!threadQueue.isEmpty()){
Thread thread = threadQueue.poll();
LockSupport.unpark(thread);
}
threadHashSet.clear();
}
}
}
public MyReentrantLock(){
this(false);
}
public MyReentrantLock(boolean fair){
this.fair = fair;
condition = newCondition();
}
/**
* 锁定
*/
@Override
public void lock() {
Thread thread = Thread.currentThread();
for (;;){
if (current != null){
if (current.equals(thread)){
//重入
//自增完成,函数返回
state.incrementAndGet();
return;
}else {
//不是自己的锁,让这个线程等
try {
condition.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}else {
//抢占锁
//尝试锁定
if (state.compareAndSet(UNLOCK,FIRST)){
//锁定成功
current = Thread.currentThread();
return;
}else {
//锁定失败,被别人抢走了,那继续等吧
try {
condition.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}
}
@Override
public void lockInterruptibly() throws InterruptedException {
}
/**
* 尝试锁定一次
* @return
*/
@Override
public boolean tryLock() {
Thread thread = Thread.currentThread();
//有线程占有了
if (current != null){
if (current.equals(thread)){
//重入
//自增完成,函数返回
state.incrementAndGet();
return true;
}
}else {
//抢占锁
//尝试锁定
if (state.compareAndSet(UNLOCK,FIRST)){
//锁定成功
current = Thread.currentThread();
return true;
}
}
return false;
}
/**
* 在一定时间内尝试锁定
* @param time
* @param unit
* @return
* @throws InterruptedException
*/
@Override
public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
long expire =System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(time,unit);
while (System.currentTimeMillis() < expire){
if (tryLock()){
return true;
}
}
return false;
}
/**
* 解锁
*/
@Override
public void unlock() {
Thread thread = Thread.currentThread();
//有线程占有并且属于当前线程
if (current != null && current.equals(thread)){
//开始解锁
int state = this.state.decrementAndGet();
//全部解锁,释放掉锁
if (state == 0){
current = null;
this.state.compareAndSet(state,UNLOCK);
//唤醒一个线程
condition.signal();
}
}
}
/**
* 新条件
* @return
*/
@Override
public Condition newCondition() {
if (this.fair){
return new MyFairCondition();
}
return new MyNonFairCondition();
}
}
测试类
package com.example.test;
import com.example.test.juc.MyReentrantLock;
import org.junit.jupiter.api.Test;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;
/**
* @Author humorchen
* @Date 2021/10/29
*/
public class MyReentrantLockTest {
int ticket = 1000;
/**
* 两个线程卖票测试
*/
@Test
public void testLock() {
MyReentrantLock lock = new MyReentrantLock(true);
AtomicInteger sell = new AtomicInteger(0);
Thread t1 = new Thread(() -> {
for (; ; ) {
try {
System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
lock.lock();
System.out.println(Thread.currentThread().getName() + " 获得了锁");
if (ticket > 0) {
ticket--;
System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
sell.incrementAndGet();
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
System.out.println(Thread.currentThread().getName() + " 释放了锁");
lock.unlock();
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}, "【t1线程】");
t1.start();
Thread t2 = new Thread(() -> {
for (; ; ) {
try {
System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
lock.lock();
System.out.println(Thread.currentThread().getName() + " 获得了锁");
if (ticket > 0) {
ticket--;
System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
sell.incrementAndGet();
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
System.out.println(Thread.currentThread().getName() + " 释放了锁");
lock.unlock();
try {
Thread.sleep(1);
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
}, "【t2线程】");
t2.start();
try {
//等待两个线程运行结束才结束主线程
t1.join();
t2.join();
System.out.println("最终两个线程总共卖出了" + sell + "张票");
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 测试线程重入
*/
@Test
public void testReentrant() {
MyReentrantLock lock = new MyReentrantLock();
AtomicInteger sell = new AtomicInteger(0);
Thread t1 = new Thread(() -> {
for (; ; ) {
try {
lock.lock();
if (ticket > 0) {
//卖出一张
sellTicket(sell);
//假设线程再次锁
try {
lock.lock();
if (ticket > 0) {
//再次卖出
sellTicket(sell);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
});
t1.start();
Thread t2 = new Thread(() -> {
for (; ; ) {
try {
lock.lock();
if (ticket > 0) {
//卖出一张
sellTicket(sell);
//假设线程再次锁
try {
lock.lock();
if (ticket > 0) {
//再次卖出
sellTicket(sell);
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
} else {
return;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
lock.unlock();
}
}
});
t2.start();
try {
//等待两个线程运行结束才结束主线程
t1.join();
t2.join();
System.out.println("两个线程总共卖出" + sell.intValue() + " 张票");
} catch (Exception e) {
e.printStackTrace();
}
}
private void sellTicket(AtomicInteger sell) {
ticket--;
sell.incrementAndGet();
System.out.println(Thread.currentThread().getName() + " 售出了一张票,剩余:" + ticket + "张票");
}
}
公平锁
非公平锁
本文来自博客园,作者:HumorChen99,转载请注明原文链接:https://www.cnblogs.com/HumorChen/p/18039527