无锁栈的实现
template<typename T>
class LockFreeStack
{
private:
struct Node;
struct CountedNode
{
int externalCount = 0;
Node* ptr = nullptr;
};
struct Node
{
std::shared_ptr<T> data;
std::atomic<int> internalCount;
CountedNode next;
Node(T const& data_):
data(std::make_shared<T>(data_)),
internalCount(0)
{}
};
std::atomic<CountedNode> head;
void increaseHeadCount(CountedNode& oldCounter)
{
CountedNode newCounter;
do {
newCounter = oldCounter;
++newCounter.externalCount;
} while (!head.compare_exchange_strong(oldCounter, newCounter,
std::memory_order_acquire,
std::memory_order_relaxed));
oldCounter.externalCount = newCounter.externalCount;
}
public:
~LockFreeStack()
{
while(pop() != nullptr);
}
void push(T const& data)
{
CountedNode newNode;
newNode.ptr = new Node(data);
newNode.externalCount = 1;
newNode.ptr->next = head.load(std::memory_order_relaxed);
while(!head.compare_exchange_weak(newNode.ptr->next, newNode,
std::memory_order_release,
std::memory_order_relaxed));
}
std::shared_ptr<T> pop()
{
auto oldHead = head.load(std::memory_order_relaxed);
for(;;){
increaseHeadCount(oldHead);
auto const nodePtr = oldHead.ptr;
if (nodePtr == nullptr){
return shared_ptr<T>();
}
if (head.compare_exchange_strong(oldHead, nodePtr->next,
std::memory_order_relaxed)){
std::shared_ptr<T> result;
result.swap(nodePtr->data);
int const increaseCount = oldHead.externalCount - 2;
if (nodePtr->internalCount.fetch_add(increaseCount,
std::memory_order_release)
== -increaseCount){
delete nodePtr;
}
return result;
}
else if (nodePtr->internalCount.fetch_add(-1,
std::memory_order_acquire) == 1){
nodePtr->internalCount.load(std::memory_order_acquire);
delete nodePtr;
}
}
}
};
为了测试其正确性,我用了以下代码作为实验:
LockFreeStack<int> stack;
std::thread t1([&]
{
for (int i = 0; i < 100; ++i){
if(i % 2 == 0){
stack.push(i);
}
else{
auto const result = stack.pop();
if (result != nullptr){
cout << *result << " ";
}
}
}
});
std::thread t2([&]
{
for (int i = 100; i < 200; ++i){
stack.push(i);
}
});
std::thread t3([&]
{
for (int i = 0; i < 199; ++i){
auto const result = stack.pop();
if (result != nullptr){
cout << *result << " ";
}
}
});
t1.join();
t2.join();
t3.join();
结果输出很奇怪,很多好多上千的数,我一度认为是栈实现的问题,直到我用 printf 替换了 cout 之后……