您的位置:首页 > 编程语言 > Java开发

概率语言模型及其变形系列(5)-LDA Gibbs Sampling 的JAVA实现

2013-04-24 13:08 435 查看
本系列博文介绍常见概率语言模型及其变形模型,主要总结PLSA、LDA及LDA的变形模型及参数Inference方法。初步计划内容如下

第一篇:PLSA及EM算法

第二篇:LDA及Gibbs Samping

第三篇:LDA变形模型-Twitter LDA,TimeUserLDA,ATM,Labeled-LDA,MaxEnt-LDA等

第四篇:基于变形LDA的paper分类总结(bibliography)

第五篇:LDA Gibbs Sampling 的JAVA实现

第五篇 LDA Gibbs Sampling的JAVA 实现

在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析 LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。

本项目Github地址 https://github.com/yangliuy/LDAGibbsSampling

1、文档集预处理

要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。

[java] view
plaincopy

package liuyang.nlp.lda.main;

import java.io.File;

import java.util.ArrayList;

import java.util.HashMap;

import java.util.Map;

import java.util.regex.Matcher;

import java.util.regex.Pattern;

import liuyang.nlp.lda.com.FileUtil;

import liuyang.nlp.lda.com.Stopwords;

/**Class for corpus which consists of M documents

* @author yangliu

* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com

*/

public class Documents {

ArrayList<Document> docs;

Map<String, Integer> termToIndexMap;

ArrayList<String> indexToTermMap;

Map<String,Integer> termCountMap;

public Documents(){

docs = new ArrayList<Document>();

termToIndexMap = new HashMap<String, Integer>();

indexToTermMap = new ArrayList<String>();

termCountMap = new HashMap<String, Integer>();

}

public void readDocs(String docsPath){

for(File docFile : new File(docsPath).listFiles()){

Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);

docs.add(doc);

}

}

public static class Document {

private String docName;

int[] docWords;

public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){

this.docName = docName;

//Read file and initialize word index array

ArrayList<String> docLines = new ArrayList<String>();

ArrayList<String> words = new ArrayList<String>();

FileUtil.readLines(docName, docLines);

for(String line : docLines){

FileUtil.tokenizeAndLowerCase(line, words);

}

//Remove stop words and noise words

for(int i = 0; i < words.size(); i++){

if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){

words.remove(i);

i--;

}

}

//Transfer word to index

this.docWords = new int[words.size()];

for(int i = 0; i < words.size(); i++){

String word = words.get(i);

if(!termToIndexMap.containsKey(word)){

int newIndex = termToIndexMap.size();

termToIndexMap.put(word, newIndex);

indexToTermMap.add(word);

termCountMap.put(word, new Integer(1));

docWords[i] = newIndex;

} else {

docWords[i] = termToIndexMap.get(word);

termCountMap.put(word, termCountMap.get(word) + 1);

}

}

words.clear();

}

public boolean isNoiseWord(String string) {

// TODO Auto-generated method stub

string = string.toLowerCase().trim();

Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");

Matcher m = MY_PATTERN.matcher(string);

// filter @xxx and URL

if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||

string.matches(".*http:.*") )

return true;

if (!m.matches()) {

return true;

} else

return false;

}

}

}

2 LDA Gibbs Sampling

文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。

包含主函数的配置参数解析类如下:

[java] view
plaincopy

package liuyang.nlp.lda.main;

import java.io.File;

import java.io.IOException;

import java.util.ArrayList;

import liuyang.nlp.lda.com.FileUtil;

import liuyang.nlp.lda.conf.ConstantConfig;

import liuyang.nlp.lda.conf.PathConfig;

/**Liu Yang's implementation of Gibbs Sampling of LDA

* @author yangliu

* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com

*/

public class LdaGibbsSampling {

public static class modelparameters {

float alpha = 0.5f; //usual value is 50 / K

float beta = 0.1f;//usual value is 0.1

int topicNum = 100;

int iteration = 100;

int saveStep = 10;

int beginSaveIters = 50;

}

/**Get parameters from configuring file. If the

* configuring file has value in it, use the value.

* Else the default value in program will be used

* @param ldaparameters

* @param parameterFile

* @return void

*/

private static void getParametersFromFile(modelparameters ldaparameters,

String parameterFile) {

// TODO Auto-generated method stub

ArrayList<String> paramLines = new ArrayList<String>();

FileUtil.readLines(parameterFile, paramLines);

for(String line : paramLines){

String[] lineParts = line.split("\t");

switch(parameters.valueOf(lineParts[0])){

case alpha:

ldaparameters.alpha = Float.valueOf(lineParts[1]);

break;

case beta:

ldaparameters.beta = Float.valueOf(lineParts[1]);

break;

case topicNum:

ldaparameters.topicNum = Integer.valueOf(lineParts[1]);

break;

case iteration:

ldaparameters.iteration = Integer.valueOf(lineParts[1]);

break;

case saveStep:

ldaparameters.saveStep = Integer.valueOf(lineParts[1]);

break;

case beginSaveIters:

ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);

break;

}

}

}

public enum parameters{

alpha, beta, topicNum, iteration, saveStep, beginSaveIters;

}

/**

* @param args

* @throws IOException

*/

public static void main(String[] args) throws IOException {

// TODO Auto-generated method stub

String originalDocsPath = PathConfig.ldaDocsPath;

String resultPath = PathConfig.LdaResultsPath;

String parameterFile= ConstantConfig.LDAPARAMETERFILE;

modelparameters ldaparameters = new modelparameters();

getParametersFromFile(ldaparameters, parameterFile);

Documents docSet = new Documents();

docSet.readDocs(originalDocsPath);

System.out.println("wordMap size " + docSet.termToIndexMap.size());

FileUtil.mkdir(new File(resultPath));

LdaModel model = new LdaModel(ldaparameters);

System.out.println("1 Initialize the model ...");

model.initializeModel(docSet);

System.out.println("2 Learning and Saving the model ...");

model.inferenceModel(docSet);

System.out.println("3 Output the final model ...");

model.saveIteratedModel(ldaparameters.iteration, docSet);

System.out.println("Done!");

}

}

LDA 模型实现类如下

[java] view
plaincopy

package liuyang.nlp.lda.main;

/**Class for Lda model

* @author yangliu

* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com

*/

import java.io.BufferedWriter;

import java.io.FileWriter;

import java.io.IOException;

import java.util.ArrayList;

import java.util.Collections;

import java.util.Comparator;

import java.util.List;

import liuyang.nlp.lda.com.FileUtil;

import liuyang.nlp.lda.conf.PathConfig;

public class LdaModel {

int [][] doc;//word index array

int V, K, M;//vocabulary size, topic number, document number

int [][] z;//topic label array

float alpha; //doc-topic dirichlet prior parameter

float beta; //topic-word dirichlet prior parameter

int [][] nmk;//given document m, count times of topic k. M*K

int [][] nkt;//given topic k, count times of term t. K*V

int [] nmkSum;//Sum for each row in nmk

int [] nktSum;//Sum for each row in nkt

double [][] phi;//Parameters for topic-word distribution K*V

double [][] theta;//Parameters for doc-topic distribution M*K

int iterations;//Times of iterations

int saveStep;//The number of iterations between two saving

int beginSaveIters;//Begin save model at this iteration

public LdaModel(LdaGibbsSampling.modelparameters modelparam) {

// TODO Auto-generated constructor stub

alpha = modelparam.alpha;

beta = modelparam.beta;

iterations = modelparam.iteration;

K = modelparam.topicNum;

saveStep = modelparam.saveStep;

beginSaveIters = modelparam.beginSaveIters;

}

public void initializeModel(Documents docSet) {

// TODO Auto-generated method stub

M = docSet.docs.size();

V = docSet.termToIndexMap.size();

nmk = new int [M][K];

nkt = new int[K][V];

nmkSum = new int[M];

nktSum = new int[K];

phi = new double[K][V];

theta = new double[M][K];

//initialize documents index array

doc = new int[M][];

for(int m = 0; m < M; m++){

//Notice the limit of memory

int N = docSet.docs.get(m).docWords.length;

doc[m] = new int
;

for(int n = 0; n < N; n++){

doc[m]
= docSet.docs.get(m).docWords
;

}

}

//initialize topic lable z for each word

z = new int[M][];

for(int m = 0; m < M; m++){

int N = docSet.docs.get(m).docWords.length;

z[m] = new int
;

for(int n = 0; n < N; n++){

int initTopic = (int)(Math.random() * K);// From 0 to K - 1

z[m]
= initTopic;

//number of words in doc m assigned to topic initTopic add 1

nmk[m][initTopic]++;

//number of terms doc[m]
assigned to topic initTopic add 1

nkt[initTopic][doc[m]
]++;

// total number of words assigned to topic initTopic add 1

nktSum[initTopic]++;

}

// total number of words in document m is N

nmkSum[m] = N;

}

}

public void inferenceModel(Documents docSet) throws IOException {

// TODO Auto-generated method stub

if(iterations < saveStep + beginSaveIters){

System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));

System.exit(0);

}

for(int i = 0; i < iterations; i++){

System.out.println("Iteration " + i);

if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){

//Saving the model

System.out.println("Saving model at iteration " + i +" ... ");

//Firstly update parameters

updateEstimatedParameters();

//Secondly print model variables

saveIteratedModel(i, docSet);

}

//Use Gibbs Sampling to update z[][]

for(int m = 0; m < M; m++){

int N = docSet.docs.get(m).docWords.length;

for(int n = 0; n < N; n++){

// Sample from p(z_i|z_-i, w)

int newTopic = sampleTopicZ(m, n);

z[m]
= newTopic;

}

}

}

}

private void updateEstimatedParameters() {

// TODO Auto-generated method stub

for(int k = 0; k < K; k++){

for(int t = 0; t < V; t++){

phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);

}

}

for(int m = 0; m < M; m++){

for(int k = 0; k < K; k++){

theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);

}

}

}

private int sampleTopicZ(int m, int n) {

// TODO Auto-generated method stub

// Sample from p(z_i|z_-i, w) using Gibbs upde rule

//Remove topic label for w_{m,n}

int oldTopic = z[m]
;

nmk[m][oldTopic]--;

nkt[oldTopic][doc[m]
]--;

nmkSum[m]--;

nktSum[oldTopic]--;

//Compute p(z_i = k|z_-i, w)

double [] p = new double[K];

for(int k = 0; k < K; k++){

p[k] = (nkt[k][doc[m]
] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);

}

//Sample a new topic label for w_{m, n} like roulette

//Compute cumulated probability for p

for(int k = 1; k < K; k++){

p[k] += p[k - 1];

}

double u = Math.random() * p[K - 1]; //p[] is unnormalised

int newTopic;

for(newTopic = 0; newTopic < K; newTopic++){

if(u < p[newTopic]){

break;

}

}

//Add new topic label for w_{m, n}

nmk[m][newTopic]++;

nkt[newTopic][doc[m]
]++;

nmkSum[m]++;

nktSum[newTopic]++;

return newTopic;

}

public void saveIteratedModel(int iters, Documents docSet) throws IOException {

// TODO Auto-generated method stub

//lda.params lda.phi lda.theta lda.tassign lda.twords

//lda.params

String resPath = PathConfig.LdaResultsPath;

String modelName = "lda_" + iters;

ArrayList<String> lines = new ArrayList<String>();

lines.add("alpha = " + alpha);

lines.add("beta = " + beta);

lines.add("topicNum = " + K);

lines.add("docNum = " + M);

lines.add("termNum = " + V);

lines.add("iterations = " + iterations);

lines.add("saveStep = " + saveStep);

lines.add("beginSaveIters = " + beginSaveIters);

FileUtil.writeLines(resPath + modelName + ".params", lines);

//lda.phi K*V

BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));

