Java 手写一个ArrayBlockingQueue(带注释)

手写一个ArrayBlockingQueue

定义类MyArrayBlockingQueue实现Queue接口

由于该类比较简单所以不写太多解释了,代码中有注释的
提示两个地方一个是阻塞的时候其实是让当前线程进入等待,当条件发出信号后再让该线程继续运行,不懂Condition就去看看ReentrantLock里的Condition还有线程的await() notify() notifyAll()
这里直接实现掉了所有方法
如果哪个方法实现有问题,请指出改正。
看完源码手写源码能够增强对底层原理的理解,增加内力。

import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;

/**
 * @Author humorchen
 * @Date 2021/10/27
 */
public class MyArrayBlockingQueue<T> implements Queue {
    /**
     * 存储队列里的元素
     */
    private Object[] items;
    /**
     * 队列的容量
     */
    private int capacity;
    /**
     * 放置元素的下标
     */
    private int putIndex = 0;
    /**
     * 取元素的下标
     */
    private int takeIndex = 0;
    /**
     * 当前队列中元素的数量
     */
    private int count = 0;
    /**
     * 保证线程安全的锁
     */
    private ReentrantLock lock = new ReentrantLock();
    /**
     * 放元素条件
     */
    private Condition putCondition = lock.newCondition();
    /**
     * 取元素的条件
     */
    private Condition takeCondition = lock.newCondition();


