[leetcode] 1032: Stream of Characters: Tries&AC自动机

其实这道题好像大部分人都直接用Tries倒序来解,但我觉得AC自动机可能更高效一点(毕竟是在Tries基础上优化的算法如果还不如原始Tries似乎说不过去)。

根据定义写了个原始的在堆上创建树形结构的solution但好像性能并不是很乐观。另外一些用AC解的dalao好像是用一条线性结构存储所有结点再用指针记录树种结点的连接关系,理解起来相对复杂一些但这样性能应该会比原始的树形结构提升很多。

 

Tries

根据设定的字典中所有单词创建的字典树,对输入字符串的匹配过程就是对tries的遍历。如果某个结点为字典中单词的结束字母时则标记为终止结点(终止结点不一定没有子节点)。当遍历走到终止结点时说明从根节点到该终止结点的路径所代表的字符串匹配成功。

 

e.g. dictionary包含"ab", "ba", "aaab", "abab", "baa"。Tries创建为:

 (终止结点标记为红色)

 

AC自动机

AC自动机在tries的基础上给各结点指定fail指针,在当前分支路径上匹配失败(当前结点的子结点没有可以继续匹配的字符值)时可通过fail指针找到其fail结点,把目前的匹配状态更新为该fail结点上,如果仍然匹配则继续转移下一个fail结点。

结点结构:

e.g. 为dictionary包含"ab", "ba", "aaab", "abab", "baa"的tries分配fail指针。

输入字符串为abaa,当从第1个位置开始试图匹配"abab"失败得知只能匹配到"aba"后,转而从第2个位置开始,利用以前的结果知道已经可以成功匹配"ba",直接从"ba"的下一个位置开始匹配。

则走到第3层中间第2个结点a时因为当前结点没有字符值为a的子节点,匹配失败。说明当前只能匹配到"aba",AC自动机会选择转移到对应的fail结点,也就是原先当前结点的fail指针所指向的第2层第2个a,该结点有值为a的子结点,则向后继续匹配,当前结点转移为第3层第3个a,匹配情况更新root->b->a->a所表示的"baa"。

 

fail指针的分配

1) 根节点 和 根节点的所有子结点 的fail指针都指向根结点;

2) 以BFS的顺序遍历tries中每个结点i,看它parent的fail结点是否有与i字符值相同的子结点,如果有则将fail指针指向那个子结点,没有则继续找下一个fail结点。如果一直找到root仍然没有(root也没有与i有相同字符值的子结点)则将fail指针指向root。

 

解题

直接套用AC自动机来构造StreamChecker。每次check当前输入字符c时在AC自动机中沿着fail搜索,如果能找到值为c的终止结点(isEnd==true)则表示能找到。

所以判断逻辑是:

1. 如果沿着fail一直搜索到根结点都找不到(说明整个trie数中都没有)值为c的结点,返回false;

2. 在搜索过程中找到值为c的结点,如果是终止结点则返回true,否则一直继续沿着fail搜索(同样,直到搜索到根节点)仍找不到值为c的终止结点才返回false。

 

直接照着这个定义写了份粗制劣造的code。缺点:1. 性能差,速度慢;2. debug麻烦,另加了辅助打印函数才调正确; 3. 因为结点基本都是在堆上创建的,还要另写一份析构函数确保内存都成功释放。

 

class TrieNode {
public:
    int value;
    TrieNode* children[26]={NULL};
    TrieNode* fail;
    bool isEnd;
    TrieNode() : value(-1), fail(NULL), isEnd(false) {}
    TrieNode(char charValue) : value(charValue-'a'),fail(NULL),isEnd(false){}
    /* used to debug
     TrieNode() : value(-1), fail(NULL), isEnd(false), parent(-1), tier(0){}
     TrieNode(char charValue, TrieNode* pParent) : value(charValue-'a'), fail(NULL), isEnd(false), parent(pParent->value), tier(pParent->tier+1) {}
     int parent;
     int tier;
     void printBasic() const {
     if (tier==0) {
     cout << "root\n";
     }  else {
     cout << "Node: "<<(char)('a'+value)
     << "\nlevel: "<<tier << '\t'
     <<"parent: ";
     if (tier==1)
     cout << "root";
     else
     cout <<(char)('a'+parent);
     cout <<endl;
     //how to print node whose parent is root
     }
     }
     void print() const {
     printBasic();
     cout << "fail-> ";
     fail->printBasic();
     cout << endl;
     }
     */
};

class StreamChecker {
public:
    
    StreamChecker(vector<string>& words):root() {
        buildTrie(words);
        buildFail();
        current=&root;
    }
    
    ~StreamChecker() {
        for(int i=0;i<26;i++) {
            if (root.children[i]!=NULL)
                releaseChildren(root.children[i]);
        }
    }
    
    void releaseChildren(TrieNode* temp) {
        for (int i=0;i<26;i++) {
            if (temp->children[i]!=NULL)
                releaseChildren(temp->children[i]);
        }
        delete temp;
    }
    
    void buildTrie(vector<string>& words) {
        TrieNode* temp_current;
        for (string word : words) {
            temp_current=&root;
            for (char c : word) {
                if (temp_current->children[c-'a']==NULL) //if node not build yet
                    //used to debug
                    //temp_current->children[c-'a']=new TrieNode(c, temp_current);
                    temp_current->children[c-'a']=new TrieNode(c);
                temp_current=temp_current->children[c-'a']; //go to the target node
            }
            temp_current->isEnd=true;
        }
    }
    
    void buildFail() {
        root.fail=&root;
        queue<TrieNode*> process_queue;
        for (int i=0;i<26;i++) {
            if(root.children[i]!=NULL) {
                root.children[i]->fail=&root;
                process_queue.push(root.children[i]);
            }
        }
        while (!process_queue.empty()) {
            TrieNode* temp_current=process_queue.front();
            for (int i=0;i<26;i++) {
                if (temp_current->children[i]!=NULL) {
                    setFailForChild(temp_current, i);
                    process_queue.push(temp_current->children[i]);
                }
            }
            process_queue.pop();
        }
    }
    
    void setFailForChild(TrieNode* parent, int child_index) {
        TrieNode* to_search=parent->fail;
        while (to_search!=&root) {
            if (to_search->children[child_index]!=NULL) {
                parent->children[child_index]->fail=to_search->children[child_index];
                return;
            }
            to_search=to_search->fail;
        }
        if (root.children[child_index]!=NULL) {
            parent->children[child_index]->fail=root.children[child_index];
            return;
        }
        parent->children[child_index]->fail=&root;
    }
    
    bool query(char letter) {
        while (current!=&root) {
            if (current->children[letter-'a']!=NULL) {
                current=current->children[letter-'a'];
                if (current->isEnd==false) {
                    TrieNode* temp=current;
                    while (temp!=&root) {
                        if (temp->isEnd==true)
                            return true;
                        else
                            temp=temp->fail;
                    }
                    return false;
                } else {
                    return true;
                }
            } else {
                current=current->fail;
            }
        }
        if (root.children[letter-'a']!=NULL) {
            current=root.children[letter-'a'];
            return current->isEnd;
        }
        return false;
    }
    
    TrieNode root;
    TrieNode* current;
    
    /* add for debug
     void printTrie() const {
     root.print();
     queue<TrieNode*> nodes;
     for (int i=0;i<26;i++) {
     if (root.children[i]!=NULL)
     nodes.push(root.children[i]);
     }
     while (!nodes.empty()) {
     const TrieNode* current=nodes.front();
     current->print();
     for (int i=0;i<26;i++) {
     if (current->children[i]!=NULL)
     nodes.push(current->children[i]);
     }
     nodes.pop();
     }
     cout << "--end.\n";
     } */
};

int main() {
    string inputs[]={"ab","ba","aaab","abab","baa"};
    vector<string> words(inputs,inputs+5);
    StreamChecker* checker=new StreamChecker(words);
    cout << "000001111100111100011111101110" << endl << endl;
    string testinput="aaaaabababbbababbbbababaaabaaa";
    for (char c:testinput)
        cout << checker->query(c);
    delete checker;
    cout << endl;
}

 

posted @ 2019-06-09 14:24  丹尼尔奥利瓦  阅读(923)  评论(0编辑  收藏  举报