for (int i = 0; i < K; i++){

for (int j = 0; j < V; j++){

writer.write(phi[i][j] + "\t");

}

writer.write("\n");

}

writer.close();

//lda.theta M*K

writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));

for(int i = 0; i < M; i++){

for(int j = 0; j < K; j++){

writer.write(theta[i][j] + "\t");

}

writer.write("\n");

}

writer.close();

//lda.tassign

writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));

for(int m = 0; m < M; m++){

for(int n = 0; n < doc[m].length; n++){

writer.write(doc[m]
+ ":" + z[m]
+ "\t");

}

writer.write("\n");

}

writer.close();

//lda.twords phi[][] K*V

writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));

int topNum = 20; //Find the top 20 topic words in each topic

for(int i = 0; i < K; i++){

List<Integer> tWordsIndexArray = new ArrayList<Integer>();

for(int j = 0; j < V; j++){

tWordsIndexArray.add(new Integer(j));

}

Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));

writer.write("topic " + i + "\t:\t");

for(int t = 0; t < topNum; t++){

writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");

}

writer.write("\n");

}

writer.close();

}

public class TwordsComparable implements Comparator<Integer> {

public double [] sortProb; // Store probability of each word in topic k

public TwordsComparable (double[] sortProb){

this.sortProb = sortProb;

}

@Override

public int compare(Integer o1, Integer o2) {

// TODO Auto-generated method stub

//Sort topic word index according to the probability of each word in topic k

if(sortProb[o1] > sortProb[o2]) return -1;

else if(sortProb[o1] < sortProb[o2]) return 1;

else return 0;

}

}

}

程序的实现细节可以参考我在程序中给出的注释,如果理解LDA Gibbs Sampling的算法流程,上面的代码很好理解。其实排除输入输出和参数解析的代码,标准LDA 的Gibbs sampling只需要不到200行程序就可以搞定。当然,里面有很多可以考虑优化和变形的地方。

还有com和conf目录下的源文件分别放置常用函数和配置类,完整的JAVA工程见Github https://github.com/yangliuy/LDAGibbsSampling

3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析

下面我们给出将上面的LDA Gibbs Sampling的实现Apply到Newsgroup 18828文档集进行主题分析的结果。 我实验时用到的数据已经上传到Github中,感兴趣的朋友可以直接从Github中下载工程运行。 我在Newsgroup
18828文档集随机选择了9个目录,每个目录下选择一个文档,将它们放置在data\LdaOriginalDocs目录下,我设定的模型参数如下

[plain] view
plaincopy

alpha 0.5

beta 0.1

topicNum 10

iteration 100

saveStep 10

beginSaveIters 80

即设定alpha和beta的值为0.5和0.1, Topic数目为10,迭代100次,从第80次开始保存模型结果,每10次保存一次。

经过100次Gibbs Sampling迭代后,程序输出10个Topic下top的topic words以及对应的概率值如下





我们可以看到虽然是unsupervised learning, LDA分析出来的Topic words还是非常make sense的。比如第5个topic是宗教类的,第6个topic是天文类的,第7个topic是计算机类的。程序的输出还包括模型参数.param文件,topic-word分布phi向量.phi文件,doc-topic分布theta向量.theta文件以及每个文档中每个单词分配到的主题label的.tassign文件。感兴趣的朋友可以从Github https://github.com/yangliuy/LDAGibbsSampling 下载完整工程自己换用其他数据集进行主题分析实验。
本程序是初步实现版本,如果大家发现任何问题或者bug欢迎交流,我第一时间在Github修复bug更新版本。

4 参考文献

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.

[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.

[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Jgibblda, http://jgibblda.sourceforge.net/
[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach.
Learn. Res. 3 (March 2003), 993-1022.
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: