导航

基于AC有限状态机的多模匹配算法

Posted on 2014-07-15 18:28  codeape  阅读(2397)  评论(0编辑  收藏  举报

参考链接:http://www.cnblogs.com/zzqcn/p/3525636.html

感谢原文作者。

花了两天半时间实现并测试了算法。

按照上文的思路实现了一遍,可能是原文中有些地方描述的不是特别清楚,导致一开始测试的时候发现了各种匹配遗漏的情况,后经过自己各种努力终于解决了各种遗漏。

同时在实现过程中也遇到了各种小问题,最后都解决了,总结起来主要有四个大坑,自己实现的时候需要注意,四个坑都在代码的注释里面了。

这里的实现虽然不会有遗漏的情况,但会有同一模式串在相同的偏移多次被命中的情况,但无伤大雅,至少没有遗漏不是吗。实际应用中只需对结果做去重就好了。

测试结论:对一个101.3MB的PE,从中随机抽取长度在[16-116)Bytes的模式串16个,分别用memcmp方式和AC自动机方式进行匹配,memcmp方式耗时33秒,AC方式耗时12秒,可见优势还是比较明显的。

代码中如有哪里不对,欢迎一起讨论。

  1 #include <cstdlib>
  2 #include <cstdio>
  3 #include <cstring>
  4 #include <stdint.h>
  5 #include <vector>
  6 #include <map>
  7 #include <queue>
  8 #include <ctime>
  9 
 10 typedef struct ACNode
 11 {
 12     uint64_t        u64Depth;
 13     struct ACNode   *pFail;
 14     std::map<unsigned char, struct ACNode *>    *pmpGotoTab;
 15     struct ACParrent
 16     {
 17         struct ACNode   *pParent;
 18         unsigned char   ucCondition;
 19     } Parent;
 20     bool            bIsMathed;
 21 } AC_NODE, *P_AC_NODE;
 22 
 23 typedef void (__stdcall *P_AC_FOUND_CALLBACK)(const unsigned char *In_pucBuf, uint64_t In_u64EndPos, uint64_t In_u64Len);
 24 
 25 int InitACGoto(const std::vector<const std::vector<unsigned char> *> &In_vctPattern,
 26     std::vector<P_AC_NODE> &Out_vctACNodes)
 27 {
 28     int             iRetVal     = 0;
 29     P_AC_NODE       pRoot       = NULL;
 30     unsigned int    uiPattIdx   = 0;
 31     unsigned int    uiUCharIdx  = 0;
 32     uint16_t        u16Idx      = 0;
 33 
 34     if (In_vctPattern.empty())
 35     {
 36         iRetVal = -1;
 37         goto fun_ret;
 38     }
 39 
 40     pRoot = (P_AC_NODE)calloc(1, sizeof(AC_NODE));
 41     if (pRoot == NULL)
 42     {
 43         iRetVal = -2;
 44         goto fun_ret;
 45     }
 46 
 47     pRoot->pmpGotoTab = new std::map<unsigned char, struct ACNode *>();
 48     for (u16Idx = 0; u16Idx <= 0xff; u16Idx ++)
 49         pRoot->pmpGotoTab->insert(std::pair<unsigned char, struct ACNode *>((unsigned char)u16Idx, pRoot));
 50     Out_vctACNodes.push_back(pRoot);
 51 
 52     for (uiPattIdx = 0; uiPattIdx < In_vctPattern.size(); uiPattIdx ++)
 53     {
 54         P_AC_NODE   pCurNode    = pRoot;
 55         for (uiUCharIdx = 0; uiUCharIdx < In_vctPattern[uiPattIdx]->size(); uiUCharIdx ++)
 56         {
 57             unsigned char   ucCurUChar  = In_vctPattern[uiPattIdx]->at(uiUCharIdx);
 58             if (pCurNode->pmpGotoTab->find(ucCurUChar) == pCurNode->pmpGotoTab->end()
 59                 || (pCurNode->pmpGotoTab->find(ucCurUChar) != pCurNode->pmpGotoTab->end()
 60                 && pCurNode->pmpGotoTab->at(ucCurUChar) == pRoot))
 61             {
 62                 P_AC_NODE   pNode = (P_AC_NODE)calloc(1, sizeof(AC_NODE));
 63                 if (pNode == NULL)
 64                 {
 65                     iRetVal = -3;
 66                     goto fun_ret;
 67                 }
 68 
 69                 pNode->u64Depth = uiUCharIdx + 1;
 70                 pNode->Parent.pParent = pCurNode;
 71                 pNode->Parent.ucCondition = ucCurUChar;
 72                 pNode->pmpGotoTab = new std::map<unsigned char, struct ACNode *>();
 73 
 74                 if (pCurNode->pmpGotoTab->find(ucCurUChar) != pCurNode->pmpGotoTab->end())
 75                     pCurNode->pmpGotoTab->erase(ucCurUChar);
 76                 pCurNode->pmpGotoTab->insert(std::pair<unsigned char, struct ACNode *>(ucCurUChar, pNode));
 77                 pCurNode = pNode;
 78                 Out_vctACNodes.push_back(pNode);
 79             }
 80             else
 81                 pCurNode = pCurNode->pmpGotoTab->at(ucCurUChar);
 82 
 83             if (uiUCharIdx == In_vctPattern[uiPattIdx]->size() - 1)
 84                 pCurNode->bIsMathed = true;
 85         }
 86     }
 87 
 88 fun_ret:
 89     return iRetVal;
 90 }
 91 
 92 int ACFail(std::vector<P_AC_NODE> &Out_vctACNodes)
 93 {
 94     int                     iRetVal = 0;
 95     std::queue<P_AC_NODE>   quNodes;
 96 
 97     if (Out_vctACNodes.empty())
 98     {
 99         iRetVal = -1;
100         goto fun_ret;
101     }
102 
103     quNodes.push(Out_vctACNodes[0]);
104     while (!quNodes.empty())
105     {
106         std::map<unsigned char, struct ACNode *>::iterator  itGoto;
107         P_AC_NODE   pNode = quNodes.front();
108         quNodes.pop();
109         if (pNode->u64Depth <= 1)
110             pNode->pFail = Out_vctACNodes[0];
111         else
112         {
113             P_AC_NODE   pParentFail = pNode->Parent.pParent->pFail;
114             while (pParentFail->pmpGotoTab->find(pNode->Parent.ucCondition) == pParentFail->pmpGotoTab->end())
115                 pParentFail = pParentFail->pFail;
116             pNode->pFail = pParentFail->pmpGotoTab->at(pNode->Parent.ucCondition);
117         }
118         for (itGoto = pNode->pmpGotoTab->begin(); itGoto != pNode->pmpGotoTab->end(); itGoto ++)
119         {
120             if (itGoto->second != Out_vctACNodes[0])
121                 quNodes.push(itGoto->second);
122         }
123     }
124 
125 fun_ret:
126     return iRetVal;
127 }
128 
129 void __stdcall ACFoundCallBack(const unsigned char *In_pucBuf, uint64_t In_u64EndPos, uint64_t In_u64Len)
130 {
131     if (In_pucBuf == NULL || In_u64Len == 0)
132         goto fun_ret;
133 
134     printf("<<<<<<<<<<FUCKOFF:%x\n", In_u64EndPos - In_u64Len);
135 
136 fun_ret:
137     return;
138 }
139 
140 int ACSearch(const P_AC_NODE In_pRoot, const unsigned char *In_pucBuf, uint64_t In_u64BufLen, P_AC_FOUND_CALLBACK In_pfCallBack)
141 {
142     int         iRetVal     = 0;
143     P_AC_NODE   pCurrent    = NULL;
144     uint64_t    u64Idx      = 0;
145 
146     if (In_pRoot == NULL || In_pucBuf == NULL || In_u64BufLen == 0 || In_pfCallBack == NULL)
147     {
148         iRetVal = -1;
149         goto fun_ret;
150     }
151 
152     pCurrent = In_pRoot;
153     for (u64Idx = 0; u64Idx < In_u64BufLen;)
154     {
155         P_AC_NODE   pFail   = NULL;
156         if (pCurrent->pmpGotoTab->find(In_pucBuf[u64Idx]) != pCurrent->pmpGotoTab->end())
157         {
158             pCurrent = pCurrent->pmpGotoTab->at(In_pucBuf[u64Idx]);
159             //坑1,出现匹配失败时不要前进,只在匹配成功时前进
160             u64Idx ++;
161         }
162         else
163             pCurrent = pCurrent->pFail;
164 
165         //坑3,每个节点都需要沿着失配指针一直向上找所有匹配到的结果,而不是
166         //只在匹配成功时才这么做,否则会出现匹配遗漏(形如“abcd”和“bc”这样的特征串并存的情况)
167         pFail = pCurrent->pFail;
168         //坑4,一定要走到根,否则会出现匹配遗漏
169         while (pFail != In_pRoot)
170         {
171             if (pFail->bIsMathed)
172                 In_pfCallBack(In_pucBuf, u64Idx, pFail->u64Depth);
173             pFail = pFail->pFail;
174         }
175         //坑2,不管是否匹配成功,都要判断当前节点状态,因为出现失配后的
176         //转移也有可能转到一个成功匹配的节点上
177         if (pCurrent->bIsMathed)
178             In_pfCallBack(In_pucBuf, u64Idx, pCurrent->u64Depth);
179     }
180 
181 fun_ret:
182     return iRetVal;
183 }
184 
185 void ReleaseACNodes(std::vector<P_AC_NODE> &Out_vctACNodes)
186 {
187     unsigned int    uiIdx   = 0;
188     for (uiIdx = 0; uiIdx < Out_vctACNodes.size(); uiIdx ++)
189     {
190         delete Out_vctACNodes[uiIdx]->pmpGotoTab;
191         free(Out_vctACNodes[uiIdx]);
192     }
193     Out_vctACNodes.clear();
194 }
195 
196 void main(int argc, char **argv)
197 {
198     std::vector<P_AC_NODE>  vctNodes;
199     std::vector<const std::vector<unsigned char> *> vctPatterns;
200     unsigned char   *pucBuf = NULL;
201     FILE            *pf     = NULL;
202     long            lFileSize   = 0;
203     time_t          tACBegin    = {0};
204     double          dMemSec     = 0.0;
205 
206     pf = fopen(argv[1], "rb");
207     fseek(pf, 0, SEEK_END);
208     lFileSize = ftell(pf);
209     fseek(pf, 0, SEEK_SET);
210     pucBuf = (unsigned char *)calloc(lFileSize, 1);
211     fread(pucBuf, 1, lFileSize, pf);
212     fclose(pf);
213     for (int i = 0; i < 1600; i ++)
214     {
215         std::vector<unsigned char>  *pvctPattern = new std::vector<unsigned char>();
216         int iBegin  = rand() % (lFileSize - 128);
217         int iLen    = rand() % 100 + 16;
218         for (int j = 0; j < iLen; j ++)
219             pvctPattern->push_back(pucBuf[j + iBegin]);
220         vctPatterns.push_back(pvctPattern);
221         printf("%x:%u\n", iBegin, iLen);
222         for (long j = 0; j < lFileSize - iLen; j ++)
223         {
224             time_t  tMemBegin   = time(NULL);
225             if (memcmp(pucBuf + iBegin, pucBuf + j, iLen) == 0)
226                 printf(">>>>>>>>>>Off:%x\n", j);
227             dMemSec += difftime(time(NULL), tMemBegin);
228         }
229     }
230 
231     InitACGoto(vctPatterns, vctNodes);
232     ACFail(vctNodes);
233     tACBegin = time(NULL);
234     ACSearch(vctNodes[0], pucBuf, lFileSize, ACFoundCallBack);
235     printf("MemTime::%f\nACTime::%f\n", dMemSec, difftime(time(NULL), tACBegin));
236     ReleaseACNodes(vctNodes);
237     return;
238 }