Lock Free 之 Epoch Based Reclamation
Epoch Based Reclamation
Epoch Based Reclamation
算法参考文档:
http://www.cs.toronto.edu/~tomhart/papers/tomhart_thesis.pdf
https://aturon.github.io/blog/2015/08/27/epoch/#epoch-based-reclamation
为什么要 Epoch Based Reclamation
设想我们要在无 gc 的语言中实现一种无锁但支持并发的数据结构(可能是个map、也可能是个 array):
- 线程 A 想替换该数据结构中的某个数据节点,线程 A 使用 newNode 原子替换了oldNode,并同时释放了oldNode,防止了内存泄露,这一切看起来没什么问题;
- 在线程 A 的上述操作中,若巧好有另一个线程 B 在线程 A 替换前便进行了读取(读到 oldNode 指针),但稍后线程 A 释放了 oldNode,那么线程 B 此时持有的便是一个悬空指针(一旦使用就 coredump 了)。
在拥有 gc 的语言中,如 java, golang,不会出现这种问题。但是对于 c/c++,则没有一种很好的方法应对这问题,除非 非gc语言也有一种延迟回收内存的机制,于是Epoch Based Reclamation作为其中一种方法应运而生。
基本概念
如下图,表示一个节点的删除包含 逻辑删除(delete) 和 物理删除(free) 两个过程:
- 逻辑删除(delete),一个节点在被逻辑删除之时可能会有其他线程正在访问它,逻辑删除不回收内存空间;
- 物理删除(free / reclaim),物理删除之后不会再被线程访问到,会将对应的内存空间回收;
- Grace Period,记时间段 T=[t1, t2],如果 t1 之前逻辑删除的节点,都可以在 t2 之后安全的回收,那么称 T = [t1, t2] 是一个 Grace Period。
算法原理
- 维护了一个全局的 epoch (取值为0、1、2),epoch 的每个取值都对应一个 retire list(存放逻辑删除后待回收的指针);
- 为每个线程维护一个局部的 active flag 和 epoch (取值自然也为0、1、2);
- 线程进入临界区前会设置 active flag = true,并设置自己的局部 epoch 为全局 epoch 的值,离开临界区时设置 active flag = false;
- 线程删除数据时,先放入对应的 retire_list [global_epoch](线程局部 epoch 等于 global epoch);
- 全局 epoch (假设为 E)的增长规则,若所有活跃线程的 epoch 是否都等于 E 时,置换 E = (E + 1) % 3;
全局 epoch 取值 | 局部 epoch 可能取值 | grace period | 可回收的 restire_list |
0 | 2、0 | [1, 0] | 1 |
1 | 0、1 | [2, 1] | 2 |
2 | 1、2 | [0, 2] | 0 |
前面说到 当全局 epoch = E 时, 活跃线程的 epoch 只能是 E 或 E - 1,当最后一个 e = E -1 的线程完成了临界区操作,也即所有活跃线程中 epoch 计数都等于E的时候,也就是说 E - 2(也即 E + 1)对应的回收队列 retireList 中的节点再也没有任何一个线程访问了,故而retireList[E - 2] 中的数据可以真正的释放了。如此做后,全局计数E = E + 1
看个例子的推演,假设就两个线程 A、B,一个共享数据 N,初始数据为 N0。A 不断更改 N,B 则不断读取 N:
epoch |
t0 |
t1 |
t2 |
t3 |
全局 |
0 -> 1 |
1 -> 2 |
2 -> 0 |
0 |
写线程 A |
0,N0 -> N1 |
1,N1 -> N2 |
2,N2 -> N3 |
0,N3 -> N4 |
读线程 B |
0,读到 N0 |
1,读到 N1 |
2,读到 N2 |
0,读到 N3 |
retire_list[0] |
[N0] |
[N0] |
[N0] -> gc -> [ ] |
[N3 ] |
retire_list[1] |
[ ] |
[N1] |
[N1] |
[N1] -> gc -> [ ] |
retire_list[1] |
[ ] |
[ ] -> gc |
[N2] |
[N2] |
实现
一个完整的实现例子:
1 #ifndef DEMO_SMR_H 2 #define DEMO_SMR_H 3 4 #include <atomic> 5 #include <vector> 6 #include <string> 7 #include <iostream> 8 #include <chrono> 9 #include <thread> 10 11 #define CACHELINE_SIZE 64 12 #define MAX_THREAD_NUM 503 13 #define FETCH_AND_ADD(address,offset) __sync_fetch_and_add(address,offset) 14 #define CPU_BARRIER() __asm__ __volatile__("mfence": : :"memory") 15 16 struct alignas(64) ReadIndicator { 17 void arrive(void) { 18 counter.fetch_add(1, std::memory_order_seq_cst); 19 } 20 21 void depart(void) { 22 counter.fetch_sub(1, std::memory_order_release); 23 } 24 25 bool empty(void) { 26 return counter.load(std::memory_order_seq_cst) == 0; 27 } 28 29 private: 30 std::atomic<uint64_t> counter{0}; 31 }; 32 33 struct alignas(64) ReadIndicatorGuard { 34 explicit ReadIndicatorGuard(ReadIndicator& inst) : indicator(inst) { 35 indicator.arrive(); 36 } 37 38 ~ReadIndicatorGuard() { indicator.depart(); } 39 40 private: 41 ReadIndicator& indicator; 42 }; 43 44 template <typename T> 45 class SMRManagerBase { 46 protected: 47 using RetireListType = std::vector<T*>; 48 static constexpr int EBR_CYCLE = 3; 49 50 public: 51 SMRManagerBase() { 52 type_name = typeid(T).name(); 53 } 54 55 virtual ~SMRManagerBase() { 56 for (int i = 0; i < EBR_CYCLE; ++i) { 57 for (T* t : retire_lists_[i]) { 58 delete t; 59 } 60 } 61 } 62 63 /* 读者访问临界资源前,首先调用这个函数 */ 64 virtual void reader_enter() = 0; 65 virtual void reader_leave() = 0; 66 virtual int smr_type() const = 0; 67 68 int32_t zombie_cnt() const { return zombie_cnt_; } 69 int32_t reclaim_cnt() const { return reclaim_cnt_; } 70 71 // 写者回收资源。待回收的资源可能不会立马被回收 72 int reclaim(T* const p) { return writer_reclaim_batch(p); } 73 74 // 批量回收内存 75 int reclaim(const std::vector<T*>& values) { 76 return writer_reclaim_batch(values); 77 } 78 79 int reclaim() { return writer_reclaim_batch(nullptr); } 80 81 int fast_reclaim(int64_t interval_ts, int64_t times) { 82 int count = 0; 83 while (times-- > 0) { 84 std::this_thread::sleep_for(std::chrono::microseconds(interval_ts)); 85 count += reclaim(); 86 } 87 return count; 88 } 89 90 int fast_reclaim(T* const p, int64_t interval_ts = 400, int64_t times = 3) { 91 int count = 0; 92 count += reclaim(p); 93 count += fast_reclaim(interval_ts, times); 94 return count; 95 } 96 97 // 物理删除一个 retire_list 98 int writer_free(int epoch) { 99 int count = 0; 100 RetireListType& retire_list = retire_lists_[epoch]; 101 for (auto& retire_pointer : retire_list) { 102 if (retire_pointer != nullptr) { 103 ++count; 104 delete retire_pointer; 105 retire_pointer = nullptr; 106 } 107 } // end for 108 109 if (retire_list.size() > 1000) { 110 retire_list.clear(); 111 retire_list.shrink_to_fit(); 112 } 113 zombie_cnt_ -= count; 114 reclaim_cnt_ += count; 115 return count; 116 } 117 118 virtual ReadIndicator& get_read_indicator() = 0; 119 120 ReadIndicatorGuard* get_read_guard() { 121 return new ReadIndicatorGuard(get_read_indicator()); 122 } 123 124 protected: 125 virtual int32_t writer_gc() = 0; 126 virtual int32_t get_epoch() = 0; 127 128 // 写者回收资源。待回收的资源可能不会立马被回收。 129 int writer_reclaim_batch(T* const p) { 130 writer_record(p); 131 return writer_gc(); 132 } 133 134 //批量回收内存,提高效率 135 int writer_reclaim_batch(const std::vector<T*>& values) { 136 writer_record(values); 137 return writer_gc(); 138 } 139 140 /* 逻辑删除,将 p 写到 retire_list 中 */ 141 void writer_record(T* const p) { 142 if (p == nullptr) { 143 return; 144 } 145 146 RetireListType& retire_list = retire_lists_[get_epoch()]; 147 bool found_vacant = false; 148 for (auto& retire_pointer : retire_list) { 149 if (retire_pointer == nullptr) { 150 retire_pointer = p; 151 found_vacant = true; 152 break; 153 } 154 } 155 if (!found_vacant) { 156 retire_list.push_back(p); 157 } 158 159 ++zombie_cnt_; 160 } 161 162 void writer_record(const std::vector<T*>& reclaim_nodes) { 163 if (reclaim_nodes.empty()) { 164 return; 165 } 166 RetireListType& retire_list = retire_lists_[get_epoch()]; 167 168 int32_t rn_index = 0; 169 int32_t store_index = 0; 170 int cnt = 0; 171 while (rn_index < reclaim_nodes.size()) { 172 // 尝试找到一个空的位置存储一下 173 while (store_index < retire_list.size() && 174 retire_list[store_index] != nullptr) { 175 ++store_index; 176 } 177 178 // 如果没有空的位置,说明 retire_list 需要扩容 179 if (store_index >= retire_list.size()) { 180 break; 181 } 182 183 // 找到空位置的 184 retire_list[store_index++] = reclaim_nodes[rn_index++]; 185 ++cnt; 186 } 187 188 if (rn_index < reclaim_nodes.size()) { 189 int remains = reclaim_nodes.size() - rn_index; 190 retire_list.reserve(retire_list.size() + remains + 10); 191 192 while (rn_index < reclaim_nodes.size()) { 193 retire_list.push_back(reclaim_nodes[rn_index++]); 194 ++cnt; 195 } 196 } 197 198 zombie_cnt_ += reclaim_nodes.size(); 199 std::cout << "reclaim node:" << reclaim_nodes.size() << ", cnt:" << cnt << std::endl; 200 } 201 202 protected: 203 RetireListType retire_lists_[3]; 204 205 int32_t zombie_cnt_ = 0; // 逻辑删除的计数器 206 int32_t reclaim_cnt_ = 0; // 物理删除的计数器 207 std::string type_name; 208 }; 209 210 template <typename T> 211 // SMR is short for Safety Memory Reclamation 212 class SMRManager : public SMRManagerBase<T> { 213 using Base = typename SMRManager::SMRManagerBase; 214 215 public: 216 using Base::fast_reclaim; 217 using Base::reclaim; 218 using Base::reclaim_cnt; 219 using Base::writer_free; 220 using Base::zombie_cnt; 221 using Base::type_name; 222 223 // 这个构造需要在单线程下进行 224 SMRManager() { 225 global_epoch_.epoch_ = 0; 226 local_epoches_ = new EpochType[MAX_THREAD_NUM]; 227 active_flags_ = new ActiveFlagType[MAX_THREAD_NUM]; 228 for (int i = 0; i < MAX_THREAD_NUM; ++i) { 229 local_epoches_[i].epoch_ = 0; 230 active_flags_[i].active_ = false; 231 } 232 } 233 234 ~SMRManager() { 235 delete[] local_epoches_; 236 delete[] active_flags_; 237 } 238 239 SMRManager(const SMRManager& rhs) = delete; 240 SMRManager& operator=(const SMRManager& rhs) = delete; 241 242 /*读者访问临界资源前,首先调用这个函数*/ 243 void reader_enter() override { 244 active_flags_[get_thread_id()].active_ = true; 245 CPU_BARRIER(); 246 local_epoches_[get_thread_id()].epoch_ = global_epoch_.epoch_; 247 CPU_BARRIER(); 248 } 249 250 /*读者离开临界区后,调用这个函数*/ 251 void reader_leave() override { 252 CPU_BARRIER(); 253 active_flags_[get_thread_id()].active_ = false; 254 } 255 256 ReadIndicator& get_read_indicator() override { 257 static ReadIndicator indicator; 258 std::cout << "unexpected call, type_name:" << type_name << std::endl; 259 return indicator; 260 } 261 262 int smr_type() const override { return 0; } 263 264 private: 265 int32_t get_epoch() override { 266 return global_epoch_.epoch_; 267 } 268 269 int writer_gc() override { 270 for (int i = 0; i < MAX_THREAD_NUM; i++) { 271 if (active_flags_[i].active_ && local_epoches_[i].epoch_ != global_epoch_.epoch_) { 272 // 此时有活跃读线程 lag 了 epoch 273 return writer_free((global_epoch_.epoch_ + 1) % 3); 274 } 275 } 276 // 所有的活跃读线程的 epoch 都升上来了 277 global_epoch_.epoch_ = (global_epoch_.epoch_ + 1) % 3; 278 CPU_BARRIER(); // 这个 memory barrier 的使用可以做一些优化 279 return writer_free((global_epoch_.epoch_ + 1) % 3); 280 } 281 282 int get_thread_id() { 283 static __thread int tid = -1; 284 if (tid == -1) { 285 tid = FETCH_AND_ADD(&thread_count_, 1); 286 if (tid >= MAX_THREAD_NUM) { 287 // abort(); 288 } 289 } 290 return tid; 291 } 292 293 private: 294 struct EpochType { 295 volatile int epoch_; 296 } __attribute__((aligned(CACHELINE_SIZE))); 297 298 struct ActiveFlagType { 299 volatile bool active_; 300 } __attribute__((aligned(CACHELINE_SIZE))); // 非常重要,保证原子,且避免 false sharing 301 302 private: 303 EpochType global_epoch_; // 全局 epoch 304 EpochType* local_epoches_; // 数组,每个线程占一个元素 305 ActiveFlagType* active_flags_; // 数组,每个线程占一个元素 306 307 int32_t zombie_cnt_ = 0; 308 int32_t reclaim_cnt_ = 0; 309 static int thread_count_; 310 }; 311 312 template <typename T> 313 int SMRManager<T>::thread_count_ = 0; 314 315 316 317 #endif //DEMO_SMR_H
SMR
SMR(Safety Memory Reclamation)是基于 Epoch Based Reclamation 的延迟 gc 技术,下面有个使用 SMR 实现线程安全的 map 例子:
1 template <typename TKey, typename TValue> 2 class ConcurrentMap { 3 public: 4 using Key = TKey; 5 using Value = TValue; 6 7 ConcurrentMap() { 8 lazy_batch_.reserve(kLazyBatchSize); 9 smr_.reset(new SMRManager<Value>()); 10 } 11 12 ~ConcurrentMap() { 13 data_.clear(); 14 } 15 16 void Acquire() { 17 smr_->reader_enter(); 18 } 19 void Release() { 20 smr_->reader_leave(); 21 } 22 ReadIndicatorGuard *get_read_guard() { 23 return smr_->get_read_guard(); 24 } 25 int SmrType() const { 26 return smr_->smr_type(); 27 } 28 void SetSmrType() { 29 smr_.reset(new SMRManager<Value>()); 30 } 31 32 size_t Size() const { 33 return data_.size(); 34 } 35 size_t Capacity() const { 36 return data_.capacity(); 37 } 38 int32_t zombie_cnt() const { 39 return smr_->zombie_cnt(); 40 } 41 int32_t reclaim_cnt() const { 42 return smr_->reclaim_cnt(); 43 } 44 45 const Value *Get(const Key &key) const { 46 const Value *node = nullptr; 47 auto it = data_.find(key); 48 if (it != data_.end()) { 49 node = it->second; 50 } 51 return node; 52 } 53 54 void lazy_reclaim(Value *&node) { 55 if (!node) { 56 return; 57 } 58 59 lazy_batch_.push_back(node); 60 if (lazy_batch_.size() >= kLazyBatchSize) { 61 smr_->reclaim(lazy_batch_); 62 lazy_batch_.clear(); 63 } 64 } 65 // 删除 key 66 int Delete(const Key &key) { 67 auto it = data_.find(key); 68 if (it != data_.end()) { 69 lazy_reclaim(it->second); 70 data_.erase(it); 71 return 1; 72 } 73 74 return 0; 75 } 76 77 // 新增或更新 key 78 int Insert(const Key &key, const Value &value) { 79 Value *node = new Value(value); 80 Value *tmp = nullptr; 81 82 auto it = data_.find(key); 83 // key 存在 84 if (it != data_.end()) { 85 Value *&v = it->second; 86 if (v->equal(*node)) { 87 // value 相等,不更新直接退出 88 delete node; 89 } else { 90 // value 不相等,删除老的 value 对象 91 tmp = v; 92 v = node; 93 lazy_reclaim(tmp); 94 } 95 } else { 96 data_.emplace(key, node); 97 } 98 return 0; 99 } 100 101 private: 102 const int32_t kLazyBatchSize = 1000; 103 std::vector<Value *> lazy_batch_; 104 105 std::map<Key, Value *> data_; 106 std::unique_ptr<SMRManagerBase<Value>> smr_; 107 };