    /**
     * 构造函数传入队列容量
     * @param capacity
     */
    public MyArrayBlockingQueue(int capacity){
        if (capacity <= 0){
            throw new IllegalArgumentException("capacity less than zero");
        }
        this.capacity = capacity;
        this.items = new Object[capacity];
    }

    
    @Override
    public boolean add(Object o) {
        ReentrantLock lock = this.lock;
        try {
            //获取锁
            lock.lock();
            //容量满了,得等移除了一个才能继续插入
            while (count == capacity){
                putCondition.await();
            }
            //放入元素
            items[putIndex++] = o;
            //如果放入元素到了末端了,就把下标移动到0去,就是轮回
            if (putIndex == capacity){
                putIndex = 0;
            }
            //元素数量加一
            count++;
            //可取的限制发一个信号,这样如果有等待取的线程就能得到信号去执行了
            takeCondition.signal();
            return true;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public boolean offer(Object o) {
        return add(o);
    }

    @Override
    public T remove() {
        return poll();
    }

    @Override
    public T poll() {
        ReentrantLock lock = this.lock;
        try {
            //获得锁
            lock.lock();
            //如果没有元素那就等到某个线程放入一个元素
            while (count == 0){
                takeCondition.await();
            }
            //拿到元素
            Object o = items[takeIndex++];
            //如果拿元素的下标已经到了末尾了,那么需要把下标重置到0去,不断轮回。
            if (takeIndex == capacity){
                takeIndex = 0;
            }
            //元素数量减一
            count--;
            //给放入限制发一个信号,说明有一个空位可以放了
            putCondition.signal();
            //返回元素
            return (T)o;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return null;
    }

    @Override
    public T element() {
        return peek();
    }

    @Override
    public T peek() {
        return (T)items[takeIndex];
    }

    @Override
    public int size() {
        return count;
    }

    @Override
    public boolean isEmpty() {
        return count == 0;
    }

    @Override
    public boolean contains(Object o) {
        ReentrantLock lock  = this.lock;
        try {
            lock.lock();
            for (int i = 0, takeIndex = this.takeIndex, size = this.count; i < size ; i++,takeIndex++){
                if (items[takeIndex].equals(o)){
                    return true;
                }
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public Iterator iterator() {
        return Arrays.stream(items).iterator();
    }

    @Override
    public Object[] toArray() {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();
            Object[] objects = new Object[count];
            toArray(objects);
            return objects;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return new Object[0];
    }

    @Override
    public Object[] toArray(Object[] objects) {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();

            for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
                objects[i] = items[takeIndex++];
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return objects;
    }

    @Override
    public boolean remove(Object o) {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();
            for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++,takeIndex++){
                //先找到
                if (items[takeIndex].equals(o)){
                    //找到了,后面的往前面移动一个位置即可
                    i++;
                    //往前面放
                    while (i < size){
                        items[i-1] = items[i++];
                    }
                    //最后一个空位设置为null
                    items[i] = null;
                    //元素总数减一
                    this.count--;
                    //对putIndex减一
                    if (putIndex == 0){
                        putIndex = capacity-1;
                    }else {
                        putIndex--;
                    }
                    return true;
                }
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public boolean containsAll(Collection c) {
        ReentrantLock lock  = this.lock;
        try {
            lock.lock();
            for (Object o : c){
                boolean found = false;
                for (int i = 0, takeIndex = this.takeIndex, size = this.count; i < size ; i++,takeIndex++){
                    if (items[takeIndex].equals(o)){
                        found = true;
                        break;
                    }
                    if (takeIndex == capacity){
                        takeIndex = 0;
                    }
                }
                if (!found){
                    return false;
                }
            }
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return true;
    }

    @Override
    public boolean addAll(Collection c) {
        ReentrantLock lock = this.lock;
        try {
            //获取锁
            lock.lock();
            if (c.size() > capacity - count){
                throw new RuntimeException("available capacity less than collection size");
            }
            for (Object o : c){
                //容量满了,得等移除了一个才能继续插入
                while (count == capacity){
                    putCondition.await();
                }
                //放入元素
                items[putIndex++] = o;
                //如果放入元素到了末端了,就把下标移动到0去,就是轮回
                if (putIndex == capacity){
                    putIndex = 0;
                }
                //元素数量加一
                count++;
                //可取的限制发一个信号,这样如果有等待取的线程就能得到信号去执行了
                takeCondition.signal();
            }
            return true;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public boolean removeAll(Collection c) {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();
            //定义一个新的对象数组
            Object[] objects = new Object[capacity];
            //新对象数组的元素总数
            int newSize = 0;
            //遍历当前数组的每个元素
            for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
                Object o = items[takeIndex++];
                //不包含就放入新数组
                if (!c.contains(o)){
                    objects[newSize++] = o;
                }
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
            //新数组替换掉旧的数组病重置参数
            items = objects;
            putIndex = count = newSize;
            takeIndex = 0;
            return true;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public boolean retainAll(Collection c) {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();
            //定义一个新的对象数组
            Object[] objects = new Object[capacity];
            //新对象数组的元素总数
            int newSize = 0;
            //遍历当前数组的每个元素
            for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
                Object o = items[takeIndex++];
                //要保留就放入新数组
                if (c.contains(o)){
                    objects[newSize++] = o;
                }
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
            //新数组替换掉旧的数组病重置参数
            items = objects;
            putIndex = count = newSize;
            takeIndex = 0;
            return true;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
        return false;
    }

    @Override
    public void clear() {
        ReentrantLock lock = this.lock;
        try {
            lock.lock();
            //清理元素
            while (count > 0){
                items[takeIndex++] = null;
                count--;
                if (takeIndex == capacity){
                    takeIndex = 0;
                }
            }
            //归位
            takeIndex = putIndex = 0;
        }catch (Exception e){
            e.printStackTrace();
        }finally {
            lock.unlock();
        }
    }

    @Override
    public String toString() {
        return Arrays.toString(toArray());
    }
}

测试类

其中join那边大家可能会有疑问为什么要加这个,是因为执行这个测试函数的时候主函数可能执行完,然后主线程结束了,那么创建的两个子线程就不会完全执行完了,因此我们显式的在主函数中等待这两个线程执行完再结束。

import com.example.test.notify.MyArrayBlockingQueue;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

/**
 * @Author humorchen
 * @Date 2021/10/27
 */
public class MyArrayBlockingQueueTest {
    MyArrayBlockingQueue<String> queue = new MyArrayBlockingQueue<>(4);

    /**
     * 测试添加一个
     */
    @Test
    public void testAdd() {
        queue.add("1");
        System.out.println(queue + " size:" + queue.size());
    }

    /**
     * 测试添加一个
     */
    @Test
    public void testAddAll() {
        queue.add("1");
        System.out.println(queue + " size:" + queue.size());
        System.out.println("添加2,3");
        queue.addAll(Arrays.asList("2", "3"));
        System.out.println(queue + " size:" + queue.size());
    }


    /**
     * 测试移除全部
     */
    @Test
    public void testClear() {
        queue.add("1");
        queue.add("2");
        queue.add("3");
        System.out.println(queue + " size:" + queue.size());
        System.out.println("清理全部");
        queue.clear();
        System.out.println(queue + " size:" + queue.size());

    }


    /**
     * 测试移除一个
     */
    @Test
    public void testRemove() {
        queue.add("1");
        queue.add("2");
        queue.add("3");
        System.out.println("当前队列:" + queue.toString());
        System.out.println("移除2");
        queue.remove("2");
        System.out.println(queue + " size:" + queue.size());
    }

    /**
     * 测试移除多个
     */
    @Test
    public void testRemoveAll() {
        queue.add("1");
        queue.add("2");
        queue.add("3");
        System.out.println("当前队列:" + queue.toString());
        System.out.println("移除2");
        queue.removeAll(Arrays.asList("1", "2"));
        System.out.println(queue + " size:" + queue.size());
    }


    /**
     * 测试包含判断
     */
    @Test
    public void testContain() {
        queue.add("1");
        queue.add("2");
        queue.add("3");
        System.out.println("当前队列:" + queue.toString());
        System.out.println("是否包含2:" + queue.contains("2") + " 是否包含4:" + queue.contains("4"));
    }

    /**
     * 测试包含所有的判断
     */
    @Test
    public void testContainAll() {
        queue.add("1");
        queue.add("2");
        queue.add("3");
        System.out.println("当前队列:" + queue.toString());
        System.out.println("是否包含1,2:" + queue.containsAll(Arrays.asList("1", "2")) + " 是否包含1,4:" + queue.containsAll(Arrays.asList("1", "4")));
    }

    /**
     * 测试多线程的阻塞效果
     */
    @Test
    public void testBlocking() {
        Thread t1 = new Thread(() -> {
            try {
                System.out.println(Thread.currentThread().getName() + " 执行休眠");
                Thread.sleep(1500);
                System.out.println(Thread.currentThread().getName() + "线程拿出了一个 " + queue.poll() + " 移除后size " + queue.size() + " 移除后队列:" + queue.toString());
            } catch (Exception e) {
                e.printStackTrace();
            }
        });
        t1.start();

        Thread t2 = new Thread(() -> {
            queue.add("1");
            queue.add("2");
            queue.add("3");
            queue.add("4");
            System.out.println(Thread.currentThread().getName() + "队列:" + queue.toString() + " size:" + queue.size() + " 再次添加一个5");
            queue.add("5");
            System.out.println(Thread.currentThread().getName() + "添加后" + queue.toString());
            System.out.println(Thread.currentThread().getName() + "拿出一个元素:" + queue.poll());
            System.out.println(Thread.currentThread().getName() + "size " + queue.size() + " " + queue.toString());

            List<String> list = new ArrayList<>();
            list.add("1");
        });
        t2.start();
        try {
            t1.join();
            t2.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    @Test
    public void testBlockingPoll(){
        Thread t1 = new Thread(()->{
            System.out.println(Thread.currentThread().getName()+"开始从队列中获取一个");
            System.out.println(Thread.currentThread().getName()+"获取到了:"+queue.poll());
        });
        t1.start();
        Thread t2 = new Thread(()->{
            System.out.println(Thread.currentThread().getName()+" 将2秒后给队列中放入一个");
            try {
                Thread.sleep(2000);
                queue.add("1");
                System.out.println(Thread.currentThread().getName()+"放入后队列中:"+queue.toString()+" size:"+queue.size());
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
        });
        t2.start();
        try {
            t1.join();
            t2.join();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

posted @ 2021-10-29 09:51  HumorChen99  阅读(1)  评论(0编辑  收藏  举报  来源