python 基于信号量和Task的协程池

import asyncio
from asyncio import Semaphore, Task, events
from inspect import iscoroutinefunction
from typing import Any, Optional, Union

from loguru import logger


class CustomTask:

    def __init__(self, task: Task):
        self._task = task

    async def result(self) -> Any:
        """
        获取task结果
        会异步阻塞,会报错
        """
        res = await self._task
        return res

    async def exception(
        self,
    ) -> Optional[Union[Exception, BaseException, asyncio.exceptions.CancelledError]]:
        """
        获取task异常
        会异步阻塞,无异常则返回None
        """
        try:
            await self._task
            return None
        except asyncio.exceptions.CancelledError as e:
            return asyncio.exceptions.CancelledError("任务已被取消")
        except Exception as e:
            return e
        except BaseException as e:
            return e

    def done(self) -> bool:
        """
        检查task是否完成,不阻塞
        """
        return self._task.done()

    def cancel(self) -> bool:
        """
        取消task,不阻塞
        返回True表示取消信号设置成功,对应的协程会在下次await时抛出CancelledError,这个异常可以在exception方法中获取
        返回False表示信号设置失败,可能的原因有协程已完成(包括正常和异常结束)或者已被取消
        """
        return self._task.cancel()


class CoroutinePool:
    def __init__(self, max_size: int):
        assert max_size > 0, "max_size需大于0"
        self.max_size = max_size
        self._semaphore: Optional[Semaphore] = None
        self._loop = None
        self._binding_loop()

    def _binding_loop(self):
        """
        绑定事件循环,强制要求在已有事件循环中初始化资源
        """
        try:
            self._loop = events.get_running_loop()
        except RuntimeError:
            raise RuntimeError(
                "未找到正在运行的事件循环,此类需在协程函数中完成实例化."
            )
        self._semaphore = asyncio.Semaphore(self.max_size)

    async def _wrapper(self, func, *args, **kwargs):
        """
        包装外来任务,加入信号量限制
        """
        if self._semaphore is None:
            raise RuntimeError("池信号量初始化异常")
        async with self._semaphore:
            return await func(*args, **kwargs)

    def submit(self, func, *args, **kwargs) -> CustomTask:
        """
        提交任务
        """
        if not iscoroutinefunction(func):
            raise ValueError("只能提交协程函数")
        if self._loop is None:
            raise RuntimeError("未绑定事件循环")
        if id(self._loop) != id(asyncio.get_event_loop()):
            raise RuntimeError("不可跨事件循环使用池")
        wrap_coroutine = self._wrapper(func, *args, **kwargs)
        task = asyncio.create_task(wrap_coroutine)
        return CustomTask(task)


# ------------------------------------------------------------测试--------------------------------------------------------------------#


async def test(x):
    logger.info(f"开始执行任务{x}")
    await asyncio.sleep(1)
    logger.info(f"任务{x}执行完毕")
    return x


async def start():
    pool = CoroutinePool(1)
    task = pool.submit(test, 1)
    logger.info("任务已提交")
    await asyncio.sleep(0.2)
    logger.info(f"任务是否完成:{task.done()}")
    e = await task.exception()
    logger.info(f"任务是否完成:{task.done()}")
    logger.info(f"任务异常信息:{e}")
    res = await task.result()
    logger.info(f"任务结果:{res}")


async def start1():
    pool = CoroutinePool(5)
    tasks = [pool.submit(test, i) for i in range(10)]
    logger.info("任务已提交")
    for task in tasks:
        res = await task.result()
        logger.info(f"任务结果:{res}")


async def start2():
    pool = CoroutinePool(1)
    task = pool.submit(test, 1)
    logger.info("任务已提交")
    await asyncio.sleep(0.2)
    task.cancel()
    e = await task.exception()
    print(e)


if __name__ == "__main__":
    # asyncio.run(start())
    # asyncio.run(start1())
    asyncio.run(start2())

这个协程池可能有很多没有考虑到的使用情况,尤其是在多线程使用场景中,可能会出现事件循环管理引起的问题,建议仅在单线程下使用此协程池。

posted @ 2025-04-03 15:45  CJTARRR  阅读(6)  评论(0)    收藏  举报