Rust 异步取消

(&mut Future).await

直接执行 join_handle.await 会消费 JoinHandle ,通过 (&mut join_handle).await 不消费 JoinHandle ,参考:https://github.com/tokio-rs/tokio/discussions/4019

// join_handle.await 会消费生命周期
impl<T> Future for JoinHandle<T> {
    type Output = super::Result<T>;


    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
      ...
    }
}

// 注意看,JoinHandle 的可变借用也实现了 Future,所以 (&mut join_handle).await 操作不需要消费 JoinHandle 的生命周期
#[stable(feature = "futures_api", since = "1.36.0")]
impl<F: ?Sized + Future + Unpin> Future for &mut F {
    type Output = F::Output;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        F::poll(Pin::new(&mut **self), cx)
    }
}

impl<T> JoinHandle<T> {
    ...
    // abort 也只需要借用 Future
    pub fn abort(&self) {
        if let Some(raw) = self.raw {
            raw.shutdown();
        }
    }
}

但是,不能多次调用 .await:

#[tokio::main]
pub async fn main() -> () {
    let mut join_handle = tokio::spawn(async {});
    join_handle.abort();
    dbg!((&mut join_handle).await); // Err(JoinError::Cancelled)
    join_handle.abort();
    dbg!((&mut join_handle).await); // thread 'main' panicked at 'unexpected task state'
}

普通的异步代码块需要 Box::pin 一下:

#[tokio::main]
pub async fn main() -> () {
    let mut block = Box::pin(async {
        timeout::sleep(2000).await;
        88
    });
    dbg!((&mut block).await); // [src\main.rs:11] (&mut block).await = 88
    block.as_ref();

    dbg!(block.await);
    block.as_ref(); // Compile Error: borrow of moved value: `block`
}

通过 spawn task 实现超时取消

timer.rs

use std::future::Future;

pub async fn sleep(n: u64) {
    tokio::time::sleep(std::time::Duration::from_millis(n)).await;
}

pub async fn timeout<T>(task: impl Future<Output = T> + Send + 'static, millis: u64) -> Option<T>
where
    T: Send + 'static,
{
    let mut join_handle = tokio::spawn(task);
    tokio::select! {
        Ok(v) = &mut join_handle => Some(v),
        () = sleep(millis) => {
            join_handle.abort();
            None
        },
    }
}

// or
pub async fn timeout<R, T>(task: T, millis: u64) -> Option<R>
where
    R: Send + 'static,
    T: Future<Output = R> + Send + 'static,
{
    let mut join_handle = tokio::spawn(task);
    tokio::select! {
        Ok(v) = &mut join_handle => Some(v),
        () = sleep(millis) => {
            join_handle.abort();
            None
        },
    }
}

测试

mod timer;
use timer::{sleep, timeout};

#[derive(Debug)]
struct A;
impl Drop for A {
    fn drop(&mut self) {
        println!("drop A, 任务结束或被取消")
    }
}

#[tokio::main]
async fn main() {
    match timeout(
        async {
            let _a = A;
            println!("开始");
            sleep(5000).await; // 模拟任务耗时,5秒
            println!("完成");
            12345
        },
        2000, // 超时时间,2秒
    )
    .await
    {
        Some(v) => {
            println!("结果:{v}");
        }
        None => {
            println!("超时");
        }
    }
    println!("程序运行结束");
}

运行结果:

开始
drop A, 任务结束或被取消
超时
程序运行结束

我们增加超时时间大于5秒,结果:

开始
完成
drop A, 任务结束或被取消
结果:12345
程序运行结束

能否不使用 spawn

使用 tokio 官方功能吧!

#[derive(Debug)]
struct A;
impl Drop for A {
    fn drop(&mut self) {
        println!("drop A, 任务结束或被取消")
    }
}

#[tokio::main]
async fn main() {
    match tokio::time::timeout(std::time::Duration::from_millis(5000), async {
        let _a = A;
        println!("开始");
        std::future::pending::<()>().await; // 永远不会被解决的未来
        println!("完成");
        12345
    }).await {
        Ok(v) => {
            println!("结果:{v}");
        }
        Err(_) => {
            println!("超时");
        }
    }
    println!("程序运行结束");
}

结果:

开始
drop A, 任务结束或被取消
超时
程序运行结束

Future 是如何 poll 的?

为了研究未来的 poll 原理,我们写一个 Delay 未来:

use std::{future::Future, task::Poll, pin::Pin};

use tokio::time::Sleep;

pub struct Delay<R> {
    sleep: Pin<Box<Sleep>>,
    task: Pin<Box<dyn Future<Output = R>>>,
}

impl<R> Delay<R> {
    pub fn new<T>(delay: u64, task: T) -> Self 
    where
        T: Future<Output = R> + 'static,
    {
        Self {
            sleep: Box::pin(tokio::time::sleep(std::time::Duration::from_millis(delay))),
            task: Box::pin(task),
        }
    }
}

impl<R> Future for Delay<R> {
    type Output = R;

    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
        println!("on poll()");
        match self.sleep.as_mut().poll(cx) {
            Poll::Pending => {
                println!("sleeping");
                Poll::Pending
            },
            Poll::Ready(()) => {
                println!("sleep done");
                if let Poll::Ready(v) = self.task.as_mut().poll(cx) {
                    println!("task done");
                    Poll::Ready(v)
                } else {
                    println!("task executing");
                    Poll::Pending
                }
            },
        }
    }
}

