您的位置:首页 > 编程语言 > Java开发

机器学习实战朴素贝叶斯的java实现

2015-10-23 00:46 363 查看
package com.haolidong.Bayes;

import java.util.ArrayList;

/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class Matrix {
public ArrayList<ArrayList<String>> data;

public Matrix() {
// TODO Auto-generated constructor stub
data = new ArrayList<ArrayList<String>>();
}
}
package com.haolidong.Bayes;

import java.util.ArrayList;

/**
*
* @author haolidong
* @Description: [该类主要用于保存特征信息以及标签值]
* @parameter labels: [主要保存标签值]
*/
public class CreateDataSet extends Matrix {
public ArrayList<String> labels;

public CreateDataSet() {
// TODO Auto-generated constructor stub
super();
labels = new ArrayList<String>();
}

/**
* @author haolidong
* @Description: [机器学习实战决策树第一个案例的数据]
*/
public void initTest() {
ArrayList<String> ab1 = new ArrayList<String>();
ArrayList<String> ab2 = new ArrayList<String>();
ArrayList<String> ab3 = new ArrayList<String>();
ArrayList<String> ab4 = new ArrayList<String>();
ArrayList<String> ab5 = new ArrayList<String>();
ArrayList<String> ab6 = new ArrayList<String>();
ab1.add("my");
ab1.add("dog");
ab1.add("has");
ab1.add("flea");
ab1.add("problems");
ab1.add("help");
ab1.add("please");
ab2.add("maybe");
ab2.add("not");
ab2.add("take");
ab2.add("him");
ab2.add("to");
ab2.add("dog");
ab2.add("park");
ab2.add("stupid");
ab3.add("my");
ab3.add("dalmation");
ab3.add("is");
ab3.add("so");
ab3.add("cute");
ab3.add("I");
ab3.add("love");
ab3.add("him");
ab4.add("stop");
ab4.add("posting");
ab4.add("stupid");
ab4.add("worthless");
ab4.add("garbage");
ab5.add("mr");
ab5.add("licks");
ab5.add("ate");
ab5.add("my");
ab5.add("steak");
ab5.add("how");
ab5.add("to");
ab5.add("stop");
ab5.add("him");
ab6.add("quit");
ab6.add("buying");
ab6.add("worthless");
ab6.add("dog");
ab6.add("food");
ab6.add("stupid");
data.add(ab1);
data.add(ab2);
data.add(ab3);
data.add(ab4);
data.add(ab5);
data.add(ab6);

labels.add("0");
labels.add("1");
labels.add("0");
labels.add("1");
labels.add("0");
labels.add("1");
}
}
package com.haolidong.Bayes;

import java.util.ArrayList;
/**
*
* @parameter p0Vect 类别0的特征向量(概率向量)
* @parameter p1Vect 类别1的特征向量(概率向量)
* @parameter pAbusive 正样本(为1的样本)的比例
* @author haolidong
* @Description: [该类主要用于保存特征信息]
* @parameter data: [主要保存特征矩阵]
*/
public class TrainNB0DataSet {
public ArrayList<Double> p0Vect;
public ArrayList<Double> p1Vect;
public double pAbusive;

public TrainNB0DataSet() {
p0Vect = new ArrayList<Double>();
p1Vect = new ArrayList<Double>();
pAbusive = 0.0;
}
}


package com.haolidong.Bayes;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashSet;

public class Bayes {
public static void main(String[] args) {
spamTest();
}
/**
* @param end  从0到end的范围中产生num个不重复的随机数
* @param num  num个随机数
* @return 返回产生的n个随机数
* @author haolidong
* @Description: [从0到end的范围中产生num个不重复的随机数]
*/
public static HashSet<Integer> randomdif(int end,int num){
HashSet<Integer> rndint = new HashSet<Integer>();
rndint.size();
while ( rndint.size() < num ) {
rndint.add((int) (Math.random()*end));
}
return rndint;
}
/**
* @author haolidong
* @Description: [垃圾邮件分类测试]
*/
public static void spamTest(){
ArrayList<String> fullText = new ArrayList<String>();
CreateDataSet DataSet = new CreateDataSet();
for (int i = 1; i < 26; i++) {
ArrayList<String> hamWordList = new ArrayList<String>();
ArrayList<String> spamWordList = new ArrayList<String>();
String hamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\ham\\"+i+".txt");
String spamPath = new String("I:\\machinelearninginaction\\Ch04\\email\\spam\\"+i+".txt");
hamWordList = textParse(spamPath, 2);
DataSet.data.add(hamWordList);
DataSet.labels.add("1");
for (int j = 0; j < hamWordList.size(); j++) {
fullText.add(hamWordList.get(j));
}
spamWordList=textParse(hamPath, 2);
DataSet.data.add(spamWordList);
DataSet.labels.add("0");
for (int j = 0; j < spamWordList.size(); j++) {
fullText.add(spamWordList.get(j));
}
}
//获取词典
HashSet<String> vocabList = new HashSet<String>();
vocabList = createVocabList(DataSet);
HashSet<Integer> rndint = new HashSet<Integer>();
//随机产生10个测试集,其余的为训练集
rndint = randomdif(50,10);
Matrix testMatrix = new Matrix();
Matrix trainMatrix = new Matrix();
ArrayList<String> trainLabels = new ArrayList<String>();
ArrayList<String> testLabels = new ArrayList<String>();
Matrix testMatrixTrans = new Matrix();
Matrix trainMatrixTrans = new Matrix();
for(Integer i:rndint){
testMatrix.data.add(DataSet.data.get(i));
testLabels.add(DataSet.labels.get(i));
}
for (int i = 0; i < DataSet.data.size(); i++) {
if(!rndint.contains(i)){
trainMatrix.data.add(DataSet.data.get(i));
trainLabels.add(DataSet.labels.get(i));
}
}
//转化到0 1矩阵
for (int i = 0; i < trainMatrix.data.size(); i++) {
trainMatrixTrans.data.add(setOfWords2Vec(vocabList,trainMatrix.data.get(i)));
}
for (int i = 0; i < testMatrix.data.size(); i++) {
testMatrixTrans.data.add(setOfWords2Vec(vocabList,testMatrix.data.get(i)));
}
//训练集的训练
TrainNB0DataSet td = new TrainNB0DataSet();
td = trainNB0(trainMatrixTrans,trainLabels);
//对测试集进行测试
int errorCount=0;
for (int i = 0; i < testMatrixTrans.data.size(); i++) {
int num=classifyNB(testMatrixTrans.data.get(i), td.p0Vect, td.p1Vect, td.pAbusive);
System.out.println("the predict:"+num+" , the real:"+testLabels.get(i));
if(num!=Integer.parseInt(testLabels.get(i))){
errorCount++;
}
}
System.out.println("the errorRate is:"+1.0*errorCount/testMatrixTrans.data.size());
}
public static ArrayList<String> textParse(String fileName,int moreThan){
ArrayList<String> strSplitList = new ArrayList<String>();
String s = readFile(fileName);
strSplitList = extractStrlist(s,moreThan);
return strSplitList;

}
/**
* @param fileName  输入的完整文件路径
* @return 所有的文件内容的字符串
* @author haolidong
* @Description: [一行一行读取文件,然后用字符串全部串起来返回,每一行之间使用空格分割]
*/
public static String readFile(String fileName) {
File file = new File(fileName);
BufferedReader reader = null;
String s = new String();
try {
reader = new BufferedReader(new FileR
4000
eader(file));
String tempString = null;
// 一次读入一行,直到读入null为文件结束
while ((tempString = reader.readLine()) != null) {
//加上" "是为了和下面一段的字符进行区分
s=s+tempString+" ";
}
reader.close();
} catch (IOException e) {
e.printStackTrace();
} finally {
if (reader != null) {
try {
reader.close();
} catch (IOException e1) {
}
}
}
return s;
}

/**
* @param inputString 输入的字符串
* @param moreThan    只有超过moreThan的字符串才会被保留
* @return    分割好的数据串
* @author haolidong
* @Description: [读取一个字符串,进行分割,去掉除了字母数字以外的字符数组,而且所有的字符都改成小写]
*/
public static ArrayList<String> extractStrlist(String inputString,int moreThan) {
ArrayList<String> strSplitList = new ArrayList<String>();
String regEx = "\\W*";
String sentence="";
//		String inputString = "This book is the best book on M.L. I have";
String[] predel = inputString.split(regEx);
for (int i = 0; i < predel.length; i++) {
if(predel[i].equals(""))
sentence+=" ";
else
sentence+=predel[i];
}
String[] strSplit=sentence.split(" ");
for (int i = 0; i < strSplit.length; i++) {
if(strSplit[i].length()>moreThan) {
strSplitList.add(strSplit[i].toLowerCase());
}
}
return strSplitList;
}

/**
* @param vec2Classify   需要进行分类的向量
* @param p0Vec          类别0的权值向量
* @param p1Vec          类别1的权值向量
* @param pClass1                            类别1所占的比重
* @return               返回最后的分类结果
* @author haolidong
* @Description: [计算在每一类中最后的概率返回最大的所对应的标签]
*/
public static int classifyNB(ArrayList<String> vec2Classify, ArrayList<Double> p0Vec, ArrayList<Double> p1Vec,
double pClass1) {
double p1 = 0.0;
double p0 = 0.0;
for (int i = 0; i < vec2Classify.size(); i++) {
p1 = p1 + Double.parseDouble(vec2Classify.get(i)) * p1Vec.get(i);
p0 = p0 + Double.parseDouble(vec2Classify.get(i)) * p0Vec.get(i);
}
p1 = p1 + Math.log(pClass1);
p0 = p0 + Math.log(1 - pClass1);
if (p1 > p0)
return 1;
else
return 0;
}

/**
* @param trainMatrix      训练矩阵
* @param trainCategory    训练目录标签
* @return                 返回最后训练结果,包括每一类的特征矩阵以及每一类的比重情况
* @author haolidong
* @Description: [贝叶斯分类的重点函数,数据集的训练,返回特征矩阵和向量]
*/
public static TrainNB0DataSet trainNB0(Matrix trainMatrix, ArrayList<String> trainCategory) {
int numTrainDocs = trainMatrix.data.size();
int numWords = trainMatrix.data.get(0).size();
TrainNB0DataSet resultSet = new TrainNB0DataSet();
ArrayList<Double> p0Num = new ArrayList<Double>();
ArrayList<Double> p1Num = new ArrayList<Double>();
double trainCategorySum = 0.0;
for (int i = 0; i < trainCategory.size(); i++) {
trainCategorySum = trainCategorySum + Double.parseDouble(trainCategory.get(i));
}
resultSet.pAbusive = trainCategorySum / numTrainDocs;
for (int i = 0; i < numWords; i++) {
p0Num.add(1.0);
p1Num.add(1.0);
}
double p0Denom = 2.0;
double p1Denom = 2.0;
for (int i = 0; i < numTrainDocs; i++) {
if (trainCategory.get(i).equals("1")) {
for (int j = 0; j < numWords; j++) {
p1Num.set(j, p1Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));
}
} else {
for (int j = 0; j < numWords; j++) {
p0Num.set(j, p0Num.get(j) + Double.parseDouble(trainMatrix.data.get(i).get(j)));
}
}

}
for (int i = 0; i < numWords; i++) {
p0Denom += p0Num.get(i);
p1Denom += p1Num.get(i);
}
p0Denom = p0Denom - numWords;
p1Denom = p1Denom - numWords;
for (int i = 0; i < numWords; i++) {
resultSet.p0Vect.add(Math.log(p0Num.get(i) / p0Denom));
resultSet.p1Vect.add(Math.log(p1Num.get(i) / p1Denom));
}

return resultSet;
}

