进阶之路

首页 新随笔 管理

先贴上这两天刚出炉的C++代码。(利用 STL 偷了不少功夫,代码待优化)

Head.h

 1 #ifndef HEAD_H
 2 #define HEAD_H
 3 
 4 #include "D:\\LiYangGuang\\VSPRO\\MYLSH\\HashTable.h"
 5 
 6 
 7 #include <iostream>
 8 #include <fstream>
 9 #include <time.h>
10 #include <cstdlib>
11 #include <vector>
12 #include <map>
13 #include <set>
14 #include <string>
15 
16 using namespace std;
17 
18 
19 void loadData(bool (*data)[128], int n, char *filename);
20 void createTable(HashTable HTSet[], bool data[][128], bool extDat[][n][k] );
21 void insert(HT HTSet[], bool (*extDat)[n][k]);
22 void standHash(HT HTSet[]);
23 void search(vector<int>& record, bool query[128], HT HTSet[]);
24 /*int getPosition(int V[], std::string s, int N);*/
25 
26 #endif

HashTable.h

#include <string>
#include <vector>

enum{ k = 15, l = 1, n = 587329, M = n};

typedef struct 
{
    std::string key;
    std::vector<int> elem; // element's index
} bucket; 

struct INT 
{
    bool used;
    int val;
    struct INT * next;
    INT() : used(false), val(0), next(NULL){}
};

typedef struct HashTable 
{
    int R[k];          // k random dimensions
    int RNum[k];   //  random numbers little than M
    //string DC;          // the contents of k dimensions 
    std::vector<bucket> BukSet;
    INT Hash2[M];
} HT;

getPosition.h

#include <string>
inline int getPosition(int V[], std::string s, int N)
{
	int position = 0;
	for(int col = 0; col < k; ++col)
	{
		position += V[col] * (s[col] - '0');
		position %= M;
	}
	return position;
}

 computeDistance.h

inline int distance(bool v1[], bool v2[], int N)
{
	int d = 0;
	for(int i = 0; i < N; ++i)
		d += v1[i] ^ v2[i];

	return d;

}

 main.cpp

#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\computeDistance.h"
using namespace std;
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000; 

HT HTSet[l];

bool data[n][128];
bool extDat[l][n][k];

bool query[MAX_Q][128]; // set the query item to 1000.

int main(int argc, char *argv)
{
	/************************************************************************/
	/*             Firstly, create the HashTables                           */
	/************************************************************************/
	char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
	loadData(data, n, filename);
	createTable(HTSet, data, extDat);
	insert(HTSet,extDat);
	standHash(HTSet);

	/************************************************************************/
	/*              Secondly, start the LSH search                          */
	/************************************************************************/

	char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
	loadData(query, MAX_Q, queryFile);
	clock_t time0 = clock();
	for(int qId = 0; qId < MAX_Q; ++qId)
	{
		vector<int> record;
		clock_t timeA = clock();
		search(record, query[qId], HTSet);
		set<int> Dis;
		for(size_t i = 0; i < record.size(); ++i)
			Dis.insert(distance(data[record[i]], query[qId]));
		clock_t timeB = clock();
		cout << "第 " << qId + 1 << " 次查询时间:" << timeB - timeA << endl;
	}
	clock_t time1 = clock();
	cout << "总查询时间:" << time1 - time0 << endl;


    return 0;

}

 loadData.cpp

#include <string>
#include <fstream>

void loadData(bool (*data)[128], int n, char* filename)
{
	std::ifstream ifs;
	ifs.open(filename, std::ios::in);
	for(int row = 0; row < n; ++row)
	{
		std::string line;
		getline(ifs, line);
		for(int col = 0; col < 128; ++col)
			data[row][col] = (line[col] - '0') & 1;
	/*	std::cout << row << std::endl;*/

	}
	ifs.close();
}

 creatTable.cpp

#include "HashTable.h"
#include <ctime>

void createTable(HT HTSet[], bool data[][128], bool extDat[][n][k] )
{
	srand((unsigned)time(NULL));
	for(int tableNum = 0; tableNum < l; ++tableNum)  
	{      /*	creat the ith Table;*/

		for(int randNum = 0; randNum < k; ++randNum)
		{
			HTSet[tableNum].R[randNum] = rand() % 128;
			HTSet[tableNum].RNum[randNum] = rand() % M;

			for(int item = 0; item < n; ++item)
			{
				extDat[tableNum][item][randNum] = 
					data[item][HTSet[tableNum].R[randNum]];
			}
		}
	}
}

insertData.cpp

#include "HashTable.h"
#include <iostream>
#include <map>
using namespace std;

map<string, int> deRepeat;
bool equal(bool V[], bool V2[], int n)
{
	int i = 0;
	while(i < n)
	{
		if(V[i] != V2[i])
			return false;
	}
	return true;
}

string itoa(bool *v, int n, string s)
{
	for(int i = 0; i < n; ++i)
		s.push_back(v[i]+'0');
	return s;
}

void insert(HT HTSet[], bool (*extDat)[n][k])
{
	for(int t = 0; t < l; ++ t) /* t: table */
	{
		int bktNum = 0;
		bucket bkt;
		bkt.key = string(itoa(extDat[t][0], k, string("")));
		bkt.elem.push_back(0);
		HTSet[t].BukSet.push_back(bkt);
		deRepeat.insert(make_pair(bkt.key, bktNum++)); // 0 为 bucket 的位置
		for(int item = 1; item < n; ++item)
		{
			cout << item << endl;
			string key = itoa(extDat[t][item], k, string(""));
			//map<string, int>::iterator it = deRepeat.find(key);
			if(deRepeat.find(key) != deRepeat.end())
			{
				HTSet[t].BukSet[deRepeat.find(key)->second].elem.push_back(item);
				cout << "exist" << endl;
			}
			else{
				bucket bkt2;
				bkt2.key = key;
				bkt2.elem.push_back(item);
				HTSet[t].BukSet.push_back(bkt2);
				deRepeat.insert(make_pair(bkt2.key, bktNum++));
				cout << "creat" << endl;
			}
		}
		deRepeat.clear();
	}
}

 standHash.cpp

#include "HashTable.h"
#include <iostream>
#include "getPosition.h"

void standHash(HT HTSet[])
{
	for(int t = 0; t < l; ++t)
	{
		int BktLen = HTSet[t].BukSet.size();
		for(int b = 0; b < BktLen; ++b)
		{
			int position = getPosition(HTSet[t].RNum, HTSet[t].BukSet[b].key, k);
			INT *pIn = &HTSet[t].Hash2[position];
			while(pIn->used && pIn->next != NULL)
				pIn = pIn->next;
			if(pIn->used){
				pIn->next = new INT;
				pIn->next->val = b;
				pIn->next->used = true;
			}else{
				pIn->val = b;
				pIn->used = true;
			}
		}
		std::cout << "the " << t << "th HashTable has been finished." << std::endl;
	}
}

 search.cpp

#include "HashTable.h"
#include "getPosition.h"
#include <vector>
using namespace std;

void search(vector<int>& record, bool query[128], HT HTSet[])
{
	for(int t = 0; t < l; ++t)
	{
		string temKey;
		int temPos = 0;
		for(int c = 0; c < k; ++c)
			temKey.push_back(query[HTSet[t].R[c]] + '0');
		temPos = getPosition(HTSet[t].RNum, temKey, k);
		vector<int> bktId;
		INT *p = &HTSet[t].Hash2[temPos];
		while(p != NULL && p->used)
		{
			bktId.push_back(p->val);
			p = p->next;
		}
		for(size_t i = 0; i < bktId.size(); ++i)
		{
			bucket temB = HTSet[t].BukSet[bktId[i]];
			if(temKey == temB.key)
			{
				for(size_t j = 0; j < temB.elem.size(); ++j)
					record.push_back(temB.elem[j]);
			}
		}
	}
}

 

 

 

 稍后总结。

代码调整:

main.cpp

#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\MYLSH\\computeDistance.h"
using namespace std;
#pragma warning(disable: 4996)
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000; 

HT HTSet[l];

bool data[n][128];
bool extDat[l][n][k];

bool query[MAX_Q][128]; // set the query item to 1000.

void getFileName(int v, char *FileName)
{
	itoa(v, FileName, 10);
	strcat(FileName, ".txt");
}



