您的位置:首页 > 编程语言 > Java开发

xgboost用于文本分类预测,java接口

2017-05-07 21:12 363 查看
周末花了两天时间从安装xgboost到用于文本预测,记录下,首先是把文本分词,去停顿词,计算tf-idf值,然后模型训练,模型保存,加载模型,模型预测:

训练模型代码:

package com.meituan.model.xgboost;

import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Map;
import java.util.Map.Entry;
import java.io.File;
import java.io.IOException;

import org.apache.commons.io.FileUtils;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

public class TrainXgboost {
private static String path = "/Users/shuubiasahi/Documents/workspace/xgboost/demo/data/";
private static String trainString = "agaricus.txt.train";
private static String testString = "agaricus.txt.test";

public static void main(String[] args) throws XGBoostError, IOException {

DMatrix trainMat = new DMatrix("file/train.txt");
DMatrix testMat = new DMatrix("file/test.txt");

// specify parameters
Map<String, Object> params = new HashMap<String, Object>();
params.put("booster", "gbtree");
params.put("eta", 0.6); // 为了防止过拟合,更新过程中用到的收缩步长。在每次提升计算之后,算法会直接获得新特征的权重
params.put("max_depth", 22);// 树最大深度
params.put("silent", 0); // 为1的时候不会打印模型迭代的信息,为0可以看到打印的信息
params.put("lambda", 2.5);// 用于逻辑回归的时候L2正则选项
params.put("min_child_weight", 6);
// params.put("nthread", 6); //不使用的话系统会默认得到最大的线程数目
params.put("objective", "binary:logistic");// 目标函数值

// specify watchList
HashMap<String, DMatrix> watches = new HashMap<String, DMatrix>();
watches.put("train", trainMat);
// watches.put("test", testMat);
// train a booster
int round = 100;

Booster booster = XGBoost.train(trainMat, params, round, watches, null,
null);

// booster.saveModel("xgboost/xgboost.model");
Map<String, Integer> map = booster.getFeatureScore(null);

List<Map.Entry<String, Integer>> list = new ArrayList<Map.Entry<String, Integer>>(
map.entrySet());

Collections.sort(list, new Comparator<Map.Entry<String, Integer>>() {
@Override
public int compare(Entry<String, Integer> o1,
Entry<String, Integer> o2) {
double result = o1.getValue() - o2.getValue();
if (result > 0) {
return -1;
} else {

return 1;
}
}
});
FileUtils.writeLines(new File("xgboost/keyword.txt"), list);

float[][] result = booster.predict(testMat);

/*
* for(int i=0;i<result.length;i++){
*
* for(int j=0;j<result[i].length;j++){
* System.out.print(result[i][j]+"\t"); } System.out.println(); }
*/
System.out.println("length is:" + result.length);
}

}


模型训练比较简单,先看看模型预测写的代码,准备的两个方法,把文本转化为libsvm的形式,再转化DMatrix:

public static CSRSparseData getSparseData(String content,Map<String, Terms> termsmap){
if (StringUtils.isBlank(content)) {
return null;
}
Map<String, Long> maps = ToAnalysis
.parse(WordUtil.replaceAllSynonyms(TextUtil.fan2Jian(WordUtil
.replaceAll(content.toLowerCase()))))
.getTerms()
.stream()
.map(x -> x.getName())
.filter(x -> !WordUtil.isStopword(x)
)
.collect(Collectors.groupingBy(p -> p, Collectors.counting()));
if (maps == null || maps.size() == 0) {
return null;
}
int sum = maps.values().stream()
.reduce((result, element) -> result = result + element).get()
.intValue();
Map<Integer, Double> treemap = new TreeMap<Integer, Double>();

for (Entry<String, Long> map : maps.entrySet()) {
String key = map.getKey();
Terms keyword = termsmap.get(key);
double tf = TFIDF.tf(map.getValue(), sum);
if (keyword == null) {
continue;
}
int id = keyword.getId();
double idf = 0;
idf = TFIDF.idf(termsmap.get("documentTotal").getFreq(),
keyword.getFreq());
double tfidf = TFIDF.tfidf(tf, idf);
treemap.put(id, tfidf);
}

if (treemap.size() == 0) {
return null;
}

CSRSparseData spData = new CSRSparseData();
List<Float> tlabels = new ArrayList<>();
List<Float> tdata = new ArrayList<>();
List<Long> theaders = new ArrayList<>();
List<Integer> tindex = new ArrayList<>();
theaders.add(0l);
theaders.add((long) treemap.size());

for (Entry<Integer, Double> map : treemap.entrySet()) {

BigDecimal b = new BigDecimal(Double.toString(map.getValue()));
tdata.add(b.floatValue());
tindex.add(Integer.valueOf(map.getKey()));
}

spData.labels = ArrayUtils.toPrimitive(tlabels
.toArray(new Float[tlabels.size()]));
spData.data = ArrayUtils.toPrimitive(tdata.toArray(new Float[tdata
.size()]));
spData.colIndex = ArrayUtils.toPrimitive(tindex
.toArray(new Integer[tindex.size()]));
spData.rowHeaders = ArrayUtils.toPrimitive(theaders
.toArray(new Long[theaders.size()]));

return spData;
}

public  static double  getClassification(Booster booster,String content,Map<String, Terms> termsmap) throws XGBoostError{
CSRSparseData spData=getSparseData(content, termsmap);
if(spData==null){
return 0.0;
}
DMatrix  data = new DMatrix(spData.rowHeaders, spData.colIndex,
spData.data, DMatrix.SparseType.CSR, 0);
return booster.predict(data)[0][0];
}


用于文本预测:

package com.meituan.model.xgboost;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
import java.util.Map;

import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;

import com.meituan.model.libsvm.DocumentTransForm;
import com.meituan.model.libsvm.Terms;

public class Prediction {
public static void main(String[] args) throws XGBoostError, IOException {
Map<String, Terms> termsmap = DocumentTransForm.readmap("file/model");
Booster booster = XGBoost.loadModel("xgboost/xgboost.model");
System.out.println(DataLoader.getClassification(booster,"我们在吃饭",termsmap));
test(termsmap, booster, "/Users/shuubiasahi/Desktop/测试文件.csv");
}

public static void test(Map<String, Terms> termsmap, Booster booster,
String path) throws IOException, XGBoostError {
BufferedReader buffer = new BufferedReader(new InputStreamReader(
new FileInputStream(path)));
BufferedWriter bufferwrite = new BufferedWriter(new OutputStreamWriter(
new FileOutputStream("xgboost/merge.txt")));

BufferedWriter bufferwriteresult = new BufferedWriter(
new OutputStreamWriter(new FileOutputStream(
"xgboost/result.txt")));

String label = null;
String line = buffer.readLine();
while (line != null) {
String[] lines = line.split("\t");

if ("1".equalsIgnoreCase(lines[0])
&& "美食".equalsIgnoreCase(lines[2])) {
label = "1";
} else {
label = "0";
}
String content = lines[3];
double p = DataLoader.getClassification(booster, content, termsmap);
if (p > 0) {
// if (p > 0.86) {
// if (WordUtil.isNumberMain(content)
// ) {
// p = 0.001;
// }
// }
bufferwrite.write(label + "," + p + "\n");
String prString = p > 0.5 ? "1" : "0";
if (!label.equals(prString))
bufferwriteresult.write(label + "\t" + prString + "\t" + p
+ "\t" + lines[3] + "\n");
}

line = buffer.readLine();
}
bufferwriteresult.close();
buffer.close();
bufferwrite.close();
}
}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: