Java 手写一个ArrayBlockingQueue(带注释)
手写一个ArrayBlockingQueue
定义类MyArrayBlockingQueue实现Queue接口
由于该类比较简单所以不写太多解释了,代码中有注释的
提示两个地方一个是阻塞的时候其实是让当前线程进入等待,当条件发出信号后再让该线程继续运行,不懂Condition就去看看ReentrantLock里的Condition还有线程的await() notify() notifyAll()
这里直接实现掉了所有方法
如果哪个方法实现有问题,请指出改正。
看完源码手写源码能够增强对底层原理的理解,增加内力。
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.Queue;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.ReentrantLock;
/**
* @Author humorchen
* @Date 2021/10/27
*/
public class MyArrayBlockingQueue<T> implements Queue {
/**
* 存储队列里的元素
*/
private Object[] items;
/**
* 队列的容量
*/
private int capacity;
/**
* 放置元素的下标
*/
private int putIndex = 0;
/**
* 取元素的下标
*/
private int takeIndex = 0;
/**
* 当前队列中元素的数量
*/
private int count = 0;
/**
* 保证线程安全的锁
*/
private ReentrantLock lock = new ReentrantLock();
/**
* 放元素条件
*/
private Condition putCondition = lock.newCondition();
/**
* 取元素的条件
*/
private Condition takeCondition = lock.newCondition();
/**
* 构造函数传入队列容量
* @param capacity
*/
public MyArrayBlockingQueue(int capacity){
if (capacity <= 0){
throw new IllegalArgumentException("capacity less than zero");
}
this.capacity = capacity;
this.items = new Object[capacity];
}
@Override
public boolean add(Object o) {
ReentrantLock lock = this.lock;
try {
//获取锁
lock.lock();
//容量满了,得等移除了一个才能继续插入
while (count == capacity){
putCondition.await();
}
//放入元素
items[putIndex++] = o;
//如果放入元素到了末端了,就把下标移动到0去,就是轮回
if (putIndex == capacity){
putIndex = 0;
}
//元素数量加一
count++;
//可取的限制发一个信号,这样如果有等待取的线程就能得到信号去执行了
takeCondition.signal();
return true;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public boolean offer(Object o) {
return add(o);
}
@Override
public T remove() {
return poll();
}
@Override
public T poll() {
ReentrantLock lock = this.lock;
try {
//获得锁
lock.lock();
//如果没有元素那就等到某个线程放入一个元素
while (count == 0){
takeCondition.await();
}
//拿到元素
Object o = items[takeIndex++];
//如果拿元素的下标已经到了末尾了,那么需要把下标重置到0去,不断轮回。
if (takeIndex == capacity){
takeIndex = 0;
}
//元素数量减一
count--;
//给放入限制发一个信号,说明有一个空位可以放了
putCondition.signal();
//返回元素
return (T)o;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return null;
}
@Override
public T element() {
return peek();
}
@Override
public T peek() {
return (T)items[takeIndex];
}
@Override
public int size() {
return count;
}
@Override
public boolean isEmpty() {
return count == 0;
}
@Override
public boolean contains(Object o) {
ReentrantLock lock = this.lock;
try {
lock.lock();
for (int i = 0, takeIndex = this.takeIndex, size = this.count; i < size ; i++,takeIndex++){
if (items[takeIndex].equals(o)){
return true;
}
if (takeIndex == capacity){
takeIndex = 0;
}
}
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public Iterator iterator() {
return Arrays.stream(items).iterator();
}
@Override
public Object[] toArray() {
ReentrantLock lock = this.lock;
try {
lock.lock();
Object[] objects = new Object[count];
toArray(objects);
return objects;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return new Object[0];
}
@Override
public Object[] toArray(Object[] objects) {
ReentrantLock lock = this.lock;
try {
lock.lock();
for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
objects[i] = items[takeIndex++];
if (takeIndex == capacity){
takeIndex = 0;
}
}
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return objects;
}
@Override
public boolean remove(Object o) {
ReentrantLock lock = this.lock;
try {
lock.lock();
for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++,takeIndex++){
//先找到
if (items[takeIndex].equals(o)){
//找到了,后面的往前面移动一个位置即可
i++;
//往前面放
while (i < size){
items[i-1] = items[i++];
}
//最后一个空位设置为null
items[i] = null;
//元素总数减一
this.count--;
//对putIndex减一
if (putIndex == 0){
putIndex = capacity-1;
}else {
putIndex--;
}
return true;
}
if (takeIndex == capacity){
takeIndex = 0;
}
}
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public boolean containsAll(Collection c) {
ReentrantLock lock = this.lock;
try {
lock.lock();
for (Object o : c){
boolean found = false;
for (int i = 0, takeIndex = this.takeIndex, size = this.count; i < size ; i++,takeIndex++){
if (items[takeIndex].equals(o)){
found = true;
break;
}
if (takeIndex == capacity){
takeIndex = 0;
}
}
if (!found){
return false;
}
}
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return true;
}
@Override
public boolean addAll(Collection c) {
ReentrantLock lock = this.lock;
try {
//获取锁
lock.lock();
if (c.size() > capacity - count){
throw new RuntimeException("available capacity less than collection size");
}
for (Object o : c){
//容量满了,得等移除了一个才能继续插入
while (count == capacity){
putCondition.await();
}
//放入元素
items[putIndex++] = o;
//如果放入元素到了末端了,就把下标移动到0去,就是轮回
if (putIndex == capacity){
putIndex = 0;
}
//元素数量加一
count++;
//可取的限制发一个信号,这样如果有等待取的线程就能得到信号去执行了
takeCondition.signal();
}
return true;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public boolean removeAll(Collection c) {
ReentrantLock lock = this.lock;
try {
lock.lock();
//定义一个新的对象数组
Object[] objects = new Object[capacity];
//新对象数组的元素总数
int newSize = 0;
//遍历当前数组的每个元素
for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
Object o = items[takeIndex++];
//不包含就放入新数组
if (!c.contains(o)){
objects[newSize++] = o;
}
if (takeIndex == capacity){
takeIndex = 0;
}
}
//新数组替换掉旧的数组病重置参数
items = objects;
putIndex = count = newSize;
takeIndex = 0;
return true;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public boolean retainAll(Collection c) {
ReentrantLock lock = this.lock;
try {
lock.lock();
//定义一个新的对象数组
Object[] objects = new Object[capacity];
//新对象数组的元素总数
int newSize = 0;
//遍历当前数组的每个元素
for (int i = 0, size = this.count, takeIndex = this.takeIndex; i < size; i++){
Object o = items[takeIndex++];
//要保留就放入新数组
if (c.contains(o)){
objects[newSize++] = o;
}
if (takeIndex == capacity){
takeIndex = 0;
}
}
//新数组替换掉旧的数组病重置参数
items = objects;
putIndex = count = newSize;
takeIndex = 0;
return true;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
return false;
}
@Override
public void clear() {
ReentrantLock lock = this.lock;
try {
lock.lock();
//清理元素
while (count > 0){
items[takeIndex++] = null;
count--;
if (takeIndex == capacity){
takeIndex = 0;
}
}
//归位
takeIndex = putIndex = 0;
}catch (Exception e){
e.printStackTrace();
}finally {
lock.unlock();
}
}
@Override
public String toString() {
return Arrays.toString(toArray());
}
}
测试类
其中join那边大家可能会有疑问为什么要加这个,是因为执行这个测试函数的时候主函数可能执行完,然后主线程结束了,那么创建的两个子线程就不会完全执行完了,因此我们显式的在主函数中等待这两个线程执行完再结束。
import com.example.test.notify.MyArrayBlockingQueue;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* @Author humorchen
* @Date 2021/10/27
*/
public class MyArrayBlockingQueueTest {
MyArrayBlockingQueue<String> queue = new MyArrayBlockingQueue<>(4);
/**
* 测试添加一个
*/
@Test
public void testAdd() {
queue.add("1");
System.out.println(queue + " size:" + queue.size());
}
/**
* 测试添加一个
*/
@Test
public void testAddAll() {
queue.add("1");
System.out.println(queue + " size:" + queue.size());
System.out.println("添加2,3");
queue.addAll(Arrays.asList("2", "3"));
System.out.println(queue + " size:" + queue.size());
}
/**
* 测试移除全部
*/
@Test
public void testClear() {
queue.add("1");
queue.add("2");
queue.add("3");
System.out.println(queue + " size:" + queue.size());
System.out.println("清理全部");
queue.clear();
System.out.println(queue + " size:" + queue.size());
}
/**
* 测试移除一个
*/
@Test
public void testRemove() {
queue.add("1");
queue.add("2");
queue.add("3");
System.out.println("当前队列:" + queue.toString());
System.out.println("移除2");
queue.remove("2");
System.out.println(queue + " size:" + queue.size());
}
/**
* 测试移除多个
*/
@Test
public void testRemoveAll() {
queue.add("1");
queue.add("2");
queue.add("3");
System.out.println("当前队列:" + queue.toString());
System.out.println("移除2");
queue.removeAll(Arrays.asList("1", "2"));
System.out.println(queue + " size:" + queue.size());
}
/**
* 测试包含判断
*/
@Test
public void testContain() {
queue.add("1");
queue.add("2");
queue.add("3");
System.out.println("当前队列:" + queue.toString());
System.out.println("是否包含2:" + queue.contains("2") + " 是否包含4:" + queue.contains("4"));
}
/**
* 测试包含所有的判断
*/
@Test
public void testContainAll() {
queue.add("1");
queue.add("2");
queue.add("3");
System.out.println("当前队列:" + queue.toString());
System.out.println("是否包含1,2:" + queue.containsAll(Arrays.asList("1", "2")) + " 是否包含1,4:" + queue.containsAll(Arrays.asList("1", "4")));
}
/**
* 测试多线程的阻塞效果
*/
@Test
public void testBlocking() {
Thread t1 = new Thread(() -> {
try {
System.out.println(Thread.currentThread().getName() + " 执行休眠");
Thread.sleep(1500);
System.out.println(Thread.currentThread().getName() + "线程拿出了一个 " + queue.poll() + " 移除后size " + queue.size() + " 移除后队列:" + queue.toString());
} catch (Exception e) {
e.printStackTrace();
}
});
t1.start();
Thread t2 = new Thread(() -> {
queue.add("1");
queue.add("2");
queue.add("3");
queue.add("4");
System.out.println(Thread.currentThread().getName() + "队列:" + queue.toString() + " size:" + queue.size() + " 再次添加一个5");
queue.add("5");
System.out.println(Thread.currentThread().getName() + "添加后" + queue.toString());
System.out.println(Thread.currentThread().getName() + "拿出一个元素:" + queue.poll());
System.out.println(Thread.currentThread().getName() + "size " + queue.size() + " " + queue.toString());
List<String> list = new ArrayList<>();
list.add("1");
});
t2.start();
try {
t1.join();
t2.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
@Test
public void testBlockingPoll(){
Thread t1 = new Thread(()->{
System.out.println(Thread.currentThread().getName()+"开始从队列中获取一个");
System.out.println(Thread.currentThread().getName()+"获取到了:"+queue.poll());
});
t1.start();
Thread t2 = new Thread(()->{
System.out.println(Thread.currentThread().getName()+" 将2秒后给队列中放入一个");
try {
Thread.sleep(2000);
queue.add("1");
System.out.println(Thread.currentThread().getName()+"放入后队列中:"+queue.toString()+" size:"+queue.size());
} catch (InterruptedException e) {
e.printStackTrace();
}
});
t2.start();
try {
t1.join();
t2.join();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
本文来自博客园,作者:HumorChen99,转载请注明原文链接:https://www.cnblogs.com/HumorChen/p/18039528