C# Task的GetAwaiter和ConfigureAwait

C# Task的GetAwaiter和ConfigureAwait

个人感觉Task 的GetAwaiter和ConfigureAwait也是比较好理解的,首先看看他们的实现

复制代码

public class Task<TResult> : Task
{
   //Gets an awaiter used to await this 
    public new TaskAwaiter<TResult> GetAwaiter()
    {
        return new TaskAwaiter<TResult>(this);
    }
    
    //Configures an awaiter used to await this
    public new ConfiguredTaskAwaitable<TResult> ConfigureAwait(bool continueOnCapturedContext)
    {
        return new ConfiguredTaskAwaitable<TResult>(this, continueOnCapturedContext);
    }
}

复制代码

现在我们来看看TaskAwaiter<TResult>和ConfiguredTaskAwaitable<TResult>的实现:

复制代码

public struct TaskAwaiter<TResult> : ICriticalNotifyCompletion
{
    private readonly Task<TResult> m_task;
    internal TaskAwaiter(Task<TResult> task)
    {
        Contract.Requires(task != null, "Constructing an awaiter requires a task to await.");
        m_task = task;
    }
    
    public bool IsCompleted 
    {
        get { return m_task.IsCompleted; }
    }

    public void OnCompleted(Action continuation)
    {
        TaskAwaiter.OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:true);
    }

    public void UnsafeOnCompleted(Action continuation)
    {
        TaskAwaiter.OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:false);
    }

    public TResult GetResult()
    {
        TaskAwaiter.ValidateEnd(m_task);
        return m_task.ResultOnSuccess;
    }
}


public struct ConfiguredTaskAwaitable<TResult>
{      
    private readonly ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter m_configuredTaskAwaiter;
    internal ConfiguredTaskAwaitable(Task<TResult> task, bool continueOnCapturedContext)
    {
        m_configuredTaskAwaiter = new ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter(task, continueOnCapturedContext);
    }

    public ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter GetAwaiter()
    {
        return m_configuredTaskAwaiter;
    }

    [HostProtection(Synchronization = true, ExternalThreading = true)]
    public struct ConfiguredTaskAwaiter : ICriticalNotifyCompletion
    {         
        private readonly Task<TResult> m_task;  
        private readonly bool m_continueOnCapturedContext;
        internal ConfiguredTaskAwaiter(Task<TResult> task, bool continueOnCapturedContext)
        {
            Contract.Requires(task != null, "Constructing an awaiter requires a task to await.");
            m_task = task;
            m_continueOnCapturedContext = continueOnCapturedContext;
        }

        public bool IsCompleted 
        {
            get { return m_task.IsCompleted; }
        }

        public void OnCompleted(Action continuation)
        {
            TaskAwaiter.OnCompletedInternal(m_task, continuation, m_continueOnCapturedContext, flowExecutionContext:true);
        }

        public void UnsafeOnCompleted(Action continuation)
        {
            TaskAwaiter.OnCompletedInternal(m_task, continuation, m_continueOnCapturedContext, flowExecutionContext:false);
        }

        public TResult GetResult()
        {
            TaskAwaiter.ValidateEnd(m_task);
            return m_task.ResultOnSuccess;
        }
    }
}

public struct TaskAwaiter : ICriticalNotifyCompletion
{
    private readonly Task m_task;
    internal TaskAwaiter(Task task)
    {
        Contract.Requires(task != null, "Constructing an awaiter requires a task to await.");
        m_task = task;
    }
    
    public void OnCompleted(Action continuation)
    {
        OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:true);
    }

    public void UnsafeOnCompleted(Action continuation)
    {
        OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:false);
    }
    
    internal static void OnCompletedInternal(Task task, Action continuation, bool continueOnCapturedContext, bool flowExecutionContext)
    {
        if (continuation == null) throw new ArgumentNullException("continuation");
        StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller;

        // If TaskWait* ETW events are enabled, trace a beginning event for this await
        // and set up an ending event to be traced when the asynchronous await completes.
        if ( TplEtwProvider.Log.IsEnabled() || Task.s_asyncDebuggingEnabled)
        {
            continuation = OutputWaitEtwEvents(task, continuation);
        }

        // Set the continuation onto the awaited task.
        task.SetContinuationForAwait(continuation, continueOnCapturedContext, flowExecutionContext, ref stackMark);
    }
        
    public void GetResult()
    {
        ValidateEnd(m_task);
    }

    internal static void ValidateEnd(Task task)
    {
        if (task.IsWaitNotificationEnabledOrNotRanToCompletion)
        {
            HandleNonSuccessAndDebuggerNotification(task);
        }
    }
    
    /// Ensures the task is completed, triggers any necessary debugger breakpoints for completing 
    /// the await on the task, and throws an exception if the task did not complete successfully.
    private static void HandleNonSuccessAndDebuggerNotification(Task task)
    {    
        if (!task.IsCompleted)
        {
            bool taskCompleted = task.InternalWait(Timeout.Infinite, default(CancellationToken));
            Contract.Assert(taskCompleted, "With an infinite timeout, the task should have always completed.");
        }
        
        // Now that we're done, alert the debugger if so requested
        task.NotifyDebuggerOfWaitCompletionIfNecessary();
        
        // And throw an exception if the task is faulted or canceled.
        if (!task.IsRanToCompletion) ThrowForNonSuccess(task);
    }
}

复制代码

TaskAwaiter<TResult>中的OnCompleted和UnsafeOnCompleted方法 参数continueOnCapturedContext为true,GetResult主要是调用TaskAwaiter.ValidateEnd(m_task)方法,而ConfiguredTaskAwaiter的OnCompleted和UnsafeOnCompleted方法中的m_continueOnCapturedContext参数不一定是true,是外面调用task的ConfigureAwait的参数continueOnCapturedContext,ConfiguredTaskAwaiter的GetResult也是调用TaskAwaiter.ValidateEnd(m_task)方法。那么让我们来看看TaskAwaiter的ValidateEnd方法,同时TaskAwaiter的GetResult方法也是调用自己的ValidateEnd,ValidateEnd方法主要是调用HandleNonSuccessAndDebuggerNotification方法,在HandleNonSuccessAndDebuggerNotification方法中检查task是否完成,没有完成我们就调用 task.InternalWait(Timeout.Infinite, default(CancellationToken))来阻塞task直到完成。

Task的ConfigureAwait中参数continueOnCapturedContext最后传递到了TaskAwaiter的OnCompletedInternal方法,也就是说你在调用Awaiter 的OnCompleted和UnsafeOnCompleted方法才有区别,一般调用GetResult是没有区别的。,OnCompletedInternal方法主要是调用task.SetContinuationForAwait(continuation, continueOnCapturedContext, flowExecutionContext, ref stackMark);方法,task的SetContinuationForAwait实现如下:

复制代码

public class Task : IThreadPoolWorkItem, IAsyncResult, IDisposable
{
  internal void SetContinuationForAwait(Action continuationAction, bool continueOnCapturedContext, bool flowExecutionContext, ref StackCrawlMark stackMark)
    {
        Contract.Requires(continuationAction != null);
        TaskContinuation tc = null;

        // If the user wants the continuation to run on the current "context" if there is one...
        if (continueOnCapturedContext)
        {
            var syncCtx = SynchronizationContext.CurrentNoFlow;
            if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext))
            {
                tc = new SynchronizationContextAwaitTaskContinuation(syncCtx, continuationAction, flowExecutionContext, ref stackMark);
            }
            else
            {                
                var scheduler = TaskScheduler.InternalCurrent;
                if (scheduler != null && scheduler != TaskScheduler.Default)
                {
                    tc = new TaskSchedulerAwaitTaskContinuation(scheduler, continuationAction, flowExecutionContext, ref stackMark);
                }
            }
        }

        if (tc == null && flowExecutionContext)
        {
            tc = new AwaitTaskContinuation(continuationAction, flowExecutionContext: true, stackMark: ref stackMark);
        }

        if (tc != null)
        {
            if (!AddTaskContinuation(tc, addBeforeOthers: false))
                tc.Run(this, bCanInlineContinuationTask: false);
        }
        else
        {
            Contract.Assert(!flowExecutionContext, "We already determined we're not required to flow context.");
            if (!AddTaskContinuation(continuationAction, addBeforeOthers: false))
                AwaitTaskContinuation.UnsafeScheduleAction(continuationAction, this);
        }
    }
}

复制代码

Task的SetContinuationForAwait方法里面涉及到SynchronizationContextAwaitTaskContinuation,TaskSchedulerAwaitTaskContinuation和AwaitTaskContinuation类,他们的主要方法如下,这里我也没有去调试,:

复制代码

   /// Task continuation for awaiting with a current synchronization context.
    internal sealed class SynchronizationContextAwaitTaskContinuation : AwaitTaskContinuation
    {
        private readonly static SendOrPostCallback s_postCallback = state => ((Action)state)(); // can't use InvokeAction as it's SecurityCritical
        private static ContextCallback s_postActionCallback;
        private readonly SynchronizationContext m_syncContext;

        internal SynchronizationContextAwaitTaskContinuation(SynchronizationContext context, Action action, bool flowExecutionContext, ref StackCrawlMark stackMark) :base(action, flowExecutionContext, ref stackMark)
        {
            Contract.Assert(context != null);
            m_syncContext = context;
        }

        internal sealed override void Run(Task task, bool canInlineContinuationTask)
        {
            if (canInlineContinuationTask && m_syncContext == SynchronizationContext.CurrentNoFlow)
            {
                RunCallback(GetInvokeActionCallback(), m_action, ref Task.t_currentTask);
            }
            else
            {
                TplEtwProvider etwLog = TplEtwProvider.Log;
                if (etwLog.IsEnabled())
                {
                    m_continuationId = Task.NewId();
                    etwLog.AwaitTaskContinuationScheduled((task.ExecutingTaskScheduler ?? TaskScheduler.Default).Id, task.Id, m_continuationId);
                }
                RunCallback(GetPostActionCallback(), this, ref Task.t_currentTask);
            }
        }

        private static void PostAction(object state)
        {
            var c = (SynchronizationContextAwaitTaskContinuation)state;

            TplEtwProvider etwLog = TplEtwProvider.Log;
            if (etwLog.TasksSetActivityIds && c.m_continuationId != 0)
            {
                c.m_syncContext.Post(s_postCallback, GetActionLogDelegate(c.m_continuationId, c.m_action));
            }
            else
            {
                c.m_syncContext.Post(s_postCallback, c.m_action); // s_postCallback is manually cached, as the compiler won't in a SecurityCritical method
            }
        }

        private static ContextCallback GetPostActionCallback()
        {
            ContextCallback callback = s_postActionCallback;
            if (callback == null) { s_postActionCallback = callback = PostAction; } // lazily initialize SecurityCritical delegate
            return callback;
        }
    }
    
    internal sealed class TaskSchedulerAwaitTaskContinuation : AwaitTaskContinuation
    {
        private readonly TaskScheduler m_scheduler;

        internal TaskSchedulerAwaitTaskContinuation(TaskScheduler scheduler, Action action, bool flowExecutionContext, ref StackCrawlMark stackMark) : base(action, flowExecutionContext, ref stackMark)
        {
            Contract.Assert(scheduler != null);
            m_scheduler = scheduler;
        }

        internal sealed override void Run(Task ignored, bool canInlineContinuationTask)
        {
            // If we're targeting the default scheduler, we can use the faster path provided by the base class.
            if (m_scheduler == TaskScheduler.Default)
            {
                base.Run(ignored, canInlineContinuationTask);
            }
            else
            {
 
                bool inlineIfPossible = canInlineContinuationTask &&
                    (TaskScheduler.InternalCurrent == m_scheduler || Thread.CurrentThread.IsThreadPoolThread);

             
                var task = CreateTask(state => {
                    try { ((Action)state)(); }
                    catch (Exception exc) { ThrowAsyncIfNecessary(exc); }
                }, m_action, m_scheduler);

                if (inlineIfPossible)
                {
                    InlineIfPossibleOrElseQueue(task, needsProtection: false);
                }
                else
                {
                    try { task.ScheduleAndStart(needsProtection: false); }
                    catch (TaskSchedulerException) { } // No further action is necessary, as ScheduleAndStart already transitioned task to faulted
                }
            }
        }
    }
    internal class AwaitTaskContinuation : TaskContinuation, IThreadPoolWorkItem
    {    
        private readonly ExecutionContext m_capturedContext;
        protected readonly Action m_action;
        protected int m_continuationId;

        internal AwaitTaskContinuation(Action action, bool flowExecutionContext, ref StackCrawlMark stackMark)
        {
            Contract.Requires(action != null);
            m_action = action;
            if (flowExecutionContext)
            {
                m_capturedContext = ExecutionContext.Capture(ref stackMark, ExecutionContext.CaptureOptions.IgnoreSyncCtx | ExecutionContext.CaptureOptions.OptimizeDefaultCase);
            }
        }

        internal AwaitTaskContinuation(Action action, bool flowExecutionContext)
        {
            Contract.Requires(action != null);
            m_action = action;
            if (flowExecutionContext)
            {
                m_capturedContext = ExecutionContext.FastCapture();
            }
        }

        protected Task CreateTask(Action<object> action, object state, TaskScheduler scheduler)
        {
            Contract.Requires(action != null);
            Contract.Requires(scheduler != null);

            return new Task( action, state, null, default(CancellationToken), TaskCreationOptions.None, InternalTaskOptions.QueuedByRuntime, scheduler)  
            { 
                CapturedContext = m_capturedContext 
            };
        }

        internal override void Run(Task task, bool canInlineContinuationTask)
        {

            if (canInlineContinuationTask && IsValidLocationForInlining)
            {
                RunCallback(GetInvokeActionCallback(), m_action, ref Task.t_currentTask); // any exceptions from m_action will be handled by s_callbackRunAction
            }
            else
            {
                TplEtwProvider etwLog = TplEtwProvider.Log;
                if (etwLog.IsEnabled())
                {
                    m_continuationId = Task.NewId();
                    etwLog.AwaitTaskContinuationScheduled((task.ExecutingTaskScheduler ?? TaskScheduler.Default).Id, task.Id, m_continuationId);
                }
                ThreadPool.UnsafeQueueCustomWorkItem(this, forceGlobal: false);
           }
        }

        [SecurityCritical]
        void ExecuteWorkItemHelper()
        {
            var etwLog = TplEtwProvider.Log;
            Guid savedActivityId = Guid.Empty;
            if (etwLog.TasksSetActivityIds && m_continuationId != 0)
            {
                Guid activityId = TplEtwProvider.CreateGuidForTaskID(m_continuationId);
                System.Diagnostics.Tracing.EventSource.SetCurrentThreadActivityId(activityId, out savedActivityId);
            }
            try
            {
                if (m_capturedContext == null)
                {
                    m_action();
                }
                else
                {
                    try
                    {
                        ExecutionContext.Run(m_capturedContext, GetInvokeActionCallback(), m_action, true);
                    }
                    finally { m_capturedContext.Dispose(); }
                }
            }
            finally
            {
                if (etwLog.TasksSetActivityIds && m_continuationId != 0)
                {
                    System.Diagnostics.Tracing.EventSource.SetCurrentThreadActivityId(savedActivityId);
                }
            }
        }
        void IThreadPoolWorkItem.ExecuteWorkItem()
        {
            // inline the fast path
            if (m_capturedContext == null && !TplEtwProvider.Log.IsEnabled())
            {
                m_action();
            }
            else
            {
                ExecuteWorkItemHelper();
            }
        }

        private static ContextCallback s_invokeActionCallback;

        private static void InvokeAction(object state) { ((Action)state)(); }

        protected static ContextCallback GetInvokeActionCallback()
        {
            ContextCallback callback = s_invokeActionCallback;
            if (callback == null) { s_invokeActionCallback = callback = InvokeAction; } // lazily initialize SecurityCritical delegate
            return callback;
        }

        protected void RunCallback(ContextCallback callback, object state, ref Task currentTask)
        {
            Contract.Requires(callback != null);
            Contract.Assert(currentTask == Task.t_currentTask);

            var prevCurrentTask = currentTask;
            try
            {
                if (prevCurrentTask != null) currentTask = null;
                if (m_capturedContext == null) callback(state);
                else ExecutionContext.Run(m_capturedContext, callback, state, true);
            }
            catch (Exception exc) // we explicitly do not request handling of dangerous exceptions like AVs
            {
                ThrowAsyncIfNecessary(exc);
            }
            finally
            {
                if (prevCurrentTask != null) currentTask = prevCurrentTask;
                if (m_capturedContext != null) m_capturedContext.Dispose();
            }
        }

    }

复制代码

 我在调试的时候,SetContinuationForAwait方法中的TaskContinuation是AwaitTaskContinuation实例,在AwaitTaskContinuation构造方法中【  m_capturedContext = ExecutionContext.FastCapture();】来捕获上下文。最后在Task的SetContinuationForAwait调用AwaitTaskContinuation的Run方法【tc.Run(this, bCanInlineContinuationTask: false)】,AwaitTaskContinuation的Run实现非常简单, 调用  ThreadPool.UnsafeQueueCustomWorkItem(this, forceGlobal: false);,实际就是调用AwaitTaskContinuation的ExecuteWorkItem方法,检查上下文m_capturedContext是否存在,存在调用ExecuteWorkItemHelper,ExecuteWorkItemHelper里面再调用  ExecutionContext.Run(m_capturedContext, GetInvokeActionCallback(), m_action, true); AwaitTaskContinuation是不是比较简单。SynchronizationContextAwaitTaskContinuation和TaskSchedulerAwaitTaskContinuation的Run方法就忽略吧,更简单。

windows技术爱好者

posted @ 2019-12-29 13:59  grj001  阅读(558)  评论(0编辑  收藏  举报