C# ConcurrentBag的实现原理
一、前言
笔者最近在做一个项目,项目中为了提升吞吐量,使用了消息队列,中间实现了生产消费模式,在生产消费者模式中需要有一个集合,来存储生产者所生产的物品,笔者使用了最常见的List<T>
集合类型。
由于生产者线程有很多个,消费者线程也有很多个,所以不可避免的就产生了线程同步的问题。开始笔者是使用lock
关键字,进行线程同步,但是性能并不是特别理想,然后有网友说可以使用SynchronizedList<T>
来代替使用List<T>
达到线程安全的目的。于是笔者就替换成了SynchronizedList<T>
,但是发现性能依旧糟糕,于是查看了SynchronizedList<T>
的源代码,发现它就是简单的在List<T>
提供的API的基础上加了lock
,所以性能基本与笔者实现方式相差无几。
最后笔者找到了解决的方案,使用ConcurrentBag<T>
类来实现,性能有很大的改观,于是笔者查看了ConcurrentBag<T>
的源代码,实现非常精妙,特此在这记录一下。
二、ConcurrentBag类
ConcurrentBag<T>
实现了IProducerConsumerCollection<T>
接口,该接口主要用于生产者消费者模式下,可见该类基本就是为生产消费者模式定制的。然后还实现了常规的IReadOnlyCollection<T>
类,实现了该类就需要实现IEnumerable<T>、IEnumerable、 ICollection
类。
ConcurrentBag<T>
对外提供的方法没有List<T>
那么多,但是同样有Enumerable
实现的扩展方法。类本身提供的方法如下所示。
名称 | 说明 |
---|---|
Add | 将对象添加到 ConcurrentBag |
CopyTo | 从指定数组索引开始,将 ConcurrentBag |
Equals(Object) | 确定指定的 Object 是否等于当前的 Object。 (继承自 Object。) |
Finalize | 允许对象在“垃圾回收”回收之前尝试释放资源并执行其他清理操作。 (继承自 Object。) |
GetEnumerator | 返回循环访问 ConcurrentBag |
GetHashCode | 用作特定类型的哈希函数。 (继承自 Object。) |
GetType | 获取当前实例的 Type。 (继承自 Object。) |
MemberwiseClone | 创建当前 Object 的浅表副本。 (继承自 Object。) |
ToArray | 将 ConcurrentBag |
ToString | 返回表示当前对象的字符串。 (继承自 Object。) |
TryPeek | 尝试从 ConcurrentBag |
TryTake | 尝试从 ConcurrentBag |
三、 ConcurrentBag线程安全实现原理
1. ConcurrentBag的私有字段
ConcurrentBag
线程安全实现主要是通过它的数据存储的结构和细颗粒度的锁。
public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T>
{
// ThreadLocalList对象包含每个线程的数据
ThreadLocal<ThreadLocalList> m_locals;
// 这个头指针和尾指针指向中的第一个和最后一个本地列表,这些本地列表分散在不同线程中
// 允许在线程局部对象上枚举
volatile ThreadLocalList m_headList, m_tailList;
// 这个标志是告知操作线程必须同步操作
// 在GlobalListsLock 锁中 设置
bool m_needSync;
}
首选我们来看它声明的私有字段,其中需要注意的是集合的数据是存放在ThreadLocal
线程本地存储中的。也就是说访问它的每个线程会维护一个自己的集合数据列表,一个集合中的数据可能会存放在不同线程的本地存储空间中,所以如果线程访问自己本地存储的对象,那么是没有问题的,这就是实现线程安全的第一层,使用线程本地存储数据。
然后可以看到ThreadLocalList m_headList, m_tailList;
这个是存放着本地列表对象的头指针和尾指针,通过这两个指针,我们就可以通过遍历的方式来访问所有本地列表。它使用volatile
修饰,不允许线程进行本地缓存,每个线程的读写都是直接操作在共享内存上,这就保证了变量始终具有一致性。任何线程在任何时间进行读写操作均是最新值。对于volatile
修饰符,感谢我是攻城狮指出描述错误。
最后又定义了一个标志,这个标志告知操作线程必须进行同步操作,这是实现了一个细颗粒度的锁,因为只有在几个条件满足的情况下才需要进行线程同步。
2. 用于数据存储的ThreadLocalList类
接下来我们来看一下ThreadLocalList
类的构造,该类就是实际存储了数据的位置。实际上它是使用双向链表这种结构进行数据存储。
[Serializable]
// 构造了双向链表的节点
internal class Node
{
public Node(T value)
{
m_value = value;
}
public readonly T m_value;
public Node m_next;
public Node m_prev;
}
/// <summary>
/// 集合操作类型
/// </summary>
internal enum ListOperation
{
None,
Add,
Take
};
/// <summary>
/// 线程锁定的类
/// </summary>
internal class ThreadLocalList
{
// 双向链表的头结点 如果为null那么表示链表为空
internal volatile Node m_head;
// 双向链表的尾节点
private volatile Node m_tail;
// 定义当前对List进行操作的种类
// 与前面的 ListOperation 相对应
internal volatile int m_currentOp;
// 这个列表元素的计数
private int m_count;
// The stealing count
// 这个不是特别理解 好像是在本地列表中 删除某个Node 以后的计数
internal int m_stealCount;
// 下一个列表 可能会在其它线程中
internal volatile ThreadLocalList m_nextList;
// 设定锁定是否已进行
internal bool m_lockTaken;
// The owner thread for this list
internal Thread m_ownerThread;
// 列表的版本,只有当列表从空变为非空统计是底层
internal volatile int m_version;
/// <summary>
/// ThreadLocalList 构造器
/// </summary>
/// <param name="ownerThread">拥有这个集合的线程</param>
internal ThreadLocalList(Thread ownerThread)
{
m_ownerThread = ownerThread;
}
/// <summary>
/// 添加一个新的item到链表首部
/// </summary>
/// <param name="item">The item to add.</param>
/// <param name="updateCount">是否更新计数.</param>
internal void Add(T item, bool updateCount)
{
checked
{
m_count++;
}
Node node = new Node(item);
if (m_head == null)
{
Debug.Assert(m_tail == null);
m_head = node;
m_tail = node;
m_version++; // 因为进行初始化了,所以将空状态改为非空状态
}
else
{
// 使用头插法 将新的元素插入链表
node.m_next = m_head;
m_head.m_prev = node;
m_head = node;
}
if (updateCount) // 更新计数以避免此添加同步时溢出
{
m_count = m_count - m_stealCount;
m_stealCount = 0;
}
}
/// <summary>
/// 从列表的头部删除一个item
/// </summary>
/// <param name="result">The removed item</param>
internal void Remove(out T result)
{
// 双向链表删除头结点数据的流程
Debug.Assert(m_head != null);
Node head = m_head;
m_head = m_head.m_next;
if (m_head != null)
{
m_head.m_prev = null;
}
else
{
m_tail = null;
}
m_count--;
result = head.m_value;
}
/// <summary>
/// 返回列表头部的元素
/// </summary>
/// <param name="result">the peeked item</param>
/// <returns>True if succeeded, false otherwise</returns>
internal bool Peek(out T result)
{
Node head = m_head;
if (head != null)
{
result = head.m_value;
return true;
}
result = default(T);
return false;
}
/// <summary>
/// 从列表的尾部获取一个item
/// </summary>
/// <param name="result">the removed item</param>
/// <param name="remove">remove or peek flag</param>
internal void Steal(out T result, bool remove)
{
Node tail = m_tail;
Debug.Assert(tail != null);
if (remove) // Take operation
{
m_tail = m_tail.m_prev;
if (m_tail != null)
{
m_tail.m_next = null;
}
else
{
m_head = null;
}
// Increment the steal count
m_stealCount++;
}
result = tail.m_value;
}
/// <summary>
/// 获取总计列表计数, 它不是线程安全的, 如果同时调用它, 则可能提供不正确的计数
/// </summary>
internal int Count
{
get
{
return m_count - m_stealCount;
}
}
}
从上面的代码中我们可以更加验证之前的观点,就是ConcurentBag<T>
在一个线程中存储数据时,使用的是双向链表,ThreadLocalList
实现了一组对链表增删改查的方法。
3. ConcurrentBag实现新增元素
接下来我们看一看ConcurentBag<T>
是如何新增元素的。
/// <summary>
/// 尝试获取无主列表,无主列表是指线程已经被暂停或者终止,但是集合中的部分数据还存储在那里
/// 这是避免内存泄漏的方法
/// </summary>
/// <returns></returns>
private ThreadLocalList GetUnownedList()
{
//此时必须持有全局锁
Contract.Assert(Monitor.IsEntered(GlobalListsLock));
// 从头线程列表开始枚举 找到那些已经被关闭的线程
// 将它所在的列表对象 返回
ThreadLocalList currentList = m_headList;
while (currentList != null)
{
if (currentList.m_ownerThread.ThreadState == System.Threading.ThreadState.Stopped)
{
currentList.m_ownerThread = Thread.CurrentThread; // the caller should acquire a lock to make this line thread safe
return currentList;
}
currentList = currentList.m_nextList;
}
return null;
}
/// <summary>
/// 本地帮助方法,通过线程对象检索线程线程本地列表
/// </summary>
/// <param name="forceCreate">如果列表不存在,那么创建新列表</param>
/// <returns>The local list object</returns>
private ThreadLocalList GetThreadList(bool forceCreate)
{
ThreadLocalList list = m_locals.Value;
if (list != null)
{
return list;
}
else if (forceCreate)
{
// 获取用于更新操作的 m_tailList 锁
lock (GlobalListsLock)
{
// 如果头列表等于空,那么说明集合中还没有元素
// 直接创建一个新的
if (m_headList == null)
{
list = new ThreadLocalList(Thread.CurrentThread);
m_headList = list;
m_tailList = list;
}
else
{
// ConcurrentBag内的数据是以双向链表的形式分散存储在各个线程的本地区域中
// 通过下面这个方法 可以找到那些存储有数据 但是已经被停止的线程
// 然后将已停止线程的数据 移交到当前线程管理
list = GetUnownedList();
// 如果没有 那么就新建一个列表 然后更新尾指针的位置
if (list == null)
{
list = new ThreadLocalList(Thread.CurrentThread);
m_tailList.m_nextList = list;
m_tailList = list;
}
}
m_locals.Value = list;
}
}
else
{
return null;
}
Debug.Assert(list != null);
return list;
}
/// <summary>
/// Adds an object to the <see cref="ConcurrentBag{T}"/>.
/// </summary>
/// <param name="item">The object to be added to the
/// <see cref="ConcurrentBag{T}"/>. The value can be a null reference
/// (Nothing in Visual Basic) for reference types.</param>
public void Add(T item)
{
// 获取该线程的本地列表, 如果此线程不存在, 则创建一个新列表 (第一次调用 add)
ThreadLocalList list = GetThreadList(true);
// 实际的数据添加操作 在AddInternal中执行
AddInternal(list, item);
}
/// <summary>
/// </summary>
/// <param name="list"></param>
/// <param name="item"></param>
private void AddInternal(ThreadLocalList list, T item)
{
bool lockTaken = false;
try
{
#pragma warning disable 0420
Interlocked.Exchange(ref list.m_currentOp, (int)ListOperation.Add);
#pragma warning restore 0420
// 同步案例:
// 如果列表计数小于两个, 因为是双向链表的关系 为了避免与任何窃取线程发生冲突 必须获取锁
// 如果设置了 m_needSync, 这意味着有一个线程需要冻结包 也必须获取锁
if (list.Count < 2 || m_needSync)
{
// 将其重置为None 以避免与窃取线程的死锁
list.m_currentOp = (int)ListOperation.None;
// 锁定当前对象
Monitor.Enter(list, ref lockTaken);
}
// 调用 ThreadLocalList.Add方法 将数据添加到双向链表中
// 如果已经锁定 那么说明线程安全 可以更新Count 计数
list.Add(item, lockTaken);
}
finally
{
list.m_currentOp = (int)ListOperation.None;
if (lockTaken)
{
Monitor.Exit(list);
}
}
}
从上面代码中,我们可以很清楚的知道Add()
方法是如何运行的,其中的关键就是GetThreadList()
方法,通过该方法可以获取当前线程的数据存储列表对象,假如不存在数据存储列表,它会自动创建或者通过GetUnownedList()
方法来寻找那些被停止但是还存储有数据列表的线程,然后将数据列表返回给当前线程中,防止了内存泄漏。
在数据添加的过程中,实现了细颗粒度的lock
同步锁,所以性能会很高。删除和其它操作与新增类似,本文不再赘述。
4. ConcurrentBag 如何实现迭代器模式
看完上面的代码后,我很好奇ConcurrentBag<T>
是如何实现IEnumerator
来实现迭代访问的,因为ConcurrentBag<T>
是通过分散在不同线程中的ThreadLocalList
来存储数据的,那么在实现迭代器模式时,过程会比较复杂。
后面再查看了源码之后,发现ConcurrentBag<T>
为了实现迭代器模式,将分在不同线程中的数据全都存到一个List<T>
集合中,然后返回了该副本的迭代器。所以每次访问迭代器,它都会新建一个List<T>
的副本,这样虽然浪费了一定的存储空间,但是逻辑上更加简单了。
/// <summary>
/// 本地帮助器方法释放所有本地列表锁
/// </summary>
private void ReleaseAllLocks()
{
// 该方法用于在执行线程同步以后 释放掉所有本地锁
// 通过遍历每个线程中存储的 ThreadLocalList对象 释放所占用的锁
ThreadLocalList currentList = m_headList;
while (currentList != null)
{
if (currentList.m_lockTaken)
{
currentList.m_lockTaken = false;
Monitor.Exit(currentList);
}
currentList = currentList.m_nextList;
}
}
/// <summary>
/// 从冻结状态解冻包的本地帮助器方法
/// </summary>
/// <param name="lockTaken">The lock taken result from the Freeze method</param>
private void UnfreezeBag(bool lockTaken)
{
// 首先释放掉 每个线程中 本地变量的锁
// 然后释放全局锁
ReleaseAllLocks();
m_needSync = false;
if (lockTaken)
{
Monitor.Exit(GlobalListsLock);
}
}
/// <summary>
/// 本地帮助器函数等待所有未同步的操作
/// </summary>
private void WaitAllOperations()
{
Contract.Assert(Monitor.IsEntered(GlobalListsLock));
ThreadLocalList currentList = m_headList;
// 自旋等待 等待其它操作完成
while (currentList != null)
{
if (currentList.m_currentOp != (int)ListOperation.None)
{
SpinWait spinner = new SpinWait();
// 有其它线程进行操作时,会将cuurentOp 设置成 正在操作的枚举
while (currentList.m_currentOp != (int)ListOperation.None)
{
spinner.SpinOnce();
}
}
currentList = currentList.m_nextList;
}
}
/// <summary>
/// 本地帮助器方法获取所有本地列表锁
/// </summary>
private void AcquireAllLocks()
{
Contract.Assert(Monitor.IsEntered(GlobalListsLock));
bool lockTaken = false;
ThreadLocalList currentList = m_headList;
// 遍历每个线程的ThreadLocalList 然后获取对应ThreadLocalList的锁
while (currentList != null)
{
// 尝试/最后 bllock 以避免在获取锁和设置所采取的标志之间的线程港口
try
{
Monitor.Enter(currentList, ref lockTaken);
}
finally
{
if (lockTaken)
{
currentList.m_lockTaken = true;
lockTaken = false;
}
}
currentList = currentList.m_nextList;
}
}
/// <summary>
/// Local helper method to freeze all bag operations, it
/// 1- Acquire the global lock to prevent any other thread to freeze the bag, and also new new thread can be added
/// to the dictionary
/// 2- Then Acquire all local lists locks to prevent steal and synchronized operations
/// 3- Wait for all un-synchronized operations to be done
/// </summary>
/// <param name="lockTaken">Retrieve the lock taken result for the global lock, to be passed to Unfreeze method</param>
private void FreezeBag(ref bool lockTaken)
{
Contract.Assert(!Monitor.IsEntered(GlobalListsLock));
// 全局锁定可安全地防止多线程调用计数和损坏 m_needSync
Monitor.Enter(GlobalListsLock, ref lockTaken);
// 这将强制同步任何将来的添加/执行操作
m_needSync = true;
// 获取所有列表的锁
AcquireAllLocks();
// 等待所有操作完成
WaitAllOperations();
}
/// <summary>
/// 本地帮助器函数返回列表中的包项, 这主要由 CopyTo 和 ToArray 使用。
/// 这不是线程安全, 应该被称为冻结/解冻袋块
/// 本方法是私有的 只有使用 Freeze/UnFreeze之后才是安全的
/// </summary>
/// <returns>List the contains the bag items</returns>
private List<T> ToList()
{
Contract.Assert(Monitor.IsEntered(GlobalListsLock));
// 创建一个新的List
List<T> list = new List<T>();
ThreadLocalList currentList = m_headList;
// 遍历每个线程中的ThreadLocalList 将里面的Node的数据 添加到list中
while (currentList != null)
{
Node currentNode = currentList.m_head;
while (currentNode != null)
{
list.Add(currentNode.m_value);
currentNode = currentNode.m_next;
}
currentList = currentList.m_nextList;
}
return list;
}
/// <summary>
/// Returns an enumerator that iterates through the <see
/// cref="ConcurrentBag{T}"/>.
/// </summary>
/// <returns>An enumerator for the contents of the <see
/// cref="ConcurrentBag{T}"/>.</returns>
/// <remarks>
/// The enumeration represents a moment-in-time snapshot of the contents
/// of the bag. It does not reflect any updates to the collection after
/// <see cref="GetEnumerator"/> was called. The enumerator is safe to use
/// concurrently with reads from and writes to the bag.
/// </remarks>
public IEnumerator<T> GetEnumerator()
{
// Short path if the bag is empty
if (m_headList == null)
return new List<T>().GetEnumerator(); // empty list
bool lockTaken = false;
try
{
// 首先冻结整个 ConcurrentBag集合
FreezeBag(ref lockTaken);
// 然后ToList 再拿到 List的 IEnumerator
return ToList().GetEnumerator();
}
finally
{
UnfreezeBag(lockTaken);
}
}
由上面的代码可知道,为了获取迭代器对象,总共进行了三步主要的操作。
- 使用
FreezeBag()
方法,冻结整个ConcurrentBag<T>
集合。因为需要生成集合的List<T>
副本,生成副本期间不能有其它线程更改损坏数据。- 将
ConcurrrentBag<T>
生成List<T>
副本。因为ConcurrentBag<T>
存储数据的方式比较特殊,直接实现迭代器模式困难,考虑到线程安全和逻辑,最佳的办法是生成一个副本。- 完成以上操作以后,就可以使用
UnfreezeBag()
方法解冻整个集合。
那么FreezeBag()
方法是如何来冻结整个集合的呢?也是分为三步走。
- 首先获取全局锁,通过
Monitor.Enter(GlobalListsLock, ref lockTaken);
这样一条语句,这样其它线程就不能冻结集合。- 然后获取所有线程中
ThreadLocalList
的锁,通过`AcquireAllLocks()方法来遍历获取。这样其它线程就不能对它进行操作损坏数据。- 等待已经进入了操作流程线程结束,通过
WaitAllOperations()
方法来实现,该方法会遍历每一个ThreadLocalList
对象的m_currentOp
属性,确保全部处于None
操作。
完成以上流程后,那么就是真正的冻结了整个ConcurrentBag<T>
集合,要解冻的话也类似。在此不再赘述。
四、总结
下面给出一张图,描述了ConcurrentBag<T>
是如何存储数据的。通过每个线程中的ThreadLocal
来实现线程本地存储,每个线程中都有这样的结构,互不干扰。然后每个线程中的m_headList
总是指向ConcurrentBag<T>
的第一个列表,m_tailList
指向最后一个列表。列表与列表之间通过m_locals
下的 m_nextList
相连,构成一个单链表。
数据存储在每个线程的m_locals
中,通过Node
类构成一个双向链表。
PS: 要注意m_tailList
和m_headList
并不是存储在ThreadLocal
中,而是所有的线程共享一份。
以上就是有关ConcurrentBag<T>
类的实现,笔者的一些记录和解析。
笔者水平有限,如果错误欢迎各位批评指正!
附上ConcurrentBag<T>
源码地址:戳一戳