/**
* @param vocabSet       字典
* @param inputSet       输入数据集
* @return               返回与字典一一对应的数据集
* @author haolidong
* @Description: [生成一个全部为0的字典,把字典中数据集中有的字符串设置为1,其他的设置为0,返回设置完的字典]
*/
public static ArrayList<String> setOfWords2Vec(HashSet<String> vocabSet, ArrayList<String> inputSet) {
ArrayList<String> returnVec = new ArrayList<String>();
boolean flag;
for (String value : vocabSet) {
flag = false;
for (int i = 0; i < inputSet.size(); i++) {
if (inputSet.get(i).equals(value)) {
returnVec.add("1");
flag = true;
break;
}
}
if (flag == false) {
returnVec.add("0");
}
}
return returnVec;
}

/**
* @param dataSet    输入数据集
* @return           字典
* @author haolidong
* @Description: [输入数据集,数据有比较大的重复,然后去掉重复的数据,最后生成字典]
*/
public static HashSet<String> createVocabList(Matrix dataSet) {
HashSet<String> vocabSet = new HashSet<String>();
for (int i = 0; i < dataSet.data.size(); i++) {
for (int j = 0; j < dataSet.data.get(i).size(); j++) {
vocabSet.add(dataSet.data.get(i).get(j));
}
}
return vocabSet;

}

/**
* @author haolidong
* @Description: [对于生成字典功能的测试]
*/
public static void testVocabList() {
CreateDataSet dataSet = new CreateDataSet();
dataSet.initTest();
HashSet<String> vocabSet = new HashSet<String>();
vocabSet = createVocabList(dataSet);
System.out.println(vocabSet);
}

/**
* @author haolidong
* @Description: [对于输入字符集转化成字典的测试]
*/
public static void testWord2Vec() {
CreateDataSet dataSet = new CreateDataSet();
dataSet.initTest();
HashSet<String> vocabSet = new HashSet<String>();
ArrayList<String> returnVec = new ArrayList<String>();
vocabSet = createVocabList(dataSet);
returnVec = setOfWords2Vec(vocabSet, dataSet.data.get(0));
System.out.println(returnVec);
}

/**
* @author haolidong
* @Description: [对于样本训练的测试]
*/
public static void testTrain() {
CreateDataSet dataSet = new CreateDataSet();
Matrix trainMatrix = new Matrix();
dataSet.initTest();
HashSet<String> vocabSet = new HashSet<String>();
vocabSet = createVocabList(dataSet);
for (int i = 0; i < dataSet.data.size(); i++) {
trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));
}
trainNB0(trainMatrix, dataSet.labels);
}
/**
* @author haolidong
* @Description: [对于样本分类的测试]
*/
public static void testingNB() {
CreateDataSet dataSet = new CreateDataSet();
TrainNB0DataSet td = new TrainNB0DataSet();
ArrayList<String> testEntry = new ArrayList<String>();
Matrix trainMatrix = new Matrix();
dataSet.initTest();
HashSet<String> vocabSet = new HashSet<String>();
vocabSet = createVocabList(dataSet);
for (int i = 0; i < dataSet.data.size(); i++) {
trainMatrix.data.add(setOfWords2Vec(vocabSet, dataSet.data.get(i)));
}
td = trainNB0(trainMatrix, dataSet.labels);
testEntry.add("love");
testEntry.add("my");
testEntry.add("dalmation");
testEntry = setOfWords2Vec(vocabSet, testEntry);
System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));
testEntry.clear();
testEntry.add("stupid");
testEntry.add("garbage");
testEntry = setOfWords2Vec(vocabSet, testEntry);
System.out.println("classified as:"+classifyNB(testEntry,td.p0Vect,td.p1Vect,td.pAbusive));
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息