C# Barrier 实现
当您需要一组任务并行地运行一连串的阶段,但是每一个阶段都要等待所有其他任务都完成前一阶段之后才能开始,你一通过Barrier实例来同步这一类协同工作。
Barrier初始化后,将等待特定数量的信号到来,这个数量在Barrier初始化时指定,在所指定的信号个数已经到来后,Barrier将执行一个指定的动作,这个动作也是在Barrier初始化时指定。Barrier在执行动作过后,将会重置,这时又将等待特定数量的信号到来,再执行指定动作。信号通过成员函数SignalAndWait()来发送,执行SignalAndWait()函数的Task或者线程将会投入等待,Barrier将等待特定数量的信号到达,然后Barrier执行完指定动作后被重置,这时SignalAndWait()函数所在的Task或者线程将继续运行。在程序的运行过程中,可以通过成员函数AddParticipant()和RemoveParticpant()来增加或者减少需要等待的信号数量。让我们来看看Barrier实现:
public class Barrier : IDisposable { // The first 15 bits are for the total count which means the maximum participants for the barrier is about 32K // The 16th bit is dummy // The next 15th bit for the current // And the last highest bit is for the sense volatile int m_currentTotalCount; const int CURRENT_MASK = 0x7FFF0000; const int TOTAL_MASK = 0x00007FFF; // Bitmask to extratc the sense flag const int SENSE_MASK = unchecked((int)0x80000000); // The maximum participants the barrier can operate = 32767 ( 2 power 15 - 1 ) const int MAX_PARTICIPANTS = TOTAL_MASK; long m_currentPhase; ManualResetEventSlim m_oddEvent; ManualResetEventSlim m_evenEvent; ExecutionContext m_ownerThreadContext; [SecurityCritical] private static ContextCallback s_invokePostPhaseAction; Action<Barrier> m_postPhaseAction; int m_actionCallerID; public Barrier(int participantCount): this(participantCount, null) {} public Barrier(int participantCount, Action<Barrier> postPhaseAction) { if (participantCount < 0 || participantCount > MAX_PARTICIPANTS) { throw new ArgumentOutOfRangeException("participantCount", participantCount, SR.GetString(SR.Barrier_ctor_ArgumentOutOfRange)); } m_currentTotalCount = (int)participantCount; m_postPhaseAction = postPhaseAction; m_oddEvent = new ManualResetEventSlim(true); m_evenEvent = new ManualResetEventSlim(false); // Capture the context if the post phase action is not null if (postPhaseAction != null && !ExecutionContext.IsFlowSuppressed()) { m_ownerThreadContext = ExecutionContext.Capture(); } m_actionCallerID = 0; } //<returns>The phase number of the barrier in which the new participants will first participate. public long AddParticipant() { try { return AddParticipants(1); } catch (ArgumentOutOfRangeException) { throw new InvalidOperationException(SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange)); } } public long AddParticipants(int participantCount) { ThrowIfDisposed(); if (participantCount < 1 ) { throw new ArgumentOutOfRangeException("participantCount", participantCount, SR.GetString(SR.Barrier_AddParticipants_NonPositive_ArgumentOutOfRange)); } else if (participantCount > MAX_PARTICIPANTS) //overflow { throw new ArgumentOutOfRangeException("participantCount", SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange)); } if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID) { throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA)); } SpinWait spinner = new SpinWait(); long newPhase = 0; while (true) { int currentTotal = m_currentTotalCount; int total; int current; bool sense; GetCurrentTotal(currentTotal, out current, out total, out sense); if (participantCount + total > MAX_PARTICIPANTS) //overflow { throw new ArgumentOutOfRangeException("participantCount",SR.GetString(SR.Barrier_AddParticipants_Overflow_ArgumentOutOfRange)); } if (SetCurrentTotal(currentTotal, current, total + participantCount, sense)) { long currPhase = CurrentPhaseNumber; newPhase = (sense != (currPhase % 2 == 0)) ? currPhase + 1 : currPhase; if (newPhase != currPhase) { // Wait on the opposite event if (sense) { m_oddEvent.Wait(); } else { m_evenEvent.Wait(); } } else { if (sense && m_evenEvent.IsSet) m_evenEvent.Reset(); else if (!sense && m_oddEvent.IsSet) m_oddEvent.Reset(); } break; } spinner.SpinOnce(); } return newPhase; } public void RemoveParticipant() { RemoveParticipants(1); } public void RemoveParticipants(int participantCount) { ThrowIfDisposed(); if (participantCount < 1) { throw new ArgumentOutOfRangeException("participantCount", participantCount,SR.GetString(SR.Barrier_RemoveParticipants_NonPositive_ArgumentOutOfRange)); } if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID) { throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA)); } SpinWait spinner = new SpinWait(); while (true) { int currentTotal = m_currentTotalCount; int total; int current; bool sense; GetCurrentTotal(currentTotal, out current, out total, out sense); if (total < participantCount) { throw new ArgumentOutOfRangeException("participantCount",SR.GetString(SR.Barrier_RemoveParticipants_ArgumentOutOfRange)); } if (total - participantCount < current) { throw new InvalidOperationException(SR.GetString(SR.Barrier_RemoveParticipants_InvalidOperation)); } // If the remaining participats = current participants, then finish the current phase int remaingParticipants = total - participantCount; if (remaingParticipants > 0 && current == remaingParticipants ) { if (SetCurrentTotal(currentTotal, 0, total - participantCount, !sense)) { FinishPhase(sense); break; } } else { if (SetCurrentTotal(currentTotal, current, total - participantCount, sense)) { break; } } spinner.SpinOnce(); } } public void SignalAndWait() { SignalAndWait(new CancellationToken()); } public void SignalAndWait(CancellationToken cancellationToken) { SignalAndWait(Timeout.Infinite, cancellationToken); } public bool SignalAndWait(int millisecondsTimeout, CancellationToken cancellationToken) { ThrowIfDisposed(); cancellationToken.ThrowIfCancellationRequested(); if (millisecondsTimeout < -1) { throw new System.ArgumentOutOfRangeException("millisecondsTimeout", millisecondsTimeout,SR.GetString(SR.Barrier_SignalAndWait_ArgumentOutOfRange)); } if (m_actionCallerID != 0 && Thread.CurrentThread.ManagedThreadId == m_actionCallerID) { throw new InvalidOperationException(SR.GetString(SR.Barrier_InvalidOperation_CalledFromPHA)); } bool sense; // The sense of the barrier *before* the phase associated with this SignalAndWait call completes int total; int current; int currentTotal; long phase; SpinWait spinner = new SpinWait(); while (true) { currentTotal = m_currentTotalCount; GetCurrentTotal(currentTotal, out current, out total, out sense); phase = CurrentPhaseNumber; // throw if zero participants if (total == 0) { throw new InvalidOperationException(SR.GetString(SR.Barrier_SignalAndWait_InvalidOperation_ZeroTotal)); } // Try to detect if the number of threads for this phase exceeded the total number of participants or not // This can be detected if the current is zero which means all participants for that phase has arrived and the phase number is not changed yet if (current == 0 && sense != (CurrentPhaseNumber % 2 == 0)) { throw new InvalidOperationException(SR.GetString(SR.Barrier_SignalAndWait_InvalidOperation_ThreadsExceeded)); } //This is the last thread, finish the phase if (current + 1 == total) { if (SetCurrentTotal(currentTotal, 0, total, !sense)) { FinishPhase(sense); return true; } } else if (SetCurrentTotal(currentTotal, current + 1, total, sense)) { break; } spinner.SpinOnce(); } // ** Perform the real wait ** // select the correct event to wait on, based on the current sense. ManualResetEventSlim eventToWaitOn = (sense) ? m_evenEvent : m_oddEvent; bool waitWasCanceled = false; bool waitResult = false; try { waitResult = DiscontinuousWait(eventToWaitOn, millisecondsTimeout, cancellationToken, phase); } catch (OperationCanceledException ) { waitWasCanceled = true; } catch (ObjectDisposedException)// in case a ---- happen where one of the thread returned from SignalAndWait and the current thread calls Wait on a disposed event { // make sure the current phase for this thread is already finished, otherwise propagate the exception if (phase < CurrentPhaseNumber) waitResult = true; else throw; } if (!waitResult) { //reset the spinLock to prepare it for the next loop spinner.Reset(); //If the wait timeout expired and all other thread didn't reach the barrier yet, update the current count back while (true) { bool newSense; currentTotal = m_currentTotalCount; GetCurrentTotal(currentTotal, out current, out total, out newSense); // If the timeout expired and the phase has just finished, return true and this is considered as succeeded SignalAndWait //otherwise the timeout expired and the current phase has not been finished yet, return false //The phase is finished if the phase member variable is changed (incremented) or the sense has been changed // we have to use the statements in the comparison below for two cases: // 1- The sense is changed but the last thread didn't update the phase yet // 2- The phase is already incremented but the sense flipped twice due to the termination of the next phase if (phase < CurrentPhaseNumber || sense != newSense) { // The current phase has been finished, but we shouldn't return before the events are set/reset otherwise this thread could start // next phase and the appropriate event has not reset yet which could make it return immediately from the next phase SignalAndWait // before waiting other threads WaitCurrentPhase(eventToWaitOn, phase); Debug.Assert(phase < CurrentPhaseNumber); break; } //The phase has not been finished yet, try to update the current count. if (SetCurrentTotal(currentTotal, current - 1, total, sense)) { //if here, then the attempt to backout was successful. //throw (a fresh) oce if cancellation woke the wait //or return false if it was the timeout that woke the wait. // if (waitWasCanceled) throw new OperationCanceledException(SR.GetString(SR.Common_OperationCanceled), cancellationToken); else return false; } spinner.SpinOnce(); } } if (m_exception != null) throw new BarrierPostPhaseException(m_exception); return true; } private void FinishPhase(bool observedSense) { // Execute the PHA in try/finally block to reset the variables back in case of it threw an exception if (m_postPhaseAction != null) { try { m_actionCallerID = Thread.CurrentThread.ManagedThreadId; if (m_ownerThreadContext != null) { var currentContext = m_ownerThreadContext; m_ownerThreadContext = m_ownerThreadContext.CreateCopy(); // create a copy for the next run ContextCallback handler = s_invokePostPhaseAction; if (handler == null) { s_invokePostPhaseAction = handler = InvokePostPhaseAction; } ExecutionContext.Run(currentContext, handler, this); currentContext.Dispose(); } else { m_postPhaseAction(this); } m_exception = null; // reset the exception if it was set previously } catch (Exception ex) { m_exception = ex; } finally { m_actionCallerID = 0; SetResetEvents(observedSense); if(m_exception != null) throw new BarrierPostPhaseException(m_exception); } } else { SetResetEvents(observedSense); } } private void SetResetEvents(bool observedSense) { // Increment the phase count using Volatile class because m_currentPhase is 64 bit long type, that could cause torn write on 32 bit machines CurrentPhaseNumber = CurrentPhaseNumber + 1; if (observedSense) { m_oddEvent.Reset(); m_evenEvent.Set(); } else { m_evenEvent.Reset(); m_oddEvent.Set(); } } //<returns>True if the event is set or the phasenumber changed, false if the timeout expired private bool DiscontinuousWait(ManualResetEventSlim currentPhaseEvent, int totalTimeout, CancellationToken token, long observedPhase) { int maxWait = 100; // 100 ms int waitTimeCeiling = 10000; // 10 seconds while (observedPhase == CurrentPhaseNumber) { // the next wait time, the min of the maxWait and the totalTimeout int waitTime = totalTimeout == Timeout.Infinite ? maxWait : Math.Min(maxWait, totalTimeout); if (currentPhaseEvent.Wait(waitTime, token)) return true; //update the total wait time if (totalTimeout != Timeout.Infinite) { totalTimeout -= waitTime; if (totalTimeout <= 0) return false; } //if the maxwait exceeded 10 seconds then we will stop increasing the maxWait time and keep it 10 seconds, otherwise keep doubling it maxWait = maxWait >= waitTimeCeiling ? waitTimeCeiling : Math.Min(maxWait << 1, waitTimeCeiling); } //if we exited the loop because the observed phase doesn't match the current phase, then we have to spin to mske sure //the event is set or the next phase is finished WaitCurrentPhase(currentPhaseEvent, observedPhase); return true; } private void WaitCurrentPhase(ManualResetEventSlim currentPhaseEvent, long observedPhase) { //spin until either of these two conditions succeeds //1- The event is set //2- the phase count is incremented more than one time, this means the next phase is finished as well, //but the event will be reset again, so we check the phase count instead SpinWait spinner = new SpinWait(); while (!currentPhaseEvent.IsSet && CurrentPhaseNumber - observedPhase <= 1) { spinner.SpinOnce(); } } private static void InvokePostPhaseAction(object obj) { var thisBarrier = (Barrier)obj; thisBarrier.m_postPhaseAction(thisBarrier); } private bool SetCurrentTotal(int currentTotal, int current, int total, bool sense) { int newCurrentTotal = (current <<16) | total; if (!sense) { newCurrentTotal |= SENSE_MASK; } return Interlocked.CompareExchange(ref m_currentTotalCount, newCurrentTotal, currentTotal) == currentTotal; } //Gets the total number of participants in the barrier. public int ParticipantCount { get { return (int)(m_currentTotalCount & TOTAL_MASK); } } public long CurrentPhaseNumber { // use the new Volatile.Read/Write method because it is cheaper than Interlocked.Read on AMD64 architecture get { return Volatile.Read(ref m_currentPhase); } internal set { Volatile.Write(ref m_currentPhase, value); } } }
这里边有几个变量需要说明一下,m_currentTotalCount,1-15存的是总的参与者总数,17-31是存的当前的参与者数量,32表示所有参与者是否都已到达,也就是后面判断执行ManualResetEventSlim的那个实例m_oddEvent还是m_evenEvent,Barrier得构造函数就不说了,如果指定了postPhaseAction,并且当前有可以捕获当前线程的上下文,那么我们需要捕获当前上下文【m_ownerThreadContext = ExecutionContext.Capture()】,便于后面调用postPhaseAction。还有就是【m_oddEvent = new ManualResetEventSlim(true);m_evenEvent = new ManualResetEventSlim(false);】
AddParticipants表示增加总的参与者数目,那么RemoveParticipants就是减少总的参与者数目,它们都是借用SpinWait的自旋和原子操作完成的,AddParticipants因为增加了总的参与者,所以通常需要调用ManualResetEventSlim的Wait方法【没有完成的情况下】,RemoveParticipants是减少参与者,那么【current==total】可能减少后程序就该触发结束标记了,这里调用FinishPhase,否者就只是减少total的值。如果我们先前的构造函数有回调,那么这里需要调用回调函数,如果先前捕获了线程上线文那么而回调需要传入线程上下文【ExecutionContext.Run(currentContext, InvokePostPhaseAction, this);】否者只是简单的方法调用【m_postPhaseAction(this)】。
现在我们再来看SignalAndWait方法,SignalAndWait方法也是借助SpinWait的自旋和原子操作完成的,其核心操作 就是current=current+1, 如果current==total 那么就调用FinishPhase,FinishPhase中会调用回调函数s_invokePostPhaseAction,以及发出Set信号,如果调用SignalAndWait方法后,current<total,那么这里继续往下面执行,调用DiscontinuousWait方法阻塞当前任务【方法】,直到其他任务调用SignalAndWait 方法【current==total时调用FinishPhase方法,发现胡Set信号】 。有关Barrier的使用在一本pdf里面发现一个比较好的图片: