您的位置:首页 > 其它

一种常见的数据挖掘的算法SPRINT算法的简单实现

2011-07-16 02:07 417 查看
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <math.h>
#include <map>
#include <set>
#include <list>
#include <algorithm>
using namespace std;
map<int,float> xmap;
typedef struct attrelem
{
int attridx;
float attrval;
string classlabel;
int    rid;

attrelem():attridx(0),rid(0){}
};
bool operator==(const attrelem &lhs,const attrelem &rhs)
{
return (lhs.attrval == rhs.attrval && lhs.classlabel == rhs.classlabel
&& lhs.rid == rhs.rid);

}

typedef struct treenode
{
string category;
vector< vector<attrelem> > vecattrlist;
float x;
int splittattridx;
treenode *left;
treenode *right;
treenode():left(NULL),right(NULL),splittattridx(-1){}

}*node_ptr;
float getx(float x1,float x2)
{
return (x1+x2)/2.0;

}
void trimspace(string & str)
{
for (string::iterator it = str.begin(); it != str.end(); it++)
{
if (*it == ' ')
{
it = str.erase(it);
}
}
}
void getsplitstring(string str,char splitter,vector<string> & vecout)
{
size_t pos = str.find_first_of(splitter,0);
size_t beg = 0;
while (pos != string::npos)
{
vecout.push_back(str.substr(beg,pos - beg));
beg = pos + 1;
pos = str.find_first_of(splitter,beg);
}

vecout.push_back(str.substr(beg,str.size() - beg));
}
float square(float x)
{
return x*x;
}
int getisplitter(attrelem *ret,vector< vector<attrelem> > &vecattrlist)
{
int isplitter = -1;
for (int i = 0;i<vecattrlist.size();i++)
{
for (int j = 0;j<vecattrlist[i].size();j++)
{
if (ret == &vecattrlist[i][j])
{
isplitter = i;
break;
}

}
if (isplitter != -1)
break;
}
return isplitter;

}
float getgini(float x,vector<attrelem> & attrlist)
{

vector<string> lessequalval;
vector<string>  greaterlval;
for (int i = 0;i<attrlist.size();i++)
{
if (attrlist[i].attrval<=x)
{
lessequalval.push_back(attrlist[i].classlabel);

}
else
{
greaterlval.push_back(attrlist[i].classlabel);
}
}

map<string,int> catecnt;
for (int i=0;i<lessequalval.size();i++)
{
map<string,int>::iterator it = catecnt.find(lessequalval[i]);
if (it == catecnt.end())
{
catecnt.insert(make_pair(lessequalval[i],0));
catecnt[lessequalval[i]]++;
}
else
{
catecnt[lessequalval[i]]++;
}
}

float lessequalgini = 0.0;
for (map<string,int>::iterator it = catecnt.begin();
it != catecnt.end();it++)
{
lessequalgini = lessequalgini + square((float)it->second/(float)lessequalval.size());
}
lessequalgini = 1 - lessequalgini;

catecnt.clear();
for (int i=0;i<greaterlval.size();i++)
{
map<string,int>::iterator it = catecnt.find(greaterlval[i]);
if (it == catecnt.end())
{
catecnt.insert(make_pair(greaterlval[i],0));
catecnt[greaterlval[i]]++;
}
else
{
catecnt[greaterlval[i]]++;
}
}

float greatergini = 0.0;
for (map<string,int>::iterator it = catecnt.begin();
it != catecnt.end();it++)
{
greatergini = greatergini + square((float)it->second/(float)greaterlval.size());
}
greatergini = 1 - greatergini;

float gini = ((float)lessequalval.size()/(float)attrlist.size()) * lessequalgini + ((float)greaterlval.size()/(float)attrlist.size()) * greatergini;
return gini;

}

int getmingini(vector< vector<attrelem> > &vecattrlist)
{

int isplitter = -1;
float mingini = getgini(getx(vecattrlist[0][0].attrval,vecattrlist[0][1].attrval),vecattrlist[0]);
isplitter = 0;

for (int i = 0;i<vecattrlist.size();i++)
{

for (int j=0;j<vecattrlist[i].size() -1 ;j++)
{
if (vecattrlist[i][j].attrval != vecattrlist[i][j+1].attrval)
{
float x = getx(vecattrlist[i][j].attrval,vecattrlist[i][j+1].attrval);
float gini = getgini(x,vecattrlist[i]);
if (mingini > gini)
{
mingini = gini;
isplitter = i;
xmap[vecattrlist[i][j].attridx] = x;
}
}
}

}
return isplitter;

}
bool isnodepure(vector<attrelem> & attrlist)
{

bool pure = true;
string classlabel = attrlist[0].classlabel;
for (int i = 0;i<attrlist.size();i++)
{
if (classlabel != attrlist[i].classlabel)
{
pure = false;
break;
}

}
return pure;
}

vector< vector<attrelem> >::iterator getit(int isplitter, vector< vector<attrelem> > &vecattrlist)
{

int i=0;
for (vector< vector<attrelem> >::iterator it =vecattrlist.begin();it!=vecattrlist.end();
it++,i++)
{
if (i == isplitter)
return it;
}

}
int splitattrlist(vector< vector<attrelem> > &vecattrlist,vector< vector<attrelem> > &left,vector< vector<attrelem> > &right)
{

int isplitter = getmingini(vecattrlist);
int ret = vecattrlist[isplitter][0].attridx;

while(left.empty() || right.empty() || left[0].empty() || right[0].empty())
{

if (!left.empty() && !right.empty() && (!left[0].empty() || !right[0].empty()) )
{
vecattrlist.erase(getit(isplitter,vecattrlist));

isplitter = getmingini(vecattrlist);
}
left.clear();
right.clear();
for (int i = 0;i<vecattrlist.size();i++)
{
vector<attrelem> attrleft, attrright;
if (i == isplitter)
continue;

float x = xmap[vecattrlist[isplitter][0].attridx];
for (int j = 0;j<vecattrlist[i].size();j++)
{
if (vecattrlist[isplitter][j].attrval <= x /*&& vecattrlist[isplitter][j].rid == vecattrlist[i][j].rid*/)
{
attrleft.push_back(vecattrlist[i][j]);
}
else if (vecattrlist[isplitter][j].attrval > x /*&& vecattrlist[isplitter][j].rid == vecattrlist[i][j].rid*/)
attrright.push_back(vecattrlist[i][j]);
}

left.push_back(attrleft);
right.push_back(attrright);

}
}
return ret;

}

void buildtree(vector< vector<attrelem> > &vecattrlist,node_ptr *tree)
{
if (vecattrlist.size() == 0)
return;
if (isnodepure(vecattrlist[0]))
{
(*tree) = new treenode();
(*tree)->category = vecattrlist[0][0].classlabel;
(*tree)->left = NULL;
(*tree)->right = NULL;
return;

}

vector< vector<attrelem> >left,right;
vector< vector<attrelem> > vecret;

(*tree) = new treenode();
(*tree)->vecattrlist = vecattrlist;
(*tree)->splittattridx = splitattrlist(vecattrlist,left,right);
(*tree)->x = xmap[(*tree)->splittattridx];

buildtree(left,&(*tree)->left);

buildtree(right,&(*tree)->right);

}
bool isvalinlist(float val,vector<float> & vallist)
{
for (int i = 0;i<vallist.size();i++)
{
if (val == vallist[i])
return true;
}
return false;
}
string gettype(vector<string> data,node_ptr tree)
{
if (!tree)
return "";

if (!tree->category.empty())
return tree->category;

float val = atof(data[tree->splittattridx].c_str());

if (val <= tree->x)
{
return gettype(data,tree->left);
}
else
return gettype(data,tree->right);

}

void outputtree(treenode *root,int level)
{
if (!root)
return;
for (int i =0;i<level;i++)
{
cout<<" ";
}

if (root->category.empty())
{
cout<<root->splittattridx+1<<":<"<<root->x<<">"<<endl;
}else
{
cout<<"Class:"<<root->category<<endl;
return;
}
/*cout<<"value:"<<root->leftvalue[0]<<endl;*/
outputtree(root->left,level+1);
/*cout<<"value:"<<root->rightvalue[0]<<endl;*/
outputtree(root->right,level+1);

}
bool lessthan(const attrelem & left,const attrelem & right)
{
if (left.attrval < right.attrval)
return true;
else
return false;
}
int main(int argc, char* argv[])
{
if (argc < 3)
{
cout<<"Usage: "<<argv[0]<<" [filename].test [filename].train"<<endl;
return false;
}

ifstream input(argv[2],ios_base::in);
if (!input)
{
cout<<"Error: open file failed."<<endl;
return false;
}
vector< vector<string> > vectrain;
while (input.peek() != EOF)
{
char oneline[512];
input.getline(oneline,512);
string line = oneline;
trimspace(line);
vector<string> vecout;
getsplitstring(line,',',vecout);
vectrain.push_back(vecout);
}

vector< vector<attrelem> > vecattrlist;
for (int j=0;j<16;j++)
{
vector<attrelem> attrlist;
for (int i = 0;i<vectrain.size();i++)
{
attrelem elem;
elem.attrval = atof(vectrain[i][j].c_str());
elem.classlabel = vectrain[i][16];
elem.rid = i;
elem.attridx = j;
attrlist.push_back(elem);
}
//sort(attrlist.begin(),attrlist.end(),lessthan);
vecattrlist.push_back(attrlist);
}

node_ptr root;

buildtree(vecattrlist,&root);

//outputtree(root,0);
vector< vector<string> > veczoo;

ifstream zooinput(argv[1]);

if (!zooinput)
{
cout<<"Error: open file failed."<<endl;
return false;
}

while (zooinput.peek() != EOF)
{
char oneline[512];
zooinput.getline(oneline,512);
string line = oneline;
trimspace(line);
vector<string> vecout;
getsplitstring(line,',',vecout);

string type = gettype(vecout,root);
cout<<type<<endl;
}
return true;
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: