[C++]简单即敲即得系统(基于Trie树)
这几日实现了一个基于Trie树的简单即敲即得系统,数据库为MySql,测试数据为380万条记录。输入数据为utf-8格式,在Trie树节点中保存的是wchar类型的字符,因为一个汉字需要2个字节的长度,如果使用char类型就必须将一个汉字拆成两个节点来保存,这样以后要做模糊搜索什么的就比较麻烦。
工程共包含5个代码文件,分别如下:
Char2Wchar.h : 实现char到wchar_t的转换
代码
Trie.h : 定义trie树的结构
代码
#ifndef __TRIE__
#define __TRIE__
#include <string.h>
#include <stdlib.h>
// structure of trie node
struct TrieNode
{
TrieNode():kidlen(0),buflen(0),kids(NULL),invlen(0),ibflen(0),invList(NULL){}
// parameter list
wchar_t key;
int kidlen,buflen;
TrieNode* kids;
int invlen,ibflen;
unsigned* invList;
TrieNode* Insert(wchar_t w)
{
for (int i=0;i<kidlen;i++)
if(kids[i].key==w)
return &kids[i];
//check memory
if (kidlen==buflen)
{
buflen=buflen*2+1;
int tlen = sizeof(TrieNode)*buflen;
kids=(TrieNode*)realloc(kids,tlen);
}
//add child node
new (&kids[kidlen]) TrieNode();
kids[kidlen].key=w;
return &kids[kidlen++];
}
// add record id to inverted list
void add2Invlist(unsigned id)
{
if (invlen>0)
if (invList[invlen-1]==id)return;
if (invlen==ibflen)
{
ibflen=ibflen*2+1;
size_t nlen=sizeof(unsigned)*ibflen;
invList=(unsigned*)realloc(invList,nlen);
}
invList[invlen++]=id;
}
// search function
TrieNode* Search(wchar_t w)
{
for (int i=0;i<kidlen;i++)
if(kids[i].key==w)
return &kids[i];
return NULL;
}
};
#endif
#define __TRIE__
#include <string.h>
#include <stdlib.h>
// structure of trie node
struct TrieNode
{
TrieNode():kidlen(0),buflen(0),kids(NULL),invlen(0),ibflen(0),invList(NULL){}
// parameter list
wchar_t key;
int kidlen,buflen;
TrieNode* kids;
int invlen,ibflen;
unsigned* invList;
TrieNode* Insert(wchar_t w)
{
for (int i=0;i<kidlen;i++)
if(kids[i].key==w)
return &kids[i];
//check memory
if (kidlen==buflen)
{
buflen=buflen*2+1;
int tlen = sizeof(TrieNode)*buflen;
kids=(TrieNode*)realloc(kids,tlen);
}
//add child node
new (&kids[kidlen]) TrieNode();
kids[kidlen].key=w;
return &kids[kidlen++];
}
// add record id to inverted list
void add2Invlist(unsigned id)
{
if (invlen>0)
if (invList[invlen-1]==id)return;
if (invlen==ibflen)
{
ibflen=ibflen*2+1;
size_t nlen=sizeof(unsigned)*ibflen;
invList=(unsigned*)realloc(invList,nlen);
}
invList[invlen++]=id;
}
// search function
TrieNode* Search(wchar_t w)
{
for (int i=0;i<kidlen;i++)
if(kids[i].key==w)
return &kids[i];
return NULL;
}
};
#endif
TypeAheadSearch.h :定义搜索结构
代码
#ifndef H_TYPEAHEADSEARCH
#define H_TYPEAHEADSEARCH
#include <vector>
using namespace std;
class TypeAheadSearch
{
public:
bool createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table);
bool search(const char *query, const int topk, vector<unsigned>& results);
};
#endif
#define H_TYPEAHEADSEARCH
#include <vector>
using namespace std;
class TypeAheadSearch
{
public:
bool createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table);
bool search(const char *query, const int topk, vector<unsigned>& results);
};
#endif
TypeAheadSearch.cpp : 实现索引建立及搜索
代码
#include "TypeAheadSearch.h"
#include "Trie.h"
#include "Char2W.h"
#include <Windows.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <mysql.h>
#include <ctime>
#include <iostream>
using namespace std;
unsigned port=3306; //server port
unsigned makeId=1;
MYSQL myCont;
MYSQL_RES *result;
MYSQL_ROW sql_row;
MYSQL_FIELD *fd;
time_t Start,End;
TrieNode firstfloor[65535]; //children list of root
// switch char* to wchar_t*
wchar_t* ConvChar(char*ch)
{
int len=strlen(ch);
wchar_t* temp=(wchar_t*)malloc((len+1)*sizeof(wchar_t));
SWIchar2wchar((const unsigned char*)ch,temp,len);
return temp;
}
// print trie tree
void showIndex(TrieNode* node,char* span)
{
printf("%sinvlist: ",span);
for(int i=0;i<node->invlen;i++)
printf("%u ",node->invList[i]);
printf("\n");
if (node->kidlen>0)
{
for (int i=0;i<node->kidlen;i++)
{
char sp[50];
sprintf(sp,"%s ",span);
showIndex(&(node->kids[i]),sp);
}
}
}
// create index function
void MakeIndex(MYSQL_RES *result)
{
int fn=mysql_num_fields(result);
unsigned count=0;
TrieNode* tmpnode=NULL;
wchar_t* word=NULL;
char* ch=NULL;
int i,j;
while(sql_row=mysql_fetch_row(result))
{
for (i=0;i<fn;i++) // visit fields of record
{
if ((int)*sql_row[i]==0) continue; // field is NULL
ch=strtok(sql_row[i]," "); // split item by space
if (ch)
{
word=ConvChar(ch);
tmpnode=&firstfloor[word[0]];
for (j=1;j<wcslen(word);j++)
{
tmpnode=tmpnode->Insert(word[j]); // insert word
}
tmpnode->add2Invlist(makeId); // add recordid to inverted list
while(ch=strtok(NULL," "))
{
word=ConvChar(ch);
tmpnode=&firstfloor[word[0]];
for (j=1;j<wcslen(word);j++)
{
tmpnode=tmpnode->Insert(word[j]);
}
tmpnode->add2Invlist(makeId);
}
}
}
makeId++;
}
}
// create index procedure
bool TypeAheadSearch::createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table)
{
Start=clock();
mysql_init(&myCont);
if(mysql_real_connect(&myCont,host,user,passwd,db,port,NULL,0)) //connect to mysql
{
printf("Connect to DataBase succeed!\n");
mysql_set_character_set(&myCont,"UTF8");
char sql[100];
sprintf(sql,"select * from %s",table);
printf("making index..\n");
memset(firstfloor,0,sizeof(firstfloor));
int res = mysql_query(&myCont,sql);
if(!res) // query succeed
{
result=mysql_use_result(&myCont);
if(result)
{
MakeIndex(result);
}
}
else
{
printf("Query failed!\n");
return false;
}
mysql_free_result(result);
End=clock();
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
printf("makeid = %u\n",makeId);
printf("makeIndex succeed!\ntime used: %lf seconds\n\n",utime);
}
else
{
printf("Connect to DataBase failed!\n");
return false;
}
mysql_free_result(result);
mysql_close(&myCont);
return true;
}
// making search results
void MakeResult(TrieNode*node,vector<unsigned>& results)
{
for (int i=0;i<node->kidlen;i++)
MakeResult(&node->kids[i],results);
for (int i=0;i<node->invlen;i++)
results.push_back(node->invList[i]);
}
// serach procedure
bool TypeAheadSearch::search(const char *querys, const int topk, vector<unsigned>& results)
{
vector<unsigned> rts,bing;
Start=clock();
char query[65535];
strcpy(query,querys);
char*ch=strtok((char*)query," "); // split query by space
if (ch)
{
wchar_t* word=ConvChar(ch);
TrieNode* temp=NULL;
temp=&firstfloor[word[0]];
if (temp==NULL) return true;
for (int i=1;i<wcslen(word);i++)
{
temp=temp->Search(word[i]);
if (temp==NULL)
return true;
}
MakeResult(temp,results);
sort(results.begin(), results.end());
vector<unsigned>::iterator iter = unique(results.begin(), results.end());
results.erase(iter, results.end()); //unique items of results
while(ch=strtok(NULL," "))
{
wchar_t* word=ConvChar(ch);
TrieNode* temp=NULL;
temp=&firstfloor[word[0]];
if (temp==NULL) return true;
for (int i=1;i<wcslen(word);i++)
{
temp=temp->Search(word[i]);
if (temp==NULL)
return true;
}
MakeResult(temp,rts);
sort(rts.begin(), rts.end());
vector<unsigned>::iterator it = unique(rts.begin(), rts.end());
rts.erase(it, rts.end());
set_intersection(results.begin(),results.end(),rts.begin(),rts.end(),back_inserter(bing));
results=bing; // intersection of two key words
rts.clear();
bing.clear();
}
End=clock();
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
printf("search succeed!\ntime used: %lf seconds\n",utime);
}
return true;
}
#include "Trie.h"
#include "Char2W.h"
#include <Windows.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <mysql.h>
#include <ctime>
#include <iostream>
using namespace std;
unsigned port=3306; //server port
unsigned makeId=1;
MYSQL myCont;
MYSQL_RES *result;
MYSQL_ROW sql_row;
MYSQL_FIELD *fd;
time_t Start,End;
TrieNode firstfloor[65535]; //children list of root
// switch char* to wchar_t*
wchar_t* ConvChar(char*ch)
{
int len=strlen(ch);
wchar_t* temp=(wchar_t*)malloc((len+1)*sizeof(wchar_t));
SWIchar2wchar((const unsigned char*)ch,temp,len);
return temp;
}
// print trie tree
void showIndex(TrieNode* node,char* span)
{
printf("%sinvlist: ",span);
for(int i=0;i<node->invlen;i++)
printf("%u ",node->invList[i]);
printf("\n");
if (node->kidlen>0)
{
for (int i=0;i<node->kidlen;i++)
{
char sp[50];
sprintf(sp,"%s ",span);
showIndex(&(node->kids[i]),sp);
}
}
}
// create index function
void MakeIndex(MYSQL_RES *result)
{
int fn=mysql_num_fields(result);
unsigned count=0;
TrieNode* tmpnode=NULL;
wchar_t* word=NULL;
char* ch=NULL;
int i,j;
while(sql_row=mysql_fetch_row(result))
{
for (i=0;i<fn;i++) // visit fields of record
{
if ((int)*sql_row[i]==0) continue; // field is NULL
ch=strtok(sql_row[i]," "); // split item by space
if (ch)
{
word=ConvChar(ch);
tmpnode=&firstfloor[word[0]];
for (j=1;j<wcslen(word);j++)
{
tmpnode=tmpnode->Insert(word[j]); // insert word
}
tmpnode->add2Invlist(makeId); // add recordid to inverted list
while(ch=strtok(NULL," "))
{
word=ConvChar(ch);
tmpnode=&firstfloor[word[0]];
for (j=1;j<wcslen(word);j++)
{
tmpnode=tmpnode->Insert(word[j]);
}
tmpnode->add2Invlist(makeId);
}
}
}
makeId++;
}
}
// create index procedure
bool TypeAheadSearch::createIndex(const char* user, const char* passwd, const char* host, const char* db, const char* table)
{
Start=clock();
mysql_init(&myCont);
if(mysql_real_connect(&myCont,host,user,passwd,db,port,NULL,0)) //connect to mysql
{
printf("Connect to DataBase succeed!\n");
mysql_set_character_set(&myCont,"UTF8");
char sql[100];
sprintf(sql,"select * from %s",table);
printf("making index..\n");
memset(firstfloor,0,sizeof(firstfloor));
int res = mysql_query(&myCont,sql);
if(!res) // query succeed
{
result=mysql_use_result(&myCont);
if(result)
{
MakeIndex(result);
}
}
else
{
printf("Query failed!\n");
return false;
}
mysql_free_result(result);
End=clock();
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
printf("makeid = %u\n",makeId);
printf("makeIndex succeed!\ntime used: %lf seconds\n\n",utime);
}
else
{
printf("Connect to DataBase failed!\n");
return false;
}
mysql_free_result(result);
mysql_close(&myCont);
return true;
}
// making search results
void MakeResult(TrieNode*node,vector<unsigned>& results)
{
for (int i=0;i<node->kidlen;i++)
MakeResult(&node->kids[i],results);
for (int i=0;i<node->invlen;i++)
results.push_back(node->invList[i]);
}
// serach procedure
bool TypeAheadSearch::search(const char *querys, const int topk, vector<unsigned>& results)
{
vector<unsigned> rts,bing;
Start=clock();
char query[65535];
strcpy(query,querys);
char*ch=strtok((char*)query," "); // split query by space
if (ch)
{
wchar_t* word=ConvChar(ch);
TrieNode* temp=NULL;
temp=&firstfloor[word[0]];
if (temp==NULL) return true;
for (int i=1;i<wcslen(word);i++)
{
temp=temp->Search(word[i]);
if (temp==NULL)
return true;
}
MakeResult(temp,results);
sort(results.begin(), results.end());
vector<unsigned>::iterator iter = unique(results.begin(), results.end());
results.erase(iter, results.end()); //unique items of results
while(ch=strtok(NULL," "))
{
wchar_t* word=ConvChar(ch);
TrieNode* temp=NULL;
temp=&firstfloor[word[0]];
if (temp==NULL) return true;
for (int i=1;i<wcslen(word);i++)
{
temp=temp->Search(word[i]);
if (temp==NULL)
return true;
}
MakeResult(temp,rts);
sort(rts.begin(), rts.end());
vector<unsigned>::iterator it = unique(rts.begin(), rts.end());
rts.erase(it, rts.end());
set_intersection(results.begin(),results.end(),rts.begin(),rts.end(),back_inserter(bing));
results=bing; // intersection of two key words
rts.clear();
bing.clear();
}
End=clock();
double utime=(double)(End-Start)/CLOCKS_PER_SEC;
printf("search succeed!\ntime used: %lf seconds\n",utime);
}
return true;
}
TestCase.cpp : 测试样例
代码
#include "TypeAheadSearch.h"
#include <iostream>
#include <vector>
using namespace std;
const char user[] = "root"; // username
const char pswd[] = "root"; // password
const char host[] = "localhost"; // or"127.0.0.1"
const char db[] = "database_name"; // database
const char table[] = "table_name"; // database
void testCase(char* query)
{
// create index
TypeAheadSearch* tas = new TypeAheadSearch();
if (!tas->createIndex(user,pswd,host,db,table)) printf("Create Index Error!\n");
int topk=100; // max number of result
vector<unsigned int> results;
if (tas->search(query,topk,results))
{
printf("number of results: %d\n",results.size());
if (topk>results.size())
topk=results.size();
for (int i=0;i<topk;i++)
printf("%u\n",results[i]);
}
else
printf("Search Error!\n");
}
int main()
{
testCase("querys");
return 0;
}
#include <iostream>
#include <vector>
using namespace std;
const char user[] = "root"; // username
const char pswd[] = "root"; // password
const char host[] = "localhost"; // or"127.0.0.1"
const char db[] = "database_name"; // database
const char table[] = "table_name"; // database
void testCase(char* query)
{
// create index
TypeAheadSearch* tas = new TypeAheadSearch();
if (!tas->createIndex(user,pswd,host,db,table)) printf("Create Index Error!\n");
int topk=100; // max number of result
vector<unsigned int> results;
if (tas->search(query,topk,results))
{
printf("number of results: %d\n",results.size());
if (topk>results.size())
topk=results.size();
for (int i=0;i<topk;i++)
printf("%u\n",results[i]);
}
else
printf("Search Error!\n");
}
int main()
{
testCase("querys");
return 0;
}
测试结果: