Java 手写一个可重入锁(带注释)

手写一个可重入锁

package com.example.test.juc;

import java.util.*;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.LockSupport;

/**
 * @Author humorchen
 * @Date 2021/10/29
 */
public class MyReentrantLock implements Lock {
    /**
     * 没有锁定
     */
    private static final int UNLOCK = -1;
    /**
     * 初次锁定
     */
    private static final int FIRST = 1;


    /**
     * 是否为公平锁
     */
    private boolean fair;
    /**
     * 当前占有锁的线程
     */
    private Thread current;
    /**
     * 状态(由于我们没办法使用内部使用类Unsafe,只好用封装的AtomicInteger)
     */
    private AtomicInteger state = new AtomicInteger(UNLOCK);




    /**
     * 给自己使用
     */
    private Condition condition ;

    abstract class MyAbstractCondition implements Condition{
        /**
         * 对象锁
         */
        protected Object lockObj = new Object();
    }

    /**
     * 非公平的条件
     */
    class MyNonFairCondition extends MyAbstractCondition{
        @Override
        public void await() throws InterruptedException {
            synchronized (lockObj){
                //挂起当前线程,等待其他线程调用该对象的notify唤醒
                lockObj.wait();
            }
        }

        @Override
        public void awaitUninterruptibly() {

        }

        @Override
        public long awaitNanos(long nanosTimeout) throws InterruptedException {
            return 0;
        }

        @Override
        public boolean await(long time, TimeUnit unit) throws InterruptedException {
            return false;
        }

        @Override
        public boolean awaitUntil(Date deadline) throws InterruptedException {
            return false;
        }

        @Override
        public void signal() {
            synchronized (lockObj){
                //唤醒一个等待该对象的线程
                lockObj.notify();
            }
        }

        @Override
        public void signalAll() {
            synchronized (lockObj){
                //唤醒所有等待该对象的线程
                lockObj.notifyAll();
            }
        }
    }
    /**
     * 公平条件类
     */
    class MyFairCondition extends MyAbstractCondition{

        /**
         * 哈希,判断是否有这个线程在排队了
         */
        private HashSet<Thread> threadHashSet = new HashSet<>();
        /**
         * 排队的线程
         */
        private Queue<Thread> threadQueue = new LinkedList<>();

        /**
         * 让当前线程等待
         * @throws InterruptedException
         */
        @Override
        public void await() throws InterruptedException {
            Thread thread = Thread.currentThread();
            //没有正在排队就放进队列排队
            if (!threadHashSet.contains(thread)){
                threadQueue.add(thread);
                threadHashSet.add(thread);
            }
            //把当前线程挂起
            LockSupport.park();
        }

        /**
         * 让线程等待并允许被中断
         */
        @Override
        public void awaitUninterruptibly() {

        }


        /**
         * 至多等待nanosTimeout纳秒
         * @param nanosTimeout
         * @return
         * @throws InterruptedException
         */
        @Override
        public long awaitNanos(long nanosTimeout) throws InterruptedException {
            Thread thread = Thread.currentThread();
            //没有正在排队就放进队列排队
            if (!threadHashSet.contains(thread)){
                threadQueue.add(thread);
                threadHashSet.add(thread);
            }
            //挂起,最多持续一段时间
            LockSupport.parkNanos(nanosTimeout);
            return 0;
        }

        /**
         * 至多等待多久,转换为纳秒再调用纳秒
         * @param time
         * @param unit
         * @return
         * @throws InterruptedException
         */
        @Override
        public boolean await(long time, TimeUnit unit) throws InterruptedException {
            awaitNanos(TimeUnit.NANOSECONDS.convert(time,unit));
            return true;
        }

        /**
         * 等待直到死亡时间
         * @param deadline
         * @return
         * @throws InterruptedException
         */
        @Override
        public boolean awaitUntil(Date deadline) throws InterruptedException {
            awaitNanos(TimeUnit.NANOSECONDS.convert(deadline.getTime() - System.currentTimeMillis(),TimeUnit.MILLISECONDS));
            return false;
        }

        /**
         * 唤起一个正在等待该条件的线程
         */
        @Override
        public void signal() {
            synchronized (lockObj){
                if (threadQueue.size() > 0){
                    //让线程起来
                    Thread thread = threadQueue.poll();
                    threadHashSet.remove(thread);
                    LockSupport.unpark(thread);
                }
            }
        }
        /**
         * 唤起所有正在等待该条件的线程
         */
        @Override
        public void signalAll() {
            synchronized (lockObj){
                //唤醒所有的等待线程
                while (!threadQueue.isEmpty()){
                    Thread thread = threadQueue.poll();
                    LockSupport.unpark(thread);
                }
                threadHashSet.clear();
            }
        }
    }

    public MyReentrantLock(){
        this(false);
    }

    public MyReentrantLock(boolean fair){
        this.fair = fair;
        condition = newCondition();
    }


