您的位置:首页 > 其它

实习日记:图像检索算法 LSH 的总结与分析

2014-07-29 17:19 381 查看
先贴上这两天刚出炉的C++代码。(利用 STL 偷了不少功夫,代码尚待优化)

Head.h

#ifndef HEAD_H
#define HEAD_H

#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\HashTable.h"

#include <iostream>
#include <fstream>
#include <time.h>
#include <cstdlib>
#include <vector>
#include <map>
#include <set>
#include <string>

using namespace std;

void loadData(bool (*data)[128], int n, char *filename);
void createTable(HashTable HTSet[], bool data[][128], bool extDat[]
[k] );
void insert(HT HTSet[], bool (*extDat)
[k]);
void standHash(HT HTSet[]);
void search(vector<int>& record, bool query[128], HT HTSet[]);
/*int getPosition(int V[], std::string s, int N);*/

#endif


HashTable.h

#include <string>
#include <vector>

enum{ k = 15, l = 1, n = 587329, M = n};

typedef struct
{
std::string key;
std::vector<int> elem; // element's index
} bucket;

struct INT
{
bool used;
int val;
struct INT * next;
INT() : used(false), val(0), next(NULL){}
};

typedef struct HashTable
{
int R[k];          // k random dimensions
int RNum[k];   //  random numbers little than M
//string DC;          // the contents of k dimensions
std::vector<bucket> BukSet;
INT Hash2[M];
} HT;


getPosition.h

#include <string>
inline int getPosition(int V[], std::string s, int N)
{
int position = 0;
for(int col = 0; col < k; ++col)
{
position += V[col] * (s[col] - '0');
position %= M;
}
return position;
}


computeDistance.h

inline int distance(bool v1[], bool v2[], int N)
{
int d = 0;
for(int i = 0; i < N; ++i)
d += v1[i] ^ v2[i];

return d;

}


main.cpp

#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\computeDistance.h"
using namespace std;
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000;

HT HTSet[l];

bool data
[128];
bool extDat[l]
[k];

bool query[MAX_Q][128]; // set the query item to 1000.

int main(int argc, char *argv)
{
/************************************************************************/
/*             Firstly, create the HashTables                           */
/************************************************************************/
char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
loadData(data, n, filename);
createTable(HTSet, data, extDat);
insert(HTSet,extDat);
standHash(HTSet);

/************************************************************************/
/*              Secondly, start the LSH search                          */
/************************************************************************/

char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
loadData(query, MAX_Q, queryFile);
clock_t time0 = clock();
for(int qId = 0; qId < MAX_Q; ++qId)
{
vector<int> record;
clock_t timeA = clock();
search(record, query[qId], HTSet);
set<int> Dis;
for(size_t i = 0; i < record.size(); ++i)
Dis.insert(distance(data[record[i]], query[qId]));
clock_t timeB = clock();
cout << "第 " << qId + 1 << " 次查询时间:" << timeB - timeA << endl;
}
clock_t time1 = clock();
cout << "总查询时间:" << time1 - time0 << endl;

return 0;

}


loadData.cpp

#include <string>
#include <fstream>

void loadData(bool (*data)[128], int n, char* filename)
{
std::ifstream ifs;
ifs.open(filename, std::ios::in);
for(int row = 0; row < n; ++row)
{
std::string line;
getline(ifs, line);
for(int col = 0; col < 128; ++col)
data[row][col] = (line[col] - '0') & 1;
/*	std::cout << row << std::endl;*/

}
ifs.close();
}


creatTable.cpp

#include "HashTable.h"
#include <ctime>

void createTable(HT HTSet[], bool data[][128], bool extDat[]
[k] )
{
srand((unsigned)time(NULL));
for(int tableNum = 0; tableNum < l; ++tableNum)
{      /*	creat the ith Table;*/

for(int randNum = 0; randNum < k; ++randNum)
{
HTSet[tableNum].R[randNum] = rand() % 128;
HTSet[tableNum].RNum[randNum] = rand() % M;

for(int item = 0; item < n; ++item)
{
extDat[tableNum][item][randNum] =
data[item][HTSet[tableNum].R[randNum]];
}
}
}
}


insertData.cpp

#include "HashTable.h"
#include <iostream>
#include <map>
using namespace std;

map<string, int> deRepeat;
bool equal(bool V[], bool V2[], int n)
{
int i = 0;
while(i < n)
{
if(V[i] != V2[i])
return false;
}
return true;
}

string itoa(bool *v, int n, string s)
{
for(int i = 0; i < n; ++i)
s.push_back(v[i]+'0');
return s;
}

