xgboost用于文本分类预测,java接口
2017-05-07 21:12
363 查看
周末花了两天时间从安装xgboost到用于文本预测,记录下,首先是把文本分词,去停顿词,计算tf-idf值,然后模型训练,模型保存,加载模型,模型预测:
训练模型代码:
package com.meituan.model.xgboost;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Map;
import java.util.Map.Entry;
import java.io.File;
import java.io.IOException;
import org.apache.commons.io.FileUtils;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class TrainXgboost {
private static String path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/";
private static String trainString = "agaricus.txt.train";
private static String testString = "agaricus.txt.test";
public static void main(String[] args) throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("file/train.txt");
DMatrix testMat = new DMatrix("file/test.txt");
// specify parameters
Map<String, Object> params = new HashMap<String, Object>();
params.put("booster", "gbtree");
params.put("eta", 0.6); // 为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重
params.put("max_depth", 22);// 树最大深度
params.put("silent", 0); // 为1的时候不会打印模型迭代的信息,为0可以看到打印的信息
params.put("lambda", 2.5);// 用于逻辑回归的时候L2正则选项
params.put("min_child_weight", 6);
// params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目
params.put("objective", "binary:logistic");// 目标函数值
// specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
// watches.put("test", testMat);
// train a booster
int round = 100;
Booster booster = XGBoost.train(trainMat, params, round, watches, null,
null);
// booster.saveModel("xgboost/xgboost.model");
Map<String, Integer> map = booster.getFeatureScore(null);
List<Map.Entry<String, Integer>> list = new ArrayList<Map.Entry<String, Integer>>(
map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() {
@Override
public int compare(Entry<String, Integer> o1,
Entry<String, Integer> o2) {
double result = o1.getValue() - o2.getValue();
if (result > 0) {
return -1;
} else {
return 1;
}
}
});
FileUtils.writeLines(new File("xgboost/keyword.txt"), list);
float[][] result = booster.predict(testMat);
/*
* for(int i=0;i<result.length;i++){
*
* for(int j=0;j<result[i].length;j++){
* System.out.print(result[i][j]+"\t"); } System.out.println(); }
*/
System.out.println("length is:" + result.length);
}
}
模型训练比较简单,先看看模型预测写的代码,准备的两个方法,把文本转化为libsvm的形式,再转化DMatrix:
用于文本预测:
训练模型代码:
package com.meituan.model.xgboost;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Map;
import java.util.Map.Entry;
import java.io.File;
import java.io.IOException;
import org.apache.commons.io.FileUtils;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
public class TrainXgboost {
private static String path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/";
private static String trainString = "agaricus.txt.train";
private static String testString = "agaricus.txt.test";
public static void main(String[] args) throws XGBoostError, IOException {
DMatrix trainMat = new DMatrix("file/train.txt");
DMatrix testMat = new DMatrix("file/test.txt");
// specify parameters
Map<String, Object> params = new HashMap<String, Object>();
params.put("booster", "gbtree");
params.put("eta", 0.6); // 为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重
params.put("max_depth", 22);// 树最大深度
params.put("silent", 0); // 为1的时候不会打印模型迭代的信息,为0可以看到打印的信息
params.put("lambda", 2.5);// 用于逻辑回归的时候L2正则选项
params.put("min_child_weight", 6);
// params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目
params.put("objective", "binary:logistic");// 目标函数值
// specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
// watches.put("test", testMat);
// train a booster
int round = 100;
Booster booster = XGBoost.train(trainMat, params, round, watches, null,
null);
// booster.saveModel("xgboost/xgboost.model");
Map<String, Integer> map = booster.getFeatureScore(null);
List<Map.Entry<String, Integer>> list = new ArrayList<Map.Entry<String, Integer>>(
map.entrySet());
Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() {
@Override
public int compare(Entry<String, Integer> o1,
Entry<String, Integer> o2) {
double result = o1.getValue() - o2.getValue();
if (result > 0) {
return -1;
} else {
return 1;
}
}
});
FileUtils.writeLines(new File("xgboost/keyword.txt"), list);
float[][] result = booster.predict(testMat);
/*
* for(int i=0;i<result.length;i++){
*
* for(int j=0;j<result[i].length;j++){
* System.out.print(result[i][j]+"\t"); } System.out.println(); }
*/
System.out.println("length is:" + result.length);
}
}
模型训练比较简单,先看看模型预测写的代码,准备的两个方法,把文本转化为libsvm的形式,再转化DMatrix:
public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){ if (StringUtils.isBlank(content)) { return null; } Map<String, Long> maps = ToAnalysis .parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil .replaceAll(content.toLowerCase())))) .getTerms() .stream() .map(x -> x.getName()) .filter(x -> !WordUtil.isStopword(x) ) .collect(Collectors.groupingBy(p -> p, Collectors.counting())); if (maps == null || maps.size() == 0) { return null; } int sum = maps.values().stream() .reduce((result, element) -> result = result + element).get() .intValue(); Map<Integer, Double> treemap = new TreeMap<Integer, Double>(); for (Entry<String, Long> map : maps.entrySet()) { String key = map.getKey(); Terms keyword = termsmap.get(key); double tf = TFIDF.tf(map.getValue(), sum); if (keyword == null) { continue; } int id = keyword.getId(); double idf = 0; idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(), keyword.getFreq()); double tfidf = TFIDF.tfidf(tf, idf); treemap.put(id, tfidf); } if (treemap.size() == 0) { return null; } CSRSparseData spData = new CSRSparseData(); List<Float> tlabels = new ArrayList<>(); List<Float> tdata = new ArrayList<>(); List<Long> theaders = new ArrayList<>(); List<Integer> tindex = new ArrayList<>(); theaders.add(0l); theaders.add((long) treemap.size()); for (Entry<Integer, Double> map : treemap.entrySet()) { BigDecimal b = new BigDecimal(Double.toString(map.getValue())); tdata.add(b.floatValue()); tindex.add(Integer.valueOf(map.getKey())); } spData.labels = ArrayUtils.toPrimitive(tlabels .toArray(new Float[tlabels.size()])); spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata .size()])); spData.colIndex = ArrayUtils.toPrimitive(tindex .toArray(new Integer[tindex.size()])); spData.rowHeaders = ArrayUtils.toPrimitive(theaders .toArray(new Long[theaders.size()])); return spData; } public static double getClassification(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{ CSRSparseData spData=getSparseData(content, termsmap); if(spData==null){ return 0.0; } DMatrix data = new DMatrix(spData.rowHeaders, spData.colIndex, spData.data, DMatrix.SparseType.CSR, 0); return booster.predict(data)[0][0]; }
用于文本预测:
package com.meituan.model.xgboost; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStreamReader; import java.io.OutputStreamWriter; import java.util.Map; import ml.dmlc.xgboost4j.java.Booster; import ml.dmlc.xgboost4j.java.XGBoost; import ml.dmlc.xgboost4j.java.XGBoostError; import com.meituan.model.libsvm.DocumentTransForm; import com.meituan.model.libsvm.Terms; public class Prediction { public static void main(String[] args) throws XGBoostError, IOException { Map<String, Terms> termsmap = DocumentTransForm.readmap("file/model"); Booster booster = XGBoost.loadModel("xgboost/xgboost.model"); System.out.println(DataLoader.getClassification(booster,"我们在吃饭",termsmap)); test(termsmap, booster, "/Users/shuubiasahi/Desktop/测试文件.csv"); } public static void test(Map<String, Terms> termsmap, Booster booster, String path) throws IOException, XGBoostError { BufferedReader buffer = new BufferedReader(new InputStreamReader( new FileInputStream(path))); BufferedWriter bufferwrite = new BufferedWriter(new OutputStreamWriter( new FileOutputStream("xgboost/merge.txt"))); BufferedWriter bufferwriteresult = new BufferedWriter( new OutputStreamWriter(new FileOutputStream( "xgboost/result.txt"))); String label = null; String line = buffer.readLine(); while (line != null) { String[] lines = line.split("\t"); if ("1".equalsIgnoreCase(lines[0]) && "美食".equalsIgnoreCase(lines[2])) { label = "1"; } else { label = "0"; } String content = lines[3]; double p = DataLoader.getClassification(booster, content, termsmap); if (p > 0) { // if (p > 0.86) { // if (WordUtil.isNumberMain(content) // ) { // p = 0.001; // } // } bufferwrite.write(label + "," + p + "\n"); String prString = p > 0.5 ? "1" : "0"; if (!label.equals(prString)) bufferwriteresult.write(label + "\t" + prString + "\t" + p + "\t" + lines[3] + "\n"); } line = buffer.readLine(); } bufferwriteresult.close(); buffer.close(); bufferwrite.close(); } }
相关文章推荐
- 利用xgboost4j下的xgboost分类模型案例
- 基于决策树的分类回归(随机森林,xgboost, gbdt)
- XGBoost:二分类问题
- sklearn 用于文本分类
- Xgboost C++预测模块线程安全修复
- python编写朴素贝叶斯用于文本分类
- XGboost 实战糖尿病预测
- 【NLP】TensorFlow实现CNN用于中文文本分类
- 基于Xgboost的不均衡数据分类
- Sklearn,xgboost机器学习多分类实验
- 用XGBoost做时间序列预测—forecastxgb包
- Kaggle房价预测进阶版/bagging/boosting/AdaBoost/XGBoost
- 利用随机森林,xgboost,logistic回归,预测泰坦尼克号上面的乘客的获救概率
- 朴素贝叶斯(NaiveBayes)针对小数据集中文文本分类预测
- 【NLP】TensorFlow实现CNN用于文本分类(译)
- XGBoost:二分类问题
- XGBoost解决多分类问题
- 贝叶斯分类器用于文本分类: Multinomial Naïve Bayes
- 归一化用于文本分类中的特征向量计算
- XGBoost:二分类问题