如果我的文章对您有帮助,麻烦回复一个赞,给我坚持下去的动力

【工具类】线程安全的滑动时间窗口记录工具类

闲来无事,分享一个工具类,写的不好,轻喷,欢迎指出问题

目标是线程安全无锁高性能的记录滑动时间窗口值

import lombok.Getter;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReferenceArray;

/**
 * 线程安全的滑动时间窗口计数,时间单位:秒
 */
public class SlidingTimeWindow {
    // 时间窗口内的总记录值
    private final AtomicInteger count = new AtomicInteger(0);
    // 每个时间点内的记录值
    private final AtomicReferenceArray<TimeSplit> arr;
    // 时间窗口的大小,单位秒
    private final int interval;
    // 最后一次处理的周期数,主要用于当长时间未被调用时更新数据使用
    private volatile long lastCycle;

    // 窗口时间长度
    public SlidingTimeWindow(int interval) {
        this.interval = interval;
        arr = new AtomicReferenceArray<>(interval);
        this.lastCycle = getCurrentCycle();
    }

    // 窗口时间内的记录值
    public int get() {
        long currentCycle = getCurrentCycle();
        if(currentCycle <= lastCycle) {
            return count.get();
        }
        updateTs(currentCycle);
        return count.get();
    }
    // 窗口时间内的记录值并加1
    public int getAndIncrement() {
        long currentCycle = getCurrentCycle();
        TimeSplit ts = updateTs(currentCycle);
        ts.getCount().getAndIncrement();
        return count.getAndIncrement();
    }

    // 计算所属时间周期
    private long getCurrentCycle() {
        return System.currentTimeMillis() / 1000;
    }

    // 更新时间窗口内的记录值
    private TimeSplit updateTs(long currentCycle) {
        long lastCycleTemp = Math.max(this.lastCycle, currentCycle - interval);
        if(currentCycle > lastCycleTemp) {
            // 更新
            this.lastCycle = currentCycle;
        } else if(currentCycle < lastCycleTemp) {
            // 避免机器发生时间回拨导致的错误
            lastCycleTemp = currentCycle;
        }

        TimeSplit ts = null;
        for(;lastCycleTemp<=currentCycle;lastCycleTemp++) {
            // 依次更新每个时间点的数据
            ts = doUpdateTs(lastCycleTemp);
        }
        return ts;
    }

    /**
     * 更新指定时间点的数据
     * 覆盖已经过期的数据,将过期数据从总记录值中减去
     * @param time
     * @return
     */
    private TimeSplit doUpdateTs(long time) {
        int index = (int)(time % interval);
        TimeSplit ts = arr.get(index);
        while (ts == null || ts.getTime() != time) {
            TimeSplit newTs = new TimeSplit(time,new AtomicInteger(0));
            if(arr.compareAndSet(index,ts,newTs) && ts != null) {
                count.getAndAdd(-ts.getCount().get());
            }
            ts = arr.get(index);
        }
        return ts;
    }

    /**
     * 记录每个时间点的值,
     * 当时间点过期时,用于移除总值中该时间点的记录值
     */
    @Getter
    private class TimeSplit {
        private final long time;
        private final AtomicInteger count;

        public TimeSplit(long time, AtomicInteger count) {
            this.time = time;
            this.count = count;
        }
    }
  // chatgpt帮忙写的测试用例
    public static void main(String[] args) throws InterruptedException {
         int THREAD_POOL_SIZE = 2;
         int TEST_TIME_SECONDS = 10;
         int INTERVAL_SECONDS = 5;
        SlidingTimeWindow stw = new SlidingTimeWindow(INTERVAL_SECONDS);
        ExecutorService executorService = Executors.newFixedThreadPool(THREAD_POOL_SIZE);
        long start = System.currentTimeMillis();
        System.out.println("start" + start);
        for (int i = 0; i < THREAD_POOL_SIZE; i++) {
            executorService.execute(() -> {
                while (System.currentTimeMillis() - start <= TEST_TIME_SECONDS * 1000) {
                    stw.getAndIncrement();
                }
            });
        }

        while (System.currentTimeMillis() - start <= (TEST_TIME_SECONDS+1) * 1000) {
            Thread.sleep(1000);
            System.out.println("waiting");
        }

        executorService.shutdown();
        executorService.awaitTermination(10, TimeUnit.SECONDS);
        int totalCount = 0;
        int expectedCount = stw.get();
        for(int i=0;i<stw.arr.length();i++) {
            totalCount += stw.arr.get(i).getCount().get();
        }


        System.out.println("Total request count: " + totalCount);
        System.out.println("Expected request count: " + expectedCount);
        System.out.println("Difference: " + (expectedCount - totalCount));
        for(int i=1;i<INTERVAL_SECONDS;i++) {
            TimeUnit.SECONDS.sleep(1);
            System.out.println("current:" + stw.get());
        }
    }
}

欢迎大家review代码,多提提意见

posted @ 2023-05-09 15:53  无所事事O_o  阅读(91)  评论(0编辑  收藏  举报