阅读《java并发编程实战》第十二章 并发程序的测试

要点:

  • 正确性测试
    • 基本的单元测试
    • 对阻塞操作的测试
    • 安全性测试
    • 资源管理的测试
  • 性能测试

本章通过实现一个有界的阻塞队列(基于信号量实现),学习怎么给并发程序做测试。

代码1:BoundedBuffer 类,基于信号量实现的有界阻塞队列。实际生产代码中,应该使用 ArrayBlockingQueue, 或者 LinkedBlockingQueue,而不是自己实现。

public class BoundedBuffer<E> {
    private final E[] items;
    private final Semaphore availableItems, availableSpaces; // why use two semaphore?
    private volatile int takeIndex, putIndex;

    public BoundedBuffer(int capacity) {
        this.items = (E[])new Object[capacity];
        this.availableItems = new Semaphore(0);
        this.availableSpaces = new Semaphore(capacity);
    }

    public void put(E item) throws InterruptedException {
        availableSpaces.acquire();
        doInsert(item);
        availableItems.release();
    }

    public E take() throws InterruptedException {
        availableItems.acquire();
        E ans = doExtract();
        availableSpaces.release();
        return ans;
    }

    public boolean isEmpty() {
        return availableItems.availablePermits() == 0;
    }

    public boolean isFull() {
        return availableSpaces.availablePermits() == 0;
    }

    private synchronized void doInsert(E item) {
        items[putIndex++] = item;
        if (putIndex == items.length) {
            putIndex = 0;
        }
    }

    private synchronized E doExtract() {
        E ans = items[takeIndex];
        items[takeIndex] = null; // help GC, must do it
        if (++takeIndex == items.length) {
            takeIndex = 0;
        }
        return ans;
    }
}

分析:

  • 为什么需要两个Semaphore变量? 通常我们自己实现一个阻塞队列,做法是使用一个Lock,然后生成两个条件队列 (notFull, notEmpty)用来给阻塞的队列排队。而锁本身是为了保证操作队列的原子性(putIndex, takeIndex). 这个程序中,使用synchronized 锁+两个阻塞队列(是Semaphore实现的),原理也是一样的,之所以要两个,一个用来给put线程阻塞排队用的,一个是给take线程阻塞排队的。
  • 为什么synchronized锁,不直接加到take 和 put上,而是在获取Semaphore之后去做的同步?如果是在take方法上加锁,那么锁的范围比较大,而且一旦获取到锁,然后availableItems这个Semaphore又用完了,那么会出先死锁,导致take的线程持有sync锁,然后又阻塞在Semaphore上了,因此不能直接在take上加锁。经过实验,修改sync之后,执行PutTakeTest后,程序会一致阻塞在Semaphore上。

代码2: 基本的单元测试

import static org.junit.jupiter.api.Assertions.*;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

/**
 * Created by xianzhon on 2023/05/28.
 * Purpose:
 */
class BoundedBufferTest {
    BoundedBuffer<Integer> bb;

    @BeforeEach
    void setUp() {
        bb = new BoundedBuffer<>(10);
    }

    // -----------------------------------------
    // 基本的单元测试
    // -----------------------------------------
    @Test
    void testIsEmptyWhenConstructed() {
        assertTrue(bb.isEmpty());
        assertFalse(bb.isFull());
    }

    @Test
    void testFullAfterPuts() throws InterruptedException {
        for (int i = 0; i < 10; i++) {
            bb.put(i);
        }
        assertTrue(bb.isFull());
        assertFalse(bb.isEmpty());
    }

    // -----------------------------------------
    // 对阻塞操作的测试
    // -----------------------------------------
    final int LOCKUP_DETEC_TIMEOUT = 3000;

    @Test
    void testTakeBlocksWhenEmpty() {
        Thread taker = new Thread(() -> {
            try {
                int unused = bb.take();
                fail();
            } catch (InterruptedException e) {
                System.out.println(Thread.currentThread().getName() + " is interrupted!");
            }
        });

        try {
            taker.start();
            Thread.sleep(LOCKUP_DETEC_TIMEOUT);
            taker.interrupt();
            taker.join(LOCKUP_DETEC_TIMEOUT);
            assertFalse(taker.isAlive());
        } catch (InterruptedException e) {
            fail();
        }
    }

