您的位置:首页 > 大数据 > 人工智能

实现mahout0.9 bayes 预测功能(mahout只有trainnb和testnb)

2014-10-11 09:08 183 查看

1、概述

mahout0.9 对贝叶斯模型只提供了训练trainnb和测试testnb函数,仅能够得到模型和测试模型的好坏,没有实现模型预测功能,通过对mahout源码的解读,自己编写了mahout bayes模型的预测功能。mahout0.9贝叶斯的使用方式见http://blog.csdn.net/mach_learn/article/details/39667713

2、mahout不支持predict原因

mahout0.9将训练集合测试集同时进行序列化和向量化,然后再将向量化的文件进行分片,分为测试集合训练集。mahout在向量化时会生成以下文件

其中,dictionary.file-0文件将词对应到整形序号,key对应词或标点符号等,value代表序号值(整数)。frequency.file-0的key值对应序号,value值为key序号对应的词在多少文件中出现。df-count文件夹存放的是document
frenquency的数据。tf-vectors中存放的是每个文件的term frenquency。tfidf-vectors中存放的是每个文件中词序号和对应的tfidf值。tokenized-documents中存放的是分词后的文件。wordcount存放的是每个词在全部文档中的词频。

mahout向量化结束后将tfidf-vectors中的文件进行分片,分为训练集和测试集,一般是80-20比例,然后使用trainnb对训练集训练得到naiveBayesModel.bin模型,之后再使用testnb对naiveBayesModel.bin模型进行测试评估。

mahout进行统一向量化后会有一个统一的dictionary文件,这就导致了其他单独通过seq2sparse进行向量化的文件时不能使用其他训练数据得到的naiveBayesModel.bin模型,因为两个向量的dictionary是不一样的。

3、mahout预测函数编写思路

为了使用naiveBayesModel.bin模型进行预测,我们需要将需要预测的数据根据使用模型的向量化标准进行处理(即要使预测数据与产生向量时的dictionary等文件对于起来)。首先,将预测数据对应到相应对的dictionary,然后,根据对应词的序号获取df-count数据,之后计算该数据对应的tfidf数据(计算tfidf仅使用df-count和numdocs,以及预测数据的词频),numdocs是df-count中key为-1对应的value值。将tfidf数据代入naiveBayesModel.bin模型,即可求得每种类别对应的似然值,取最大值对应的类别,即是预测类别。

编程环境,mahout0.9,需要的jar包见下图



程序需要mahout训练的模型和seq2sparse向量化的文件。seq2sparse向量化的文件需要使用mahout seqdumper -i inputfile -o outputfile命令,将序列化文件转为文本文件。文件结构图如下:



代码如下:

import java.awt.print.Printable;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;

import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.fs.shell.Count;
import org.apache.hadoop.hdfs.server.namenode.status_jsp;
import org.apache.hadoop.mapred.ID;
import org.apache.mahout.cf.taste.hadoop.als.PredictionMapper;
import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier;
import org.apache.mahout.classifier.naivebayes.NaiveBayesModel;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.math.NamedVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.hadoop.similarity.cooccurrence.MutableElement;
import org.apache.mahout.vectorizer.TFIDF;
import org.apache.mahout.vectorizer.TFIDF.*;

import com.ibm.icu.impl.ICUService.Key;

public class BayesPredict extends AbstractJob
{

public static HashMap<String, String> dictionaryHashMap = new HashMap<>();
public static HashMap<String, String> dfcountHashMap = new HashMap<>();
public static HashMap<String, String> wordcountHashMap = new HashMap<>();
public static HashMap<String, String> labelindexHashMap = new HashMap<>();
public BayesPredict()
{
readDfCount("model/df-count.txt");
readDictionary("model/dictionary.txt");
readLabelIndex("model/labelindex.txt");
readWordCount("model/wordcount.txt");
}
public static String[] readFile(String filename)
{

File file = new File(filename);
BufferedReader reader;
String tempstring = null;
try
{
reader = new BufferedReader(new FileReader(file));
tempstring = reader.readLine();
reader.close();
if(tempstring==null)
return null;
}
catch (IOException e)
{
e.printStackTrace();
}

String[] mess = tempstring.trim().split(" ");
return mess;
}
public static void readDictionary(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{

reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{

if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
dictionaryHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readDfCount(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{

reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{

if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
dfcountHashMap.put(key.trim(), value.trim());
}

}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readWordCount(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{

reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{

if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
wordcountHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static void readLabelIndex(String fileName)
{
File file = new File(fileName);
BufferedReader reader;
String tempstring = null;
try
{

reader = new BufferedReader(new FileReader(file));
while((tempstring = reader.readLine())!=null)
{

if(tempstring.startsWith("Key:"))
{
String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2);
String value = tempstring.substring(tempstring.lastIndexOf(":")+1);
labelindexHashMap.put(key.trim(), value.trim());
}
}
reader.close();
}
catch (IOException e)
{
e.printStackTrace();
}
}
public static HashMap<Integer, Double> calcTfIdf(String filename)
{
String[] words = readFile(filename);
if(words==null)
return null;
HashMap<Integer, Double> tfidfHashMap = new HashMap<Integer, Double>();
HashMap<String, Integer> wordHashMap = new HashMap<String, Integer>();
for(int k=0; k<words.length; k++)
{
if(wordHashMap.get(words[k])==null)
{
wordHashMap.put(words[k], 1);
}
else
{
wordHashMap.put(words[k], wordHashMap.get(words[k])+1);
}
}

//		System.out.println("wordcount:"+wordHashMap.size());

/*
System.out.println("dfcount:"+dfcountHashMap.size());
System.out.println("dictionary:"+dictionaryHashMap.size());
System.out.println("labelindex:"+labelindexHashMap.size());
System.out.println("wordcount:"+wordcountHashMap.size());
*/

Iterator iterator = wordHashMap.entrySet().iterator();
int numDocs = Integer.parseInt(dfcountHashMap.get("-1"));

while(iterator.hasNext())
{
Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)iterator.next();
String key = entry.getKey();
int value = entry.getValue();
int tf = value;
//			System.out.println(key+":"+value);
if(dictionaryHashMap.get(key)!=null)
{
String idString = dictionaryHashMap.get(key);
int df = Integer.parseInt(dfcountHashMap.get(idString));
TFIDF tfidf = new TFIDF();
double tfidf_value = tfidf.calculate(tf, df, 0, numDocs);

tfidfHashMap.put(Integer.parseInt(idString), tfidf_value);
//				System.out.println(idString+":"+tfidf_value);
}

}
return tfidfHashMap;
}
public String predict(String filename) throws IOException
{

HashMap<Integer, Double> tfidfHashMap = calcTfIdf(filename);
if(tfidfHashMap==null)
return "file is empty,unknow classify";
//		FileSystem fs = FileSystem.get(getConf());
NaiveBayesModel model = NaiveBayesModel.materialize(new Path("model/model/"), getConf());
ComplementaryNaiveBayesClassifier classifier;
classifier = new ComplementaryNaiveBayesClassifier(model);

double label_1=0;
double label_2=0;

Iterator iterator = tfidfHashMap.entrySet().iterator();
while(iterator.hasNext())
{
Map.Entry<Integer, Double> entry = (Map.Entry<Integer, Double>)iterator.next();
int key = entry.getKey();
double value = entry.getValue();
label_1 += value*classifier.getScoreForLabelFeature(0, key);
label_2 += value*classifier.getScoreForLabelFeature(1, key);
}
//		System.out.println("label_1:"+label_1);
//		System.out.println("label_2:"+label_2);
if(label_1>label_2)
return "fraud-female";
else
return "norm-female";
}
@Override
public int run(String[] arg0) throws Exception {
// TODO Auto-generated method stub
return 0;
}
public static void main(String[] args)
{

//dictionary test
/*
readDictionary("model/dictionary.txt");
Iterator iterator = dictionaryHashMap.entrySet().iterator();
while(iterator.hasNext())
{
Map.Entry<String, String> entry = (Map.Entry<String, String>)iterator.next();
System.out.println(entry.getKey()+"--"+entry.getValue());
}
System.out.println(dictionaryHashMap.size());
System.out.println(System.getProperty("user.dir"));

*/
long startTime=System.currentTimeMillis();
BayesPredict bPredict = new BayesPredict();
try {
File file = new File("model/test/");
String[] filenames = file.list();
int count1 = 0;
int count2 = 0;
int count = 0;
for(int i=0;i<filenames.length;i++)
{
String result = bPredict.predict("model/test/"+filenames[i]);
count++;
if(result.equals("fraud-female"))
count1++;
else if(result.equals("norm-female"))
count2++;
System.out.println(filenames[i]+":"+result);

}
System.out.println("count:"+count);
System.out.println("count1:"+count1);
System.out.println("count2:"+count2);
System.out.println("time:"+(System.currentTimeMillis()-startTime)/1000.0);
} catch (IOException e) {
e.printStackTrace();
}
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息