您的位置:首页 > 编程语言

lda(linear discriminant analysis)线性判别分析算法代码

2014-08-06 11:49 561 查看
做文本聚类分析,采用了pca等降维效果都不好,于是决定采用有监督的学习算法lda,网络找代码,找到一个看不懂如何降维,于是自己改写,代码如下:

package lda;

import java.util.ArrayList;

import java.util.Collections;

import java.util.HashMap;

import java.util.Iterator;

import java.util.Map;

import org.apache.commons.math3.linear.EigenDecomposition;

import org.apache.commons.math3.linear.LUDecomposition;

import org.apache.commons.math3.linear.MatrixUtils;

import org.apache.commons.math3.linear.RealMatrix;

import org.apache.commons.math3.linear.RealVector;

import Jama.Matrix;

public class LDA

{

private double[][] groupRataTengah;

private double[][] kovarianGlobal;

private double[] probabilitas;

private ArrayList<Integer> groupList = new ArrayList<Integer>();

static int hasil;

static double f1, f2, f3;

private HashMap _map = new HashMap();

private RealVector[] _top2vec = new RealVector[2];

public LDA()

{

}

/**

*

* @param d 聚类结果数组

* @param g 聚类的类别标识,和前面的d关系一致

* @param p

*/

public LDA(double[][] d, int[] g, boolean p)

{

// memeriksa apakah data dan kelompok array mempunyai ukuran yang sama

if (d.length != g.length)

return;

double[][] data = new double[d.length][d[0].length];// panjang data(i)

// dan fitur(j)

for (int i = 0; i < d.length; i++)

{

for (int j = 0; j < d[i].length; j++)

{

data[i][j] = d[i][j];

}

}

int[] group = new int[g.length];

for (int j = 0; j < g.length; j++)

{

group[j] = g[j];

}

double[] rataTengah;

double[][][] kovarian;

// memisahkan berdasarkan grup atau kelas

for (int i = 0; i < group.length; i++)

{

if (!groupList.contains(group[i]))

{

groupList.add(group[i]);

}

}

// membagi data ke dalam subset

ArrayList<double[]>[] subset = new ArrayList[groupList.size()];

for (int i = 0; i < subset.length; i++)

{

subset[i] = new ArrayList<double[]>();

for (int j = 0; j < data.length; j++)

{

if (group[j] == groupList.get(i))

{

subset[i].add(data[j]);

}

}

}

// menghitung mean tiap fitur tiap kelas

groupRataTengah = new double[subset.length][data[0].length];

for (int i = 0; i < groupRataTengah.length; i++)

{

for (int j = 0; j < groupRataTengah[i].length; j++)

{

groupRataTengah[i][j] = getGroupMean(j, subset[i]);

}

}

// menghitung global mean atau mean tiap fitur pada semua kelas

rataTengah = new double[data[0].length];

for (int i = 0; i < data[0].length; i++)

{

rataTengah[i] = getGlobalMean(i, data);

}

double[][] tempMatrix = new double[subset.length][data[0].length];

for (int i = 0; i < subset.length; i++)

{

for (int j = 0; j < data[0].length; j++)

{

tempMatrix[i][j] = groupRataTengah[i][j] - rataTengah[j];

}

}

double[][] SB = new double[data[0].length][data[0].length];

for (int k = 0; k < subset.length; k++)

{

int t = subset[k].size();

for (int i = 0; i < SB.length; i++)

{

for (int j = 0; j < SB[i].length; j++)

{

SB[i][j] = SB[i][j] + (t * (tempMatrix[k][i] * tempMatrix[k][j]))/data.length;

}

}

}

double[][] SW = new double[data[0].length][data[0].length];

for (int k = 0; k < groupList.size(); k++)

{

ArrayList _class = subset[k];

for (int l = 0; l < _class.size(); l++)

{

double [] _el = (double[])_class.get(l);

for (int i = 0; i < SB.length; i++)

{

for (int j = 0; j < SB[i].length; j++)

{

SW[i][j] = SW[i][j] + (groupRataTengah[k][i]-_el[i])*(groupRataTengah[k][j]-_el[j]);

}

}

}

}

RealMatrix rsw = MatrixUtils.createRealMatrix(SW);

RealMatrix rswInverse = new LUDecomposition(rsw).getSolver().getInverse();

RealMatrix rsb = MatrixUtils.createRealMatrix(SB);

RealMatrix r = rswInverse.multiply(rsb);

EigenDecomposition en = new EigenDecomposition(r);

double [] eg =en.getRealEigenvalues();

for(int i=0;i<eg.length;i++)

{

RealVector rv = en.getEigenvector(i);

System.out.println(eg[i] + ":");

System.out.println(rv.toString());

_map.put(eg[i], rv);

}

}

/**

* 得到前两个特征向量

* @return

*/

public RealVector[] getTop2Vector()

{

RealVector [] arrVec = new RealVector[2];

Iterator iter = this._map.entrySet().iterator();

ArrayList _list = new ArrayList();

while (iter.hasNext()) {

Map.Entry entry = (Map.Entry) iter.next();

Object key = entry.getKey();

//Object val = entry.getValue();

_list.add(key);

}

Collections.sort(_list);

int j=0;

for(int i=_list.size()-1;i>-1;i--)

{

System.out.println(_list.get(i));

arrVec[j] = (RealVector)this._map.get(_list.get(i));

j++;

if(j==2)

{

break;

}

}

this._top2vec = arrVec;

return arrVec;

}

/**

* 输入点向量,得到对应的二维平面坐标

* @param in

* @return

*/

public double[] getxydot(double[]in)

{

RealVector inv = MatrixUtils.createRealVector(in);

double[] out = new double[2];

for(int i=0;i<2;i++)

{

out[i] = this._top2vec[i].dotProduct(inv);

}

return out;

}

private double getGroupMean(int column, ArrayList<double[]> data)

{

double[] d = new double[data.size()];

for (int i = 0; i < data.size(); i++)

{

d[i] = data.get(i)[column];

}

return getMean(d);

}

private double getGlobalMean(int column, double data[][])

{

double[] d = new double[data.length];

for (int i = 0; i < data.length; i++)

{

d[i] = data[i][column];

}

return getMean(d);

}

// menghitung nilai fungsi discriminant untuk kelas yang berbeda

public double[] getDiscriminantFunctionValues(double[] values)

{

double[] function = new double[groupList.size()];

for (int i = 0; i < groupList.size(); i++)

{

double[] tmp = matrixMultiplication(groupRataTengah[i], kovarianGlobal);

function[i] = (matrixMultiplication(tmp, values))// fi=miu i*invers

// kovarian*data

// testing-1/2 miu

// i*invers

// kovarian*miu i

// trans+ln(pi)

- (.5d * matrixMultiplication(tmp, groupRataTengah[i]))

+ Math.log(probabilitas[i]);

}

return function;

}

// memprediksi masuk kelas mana

public int predict(double[] values)

{

int group = -1;

double max = Double.NEGATIVE_INFINITY;

double[] discr = this.getDiscriminantFunctionValues(values);

for (int i = 0; i < discr.length; i++)

{

if (discr[i] > max)

{

max = discr[i];

group = groupList.get(i);

}

}

return group;

}

// mengalikan dua matriks

private double[] matrixMultiplication(double[] matrixA, double[][] matrixB)

{

double c[] = new double[matrixA.length];

for (int i = 0; i < matrixA.length; i++)

{

c[i] = 0;

for (int j = 0; j < matrixB[i].length; j++)

{

c[i] += matrixA[i] * matrixB[i][j];

}

}

return c;

}

private double matrixMultiplication(double[] matrixA, double[] matrixB)

{

double c = 0d;

for (int i = 0; i < matrixA.length; i++)

{

c += matrixA[i] * matrixB[i];

}

return c;

}

public static double getMean(final double[] values)

{

if (values == null || values.length == 0)

return Double.NaN;

double mean = 0.0d;

for (int index = 0; index < values.length; index++)

mean += values[index];

return mean / (double) values.length;

}

public static void test(extraksi_fitur e, double a, double b, double c, double d)

{

extraksi_fitur ef = e;

int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis

double[][] data = new double[7][2];// 15=jumlah data,4=fitur R,G,B,D

int count = 0;

data[0][0] = 2.95;

data[0][1] = 6.63;

data[0][0] = 2.53;

data[0][1] = 7.79;

data[0][0] = 3.57;

data[0][1] = 5.65;

data[0][0] = 3.16;

data[0][1] = 5.47;

data[0][0] = 2.58;

data[0][1] = 4.46;

data[0][0] = 2.16;

data[0][1] = 6.22;

data[0][0] = 3.27;

data[0][1] = 3.52;

LDA test = new LDA(data, group, true);

double[] testData = { a, b, c, d };

// test

double[] values = test.getDiscriminantFunctionValues(testData);

for (int i = 0; i < values.length; i++)

{

System.out.println("Discriminant function " + (i + 1) + ": " + values[i]);

}

System.out.println("Predicted group: " + test.predict(testData));

hasil = test.predict(testData);

f1 = values[0];

f2 = values[1];

f3 = values[2];

}

public static void main(String[] args)

{

double[][] data = new double[7][3];// 15=jumlah data,4=fitur R,G,B,D

data[0][0] = 2.95;

data[0][1] = 6.63;

data[0][2] = 2.34;

data[1][0] = 2.53;

data[1][1] = 7.79;

data[1][2] = 2.56;

data[2][0] = 3.57;

data[2][1] = 5.65;

data[2][2] = 2.76;

data[3][0] = 3.16;

data[3][1] = 5.47;

data[3][2] = 2.36;

data[4][0] = 2.58;

data[4][1] = 4.46;

data[4][2] = 5.2;

data[5][0] = 2.16;

data[5][1] = 6.22;

data[5][2] = 5.4;

data[6][0] = 3.27;

data[6][1] = 3.52;

data[6][2] = 6;

int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis

LDA lda = new LDA(data, group, true);

lda.getTop2Vector();

double[] tt = {123.0,23.0,4};

double[] out = lda.getxydot(tt);

for(int i=0;i<out.length;i++)

{

System.out.println(out[i]);

}

}

}

试用效果良好,实验数据为200维矩阵。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: