实现mahout0.9 bayes 预测功能(mahout只有trainnb和testnb)
2014-10-11 09:08
183 查看
1、概述
mahout0.9 对贝叶斯模型只提供了训练trainnb和测试testnb函数,仅能够得到模型和测试模型的好坏,没有实现模型预测功能,通过对mahout源码的解读,自己编写了mahout bayes模型的预测功能。mahout0.9贝叶斯的使用方式见http://blog.csdn.net/mach_learn/article/details/396677132、mahout不支持predict原因
mahout0.9将训练集合测试集同时进行序列化和向量化,然后再将向量化的文件进行分片,分为测试集合训练集。mahout在向量化时会生成以下文件其中,dictionary.file-0文件将词对应到整形序号,key对应词或标点符号等,value代表序号值(整数)。frequency.file-0的key值对应序号,value值为key序号对应的词在多少文件中出现。df-count文件夹存放的是document
frenquency的数据。tf-vectors中存放的是每个文件的term frenquency。tfidf-vectors中存放的是每个文件中词序号和对应的tfidf值。tokenized-documents中存放的是分词后的文件。wordcount存放的是每个词在全部文档中的词频。
mahout向量化结束后将tfidf-vectors中的文件进行分片,分为训练集和测试集,一般是80-20比例,然后使用trainnb对训练集训练得到naiveBayesModel.bin模型,之后再使用testnb对naiveBayesModel.bin模型进行测试评估。
mahout进行统一向量化后会有一个统一的dictionary文件,这就导致了其他单独通过seq2sparse进行向量化的文件时不能使用其他训练数据得到的naiveBayesModel.bin模型,因为两个向量的dictionary是不一样的。
3、mahout预测函数编写思路
为了使用naiveBayesModel.bin模型进行预测,我们需要将需要预测的数据根据使用模型的向量化标准进行处理(即要使预测数据与产生向量时的dictionary等文件对于起来)。首先,将预测数据对应到相应对的dictionary,然后,根据对应词的序号获取df-count数据,之后计算该数据对应的tfidf数据(计算tfidf仅使用df-count和numdocs,以及预测数据的词频),numdocs是df-count中key为-1对应的value值。将tfidf数据代入naiveBayesModel.bin模型,即可求得每种类别对应的似然值,取最大值对应的类别,即是预测类别。编程环境,mahout0.9,需要的jar包见下图
程序需要mahout训练的模型和seq2sparse向量化的文件。seq2sparse向量化的文件需要使用mahout seqdumper -i inputfile -o outputfile命令,将序列化文件转为文本文件。文件结构图如下:
代码如下:
import java.awt.print.Printable; import java.io.BufferedReader; import java.io.File; import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.shell.Count; import org.apache.hadoop.hdfs.server.namenode.status_jsp; import org.apache.hadoop.mapred.ID; import org.apache.mahout.cf.taste.hadoop.als.PredictionMapper; import org.apache.mahout.classifier.naivebayes.AbstractNaiveBayesClassifier; import org.apache.mahout.classifier.naivebayes.ComplementaryNaiveBayesClassifier; import org.apache.mahout.classifier.naivebayes.NaiveBayesModel; import org.apache.mahout.common.AbstractJob; import org.apache.mahout.math.NamedVector; import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.Vector.Element; import org.apache.mahout.math.hadoop.similarity.cooccurrence.MutableElement; import org.apache.mahout.vectorizer.TFIDF; import org.apache.mahout.vectorizer.TFIDF.*; import com.ibm.icu.impl.ICUService.Key; public class BayesPredict extends AbstractJob { public static HashMap<String, String> dictionaryHashMap = new HashMap<>(); public static HashMap<String, String> dfcountHashMap = new HashMap<>(); public static HashMap<String, String> wordcountHashMap = new HashMap<>(); public static HashMap<String, String> labelindexHashMap = new HashMap<>(); public BayesPredict() { readDfCount("model/df-count.txt"); readDictionary("model/dictionary.txt"); readLabelIndex("model/labelindex.txt"); readWordCount("model/wordcount.txt"); } public static String[] readFile(String filename) { File file = new File(filename); BufferedReader reader; String tempstring = null; try { reader = new BufferedReader(new FileReader(file)); tempstring = reader.readLine(); reader.close(); if(tempstring==null) return null; } catch (IOException e) { e.printStackTrace(); } String[] mess = tempstring.trim().split(" "); return mess; } public static void readDictionary(String fileName) { File file = new File(fileName); BufferedReader reader; String tempstring = null; try { reader = new BufferedReader(new FileReader(file)); while((tempstring = reader.readLine())!=null) { if(tempstring.startsWith("Key:")) { String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2); String value = tempstring.substring(tempstring.lastIndexOf(":")+1); dictionaryHashMap.put(key.trim(), value.trim()); } } reader.close(); } catch (IOException e) { e.printStackTrace(); } } public static void readDfCount(String fileName) { File file = new File(fileName); BufferedReader reader; String tempstring = null; try { reader = new BufferedReader(new FileReader(file)); while((tempstring = reader.readLine())!=null) { if(tempstring.startsWith("Key:")) { String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2); String value = tempstring.substring(tempstring.lastIndexOf(":")+1); dfcountHashMap.put(key.trim(), value.trim()); } } reader.close(); } catch (IOException e) { e.printStackTrace(); } } public static void readWordCount(String fileName) { File file = new File(fileName); BufferedReader reader; String tempstring = null; try { reader = new BufferedReader(new FileReader(file)); while((tempstring = reader.readLine())!=null) { if(tempstring.startsWith("Key:")) { String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2); String value = tempstring.substring(tempstring.lastIndexOf(":")+1); wordcountHashMap.put(key.trim(), value.trim()); } } reader.close(); } catch (IOException e) { e.printStackTrace(); } } public static void readLabelIndex(String fileName) { File file = new File(fileName); BufferedReader reader; String tempstring = null; try { reader = new BufferedReader(new FileReader(file)); while((tempstring = reader.readLine())!=null) { if(tempstring.startsWith("Key:")) { String key = tempstring.substring(tempstring.indexOf(":")+1, tempstring.indexOf("Value")-2); String value = tempstring.substring(tempstring.lastIndexOf(":")+1); labelindexHashMap.put(key.trim(), value.trim()); } } reader.close(); } catch (IOException e) { e.printStackTrace(); } } public static HashMap<Integer, Double> calcTfIdf(String filename) { String[] words = readFile(filename); if(words==null) return null; HashMap<Integer, Double> tfidfHashMap = new HashMap<Integer, Double>(); HashMap<String, Integer> wordHashMap = new HashMap<String, Integer>(); for(int k=0; k<words.length; k++) { if(wordHashMap.get(words[k])==null) { wordHashMap.put(words[k], 1); } else { wordHashMap.put(words[k], wordHashMap.get(words[k])+1); } } // System.out.println("wordcount:"+wordHashMap.size()); /* System.out.println("dfcount:"+dfcountHashMap.size()); System.out.println("dictionary:"+dictionaryHashMap.size()); System.out.println("labelindex:"+labelindexHashMap.size()); System.out.println("wordcount:"+wordcountHashMap.size()); */ Iterator iterator = wordHashMap.entrySet().iterator(); int numDocs = Integer.parseInt(dfcountHashMap.get("-1")); while(iterator.hasNext()) { Map.Entry<String, Integer> entry = (Map.Entry<String, Integer>)iterator.next(); String key = entry.getKey(); int value = entry.getValue(); int tf = value; // System.out.println(key+":"+value); if(dictionaryHashMap.get(key)!=null) { String idString = dictionaryHashMap.get(key); int df = Integer.parseInt(dfcountHashMap.get(idString)); TFIDF tfidf = new TFIDF(); double tfidf_value = tfidf.calculate(tf, df, 0, numDocs); tfidfHashMap.put(Integer.parseInt(idString), tfidf_value); // System.out.println(idString+":"+tfidf_value); } } return tfidfHashMap; } public String predict(String filename) throws IOException { HashMap<Integer, Double> tfidfHashMap = calcTfIdf(filename); if(tfidfHashMap==null) return "file is empty,unknow classify"; // FileSystem fs = FileSystem.get(getConf()); NaiveBayesModel model = NaiveBayesModel.materialize(new Path("model/model/"), getConf()); ComplementaryNaiveBayesClassifier classifier; classifier = new ComplementaryNaiveBayesClassifier(model); double label_1=0; double label_2=0; Iterator iterator = tfidfHashMap.entrySet().iterator(); while(iterator.hasNext()) { Map.Entry<Integer, Double> entry = (Map.Entry<Integer, Double>)iterator.next(); int key = entry.getKey(); double value = entry.getValue(); label_1 += value*classifier.getScoreForLabelFeature(0, key); label_2 += value*classifier.getScoreForLabelFeature(1, key); } // System.out.println("label_1:"+label_1); // System.out.println("label_2:"+label_2); if(label_1>label_2) return "fraud-female"; else return "norm-female"; } @Override public int run(String[] arg0) throws Exception { // TODO Auto-generated method stub return 0; } public static void main(String[] args) { //dictionary test /* readDictionary("model/dictionary.txt"); Iterator iterator = dictionaryHashMap.entrySet().iterator(); while(iterator.hasNext()) { Map.Entry<String, String> entry = (Map.Entry<String, String>)iterator.next(); System.out.println(entry.getKey()+"--"+entry.getValue()); } System.out.println(dictionaryHashMap.size()); System.out.println(System.getProperty("user.dir")); */ long startTime=System.currentTimeMillis(); BayesPredict bPredict = new BayesPredict(); try { File file = new File("model/test/"); String[] filenames = file.list(); int count1 = 0; int count2 = 0; int count = 0; for(int i=0;i<filenames.length;i++) { String result = bPredict.predict("model/test/"+filenames[i]); count++; if(result.equals("fraud-female")) count1++; else if(result.equals("norm-female")) count2++; System.out.println(filenames[i]+":"+result); } System.out.println("count:"+count); System.out.println("count1:"+count1); System.out.println("count2:"+count2); System.out.println("time:"+(System.currentTimeMillis()-startTime)/1000.0); } catch (IOException e) { e.printStackTrace(); } } }
相关文章推荐
- mahout 0.9 + hadoop 1.0.2 实现中文文本聚类
- 自定义对象的归档 //类只有实现 NSCoding 协议才具备归档功能 归档最好定义成宏,笔误好发现.
- 【TextView】自己实现的TextView,只有显示文字功能。
- C++链式栈的简单实现(只有基本功能)
- JS——实现短信验证码的倒计时功能(没有验证码,只有倒计时)
- System.Threading.Mutex:一台电脑上面只有一个进程实例在运行,利用Mutex互斥量可以实现了这个功能
- 02 机器学习算法库Mahout - 协同过滤算法实现推荐功能
- Vue 实现拖动滑块验证功能(只有css+js没有后台验证步骤)
- Azure HDInsights实战:使用Azure Hadoop和Mahout实现推荐功能
- EditText实现全选和复制的功能,自定义样式只有底部才有线
- JS短信验证码倒计时功能的实现(没有验证码,只有倒计时)
- Mahout实现的分类算法,两个例子,预测期望的目标变量
- Mahout中分布式bayes实现-转
- 俄罗斯方块带预测功能---C语言实现
- Mahout实现的分类算法,两个例子,预测期望的目标变量
- xml应用(3):附带选择功能的树的实现 XMLSelTree(V1.0)
- Js中实现拼音和UrlEncode的功能(利用GB和Unicode对照表)
- 设计模式:利用Command模式实现无限次数的Undo/Redo功能
- 在JAVA应用程序中如何实现FTP的功能
- 利用C++Builder 中OLE自动化功能实现调用Word进行报表制作