    /**
     * 锁定
     */
    @Override
    public void lock() {
        Thread thread = Thread.currentThread();
        for (;;){
            if (current != null){
                if (current.equals(thread)){
                    //重入
                    //自增完成,函数返回
                    state.incrementAndGet();
                    return;
                }else {
                    //不是自己的锁,让这个线程等
                    try {
                        condition.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }else {
                //抢占锁
                //尝试锁定
                if (state.compareAndSet(UNLOCK,FIRST)){
                    //锁定成功
                    current = Thread.currentThread();
                    return;
                }else {
                    //锁定失败,被别人抢走了,那继续等吧
                    try {
                        condition.await();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }
    }

    @Override
    public void lockInterruptibly() throws InterruptedException {

    }

    /**
     * 尝试锁定一次
     * @return
     */
    @Override
    public boolean tryLock() {
        Thread thread = Thread.currentThread();
        //有线程占有了
        if (current != null){
            if (current.equals(thread)){
                //重入
                //自增完成,函数返回
                state.incrementAndGet();
                return true;
            }
        }else {
            //抢占锁
            //尝试锁定
            if (state.compareAndSet(UNLOCK,FIRST)){
                //锁定成功
                current = Thread.currentThread();
                return true;
            }
        }
        return false;
    }

    /**
     * 在一定时间内尝试锁定
     * @param time
     * @param unit
     * @return
     * @throws InterruptedException
     */
    @Override
    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        long expire =System.currentTimeMillis() + TimeUnit.MILLISECONDS.convert(time,unit);
        while (System.currentTimeMillis() < expire){
            if (tryLock()){
                return true;
            }
        }
        return false;
    }

    /**
     * 解锁
     */
    @Override
    public void unlock() {
        Thread thread = Thread.currentThread();
        //有线程占有并且属于当前线程
        if (current != null && current.equals(thread)){
            //开始解锁
            int state = this.state.decrementAndGet();
            //全部解锁,释放掉锁
            if (state == 0){
                current = null;
                this.state.compareAndSet(state,UNLOCK);
                //唤醒一个线程
                condition.signal();
            }
        }
    }

    /**
     * 新条件
     * @return
     */
    @Override
    public Condition newCondition() {
        if (this.fair){
            return new MyFairCondition();
        }
        return new MyNonFairCondition();
    }
}

测试类

package com.example.test;

import com.example.test.juc.MyReentrantLock;
import org.junit.jupiter.api.Test;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

/**
 * @Author humorchen
 * @Date 2021/10/29
 */
public class MyReentrantLockTest {
    int ticket = 1000;

    /**
     * 两个线程卖票测试
     */
    @Test
    public void testLock() {
        MyReentrantLock lock = new MyReentrantLock(true);
        AtomicInteger sell = new AtomicInteger(0);

        Thread t1 = new Thread(() -> {
            for (; ; ) {
                try {
                    System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
                    lock.lock();
                    System.out.println(Thread.currentThread().getName() + " 获得了锁");
                    if (ticket > 0) {
                        ticket--;
                        System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
                        sell.incrementAndGet();
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    System.out.println(Thread.currentThread().getName() + " 释放了锁");
                    lock.unlock();
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }, "【t1线程】");
        t1.start();

        Thread t2 = new Thread(() -> {
            for (; ; ) {
                try {
                    System.out.println(Thread.currentThread().getName() + " 尝试获得锁");
                    lock.lock();
                    System.out.println(Thread.currentThread().getName() + " 获得了锁");
                    if (ticket > 0) {
                        ticket--;
                        System.out.println(Thread.currentThread().getName() + " 卖了一张票,剩下" + ticket + "张票");
                        sell.incrementAndGet();
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    System.out.println(Thread.currentThread().getName() + " 释放了锁");
                    lock.unlock();
                    try {
                        Thread.sleep(1);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }
        }, "【t2线程】");
        t2.start();

        try {
            //等待两个线程运行结束才结束主线程
            t1.join();
            t2.join();
            System.out.println("最终两个线程总共卖出了" + sell + "张票");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 测试线程重入
     */
    @Test
    public void testReentrant() {
        MyReentrantLock lock = new MyReentrantLock();
        AtomicInteger sell = new AtomicInteger(0);
        Thread t1 = new Thread(() -> {
            for (; ; ) {
                try {
                    lock.lock();
                    if (ticket > 0) {
                        //卖出一张
                        sellTicket(sell);
                        //假设线程再次锁
                        try {
                            lock.lock();
                            if (ticket > 0) {
                                //再次卖出
                                sellTicket(sell);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        } finally {
                            lock.unlock();
                        }
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }

            }
        });
        t1.start();


        Thread t2 = new Thread(() -> {
            for (; ; ) {
                try {
                    lock.lock();
                    if (ticket > 0) {
                        //卖出一张
                        sellTicket(sell);
                        //假设线程再次锁
                        try {
                            lock.lock();
                            if (ticket > 0) {
                                //再次卖出
                                sellTicket(sell);
                            }
                        } catch (Exception e) {
                            e.printStackTrace();
                        } finally {
                            lock.unlock();
                        }
                    } else {
                        return;
                    }
                } catch (Exception e) {
                    e.printStackTrace();
                } finally {
                    lock.unlock();
                }

            }
        });
        t2.start();
        try {
            //等待两个线程运行结束才结束主线程
            t1.join();
            t2.join();
            System.out.println("两个线程总共卖出" + sell.intValue() + " 张票");
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void sellTicket(AtomicInteger sell) {
        ticket--;
        sell.incrementAndGet();
        System.out.println(Thread.currentThread().getName() + " 售出了一张票,剩余:" + ticket + "张票");
    }
}

公平锁
在这里插入图片描述

非公平锁
在这里插入图片描述

posted @ 2021-10-30 11:32  HumorChen99  阅读(6)  评论(0编辑  收藏  举报  来源