#[tokio::test]
pub async fn test() -> () {
    let r = Delay::new(5000, async {
        println!("开始计算");
        tokio::time::sleep(std::time::Duration::from_millis(5000)).await;
        54321
    });
    dbg!(r.await);
}

测试结果:

running 1 test
on poll()
sleeping
// <--------- 5秒后
on poll()
sleep done    
开始计算      
task executing
on poll()     
sleep done    
task executing
// <--------- 大概2秒后
on poll()
sleep done    
task executing
// <--------- 大概2秒后
on poll()
sleep done    
task executing
on poll()
sleep done
task done
[src\timer.rs:53] r.await = 54321
test timer::test ... ok

没错,我们延迟了5秒执行任务,并且可以观察到 poll 过程,猜测:当一个子异步任务得到解决,成为Poll::Ready状态后,会再次触发链上的 poll,从而重新评估整个异步任务状态。(但是,为什么task并未取得进展时会触发poll?是tokio::time::sleep的自定义实现吗?)
不过这相当于 { sleep(5000).await; task.await },现在我们提高一下难度,让任务立即执行,但结果必须延迟后返回。
首先想到在 sleeping 时也 poll 一下 task,很遗憾,引发了 panic :

match self.sleep.as_mut().poll(cx) {
            Poll::Pending => {
                println!("sleeping");
                self.task.as_mut().poll(cx); // thread 'main' panicked at '`async fn` resumed after completion', src\main.rs:14:30
                Poll::Pending
            },
...

尝试:

use std::{future::Future, task::Poll, pin::Pin};

use tokio::time::Sleep;

pub struct Delay<R> {
    sleep: Pin<Box<Sleep>>,
    task: Pin<Box<dyn Future<Output = R>>>,
    result: Pin<Box<Option<R>>>,
}

impl<R> Future for Delay<R>
where R: Clone <-----此处申请克隆
{
    type Output = R;

    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
        println!("on poll()");
        match self.sleep.as_mut().poll(cx) {
            Poll::Pending => {
                println!("sleeping");
                if let Poll::Ready(v) = self.task.as_mut().poll(cx) {
                    println!("task done");
                    self.result = Box::pin(Some(v));
                    Poll::Pending
                } else {
                    println!("task executing");
                    Poll::Pending
                }
            },
            Poll::Ready(()) => {
                println!("sleep done");
                if self.result.is_some() {
                    todo!("返回result"); <-----此处无法实现移动值
                } else if let Poll::Ready(v) = self.task.as_mut().poll(cx) {
                    println!("task done");
                    Poll::Ready(v)
                } else {
                    println!("task executing");
                    Poll::Pending
                }
            },
        }
    }
}

再试:

use std::{future::Future, task::Poll, pin::Pin};

use tokio::{time::Sleep, sync::Mutex};

pub struct Delay<R> {
    sleep: Pin<Box<Sleep>>,
    task: Pin<Box<dyn Future<Output = R>>>,
    result: Pin<Box<Mutex<Option<R>>>>,
}

impl<R> Delay<R> {
    pub fn new<T>(delay: u64, task: T) -> Self 
    where
        T: Future<Output = R> + 'static,
    {
        Self {
            sleep: Box::pin(tokio::time::sleep(std::time::Duration::from_millis(delay))),
            task: Box::pin(task),
            result: Box::pin(Mutex::new(None)),
        }
    }
}

impl<R> Future for Delay<R> {
    type Output = R;

    fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
        println!("on poll()");
        let mut done = self.result.try_lock().unwrap().is_some();
        if !done {
            match self.task.as_mut().poll(cx) {
                Poll::Pending => {
                    println!("task executing");
                },
                Poll::Ready(v) => {
                    println!("task done");
                    let mut result = self.result.try_lock().unwrap();
                    *result = Some(v);
                    done = true;
                },
            };
        }
        let sleep = match self.sleep.as_mut().poll(cx) {
            Poll::Pending => {
                println!("sleeping");
                Poll::Pending
            },
            Poll::Ready(()) => {
                println!("sleep done");
                Poll::Ready(())
            },
        };
        if done && sleep.is_ready() {
            let mut result = self.result.try_lock().unwrap();
            let result = result.take().unwrap();
            Poll::Ready(result)
        } else {
            Poll::Pending
        }
    }
}

#[tokio::test]
pub async fn test() -> () {
    let r = Delay::new(2000, async {
        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
        println!("开始计算");
        tokio::time::sleep(std::time::Duration::from_millis(500)).await;
        54321
    });
    dbg!(r.await);
}

运行结果:

running 1 test
on poll()     
task executing
sleeping      
on poll()
开始计算
task executing
sleeping
on poll()
task executing
sleeping
on poll()
task executing
sleeping
on poll()
task executing
sleeping
on poll()
task done
sleeping
on poll()
sleeping
on poll()
sleeping
on poll()
sleeping
on poll()
sleep done
[src\timer.rs:71] r.await = 54321
test timer::test ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 2.00s

可以看到,虽然任务只耗时1秒,我们仍然在2秒后才拿到数据。

取消计算密集型任务

作为例子,下面这个任务耗时 10s 左右。显然,它是无法直接 abort 的:

    let task = async move {
        println!("开始");
        for i in 0 ..= 800000000i64 {
        }
        println!("结束");
    };

解决办法,穿插一些 sleep(0) 来响应取消请求即可:

#[tokio::main]
pub async fn main() -> () {
    let count = std::sync::Arc::new( tokio::sync::Mutex::new(0i64) );
    let count1 = count.clone();
    let count2 = count.clone();
    let task = async move {
        println!("开始");
        for i in 0 ..= 800000000i64 {
            *count1.lock().await = i; // 进度探测
            tokio::time::sleep(std::time::Duration::from_millis(0)).await;
        }
        println!("结束");
        tokio::time::sleep(std::time::Duration::from_millis(10000)).await;
    };
    let join_handle = tokio::spawn(task);
    tokio::time::sleep(std::time::Duration::from_millis(2000)).await;
    join_handle.abort();
    println!("调用abort, i = {}", count2.lock().await);
    match join_handle.await {
        Ok(()) => {
            println!("任务已经完成");
        },
        Err(a) => {
            println!("中断成功, {:?}", a);
        }
    }
}

// 输出
开始
调用abort, i = 1801
中断成功, JoinError::Cancelled

real    0m2.203s
user    0m0.000s
sys     0m0.015s

可以看到对性能还是有非常大的影响(当然,上面代码有对锁进行操作),我们需要减少调用次数,比如判断时间间隔。

#[tokio::main]
pub async fn main() -> () {
    let count = std::sync::Arc::new( tokio::sync::Mutex::new(0i64) );
    let count1 = count.clone();
    let count2 = count.clone();
    let task = async move {
        println!("开始");
        for i in 0 ..= 800000000i64 {
            *count1.lock().await = i; // 进度探测
            if i % 800000 == 0 {
                println!("检测中断");
                tokio::time::sleep(std::time::Duration::from_millis(0)).await;
            }
        }
        println!("结束");
        tokio::time::sleep(std::time::Duration::from_millis(10000)).await;
    };
    let join_handle = tokio::spawn(task);
    tokio::time::sleep(std::time::Duration::from_millis(2000)).await;
    join_handle.abort();
    println!("调用abort, i = {}", count2.lock().await);
    match join_handle.await {
        Ok(_) => {
            println!("任务已经完成");
        },
        Err(a) => {
            println!("中断成功, {:?}", a);
        }
    }
}

// 输出
开始
检测中断
检测中断
检测中断
检测中断
调用abort, i = 2449067
中断成功, JoinError::Cancelled

经过测试,sleep(0) 不如 tokio::task::yield_now。并且我们修复了锁的问题:

    let task = async move {
        println!("开始");
        let mut v = count1.lock().await; // 在确保 2s 后任务可以被中断的情况下,只加锁一次,未来被 drop 时锁自动释放
        for i in 0 ..= 800000000i64 {
            *v = i; // 进度探测
            if i % 800000 == 0 {
                // println!("检测中断");
                // tokio::time::sleep(std::time::Duration::from_millis(0)).await;
                let _ = tokio::task::yield_now().await;
            }
        }
        println!("结束");
        tokio::time::sleep(std::time::Duration::from_millis(10000)).await;
    };

//
开始
调用abort, i = 153600000
中断成功, JoinError::Cancelled

额,yield_now 的性能又不如对一个新的锁进行 lock().await......

see

https://blog.yoshuawuyts.com/async-cancellation-1/
https://course.rs/advance-practice/select.html

posted @ 2023-05-15 17:12  develon  阅读(112)  评论(2编辑  收藏  举报