C# Task的GetAwaiter和ConfigureAwait
C# Task的GetAwaiter和ConfigureAwait
个人感觉Task 的GetAwaiter和ConfigureAwait也是比较好理解的,首先看看他们的实现
public class Task<TResult> : Task { //Gets an awaiter used to await this public new TaskAwaiter<TResult> GetAwaiter() { return new TaskAwaiter<TResult>(this); } //Configures an awaiter used to await this public new ConfiguredTaskAwaitable<TResult> ConfigureAwait(bool continueOnCapturedContext) { return new ConfiguredTaskAwaitable<TResult>(this, continueOnCapturedContext); } }
现在我们来看看TaskAwaiter<TResult>和ConfiguredTaskAwaitable<TResult>的实现:
public struct TaskAwaiter<TResult> : ICriticalNotifyCompletion { private readonly Task<TResult> m_task; internal TaskAwaiter(Task<TResult> task) { Contract.Requires(task != null, "Constructing an awaiter requires a task to await."); m_task = task; } public bool IsCompleted { get { return m_task.IsCompleted; } } public void OnCompleted(Action continuation) { TaskAwaiter.OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:true); } public void UnsafeOnCompleted(Action continuation) { TaskAwaiter.OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:false); } public TResult GetResult() { TaskAwaiter.ValidateEnd(m_task); return m_task.ResultOnSuccess; } } public struct ConfiguredTaskAwaitable<TResult> { private readonly ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter m_configuredTaskAwaiter; internal ConfiguredTaskAwaitable(Task<TResult> task, bool continueOnCapturedContext) { m_configuredTaskAwaiter = new ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter(task, continueOnCapturedContext); } public ConfiguredTaskAwaitable<TResult>.ConfiguredTaskAwaiter GetAwaiter() { return m_configuredTaskAwaiter; } [HostProtection(Synchronization = true, ExternalThreading = true)] public struct ConfiguredTaskAwaiter : ICriticalNotifyCompletion { private readonly Task<TResult> m_task; private readonly bool m_continueOnCapturedContext; internal ConfiguredTaskAwaiter(Task<TResult> task, bool continueOnCapturedContext) { Contract.Requires(task != null, "Constructing an awaiter requires a task to await."); m_task = task; m_continueOnCapturedContext = continueOnCapturedContext; } public bool IsCompleted { get { return m_task.IsCompleted; } } public void OnCompleted(Action continuation) { TaskAwaiter.OnCompletedInternal(m_task, continuation, m_continueOnCapturedContext, flowExecutionContext:true); } public void UnsafeOnCompleted(Action continuation) { TaskAwaiter.OnCompletedInternal(m_task, continuation, m_continueOnCapturedContext, flowExecutionContext:false); } public TResult GetResult() { TaskAwaiter.ValidateEnd(m_task); return m_task.ResultOnSuccess; } } } public struct TaskAwaiter : ICriticalNotifyCompletion { private readonly Task m_task; internal TaskAwaiter(Task task) { Contract.Requires(task != null, "Constructing an awaiter requires a task to await."); m_task = task; } public void OnCompleted(Action continuation) { OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:true); } public void UnsafeOnCompleted(Action continuation) { OnCompletedInternal(m_task, continuation, continueOnCapturedContext:true, flowExecutionContext:false); } internal static void OnCompletedInternal(Task task, Action continuation, bool continueOnCapturedContext, bool flowExecutionContext) { if (continuation == null) throw new ArgumentNullException("continuation"); StackCrawlMark stackMark = StackCrawlMark.LookForMyCaller; // If TaskWait* ETW events are enabled, trace a beginning event for this await // and set up an ending event to be traced when the asynchronous await completes. if ( TplEtwProvider.Log.IsEnabled() || Task.s_asyncDebuggingEnabled) { continuation = OutputWaitEtwEvents(task, continuation); } // Set the continuation onto the awaited task. task.SetContinuationForAwait(continuation, continueOnCapturedContext, flowExecutionContext, ref stackMark); } public void GetResult() { ValidateEnd(m_task); } internal static void ValidateEnd(Task task) { if (task.IsWaitNotificationEnabledOrNotRanToCompletion) { HandleNonSuccessAndDebuggerNotification(task); } } /// Ensures the task is completed, triggers any necessary debugger breakpoints for completing /// the await on the task, and throws an exception if the task did not complete successfully. private static void HandleNonSuccessAndDebuggerNotification(Task task) { if (!task.IsCompleted) { bool taskCompleted = task.InternalWait(Timeout.Infinite, default(CancellationToken)); Contract.Assert(taskCompleted, "With an infinite timeout, the task should have always completed."); } // Now that we're done, alert the debugger if so requested task.NotifyDebuggerOfWaitCompletionIfNecessary(); // And throw an exception if the task is faulted or canceled. if (!task.IsRanToCompletion) ThrowForNonSuccess(task); } }
TaskAwaiter<TResult>中的OnCompleted和UnsafeOnCompleted方法 参数continueOnCapturedContext为true,GetResult主要是调用TaskAwaiter.ValidateEnd(m_task)方法,而ConfiguredTaskAwaiter的OnCompleted和UnsafeOnCompleted方法中的m_continueOnCapturedContext参数不一定是true,是外面调用task的ConfigureAwait的参数continueOnCapturedContext,ConfiguredTaskAwaiter的GetResult也是调用TaskAwaiter.ValidateEnd(m_task)方法。那么让我们来看看TaskAwaiter的ValidateEnd方法,同时TaskAwaiter的GetResult方法也是调用自己的ValidateEnd,ValidateEnd方法主要是调用HandleNonSuccessAndDebuggerNotification方法,在HandleNonSuccessAndDebuggerNotification方法中检查task是否完成,没有完成我们就调用 task.InternalWait(Timeout.Infinite, default(CancellationToken))来阻塞task直到完成。
Task的ConfigureAwait中参数continueOnCapturedContext最后传递到了TaskAwaiter的OnCompletedInternal方法,也就是说你在调用Awaiter 的OnCompleted和UnsafeOnCompleted方法才有区别,一般调用GetResult是没有区别的。,OnCompletedInternal方法主要是调用task.SetContinuationForAwait(continuation, continueOnCapturedContext, flowExecutionContext, ref stackMark);方法,task的SetContinuationForAwait实现如下:
public class Task : IThreadPoolWorkItem, IAsyncResult, IDisposable { internal void SetContinuationForAwait(Action continuationAction, bool continueOnCapturedContext, bool flowExecutionContext, ref StackCrawlMark stackMark) { Contract.Requires(continuationAction != null); TaskContinuation tc = null; // If the user wants the continuation to run on the current "context" if there is one... if (continueOnCapturedContext) { var syncCtx = SynchronizationContext.CurrentNoFlow; if (syncCtx != null && syncCtx.GetType() != typeof(SynchronizationContext)) { tc = new SynchronizationContextAwaitTaskContinuation(syncCtx, continuationAction, flowExecutionContext, ref stackMark); } else { var scheduler = TaskScheduler.InternalCurrent; if (scheduler != null && scheduler != TaskScheduler.Default) { tc = new TaskSchedulerAwaitTaskContinuation(scheduler, continuationAction, flowExecutionContext, ref stackMark); } } } if (tc == null && flowExecutionContext) { tc = new AwaitTaskContinuation(continuationAction, flowExecutionContext: true, stackMark: ref stackMark); } if (tc != null) { if (!AddTaskContinuation(tc, addBeforeOthers: false)) tc.Run(this, bCanInlineContinuationTask: false); } else { Contract.Assert(!flowExecutionContext, "We already determined we're not required to flow context."); if (!AddTaskContinuation(continuationAction, addBeforeOthers: false)) AwaitTaskContinuation.UnsafeScheduleAction(continuationAction, this); } } }
Task的SetContinuationForAwait方法里面涉及到SynchronizationContextAwaitTaskContinuation,TaskSchedulerAwaitTaskContinuation和AwaitTaskContinuation类,他们的主要方法如下,这里我也没有去调试,:
/// Task continuation for awaiting with a current synchronization context. internal sealed class SynchronizationContextAwaitTaskContinuation : AwaitTaskContinuation { private readonly static SendOrPostCallback s_postCallback = state => ((Action)state)(); // can't use InvokeAction as it's SecurityCritical private static ContextCallback s_postActionCallback; private readonly SynchronizationContext m_syncContext; internal SynchronizationContextAwaitTaskContinuation(SynchronizationContext context, Action action, bool flowExecutionContext, ref StackCrawlMark stackMark) :base(action, flowExecutionContext, ref stackMark) { Contract.Assert(context != null); m_syncContext = context; } internal sealed override void Run(Task task, bool canInlineContinuationTask) { if (canInlineContinuationTask && m_syncContext == SynchronizationContext.CurrentNoFlow) { RunCallback(GetInvokeActionCallback(), m_action, ref Task.t_currentTask); } else { TplEtwProvider etwLog = TplEtwProvider.Log; if (etwLog.IsEnabled()) { m_continuationId = Task.NewId(); etwLog.AwaitTaskContinuationScheduled((task.ExecutingTaskScheduler ?? TaskScheduler.Default).Id, task.Id, m_continuationId); } RunCallback(GetPostActionCallback(), this, ref Task.t_currentTask); } } private static void PostAction(object state) { var c = (SynchronizationContextAwaitTaskContinuation)state; TplEtwProvider etwLog = TplEtwProvider.Log; if (etwLog.TasksSetActivityIds && c.m_continuationId != 0) { c.m_syncContext.Post(s_postCallback, GetActionLogDelegate(c.m_continuationId, c.m_action)); } else { c.m_syncContext.Post(s_postCallback, c.m_action); // s_postCallback is manually cached, as the compiler won't in a SecurityCritical method } } private static ContextCallback GetPostActionCallback() { ContextCallback callback = s_postActionCallback; if (callback == null) { s_postActionCallback = callback = PostAction; } // lazily initialize SecurityCritical delegate return callback; } } internal sealed class TaskSchedulerAwaitTaskContinuation : AwaitTaskContinuation { private readonly TaskScheduler m_scheduler; internal TaskSchedulerAwaitTaskContinuation(TaskScheduler scheduler, Action action, bool flowExecutionContext, ref StackCrawlMark stackMark) : base(action, flowExecutionContext, ref stackMark) { Contract.Assert(scheduler != null); m_scheduler = scheduler; } internal sealed override void Run(Task ignored, bool canInlineContinuationTask) { // If we're targeting the default scheduler, we can use the faster path provided by the base class. if (m_scheduler == TaskScheduler.Default) { base.Run(ignored, canInlineContinuationTask); } else { bool inlineIfPossible = canInlineContinuationTask && (TaskScheduler.InternalCurrent == m_scheduler || Thread.CurrentThread.IsThreadPoolThread); var task = CreateTask(state => { try { ((Action)state)(); } catch (Exception exc) { ThrowAsyncIfNecessary(exc); } }, m_action, m_scheduler); if (inlineIfPossible) { InlineIfPossibleOrElseQueue(task, needsProtection: false); } else { try { task.ScheduleAndStart(needsProtection: false); } catch (TaskSchedulerException) { } // No further action is necessary, as ScheduleAndStart already transitioned task to faulted } } } } internal class AwaitTaskContinuation : TaskContinuation, IThreadPoolWorkItem { private readonly ExecutionContext m_capturedContext; protected readonly Action m_action; protected int m_continuationId; internal AwaitTaskContinuation(Action action, bool flowExecutionContext, ref StackCrawlMark stackMark) { Contract.Requires(action != null); m_action = action; if (flowExecutionContext) { m_capturedContext = ExecutionContext.Capture(ref stackMark, ExecutionContext.CaptureOptions.IgnoreSyncCtx | ExecutionContext.CaptureOptions.OptimizeDefaultCase); } } internal AwaitTaskContinuation(Action action, bool flowExecutionContext) { Contract.Requires(action != null); m_action = action; if (flowExecutionContext) { m_capturedContext = ExecutionContext.FastCapture(); } } protected Task CreateTask(Action<object> action, object state, TaskScheduler scheduler) { Contract.Requires(action != null); Contract.Requires(scheduler != null); return new Task( action, state, null, default(CancellationToken), TaskCreationOptions.None, InternalTaskOptions.QueuedByRuntime, scheduler) { CapturedContext = m_capturedContext }; } internal override void Run(Task task, bool canInlineContinuationTask) { if (canInlineContinuationTask && IsValidLocationForInlining) { RunCallback(GetInvokeActionCallback(), m_action, ref Task.t_currentTask); // any exceptions from m_action will be handled by s_callbackRunAction } else { TplEtwProvider etwLog = TplEtwProvider.Log; if (etwLog.IsEnabled()) { m_continuationId = Task.NewId(); etwLog.AwaitTaskContinuationScheduled((task.ExecutingTaskScheduler ?? TaskScheduler.Default).Id, task.Id, m_continuationId); } ThreadPool.UnsafeQueueCustomWorkItem(this, forceGlobal: false); } } [SecurityCritical] void ExecuteWorkItemHelper() { var etwLog = TplEtwProvider.Log; Guid savedActivityId = Guid.Empty; if (etwLog.TasksSetActivityIds && m_continuationId != 0) { Guid activityId = TplEtwProvider.CreateGuidForTaskID(m_continuationId); System.Diagnostics.Tracing.EventSource.SetCurrentThreadActivityId(activityId, out savedActivityId); } try { if (m_capturedContext == null) { m_action(); } else { try { ExecutionContext.Run(m_capturedContext, GetInvokeActionCallback(), m_action, true); } finally { m_capturedContext.Dispose(); } } } finally { if (etwLog.TasksSetActivityIds && m_continuationId != 0) { System.Diagnostics.Tracing.EventSource.SetCurrentThreadActivityId(savedActivityId); } } } void IThreadPoolWorkItem.ExecuteWorkItem() { // inline the fast path if (m_capturedContext == null && !TplEtwProvider.Log.IsEnabled()) { m_action(); } else { ExecuteWorkItemHelper(); } } private static ContextCallback s_invokeActionCallback; private static void InvokeAction(object state) { ((Action)state)(); } protected static ContextCallback GetInvokeActionCallback() { ContextCallback callback = s_invokeActionCallback; if (callback == null) { s_invokeActionCallback = callback = InvokeAction; } // lazily initialize SecurityCritical delegate return callback; } protected void RunCallback(ContextCallback callback, object state, ref Task currentTask) { Contract.Requires(callback != null); Contract.Assert(currentTask == Task.t_currentTask); var prevCurrentTask = currentTask; try { if (prevCurrentTask != null) currentTask = null; if (m_capturedContext == null) callback(state); else ExecutionContext.Run(m_capturedContext, callback, state, true); } catch (Exception exc) // we explicitly do not request handling of dangerous exceptions like AVs { ThrowAsyncIfNecessary(exc); } finally { if (prevCurrentTask != null) currentTask = prevCurrentTask; if (m_capturedContext != null) m_capturedContext.Dispose(); } } }
我在调试的时候,SetContinuationForAwait方法中的TaskContinuation是AwaitTaskContinuation实例,在AwaitTaskContinuation构造方法中【 m_capturedContext = ExecutionContext.FastCapture();】来捕获上下文。最后在Task的SetContinuationForAwait调用AwaitTaskContinuation的Run方法【tc.Run(this, bCanInlineContinuationTask: false)】,AwaitTaskContinuation的Run实现非常简单, 调用 ThreadPool.UnsafeQueueCustomWorkItem(this, forceGlobal: false);,实际就是调用AwaitTaskContinuation的ExecuteWorkItem方法,检查上下文m_capturedContext是否存在,存在调用ExecuteWorkItemHelper,ExecuteWorkItemHelper里面再调用 ExecutionContext.Run(m_capturedContext, GetInvokeActionCallback(), m_action, true); AwaitTaskContinuation是不是比较简单。SynchronizationContextAwaitTaskContinuation和TaskSchedulerAwaitTaskContinuation的Run方法就忽略吧,更简单。
windows技术爱好者