您的位置:首页 > 其它

spark 使用lda算法提取中文文档文本主题

2017-07-22 08:46 495 查看
本篇文章的呢主要写的使用spark ml 中的lda算法提取文档的主题的方法思路,不牵扯到lda的 算法原理。至于算法请参照http://www.aboutyun.com/thread-20130-1-1.html 这篇文章

使用lda算法对中文文本聚类并提取主题,大体上需要这么几个过程:

1.首先采用中文分词工具对中文分词,这里采用开源的IK分词。

2.从分词之后的词表中去掉停用词,生成新的词表。

3.利用文档转向量的工具将文档转换为向量。

4.对向量使用lda算法运算,运算完成之后取出主题的详情,以及主题在文档中的分布详情。

具体代码如下:

public class IkAnalyzerTool{

public String call(String line) throws Exception {
StringReader sr=new StringReader(line);
IKSegmenter ik=new IKSegmenter(sr, true);
Lexeme lex=null;
StringBuffer sb = new StringBuffer();
while((lex=ik.next())!=null){
sb.append(lex.getLexemeText());
sb.append(" ");
}
return sb.toString();
}
public static void main(String[] args) throws Exception {
IkAnalyzerTool a = new IkAnalyzerTool();
System.out.println(a.call("我是中国人"));
}
}
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.ml.clustering.LDA;
import org.apache.spark.ml.clustering.LDAModel;
import org.apache.spark.ml.feature.CountVectorizer;
import org.apache.spark.ml.feature.CountVectorizerModel;
import org.apache.spark.ml.feature.StopWordsRemover;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import scala.Tuple2;
import scala.collection.Iterator;
import scala.collection.mutable.WrappedArray;

import com.googsoft.spark.ml.ik.IkAnalyzerTool;

public class MyCluster {

public static void main(String[] agrs){

//配置spark的初始文件
SparkSession spark = SparkSession
.builder()
.appName("mylda")
.getOrCreate();
//加载初始数据
JavaRDD<Tuple2<String, String>>  files= spark.sparkContext().wholeTextFiles("hdfs://mycluster/ml/edata", 1).toJavaRDD();
List<Row> rows=  files.map(new Function<Tuple2<String,String>,Row>(){
@Override
public Row call(Tuple2<String, String> v1) throws Exception {
IkAnalyzerTool it = new IkAnalyzerTool();
return RowFactory.create(v1._1,Arrays.asList(it.call(v1._2).split(" ")));
}

}).collect();
StructType schema = new StructType(new StructField[] {
new StructField(
"fpath", DataTypes.StringType, false,
Metadata.empty())
,new StructField(
"words", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())});
Dataset<Row> documentDF = spark.createDataFrame(rows, schema);
//将文本路径变为数字序号
StringIndexer indexer = new StringIndexer()
.setInputCol("fpath")
.setOutputCol("docid");
Dataset<Row> indexed = indexer.fit(documentDF).transform(documentDF);
/**
* 过滤停用词,提高精准度
*/
String[] stopWords=(String[]) spark.read().textFile("hdfs://mycluster/ml/stopwords/chinese.txt").collect();
StopWordsRemover remover = new StopWordsRemover()
.setInputCol("words")
.setOutputCol("filtered")
.setStopWords(stopWords);
Dataset<Row> fitlered = remover.transform(indexed);
//利用countvector 算法将过滤之后的词表转换为向量
CountVectorizer cv = new CountVectorizer().setInputCol("filtered").setOutputCol("features");
CountVectorizerModel cvmodel =cv.fit(fitlered);
Dataset<Row> cvResult= cvmodel.transform(fitlered);
//获得转成向量时词表
final String vocabulary[] = cvmodel.vocabulary();
//利用LDA算法训练,提取文本的主题
LDA lda = new LDA().setK(5).setMaxIter(20);
LDAModel ldaModel = lda.fit(cvResult);
double ll = ldaModel.logLikelihood(cvResult);
double lp = ldaModel.logPerplexity(cvResult);
System.out.println("The lower bound on the log likelihood of the entire corpus: " + ll);
//LDA主题模型的评价指标是困惑度,困惑度越小,模型越好
System.out.println("The upper bound bound on perplexity: " + lp);
JavaRDD<Row> topics = ldaModel.describeTopics(30).toJavaRDD();
List<Row> t1=topics.map(new Function<Row,Row>(){
@Override
public Row call(Row row) throws Exception {
int topic =row.getAs(0);
WrappedArray<Integer> terms = row.getAs(1);
List<String> termsed = new ArrayList<String>();
Iterator<Integer> it=terms.iterator();
while(it.hasNext()){
int indice=it.next();
termsed.add(vocabulary[indice]);
}
WrappedArray<Double> termWeights=row.getAs(2);
return RowFactory.create(topic,termsed.toArray(),termWeights);
}
}).collect();
//取出topic中的中文主题
StructType topicschema = new StructType(new StructField[] {
new StructField(
"topic", DataTypes.IntegerType, false,
Metadata.empty())
,new StructField(
"terms", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()),
new StructField(
"termWeights ", DataTypes.createArrayType(DataTypes.DoubleType), false, Metadata.empty())
});
Dataset<Row> topicdatas = spark.createDataFrame(t1, topicschema);
topicdatas.show(false);
Dataset<Row> transformed = ldaModel.transform(cvResult);
Dataset<Row> finalset=transformed.select("docid","topicDistribution");
finalset.write().json("hdfs://mycluster/ml/result");
spark.stop();

}
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