c++ 为自定义类添加stl遍历器风格的遍历方式

为仿照stl的遍历风格,实现对自定义类型的遍历。

1. 需要遍历的基础结构:

 1 struct ConnectionPtr
 2 {
 3     int id_;
 4     int port_;
 5     string addr_;
 6 
 7     
 8     //std::set 需要排序,需要重载<
 9     bool operator <(const ConnectionPtr &ptr)const {
10         return id_ < ptr.id_;
11     }
12 
13     void printPtr() //显示函数
14     {
15         cout << "ConnectionPtr :" << endl;
16         cout << "id = " << id_ << "  port = " << port_ << "  addr_ = " << addr_ << endl;
17         cout << endl;
18     }
19 };

2. 需要实现统一风格遍历的自定义结构:

1 struct Region_connection
2 {
3     map<int, set<ConnectionPtr>> connections; //自定义结构 
4

3. 遍历器的结构:

  1 class RegionIterator
  2 {
  3 public:
  4     int key;
  5     ConnectionPtr* ptr;
  6     map<int, set<ConnectionPtr>>* pconnections;
  7 
  8 
  9     RegionIterator(){
 10         clear();
 11     }
 12 
 13     //赋值
 14     RegionIterator& operator = (const RegionIterator &iter)
 15     {
 16         key = iter.key;
 17         ptr = iter.ptr;
 18     }
 19     //不等于
 20     bool operator != (const RegionIterator &iter)
 21     {
 22         return (key != iter.key) || (ptr != iter.ptr);
 23     }
 24     //等于
 25     bool operator == (const RegionIterator &iter)
 26     {
 27         return (key == iter.key) && (ptr == iter.ptr);
 28     }
 29     //前缀自加
 30     RegionIterator& operator ++ ()
 31     {
 32         RegionIterator itor = next(key, ptr);
 33         this->key = itor.key;
 34         this->ptr = itor.ptr;
 35         return *this;
 36     }
 37     //后缀自加
 38     RegionIterator operator ++ (int)
 39     {
 40         RegionIterator tmp = *this;
 41         RegionIterator itor = next(key, ptr);
 42         this->key = itor.key;
 43         this->ptr = itor.ptr;
 44         return tmp;
 45     }
 46     //取值
 47     ConnectionPtr& operator * ()
 48     {
 49         return *ptr;
 50     }
 51 
 52 private:
 53     void clear()
 54     {
 55         key = -1;
 56         ptr = nullptr;
 57         pconnections = nullptr;
 58     }
 59 
 60     RegionIterator next(int key_tmp, ConnectionPtr* ptr_tmp)
 61     {
 62         assert(pconnections);
 63         RegionIterator region;
 64         auto iter = pconnections->find(key_tmp);
 65         if (iter != pconnections->end()){
 66             const set<ConnectionPtr>& sets = iter->second;
 67             if (sets.size()){
 68                 if (ptr_tmp){
 69                     auto itr = sets.find(*ptr_tmp);
 70                     if (itr != sets.end()){
 71                         if (++itr != sets.end()){
 72                             region.key = key_tmp;
 73                             region.ptr = (ConnectionPtr*)&*(itr);    
 74                             return region;
 75                         }else{
 76                             if (++iter != pconnections->end()){
 77                                 key_tmp = iter->first;
 78                                 return next(key_tmp, nullptr);
 79                             }else{
 80                                 return region;
 81                             }
 82                         }
 83                     }
 84                     else{
 85                         return region; 
 86                     }
 87                 }else{
 88                     region.key = key_tmp;
 89                     region.ptr = (ConnectionPtr*)&(*sets.begin());
 90                     return region;
 91                 }
 92                 
 93             }else{
 94                 assert(ptr_tmp == nullptr);
 95                 key_tmp = (++iter)->first;
 96                 return next(key_tmp, ptr_tmp);
 97             }
 98         }
 99         else{
100             return region;
101         }
102     }
103 };

4. 为实现要求,需要在自定义结构添加部分函数:

 1 struct Region_connection
 2 {
 3     map<int, set<ConnectionPtr>> connections;
 4 
 5     typedef RegionIterator iterator;
 6     iterator begin(){
 7         iterator itor;
 8         itor.pconnections = &connections;
 9         if (connections.size() > 0)
10         {
11             for (auto itr = connections.begin(); itr != connections.end(); itr++)
12             {
13                 const set<ConnectionPtr>& sets = itr->second;
14                 if (sets.size() > 0)
15                 {
16                     auto itr2 = sets.begin();
17                     ConnectionPtr* p = (ConnectionPtr*)&(*itr2);
18                     itor.key = itr->first;
19                     itor.ptr = p;
20                     return itor;
21                 }
22             }
23         }
24         return iterator();
25     }
26 
27     iterator end(){
28         return iterator();
29     }
30 
31     ConnectionPtr& operator[](const RegionIterator& itor){
32         auto connect = connections.find(itor.key);
33         assert(connect != connections.end());
34         const set<ConnectionPtr>& sets = connect->second;
35         auto ptr = sets.find(*itor.ptr);
36         assert(ptr != sets.end());
37         return (ConnectionPtr&)*ptr;
38     }
39 
40     int size(){
41         int size = 0;
42         for (auto itor : connections)
43         {
44             size += itor.second.size();
45         }
46         return size;
47     }
48 
49 };

5. 测试代码:

 1 #include "stdafx.h"
 2 
 3 #include <iostream>
 4 #include <map>
 5 #include <set>
 6 #include <string>
 7 #include <algorithm>
 8 #include <stdlib.h>
 9 #include <time.h>
10 #include "mylterater.h"
11 #include "regiontest.h"
12 #include <time.h>
13 
14 using namespace std;
15 
16 #define random(x,y) (((double)rand()/RAND_MAX)*(y-x)+x)
17 
18 int _tmain(int argc, _TCHAR* argv[])
19 {
20 
21     srand((int)time(0));
22     //构造Region_connection
23     Region_connection region;
24     int n = 0;
25     for (int i = 0; i < random(900,1000); i++)
26     {
27         set<ConnectionPtr> sets;
28         int num = random(80, 90);
29         for (int j = 0; j < num; j++, n++)
30         {
31             ConnectionPtr ptr;
32             ptr.id_ = n;
33             ptr.port_ = 200 + n;
34             sets.insert(ptr);
35         }
36         region.connections.insert(std::make_pair(i, sets));
37     }
38 
39     //遍历打印
40     clock_t starttim, endtim;
41     starttim = clock();
42     for (auto iter = region.begin(); iter != region.end(); iter++)
43     {
44         ConnectionPtr& ptr = region[iter];//*iter;
45         ptr.printPtr();
46     }
47     endtim = clock();
48     cout << "Total time : " << endtim - starttim << " ms" << endl;
49 
50     getchar();
51 
52     return 0;
53 }

 

posted @ 2018-08-01 17:42  漆天初晓  阅读(869)  评论(0编辑  收藏  举报