    @Test
    void testPutBlocksWhenFull() {
        for (int i = 0; i < 10; i++) {
            try {
                bb.put(i);
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            }
        }
        Thread putter = new Thread(() -> {
            try {
                bb.put(100);
                fail(); // if put thread is not blocked, fail the test
            } catch (InterruptedException e) {
                // when it was interrupted, goes here
            }
        });

        try {
            putter.start();
            Thread.sleep(LOCKUP_DETEC_TIMEOUT);
            putter.interrupt();
            putter.join(LOCKUP_DETEC_TIMEOUT);
            assertFalse(putter.isAlive());
        } catch (InterruptedException e) {
            fail();
        }
    }
}

代码3:安全性测试,使用10个生产者和10个消费者。

// -----------------------------------------
// 安全性测试
// -----------------------------------------
public class PutTakeTest {
    private static final ExecutorService pool = Executors.newCachedThreadPool();
    private final AtomicInteger putSum = new AtomicInteger(0);
    private final AtomicInteger takeSum = new AtomicInteger(0);
    private final CyclicBarrier barrier;
    private final BoundedBuffer<Integer> bb;
    private final int nTrails, nPairs;

    public PutTakeTest(int capacity, int nPairs, int nTrails) {
        this.bb = new BoundedBuffer<>(capacity);
        this.nPairs = nPairs;
        this.nTrails = nTrails;
        this.barrier = new CyclicBarrier(nPairs * 2 + 1); // +1 是主线程本身
    }

    void test() {
        try {
            for (int i = 0; i < nPairs; i++) {
                pool.execute(new Producer());
                pool.execute(new Consumer());
            }
            barrier.await(); // 等待所有线程就绪
            barrier.await(); // 等待所有线程执行完成
            System.out.println("putSum: " + putSum.get() + ", takeSum: " + takeSum.get() + ", isSame: " + (putSum.get() == takeSum.get()));
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        } catch (BrokenBarrierException e) {
            throw new RuntimeException(e);
        }
    }

    class Producer implements Runnable {

        @Override
        public void run() {
            try {
                int seed = this.hashCode() ^ (int)System.nanoTime();
                int sum = 0;
                barrier.await(); // 线程就绪
                for (int i = nTrails; i > 0; i--) {
                    bb.put(seed);
                    sum += seed;
                    seed = xorShift(seed);
                }
                putSum.getAndAdd(sum);
                barrier.await();  // 线程执行结束
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            } catch (BrokenBarrierException e) {
                throw new RuntimeException(e);
            }
        }
    }

    class Consumer implements Runnable {

        @Override
        public void run() {
            try {
                barrier.await(); // 线程就绪
                int sum = 0;
                for (int i = nTrails; i > 0; i--) {
                    sum += bb.take();
                }
                takeSum.getAndAdd(sum);
                barrier.await();  // 线程执行结束
            } catch (InterruptedException e) {
                throw new RuntimeException(e);
            } catch (BrokenBarrierException e) {
                throw new RuntimeException(e);
            }
        }
    }

    int xorShift(int y) {
        y ^= (y << 6);
        y ^= (y >>> 21);
        y ^= (y << 7);
        return y;
    }

    public static void main(String[] args) {
        new PutTakeTest(10, 10, 100000).test();
        pool.shutdown();
    }
}
// 结果:每次输出的sum不一样,但是都相等。
// putSum: 876698623, takeSum: 876698623, isSame: true

分析:

  • xorShift 是做什么的?它是一个使用XORShift算法的简单随机数生成器,是一种伪随机数生成算法,可以产生具有良好统计特性的数字序列。通过反复调用xorShift方法并使用不同的初始y值,可以生成一个伪随机数序列。生成的数字的质量和随机性取决于XORShift算法的特性。【注意:尽管这个算法可以提供一个相当好的伪随机序列,但它不适用于密码学或需要强随机性的应用。对于这些情况,应使用专门的随机数生成库或类。】
  • 为什么能证明并发正确性?10个线程往队列中put,以及10个线程从队列中取。最后得到的sum是一致的,说明put进去的每个元素,都被take出来了。没有出现数据覆盖,或者丢失。
posted @ 2023-05-28 15:17  编程爱好者-java  阅读(15)  评论(0编辑  收藏  举报