int main(int argc, char *argv)
{
	/************************************************************************/
	/*             Firstly, create the HashTables                           */
	/************************************************************************/
	char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
	loadData(data, n, filename);
	createTable(HTSet, data, extDat);
	insert(HTSet,extDat);
	standHash(HTSet);

	char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
	loadData(query, MAX_Q, queryFile);
	/************************************************************************/
	/*               Secondly, start the linear Search                       */
// 	/************************************************************************/
// 
// 	vector<RECORD> record2;
// 	clock_t LineTime1 = clock();
// 	for(int qId = 0; qId < MAX_Q; ++qId)
// 	{
// 		for(int i = 0; i < n; ++i)
// 		{
// 			RECORD tem;
// 			tem.Id = i;
// 			tem.Dis = distance(data[i], query[qId]);
// 			record2.push_back(tem);
// 		}
// 		record2.clear();
// 	}
// 	clock_t LineTime2 = clock();
// 	float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
// 	cout << "全部线性查询时间:" << LineTime << " s," << " 合"
// 		<< LineTime / 60 << " minutes."<< endl;
// 
// 	/************************************************************************/
// 	/*              Thirdly, start the LSH search                          */
// 	/************************************************************************/
// 
// 	clock_t time0 = clock();
// 	ofstream ofs;
// 	char outFileName[10] = { '\0'};
// 	int K = 1; /// define KNN
// 	getFileName(K, outFileName);
// 	ofs.out(outFileName);
// 
// 	for(int qId = 0; qId < MAX_Q; ++qId)
// 	{
// 		vector<RECORD> record;
// 		clock_t timeA = clock();
// 		search(record, query[qId], HTSet, data);
// 		if(getkNN(record,K))
// 		clock_t timeB = clock();
// 		record.clear();
// 		cout << "第 " << qId + 1 << " 次查询时间:" << 
// 			(float)(timeB - timeA) / CLOCKS_PER_SEC << " s" << endl;
// 	}
// 	clock_t time1 = clock();
// 	cout << "总查询时间:" << (float)(time1 - time0) / CLOCKS_PER_SEC 
// 		<< " s." << endl;
/************************************************************************/
/*                                                                      */
/************************************************************************/
	ofstream ofs;
	char outFileName[10] = { '\0'};
	int K = 1; /// define KNN
	getFileName(K, outFileName);
	ofs.open(outFileName, ios::out);
	//ofs.precision(3);
	float TotalLinearTime, TotalLSHTime;
	TotalLinearTime = TotalLSHTime = 0;

	float TotalError = 0;
	int TotalMiss = 0;


	vector<RECORD> record2;
	for(int qId = 0; qId < MAX_Q; ++qId)
	{
		cout << "第 " << qId << " 次查询" << endl;
		clock_t LineTime1 = clock();
		for(int i = 0; i < n; ++i)
		{
			RECORD tem;
			tem.Id = i;
			tem.Dis = computeDistance(data[i], query[qId], 128);
			record2.push_back(tem);
		}
	   getkNN(record2); // 利用其对距离排序
	   clock_t LineTime2 = clock();
	   float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
	   TotalLinearTime += LineTime;

	/************************************************************************/
	/*              Thirdly, start the LSH search                          */
	/************************************************************************/

		vector<RECORD> record;
		clock_t timeA = clock();
		search(record, query[qId], HTSet, data);
		if(!getkNN(record, K)) 
		{
			float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
			TotalLSHTime += queryTime;
			ofs << "Miss\t" << "LSH Time: " << queryTime 
				<< "s\tLinear time: " << LineTime << 's' << endl;
			TotalMiss += 1;
		}
		else{
			float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
			TotalLSHTime += queryTime;
			float error = 0;
			if(record[K-1].Dis == 0)
				error = 1;
			else
				error = (float)record2[K-1].Dis / record[K-1].Dis;
			ofs << "Error: " << error << "\tLSH Time: " 
				<< queryTime << "s\tLinear time: " << LineTime << 's' << endl;
			TotalError += error;

		}
		record.clear();
		record2.clear();
	}
	ofs << "Average errror: " << TotalError / 817 << endl;//recitfy
	ofs << "Miss ratio: " << TotalMiss / MAX_Q << endl;
	ofs << "Total query time: " << "LSH, " << TotalLSHTime / 3600 << " h; "
		<< "Linear, " << TotalLinearTime / 3600 << " h." << endl;
	ofs.close();


	return 0;

}

 computeDistance.h

inline int computeDistance(bool v1[], bool v2[], int N)
{
	int d = 0;
	for(int i = 0; i < N; ++i)
		d += v1[i] ^ v2[i];

	return d;

}

 Search.cpp

#include "HashTable.h"
#include "getPosition.h"
#include "computeDistance.h"
#include <vector>
using namespace std;

/***    加入 data 项是为了计算距离  ***/
void search(vector<RECORD>& record, bool query[128], HT HTSet[], bool data[][128])
{
	for(int t = 0; t < l; ++t)
	{
		string temKey;
		int temPos = 0;
		for(int c = 0; c < k; ++c)
			temKey.push_back(query[HTSet[t].R[c]] + '0');
		temPos = getPosition(HTSet[t].RNum, temKey, k);
		vector<int> bktId;
		INT *p = &HTSet[t].Hash2[temPos];
		while(p != NULL && p->used)
		{
			bktId.push_back(p->val);
			p = p->next;
		}
		for(size_t i = 0; i < bktId.size(); ++i)
		{
			bucket temB = HTSet[t].BukSet[bktId[i]];
			if(temKey == temB.key)
			{
				for(size_t j = 0; j < temB.elem.size(); ++j)
				{
					RECORD temp;
					temp.Id = temB.elem[j];
					temp.Dis = computeDistance(data[temp.Id], query, 128);
					record.push_back(temp);
				}
					
			}
		}
	}
}

 

相关截图:

 

posted on 2014-07-29 17:19  进阶之路  阅读(1060)  评论(0编辑  收藏  举报