先贴上这两天刚出炉的C++代码。(利用 STL 偷了不少功夫,代码待优化)
Head.h
1 #ifndef HEAD_H 2 #define HEAD_H 3 4 #include "D:\\LiYangGuang\\VSPRO\\MYLSH\\HashTable.h" 5 6 7 #include <iostream> 8 #include <fstream> 9 #include <time.h> 10 #include <cstdlib> 11 #include <vector> 12 #include <map> 13 #include <set> 14 #include <string> 15 16 using namespace std; 17 18 19 void loadData(bool (*data)[128], int n, char *filename); 20 void createTable(HashTable HTSet[], bool data[][128], bool extDat[][n][k] ); 21 void insert(HT HTSet[], bool (*extDat)[n][k]); 22 void standHash(HT HTSet[]); 23 void search(vector<int>& record, bool query[128], HT HTSet[]); 24 /*int getPosition(int V[], std::string s, int N);*/ 25 26 #endif
HashTable.h
#include <string> #include <vector> enum{ k = 15, l = 1, n = 587329, M = n}; typedef struct { std::string key; std::vector<int> elem; // element's index } bucket; struct INT { bool used; int val; struct INT * next; INT() : used(false), val(0), next(NULL){} }; typedef struct HashTable { int R[k]; // k random dimensions int RNum[k]; // random numbers little than M //string DC; // the contents of k dimensions std::vector<bucket> BukSet; INT Hash2[M]; } HT;
getPosition.h
#include <string> inline int getPosition(int V[], std::string s, int N) { int position = 0; for(int col = 0; col < k; ++col) { position += V[col] * (s[col] - '0'); position %= M; } return position; }
computeDistance.h
inline int distance(bool v1[], bool v2[], int N) { int d = 0; for(int i = 0; i < N; ++i) d += v1[i] ^ v2[i]; return d; }
main.cpp
#include "Head.h" #include "D:\\LiYangGuang\\VSPRO\\MYLSH\\computeDistance.h" using namespace std; // length of sub hashtable, as well the number of elements. const int MAX_Q = 1000; HT HTSet[l]; bool data[n][128]; bool extDat[l][n][k]; bool query[MAX_Q][128]; // set the query item to 1000. int main(int argc, char *argv) { /************************************************************************/ /* Firstly, create the HashTables */ /************************************************************************/ char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt"; loadData(data, n, filename); createTable(HTSet, data, extDat); insert(HTSet,extDat); standHash(HTSet); /************************************************************************/ /* Secondly, start the LSH search */ /************************************************************************/ char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt"; loadData(query, MAX_Q, queryFile); clock_t time0 = clock(); for(int qId = 0; qId < MAX_Q; ++qId) { vector<int> record; clock_t timeA = clock(); search(record, query[qId], HTSet); set<int> Dis; for(size_t i = 0; i < record.size(); ++i) Dis.insert(distance(data[record[i]], query[qId])); clock_t timeB = clock(); cout << "第 " << qId + 1 << " 次查询时间:" << timeB - timeA << endl; } clock_t time1 = clock(); cout << "总查询时间:" << time1 - time0 << endl; return 0; }
loadData.cpp
#include <string> #include <fstream> void loadData(bool (*data)[128], int n, char* filename) { std::ifstream ifs; ifs.open(filename, std::ios::in); for(int row = 0; row < n; ++row) { std::string line; getline(ifs, line); for(int col = 0; col < 128; ++col) data[row][col] = (line[col] - '0') & 1; /* std::cout << row << std::endl;*/ } ifs.close(); }
creatTable.cpp
#include "HashTable.h" #include <ctime> void createTable(HT HTSet[], bool data[][128], bool extDat[][n][k] ) { srand((unsigned)time(NULL)); for(int tableNum = 0; tableNum < l; ++tableNum) { /* creat the ith Table;*/ for(int randNum = 0; randNum < k; ++randNum) { HTSet[tableNum].R[randNum] = rand() % 128; HTSet[tableNum].RNum[randNum] = rand() % M; for(int item = 0; item < n; ++item) { extDat[tableNum][item][randNum] = data[item][HTSet[tableNum].R[randNum]]; } } } }
insertData.cpp
#include "HashTable.h" #include <iostream> #include <map> using namespace std; map<string, int> deRepeat; bool equal(bool V[], bool V2[], int n) { int i = 0; while(i < n) { if(V[i] != V2[i]) return false; } return true; } string itoa(bool *v, int n, string s) { for(int i = 0; i < n; ++i) s.push_back(v[i]+'0'); return s; } void insert(HT HTSet[], bool (*extDat)[n][k]) { for(int t = 0; t < l; ++ t) /* t: table */ { int bktNum = 0; bucket bkt; bkt.key = string(itoa(extDat[t][0], k, string(""))); bkt.elem.push_back(0); HTSet[t].BukSet.push_back(bkt); deRepeat.insert(make_pair(bkt.key, bktNum++)); // 0 为 bucket 的位置 for(int item = 1; item < n; ++item) { cout << item << endl; string key = itoa(extDat[t][item], k, string("")); //map<string, int>::iterator it = deRepeat.find(key); if(deRepeat.find(key) != deRepeat.end()) { HTSet[t].BukSet[deRepeat.find(key)->second].elem.push_back(item); cout << "exist" << endl; } else{ bucket bkt2; bkt2.key = key; bkt2.elem.push_back(item); HTSet[t].BukSet.push_back(bkt2); deRepeat.insert(make_pair(bkt2.key, bktNum++)); cout << "creat" << endl; } } deRepeat.clear(); } }
standHash.cpp
#include "HashTable.h" #include <iostream> #include "getPosition.h" void standHash(HT HTSet[]) { for(int t = 0; t < l; ++t) { int BktLen = HTSet[t].BukSet.size(); for(int b = 0; b < BktLen; ++b) { int position = getPosition(HTSet[t].RNum, HTSet[t].BukSet[b].key, k); INT *pIn = &HTSet[t].Hash2[position]; while(pIn->used && pIn->next != NULL) pIn = pIn->next; if(pIn->used){ pIn->next = new INT; pIn->next->val = b; pIn->next->used = true; }else{ pIn->val = b; pIn->used = true; } } std::cout << "the " << t << "th HashTable has been finished." << std::endl; } }
search.cpp
#include "HashTable.h" #include "getPosition.h" #include <vector> using namespace std; void search(vector<int>& record, bool query[128], HT HTSet[]) { for(int t = 0; t < l; ++t) { string temKey; int temPos = 0; for(int c = 0; c < k; ++c) temKey.push_back(query[HTSet[t].R[c]] + '0'); temPos = getPosition(HTSet[t].RNum, temKey, k); vector<int> bktId; INT *p = &HTSet[t].Hash2[temPos]; while(p != NULL && p->used) { bktId.push_back(p->val); p = p->next; } for(size_t i = 0; i < bktId.size(); ++i) { bucket temB = HTSet[t].BukSet[bktId[i]]; if(temKey == temB.key) { for(size_t j = 0; j < temB.elem.size(); ++j) record.push_back(temB.elem[j]); } } } }
稍后总结。
代码调整:
main.cpp
#include "Head.h" #include "D:\\LiYangGuang\\VSPRO\\MYLSH\\MYLSH\\computeDistance.h" using namespace std; #pragma warning(disable: 4996) // length of sub hashtable, as well the number of elements. const int MAX_Q = 1000; HT HTSet[l]; bool data[n][128]; bool extDat[l][n][k]; bool query[MAX_Q][128]; // set the query item to 1000. void getFileName(int v, char *FileName) { itoa(v, FileName, 10); strcat(FileName, ".txt"); } int main(int argc, char *argv) { /************************************************************************/ /* Firstly, create the HashTables */ /************************************************************************/ char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt"; loadData(data, n, filename); createTable(HTSet, data, extDat); insert(HTSet,extDat); standHash(HTSet); char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt"; loadData(query, MAX_Q, queryFile); /************************************************************************/ /* Secondly, start the linear Search */ // /************************************************************************/ // // vector<RECORD> record2; // clock_t LineTime1 = clock(); // for(int qId = 0; qId < MAX_Q; ++qId) // { // for(int i = 0; i < n; ++i) // { // RECORD tem; // tem.Id = i; // tem.Dis = distance(data[i], query[qId]); // record2.push_back(tem); // } // record2.clear(); // } // clock_t LineTime2 = clock(); // float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC; // cout << "全部线性查询时间:" << LineTime << " s," << " 合" // << LineTime / 60 << " minutes."<< endl; // // /************************************************************************/ // /* Thirdly, start the LSH search */ // /************************************************************************/ // // clock_t time0 = clock(); // ofstream ofs; // char outFileName[10] = { '\0'}; // int K = 1; /// define KNN // getFileName(K, outFileName); // ofs.out(outFileName); // // for(int qId = 0; qId < MAX_Q; ++qId) // { // vector<RECORD> record; // clock_t timeA = clock(); // search(record, query[qId], HTSet, data); // if(getkNN(record,K)) // clock_t timeB = clock(); // record.clear(); // cout << "第 " << qId + 1 << " 次查询时间:" << // (float)(timeB - timeA) / CLOCKS_PER_SEC << " s" << endl; // } // clock_t time1 = clock(); // cout << "总查询时间:" << (float)(time1 - time0) / CLOCKS_PER_SEC // << " s." << endl; /************************************************************************/ /* */ /************************************************************************/ ofstream ofs; char outFileName[10] = { '\0'}; int K = 1; /// define KNN getFileName(K, outFileName); ofs.open(outFileName, ios::out); //ofs.precision(3); float TotalLinearTime, TotalLSHTime; TotalLinearTime = TotalLSHTime = 0; float TotalError = 0; int TotalMiss = 0; vector<RECORD> record2; for(int qId = 0; qId < MAX_Q; ++qId) { cout << "第 " << qId << " 次查询" << endl; clock_t LineTime1 = clock(); for(int i = 0; i < n; ++i) { RECORD tem; tem.Id = i; tem.Dis = computeDistance(data[i], query[qId], 128); record2.push_back(tem); } getkNN(record2); // 利用其对距离排序 clock_t LineTime2 = clock(); float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC; TotalLinearTime += LineTime; /************************************************************************/ /* Thirdly, start the LSH search */ /************************************************************************/ vector<RECORD> record; clock_t timeA = clock(); search(record, query[qId], HTSet, data); if(!getkNN(record, K)) { float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC; TotalLSHTime += queryTime; ofs << "Miss\t" << "LSH Time: " << queryTime << "s\tLinear time: " << LineTime << 's' << endl; TotalMiss += 1; } else{ float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC; TotalLSHTime += queryTime; float error = 0; if(record[K-1].Dis == 0) error = 1; else error = (float)record2[K-1].Dis / record[K-1].Dis; ofs << "Error: " << error << "\tLSH Time: " << queryTime << "s\tLinear time: " << LineTime << 's' << endl; TotalError += error; } record.clear(); record2.clear(); } ofs << "Average errror: " << TotalError / 817 << endl;//recitfy ofs << "Miss ratio: " << TotalMiss / MAX_Q << endl; ofs << "Total query time: " << "LSH, " << TotalLSHTime / 3600 << " h; " << "Linear, " << TotalLinearTime / 3600 << " h." << endl; ofs.close(); return 0; }
computeDistance.h
inline int computeDistance(bool v1[], bool v2[], int N) { int d = 0; for(int i = 0; i < N; ++i) d += v1[i] ^ v2[i]; return d; }
Search.cpp
#include "HashTable.h" #include "getPosition.h" #include "computeDistance.h" #include <vector> using namespace std; /*** 加入 data 项是为了计算距离 ***/ void search(vector<RECORD>& record, bool query[128], HT HTSet[], bool data[][128]) { for(int t = 0; t < l; ++t) { string temKey; int temPos = 0; for(int c = 0; c < k; ++c) temKey.push_back(query[HTSet[t].R[c]] + '0'); temPos = getPosition(HTSet[t].RNum, temKey, k); vector<int> bktId; INT *p = &HTSet[t].Hash2[temPos]; while(p != NULL && p->used) { bktId.push_back(p->val); p = p->next; } for(size_t i = 0; i < bktId.size(); ++i) { bucket temB = HTSet[t].BukSet[bktId[i]]; if(temKey == temB.key) { for(size_t j = 0; j < temB.elem.size(); ++j) { RECORD temp; temp.Id = temB.elem[j]; temp.Dis = computeDistance(data[temp.Id], query, 128); record.push_back(temp); } } } } }
相关截图: