您的位置:首页 > 大数据 > 人工智能

人工智能(AI)之朴素贝叶斯(NB)的基本实现

2015-12-17 21:59 579 查看
训练集测试集下载地址

具体的公式我就不一一描述了,主要看下图大概就能理解,主要是基于条件概率来实现的,最底下也有一个关于具体介绍的链接:



#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <sstream>
#include <string.h>
#include <set>
#include <cmath>
#include <iterator>
#include <queue>
#include <map>

using namespace std;
#define ANGER 0
#define DISGUST 1
#define FEAR 2
#define JOY 3
#define SAD 4
#define SURPRISE 5
const double lapace = 0.09;

char c[300];
priority_queue<double,vector<double>,greater<double> >q;
map<double,int>map1; //从小到大
map<double,int, greater<double> >map2; //从大到小double> >两者空格不可少
const string Str1 = "train", Str2 = "test";
set<string> sets;
bool vector_old[2000][4000];
double vector2[2000][4000];
double proba[9][2000];
double newproba[9][2000];
double dis_save[2000];
double K;
int num1=0;

void readanger()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/anger_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[ANGER][i++] = d;
}
in.close();
}

void readdisgust()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/disgust_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[DISGUST][i++] = d;
}
in.close();
}

void readfear()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/fear_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[FEAR][i++] = d;
}
in.close();
}

void readjoy()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/joy_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[JOY][i++] = d;
}
in.close();
}

void readsad()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/sad_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[SAD][i++] = d;
}
in.close();
}

void readsurprise()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/surprise_train.txt");
int i = 0;
while (in && i < 246){
memset(c, 0, sizeof(c));
in.getline(c, 300);
string s;
s.append(c, 300);
stringstream ss(s);
ss >> s; // 第一个单词不用
double d;
ss >> d;
proba[SURPRISE][i++] = d;
}
in.close();
}

void get_proba()
{
readanger();
readdisgust();
readfear();
readsad();
readjoy();
readsurprise();
}

void get_word()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/anger.txt");
string str;
int i = 0;
if(in&&out)
{
while(getline(in,str))
{
if(i==0)
{
i++;
continue;
}
else
{
int j = 0;
stringstream ss;
ss << str;
while(!ss.eof())
{
{
if(j==0)
{
j++;
ss >> str;
str = " ";
sets.insert(str);
}
//cout << str <<endl;
else
{
ss >> str;
sets.insert(str);
}
}
}
}
}
}else{
cerr<<"open in or out file error"<<endl;
}

for(set<string>::iterator it = sets.begin();it != sets.end();it++)
{
if(*it != " ")
{
out << *it << endl;
//cout << *it << endl;
}

}
in.close();
out.close();
}

void clear_stopwords()
{
fstream in;
in.open("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplist (1).txt");
ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplistout.txt");
string str;
if(in)
{
while(getline(in,str))
{
stringstream ss;
ss << str;
while(!ss.eof())
{
ss >> str;
out << str <<endl;
for(set<string>::iterator it = sets.begin();it != sets.end();)
{
if(*it == str)
{
sets.erase(it);
break;
}
else
{
it++;
}
}
}
}
}
in.close();
out.close();
}

void vector_out()
{
ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/vector.txt");
string str;
int i = 0;
int row_num = 0;
while(in&&out)
{
while(getline(in,str))
{

if(i==0)
{
i++;
continue;
}
else
{
int j = 0;
stringstream ss;
ss << str;
while(!ss.eof())
{
int lin_num = 0;
if(j==0)
{
j++;
ss >> str;
}
else
{
ss >> str;
for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++)
{
if(*it == str)
{
vector_old[row_num][lin_num] = true;
}
lin_num++;
}
}
}
}
row_num++;
}
}
string wenben = "文本编号 ";
out << wenben;
for(set<string>::iterator it= sets.begin(); it != sets.end(); it++)
{
out << *it << " ";
}
in.close();
out.close();
}

void compute_dis()
{
for (int i = 0; i < 1246; i++){
double sum = 0;
for (int j = 0; j < sets.size(); j++){
if (vector_old[i][j])
{
sum++;
}
}
for (int j = 0; j < sets.size(); j++){
vector2[i][j] = vector_old[i][j]*1.0/sum;
//out << vector2[i][j] << " ";
}
//out <<endl;
}
double newpro_sum[1009] = {0};
for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
{

for(int i = 0 ; i < 1000 ; i++)
{
double dis_sum = 0;
double pro_sum = 0;
double dis;
int pos;
for(int j = 0 ; j < 246 ; j++)
{
double same_words = 1;
for(int k = 0 ; k < sets.size(); k++)
{
if(vector2[i+246][k] > 0)
{
if(vector2[j][k] == 0)
{
same_words*=lapace;
}
else
{
same_words*=vector2[j][k];
}
}
}
if(proba[j] > 0)
{
pro_sum+=same_words*proba[mood_n][j];
}
else
{
pro_sum+=same_words*lapace;
}
}
newproba[mood_n][i] = pro_sum;
}
}

for(int i = 0 ; i < 1000; i++){
for(int mood_n = 0 ; mood_n < 6 ; mood_n++){
newpro_sum[i]+=newproba[mood_n][i];
}
}

for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
{
for(int i = 0 ; i < 1000 ; i++)
{
//cout << newpro_sum <<endl;
newproba[mood_n][i] = newproba[mood_n][i] / newpro_sum[i];
}
}

cout << "happy" <<endl;
}

void print()
{
for(int i = 0 ; i < 6 ; i++)
{

ofstream f;
switch(i)
{
case ANGER:    f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break;
case DISGUST:  f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break;
case FEAR:     f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break;
case JOY:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break;
case SAD:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break;
case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break;
}
for(int j = 0 ; j < 1000 ; j++)
{
f << newproba[i][j] <<endl;
//cout << newproba[i][j] <<endl;
}
f.close();
}
}

int main()
{
get_word();
cout << 0 <<endl;
clear_stopwords();
cout << 1 <<endl;
get_proba();
cout << 2 <<endl;
vector_out();
cout << 3 <<endl;
compute_dis();
cout << 4 <<endl;
print();
cout << 5 <<endl;
cout << sets.size() <<endl;

return 0;
}


NB参考网址
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息