C# Task WaitAll和WaitAny
Task 有静态方法WaitAll和WaitAny,主要用于等待其他Task完成后做一些事情,先看看其实现部分吧:
public class Task : IThreadPoolWorkItem, IAsyncResult, IDisposable { //Waits for all of the provided Task objects to complete execution. public static void WaitAll(params Task[] tasks) { WaitAll(tasks, Timeout.Infinite); } //Waits for any of the provided Task objects to complete execution.Return The index of the completed task in the tasks array argument. public static int WaitAny(params Task[] tasks) { int waitResult = WaitAny(tasks, Timeout.Infinite); Contract.Assert(tasks.Length == 0 || waitResult != -1, "expected wait to succeed"); return waitResult; } //true if all of the Task instances completed execution within the allotted time; otherwise, false. public static bool WaitAll(Task[] tasks, int millisecondsTimeout, CancellationToken cancellationToken) { if (tasks == null) { throw new ArgumentNullException("tasks"); } if (millisecondsTimeout < -1) { throw new ArgumentOutOfRangeException("millisecondsTimeout"); } Contract.EndContractBlock(); cancellationToken.ThrowIfCancellationRequested(); // early check before we make any allocations List<Exception> exceptions = null; List<Task> waitedOnTaskList = null; List<Task> notificationTasks = null; // If any of the waited-upon tasks end as Faulted or Canceled, set these to true. bool exceptionSeen = false, cancellationSeen = false; bool returnValue = true; // Collects incomplete tasks in "waitedOnTaskList" for (int i = tasks.Length - 1; i >= 0; i--) { Task task = tasks[i]; if (task == null) { throw new ArgumentException(Environment.GetResourceString("Task_WaitMulti_NullTask"), "tasks"); } bool taskIsCompleted = task.IsCompleted; if (!taskIsCompleted) { // try inlining the task only if we have an infinite timeout and an empty cancellation token if (millisecondsTimeout != Timeout.Infinite || cancellationToken.CanBeCanceled) { AddToList(task, ref waitedOnTaskList, initSize: tasks.Length); } else { // We are eligible for inlining. If it doesn't work, we'll do a full wait. taskIsCompleted = task.WrappedTryRunInline() && task.IsCompleted; // A successful TryRunInline doesn't guarantee completion if (!taskIsCompleted) AddToList(task, ref waitedOnTaskList, initSize: tasks.Length); } } if (taskIsCompleted) { if (task.IsFaulted) exceptionSeen = true; else if (task.IsCanceled) cancellationSeen = true; if (task.IsWaitNotificationEnabled) AddToList(task, ref notificationTasks, initSize: 1); } } if (waitedOnTaskList != null) { // Block waiting for the tasks to complete. returnValue = WaitAllBlockingCore(waitedOnTaskList, millisecondsTimeout, cancellationToken); // If the wait didn't time out, ensure exceptions are propagated, and if a debugger is // attached and one of these tasks requires it, that we notify the debugger of a wait completion. if (returnValue) { foreach (var task in waitedOnTaskList) { if (task.IsFaulted) exceptionSeen = true; else if (task.IsCanceled) cancellationSeen = true; if (task.IsWaitNotificationEnabled) AddToList(task, ref notificationTasks, initSize: 1); } } GC.KeepAlive(tasks); } if (returnValue && notificationTasks != null) { foreach (var task in notificationTasks) { if (task.NotifyDebuggerOfWaitCompletionIfNecessary()) break; } } // If one or more threw exceptions, aggregate and throw them. if (returnValue && (exceptionSeen || cancellationSeen)) { if (!exceptionSeen) cancellationToken.ThrowIfCancellationRequested(); // Now gather up and throw all of the exceptions. foreach (var task in tasks) AddExceptionsForCompletedTask(ref exceptions, task); Contract.Assert(exceptions != null, "Should have seen at least one exception"); throw new AggregateException(exceptions); } return returnValue; } public static int WaitAny(Task[] tasks, int millisecondsTimeout, CancellationToken cancellationToken) { if (tasks == null) { throw new ArgumentNullException("tasks"); } if (millisecondsTimeout < -1) { throw new ArgumentOutOfRangeException("millisecondsTimeout"); } Contract.EndContractBlock(); cancellationToken.ThrowIfCancellationRequested(); // early check before we make any allocations int signaledTaskIndex = -1; for (int taskIndex = 0; taskIndex < tasks.Length; taskIndex++) { Task task = tasks[taskIndex]; if (task == null) { throw new ArgumentException(Environment.GetResourceString("Task_WaitMulti_NullTask"), "tasks"); } if (signaledTaskIndex == -1 && task.IsCompleted) { signaledTaskIndex = taskIndex; } } if (signaledTaskIndex == -1 && tasks.Length != 0) { Task<Task> firstCompleted = TaskFactory.CommonCWAnyLogic(tasks); bool waitCompleted = firstCompleted.Wait(millisecondsTimeout, cancellationToken); if (waitCompleted) { Contract.Assert(firstCompleted.Status == TaskStatus.RanToCompletion); signaledTaskIndex = Array.IndexOf(tasks, firstCompleted.Result); Contract.Assert(signaledTaskIndex >= 0); } } GC.KeepAlive(tasks); return signaledTaskIndex; } //Performs a blocking WaitAll on the vetted list of tasks.true if all of the tasks completed; otherwise, false. private static bool WaitAllBlockingCore(List<Task> tasks, int millisecondsTimeout, CancellationToken cancellationToken) { Contract.Assert(tasks != null, "Expected a non-null list of tasks"); Contract.Assert(tasks.Count > 0, "Expected at least one task"); bool waitCompleted = false; var mres = new SetOnCountdownMres(tasks.Count); try { foreach (var task in tasks) { task.AddCompletionAction(mres, addBeforeOthers: true); } waitCompleted = mres.Wait(millisecondsTimeout, cancellationToken); } finally { if (!waitCompleted) { foreach (var task in tasks) { if (!task.IsCompleted) task.RemoveContinuation(mres); } } } return waitCompleted; } private sealed class SetOnCountdownMres : ManualResetEventSlim, ITaskCompletionAction { private int _count; internal SetOnCountdownMres(int count) { Contract.Assert(count > 0, "Expected count > 0"); _count = count; } public void Invoke(Task completingTask) { if (Interlocked.Decrement(ref _count) == 0) Set(); Contract.Assert(_count >= 0, "Count should never go below 0"); } } }
我们首先看看WaitAll的方法,检查Task数组中每个Task实例,检查Task是否已经完成,如果没有完成就把Task添加到waitedOnTaskList集合中,如果waitedOnTaskList集合有元素那么,我们就调用WaitAllBlockingCore来实现真正的等待,当代完毕后我们需要检查notificationTasks集合是否有元素,如果有则依次调用Task的NotifyDebuggerOfWaitCompletionIfNecessary方法。WaitAllBlockingCore实现阻塞是依靠SetOnCountdownMres实例的【和CountdownEvent思路一样,每次调用Invoke的时候,就把计数器_count减1,当_count==0时就调用Set方法】,在WaitAllBlockingCore方法退出前,需要检查Task是否都完成,如果有没有完成的需要移除相应task的SetOnCountdownMres实例【if (!task.IsCompleted) task.RemoveContinuation(mres);】,SetOnCountdownMres的Invoke方法是在Task的FinishContinuations方法调用的【 ITaskCompletionAction singleTaskCompletionAction = continuationObject as ITaskCompletionAction; singleTaskCompletionAction.Invoke(this);注意FinishContinuations方法是在FinishStageThree中调用】注意里面的GC.KeepAlive(tasks)。
现在我们来看看WaitAny方法的实现,首先我们需要循环Task[],检查里面是否有Task已经完成,如果有则直接返回,否者我们调用Task<Task> firstCompleted = TaskFactory.CommonCWAnyLogic(tasks);返回一个Task,然后调用该Task的Wait方法【bool waitCompleted = firstCompleted.Wait(millisecondsTimeout, cancellationToken);】,让我们来看看CommonCWAnyLogic的实现:
public class TaskFactory { internal static Task<Task> CommonCWAnyLogic(IList<Task> tasks) { Contract.Requires(tasks != null); var promise = new CompleteOnInvokePromise(tasks); bool checkArgsOnly = false; int numTasks = tasks.Count; for(int i=0; i<numTasks; i++) { var task = tasks[i]; if (task == null) throw new ArgumentException(Environment.GetResourceString("Task_MultiTaskContinuation_NullTask"), "tasks"); if (checkArgsOnly) continue; // If the promise has already completed, don't bother with checking any more tasks. if (promise.IsCompleted) { checkArgsOnly = true; } // If a task has already completed, complete the promise. else if (task.IsCompleted) { promise.Invoke(task); checkArgsOnly = true; } // Otherwise, add the completion action and keep going. else task.AddCompletionAction(promise); } return promise; } internal sealed class CompleteOnInvokePromise : Task<Task>, ITaskCompletionAction { private IList<Task> _tasks; // must track this for cleanup private int m_firstTaskAlreadyCompleted; public CompleteOnInvokePromise(IList<Task> tasks) : base() { Contract.Requires(tasks != null, "Expected non-null collection of tasks"); _tasks = tasks; if (AsyncCausalityTracer.LoggingOn) AsyncCausalityTracer.TraceOperationCreation(CausalityTraceLevel.Required, this.Id, "TaskFactory.ContinueWhenAny", 0); if (Task.s_asyncDebuggingEnabled) { AddToActiveTasks(this); } } public void Invoke(Task completingTask) { if (Interlocked.CompareExchange(ref m_firstTaskAlreadyCompleted, 1, 0) == 0) { if (AsyncCausalityTracer.LoggingOn) { AsyncCausalityTracer.TraceOperationRelation(CausalityTraceLevel.Important, this.Id, CausalityRelation.Choice); AsyncCausalityTracer.TraceOperationCompletion(CausalityTraceLevel.Required, this.Id, AsyncCausalityStatus.Completed); } if (Task.s_asyncDebuggingEnabled) { RemoveFromActiveTasks(this.Id); } bool success = TrySetResult(completingTask); Contract.Assert(success, "Only one task should have gotten to this point, and thus this must be successful."); var tasks = _tasks; int numTasks = tasks.Count; for (int i = 0; i < numTasks; i++) { var task = tasks[i]; if (task != null && // if an element was erroneously nulled out concurrently, just skip it; worst case is we don't remove a continuation !task.IsCompleted) task.RemoveContinuation(this); } _tasks = null; } } } }
CommonCWAnyLogic首先实例化CompleteOnInvokePromise【var promise = new CompleteOnInvokePromise(tasks)】,检查promise 是否完成,检查每个Task是否完成,否者就把promise作为Task的Continue Task【这里可以理解为每个Task都有一个相同Continue Task】,而CompleteOnInvokePromise自己的wait是在WaitAny中的firstCompleted.Wait(millisecondsTimeout, cancellationToken)方法,当其中其中一个Task完成后,在Task的FinishContinuations方法调用的CompleteOnInvokePromise的Invoke【一旦触发后就需要移调其他task上的CompleteOnInvokePromise,如这里的task.RemoveContinuation(this)】。在CompleteOnInvokePromise的Invoke方法我们调用TrySetResult(completingTask)方法,期实现如下:
public class Task<TResult> : Task { internal bool TrySetResult(TResult result) { if (IsCompleted) return false; Contract.Assert(m_action == null, "Task<T>.TrySetResult(): non-null m_action"); if (AtomicStateUpdate(TASK_STATE_COMPLETION_RESERVED, TASK_STATE_COMPLETION_RESERVED | TASK_STATE_RAN_TO_COMPLETION | TASK_STATE_FAULTED | TASK_STATE_CANCELED)) { m_result = result; Interlocked.Exchange(ref m_stateFlags, m_stateFlags | TASK_STATE_RAN_TO_COMPLETION); var cp = m_contingentProperties; if (cp != null) cp.SetCompleted(); FinishStageThree(); return true; } return false; } }
这里的TrySetResult方法里面调用FinishStageThree方法,以保证Task后面的Continue Task的执行。