void insert(HT HTSet[], bool (*extDat)
[k])
{
for(int t = 0; t < l; ++ t) /* t: table */
{
int bktNum = 0;
bucket bkt;
bkt.key = string(itoa(extDat[t][0], k, string("")));
bkt.elem.push_back(0);
HTSet[t].BukSet.push_back(bkt);
deRepeat.insert(make_pair(bkt.key, bktNum++)); // 0 为 bucket 的位置
for(int item = 1; item < n; ++item)
{
cout << item << endl;
string key = itoa(extDat[t][item], k, string(""));
//map<string, int>::iterator it = deRepeat.find(key);
if(deRepeat.find(key) != deRepeat.end())
{
HTSet[t].BukSet[deRepeat.find(key)->second].elem.push_back(item);
cout << "exist" << endl;
}
else{
bucket bkt2;
bkt2.key = key;
bkt2.elem.push_back(item);
HTSet[t].BukSet.push_back(bkt2);
deRepeat.insert(make_pair(bkt2.key, bktNum++));
cout << "creat" << endl;
}
}
deRepeat.clear();
}
}


standHash.cpp

#include "HashTable.h"
#include <iostream>
#include "getPosition.h"

void standHash(HT HTSet[])
{
for(int t = 0; t < l; ++t)
{
int BktLen = HTSet[t].BukSet.size();
for(int b = 0; b < BktLen; ++b)
{
int position = getPosition(HTSet[t].RNum, HTSet[t].BukSet[b].key, k);
INT *pIn = &HTSet[t].Hash2[position];
while(pIn->used && pIn->next != NULL)
pIn = pIn->next;
if(pIn->used){
pIn->next = new INT;
pIn->next->val = b;
pIn->next->used = true;
}else{
pIn->val = b;
pIn->used = true;
}
}
std::cout << "the " << t << "th HashTable has been finished." << std::endl;
}
}


search.cpp

#include "HashTable.h"
#include "getPosition.h"
#include <vector>
using namespace std;

void search(vector<int>& record, bool query[128], HT HTSet[])
{
for(int t = 0; t < l; ++t)
{
string temKey;
int temPos = 0;
for(int c = 0; c < k; ++c)
temKey.push_back(query[HTSet[t].R[c]] + '0');
temPos = getPosition(HTSet[t].RNum, temKey, k);
vector<int> bktId;
INT *p = &HTSet[t].Hash2[temPos];
while(p != NULL && p->used)
{
bktId.push_back(p->val);
p = p->next;
}
for(size_t i = 0; i < bktId.size(); ++i)
{
bucket temB = HTSet[t].BukSet[bktId[i]];
if(temKey == temB.key)
{
for(size_t j = 0; j < temB.elem.size(); ++j)
record.push_back(temB.elem[j]);
}
}
}
}




稍后总结。

代码调整:

main.cpp

#include "Head.h"
#include "D:\\LiYangGuang\\VSPRO\\MYLSH\\MYLSH\\computeDistance.h"
using namespace std;
#pragma warning(disable: 4996)
// length of sub hashtable, as well the number of elements.
const int MAX_Q = 1000;

HT HTSet[l];

bool data
[128];
bool extDat[l]
[k];

bool query[MAX_Q][128]; // set the query item to 1000.

void getFileName(int v, char *FileName)
{
itoa(v, FileName, 10);
strcat(FileName, ".txt");
}

int main(int argc, char *argv)
{
/************************************************************************/
/*             Firstly, create the HashTables                           */
/************************************************************************/
char *filename = "D:\\LiYangGuang\\VSPRO\\MYLSH\\data.txt";
loadData(data, n, filename);
createTable(HTSet, data, extDat);
insert(HTSet,extDat);
standHash(HTSet);

char *queryFile = "D:\\LiYangGuang\\VSPRO\\MYLSH\\query.txt";
loadData(query, MAX_Q, queryFile);
/************************************************************************/
/*               Secondly, start the linear Search                       */
// 	/************************************************************************/
//
// 	vector<RECORD> record2;
// 	clock_t LineTime1 = clock();
// 	for(int qId = 0; qId < MAX_Q; ++qId)
// 	{
// 		for(int i = 0; i < n; ++i)
// 		{
// 			RECORD tem;
// 			tem.Id = i;
// 			tem.Dis = distance(data[i], query[qId]);
// 			record2.push_back(tem);
// 		}
// 		record2.clear();
// 	}
// 	clock_t LineTime2 = clock();
// 	float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
// 	cout << "全部线性查询时间:" << LineTime << " s," << " 合"
// 		<< LineTime / 60 << " minutes."<< endl;
//
// 	/************************************************************************/
// 	/*              Thirdly, start the LSH search                          */
// 	/************************************************************************/
//
// 	clock_t time0 = clock();
// 	ofstream ofs;
// 	char outFileName[10] = { '\0'};
// 	int K = 1; /// define KNN
// 	getFileName(K, outFileName);
// 	ofs.out(outFileName);
//
// 	for(int qId = 0; qId < MAX_Q; ++qId)
// 	{
// 		vector<RECORD> record;
// 		clock_t timeA = clock();
// 		search(record, query[qId], HTSet, data);
// 		if(getkNN(record,K))
// 		clock_t timeB = clock();
// 		record.clear();
// 		cout << "第 " << qId + 1 << " 次查询时间:" <<
// 			(float)(timeB - timeA) / CLOCKS_PER_SEC << " s" << endl;
// 	}
// 	clock_t time1 = clock();
// 	cout << "总查询时间:" << (float)(time1 - time0) / CLOCKS_PER_SEC
// 		<< " s." << endl;
/************************************************************************/
/*                                                                      */
/************************************************************************/
ofstream ofs;
char outFileName[10] = { '\0'};
int K = 1; /// define KNN
getFileName(K, outFileName);
ofs.open(outFileName, ios::out);
//ofs.precision(3);
float TotalLinearTime, TotalLSHTime;
TotalLinearTime = TotalLSHTime = 0;

float TotalError = 0;
int TotalMiss = 0;

vector<RECORD> record2;
for(int qId = 0; qId < MAX_Q; ++qId)
{
cout << "第 " << qId << " 次查询" << endl;
clock_t LineTime1 = clock();
for(int i = 0; i < n; ++i)
{
RECORD tem;
tem.Id = i;
tem.Dis = computeDistance(data[i], query[qId], 128);
record2.push_back(tem);
}
getkNN(record2); // 利用其对距离排序
clock_t LineTime2 = clock();
float LineTime = (float)(LineTime2 - LineTime1) / CLOCKS_PER_SEC;
TotalLinearTime += LineTime;

/************************************************************************/
/*              Thirdly, start the LSH search                          */
/************************************************************************/

vector<RECORD> record;
clock_t timeA = clock();
search(record, query[qId], HTSet, data);
if(!getkNN(record, K))
{
float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
TotalLSHTime += queryTime;
ofs << "Miss\t" << "LSH Time: " << queryTime
<< "s\tLinear time: " << LineTime << 's' << endl;
TotalMiss += 1;
}
else{
float queryTime = (float)(clock() - timeA) / CLOCKS_PER_SEC;
TotalLSHTime += queryTime;
float error = 0;
if(record[K-1].Dis == 0)
error = 1;
else
error = (float)record2[K-1].Dis / record[K-1].Dis;
ofs << "Error: " << error << "\tLSH Time: "
<< queryTime << "s\tLinear time: " << LineTime << 's' << endl;
TotalError += error;

}
record.clear();
record2.clear();
}
ofs << "Average errror: " << TotalError / 817 << endl;//recitfy
ofs << "Miss ratio: " << TotalMiss / MAX_Q << endl;
ofs << "Total query time: " << "LSH, " << TotalLSHTime / 3600 << " h; "
<< "Linear, " << TotalLinearTime / 3600 << " h." << endl;
ofs.close();

return 0;

}


computeDistance.h

inline int computeDistance(bool v1[], bool v2[], int N)
{
int d = 0;
for(int i = 0; i < N; ++i)
d += v1[i] ^ v2[i];

return d;

}


Search.cpp

#include "HashTable.h"
#include "getPosition.h"
#include "computeDistance.h"
#include <vector>
using namespace std;

/***    加入 data 项是为了计算距离  ***/
void search(vector<RECORD>& record, bool query[128], HT HTSet[], bool data[][128])
{
for(int t = 0; t < l; ++t)
{
string temKey;
int temPos = 0;
for(int c = 0; c < k; ++c)
temKey.push_back(query[HTSet[t].R[c]] + '0');
temPos = getPosition(HTSet[t].RNum, temKey, k);
vector<int> bktId;
INT *p = &HTSet[t].Hash2[temPos];
while(p != NULL && p->used)
{
bktId.push_back(p->val);
p = p->next;
}
for(size_t i = 0; i < bktId.size(); ++i)
{
bucket temB = HTSet[t].BukSet[bktId[i]];
if(temKey == temB.key)
{
for(size_t j = 0; j < temB.elem.size(); ++j)
{
RECORD temp;
temp.Id = temB.elem[j];
temp.Dis = computeDistance(data[temp.Id], query, 128);
record.push_back(temp);
}

}
}
}
}




相关截图:



内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: