lda(linear discriminant analysis)线性判别分析算法代码
2014-08-06 11:49
561 查看
做文本聚类分析,采用了pca等降维效果都不好,于是决定采用有监督的学习算法lda,网络找代码,找到一个看不懂如何降维,于是自己改写,代码如下:
package lda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import Jama.Matrix;
public class LDA
{
private double[][] groupRataTengah;
private double[][] kovarianGlobal;
private double[] probabilitas;
private ArrayList<Integer> groupList = new ArrayList<Integer>();
static int hasil;
static double f1, f2, f3;
private HashMap _map = new HashMap();
private RealVector[] _top2vec = new RealVector[2];
public LDA()
{
}
/**
*
* @param d 聚类结果数组
* @param g 聚类的类别标识,和前面的d关系一致
* @param p
*/
public LDA(double[][] d, int[] g, boolean p)
{
// memeriksa apakah data dan kelompok array mempunyai ukuran yang sama
if (d.length != g.length)
return;
double[][] data = new double[d.length][d[0].length];// panjang data(i)
// dan fitur(j)
for (int i = 0; i < d.length; i++)
{
for (int j = 0; j < d[i].length; j++)
{
data[i][j] = d[i][j];
}
}
int[] group = new int[g.length];
for (int j = 0; j < g.length; j++)
{
group[j] = g[j];
}
double[] rataTengah;
double[][][] kovarian;
// memisahkan berdasarkan grup atau kelas
for (int i = 0; i < group.length; i++)
{
if (!groupList.contains(group[i]))
{
groupList.add(group[i]);
}
}
// membagi data ke dalam subset
ArrayList<double[]>[] subset = new ArrayList[groupList.size()];
for (int i = 0; i < subset.length; i++)
{
subset[i] = new ArrayList<double[]>();
for (int j = 0; j < data.length; j++)
{
if (group[j] == groupList.get(i))
{
subset[i].add(data[j]);
}
}
}
// menghitung mean tiap fitur tiap kelas
groupRataTengah = new double[subset.length][data[0].length];
for (int i = 0; i < groupRataTengah.length; i++)
{
for (int j = 0; j < groupRataTengah[i].length; j++)
{
groupRataTengah[i][j] = getGroupMean(j, subset[i]);
}
}
// menghitung global mean atau mean tiap fitur pada semua kelas
rataTengah = new double[data[0].length];
for (int i = 0; i < data[0].length; i++)
{
rataTengah[i] = getGlobalMean(i, data);
}
double[][] tempMatrix = new double[subset.length][data[0].length];
for (int i = 0; i < subset.length; i++)
{
for (int j = 0; j < data[0].length; j++)
{
tempMatrix[i][j] = groupRataTengah[i][j] - rataTengah[j];
}
}
double[][] SB = new double[data[0].length][data[0].length];
for (int k = 0; k < subset.length; k++)
{
int t = subset[k].size();
for (int i = 0; i < SB.length; i++)
{
for (int j = 0; j < SB[i].length; j++)
{
SB[i][j] = SB[i][j] + (t * (tempMatrix[k][i] * tempMatrix[k][j]))/data.length;
}
}
}
double[][] SW = new double[data[0].length][data[0].length];
for (int k = 0; k < groupList.size(); k++)
{
ArrayList _class = subset[k];
for (int l = 0; l < _class.size(); l++)
{
double [] _el = (double[])_class.get(l);
for (int i = 0; i < SB.length; i++)
{
for (int j = 0; j < SB[i].length; j++)
{
SW[i][j] = SW[i][j] + (groupRataTengah[k][i]-_el[i])*(groupRataTengah[k][j]-_el[j]);
}
}
}
}
RealMatrix rsw = MatrixUtils.createRealMatrix(SW);
RealMatrix rswInverse = new LUDecomposition(rsw).getSolver().getInverse();
RealMatrix rsb = MatrixUtils.createRealMatrix(SB);
RealMatrix r = rswInverse.multiply(rsb);
EigenDecomposition en = new EigenDecomposition(r);
double [] eg =en.getRealEigenvalues();
for(int i=0;i<eg.length;i++)
{
RealVector rv = en.getEigenvector(i);
System.out.println(eg[i] + ":");
System.out.println(rv.toString());
_map.put(eg[i], rv);
}
}
/**
* 得到前两个特征向量
* @return
*/
public RealVector[] getTop2Vector()
{
RealVector [] arrVec = new RealVector[2];
Iterator iter = this._map.entrySet().iterator();
ArrayList _list = new ArrayList();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
Object key = entry.getKey();
//Object val = entry.getValue();
_list.add(key);
}
Collections.sort(_list);
int j=0;
for(int i=_list.size()-1;i>-1;i--)
{
System.out.println(_list.get(i));
arrVec[j] = (RealVector)this._map.get(_list.get(i));
j++;
if(j==2)
{
break;
}
}
this._top2vec = arrVec;
return arrVec;
}
/**
* 输入点向量,得到对应的二维平面坐标
* @param in
* @return
*/
public double[] getxydot(double[]in)
{
RealVector inv = MatrixUtils.createRealVector(in);
double[] out = new double[2];
for(int i=0;i<2;i++)
{
out[i] = this._top2vec[i].dotProduct(inv);
}
return out;
}
private double getGroupMean(int column, ArrayList<double[]> data)
{
double[] d = new double[data.size()];
for (int i = 0; i < data.size(); i++)
{
d[i] = data.get(i)[column];
}
return getMean(d);
}
private double getGlobalMean(int column, double data[][])
{
double[] d = new double[data.length];
for (int i = 0; i < data.length; i++)
{
d[i] = data[i][column];
}
return getMean(d);
}
// menghitung nilai fungsi discriminant untuk kelas yang berbeda
public double[] getDiscriminantFunctionValues(double[] values)
{
double[] function = new double[groupList.size()];
for (int i = 0; i < groupList.size(); i++)
{
double[] tmp = matrixMultiplication(groupRataTengah[i], kovarianGlobal);
function[i] = (matrixMultiplication(tmp, values))// fi=miu i*invers
// kovarian*data
// testing-1/2 miu
// i*invers
// kovarian*miu i
// trans+ln(pi)
- (.5d * matrixMultiplication(tmp, groupRataTengah[i]))
+ Math.log(probabilitas[i]);
}
return function;
}
// memprediksi masuk kelas mana
public int predict(double[] values)
{
int group = -1;
double max = Double.NEGATIVE_INFINITY;
double[] discr = this.getDiscriminantFunctionValues(values);
for (int i = 0; i < discr.length; i++)
{
if (discr[i] > max)
{
max = discr[i];
group = groupList.get(i);
}
}
return group;
}
// mengalikan dua matriks
private double[] matrixMultiplication(double[] matrixA, double[][] matrixB)
{
double c[] = new double[matrixA.length];
for (int i = 0; i < matrixA.length; i++)
{
c[i] = 0;
for (int j = 0; j < matrixB[i].length; j++)
{
c[i] += matrixA[i] * matrixB[i][j];
}
}
return c;
}
private double matrixMultiplication(double[] matrixA, double[] matrixB)
{
double c = 0d;
for (int i = 0; i < matrixA.length; i++)
{
c += matrixA[i] * matrixB[i];
}
return c;
}
public static double getMean(final double[] values)
{
if (values == null || values.length == 0)
return Double.NaN;
double mean = 0.0d;
for (int index = 0; index < values.length; index++)
mean += values[index];
return mean / (double) values.length;
}
public static void test(extraksi_fitur e, double a, double b, double c, double d)
{
extraksi_fitur ef = e;
int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
double[][] data = new double[7][2];// 15=jumlah data,4=fitur R,G,B,D
int count = 0;
data[0][0] = 2.95;
data[0][1] = 6.63;
data[0][0] = 2.53;
data[0][1] = 7.79;
data[0][0] = 3.57;
data[0][1] = 5.65;
data[0][0] = 3.16;
data[0][1] = 5.47;
data[0][0] = 2.58;
data[0][1] = 4.46;
data[0][0] = 2.16;
data[0][1] = 6.22;
data[0][0] = 3.27;
data[0][1] = 3.52;
LDA test = new LDA(data, group, true);
double[] testData = { a, b, c, d };
// test
double[] values = test.getDiscriminantFunctionValues(testData);
for (int i = 0; i < values.length; i++)
{
System.out.println("Discriminant function " + (i + 1) + ": " + values[i]);
}
System.out.println("Predicted group: " + test.predict(testData));
hasil = test.predict(testData);
f1 = values[0];
f2 = values[1];
f3 = values[2];
}
public static void main(String[] args)
{
double[][] data = new double[7][3];// 15=jumlah data,4=fitur R,G,B,D
data[0][0] = 2.95;
data[0][1] = 6.63;
data[0][2] = 2.34;
data[1][0] = 2.53;
data[1][1] = 7.79;
data[1][2] = 2.56;
data[2][0] = 3.57;
data[2][1] = 5.65;
data[2][2] = 2.76;
data[3][0] = 3.16;
data[3][1] = 5.47;
data[3][2] = 2.36;
data[4][0] = 2.58;
data[4][1] = 4.46;
data[4][2] = 5.2;
data[5][0] = 2.16;
data[5][1] = 6.22;
data[5][2] = 5.4;
data[6][0] = 3.27;
data[6][1] = 3.52;
data[6][2] = 6;
int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
LDA lda = new LDA(data, group, true);
lda.getTop2Vector();
double[] tt = {123.0,23.0,4};
double[] out = lda.getxydot(tt);
for(int i=0;i<out.length;i++)
{
System.out.println(out[i]);
}
}
}
试用效果良好,实验数据为200维矩阵。
package lda;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import Jama.Matrix;
public class LDA
{
private double[][] groupRataTengah;
private double[][] kovarianGlobal;
private double[] probabilitas;
private ArrayList<Integer> groupList = new ArrayList<Integer>();
static int hasil;
static double f1, f2, f3;
private HashMap _map = new HashMap();
private RealVector[] _top2vec = new RealVector[2];
public LDA()
{
}
/**
*
* @param d 聚类结果数组
* @param g 聚类的类别标识,和前面的d关系一致
* @param p
*/
public LDA(double[][] d, int[] g, boolean p)
{
// memeriksa apakah data dan kelompok array mempunyai ukuran yang sama
if (d.length != g.length)
return;
double[][] data = new double[d.length][d[0].length];// panjang data(i)
// dan fitur(j)
for (int i = 0; i < d.length; i++)
{
for (int j = 0; j < d[i].length; j++)
{
data[i][j] = d[i][j];
}
}
int[] group = new int[g.length];
for (int j = 0; j < g.length; j++)
{
group[j] = g[j];
}
double[] rataTengah;
double[][][] kovarian;
// memisahkan berdasarkan grup atau kelas
for (int i = 0; i < group.length; i++)
{
if (!groupList.contains(group[i]))
{
groupList.add(group[i]);
}
}
// membagi data ke dalam subset
ArrayList<double[]>[] subset = new ArrayList[groupList.size()];
for (int i = 0; i < subset.length; i++)
{
subset[i] = new ArrayList<double[]>();
for (int j = 0; j < data.length; j++)
{
if (group[j] == groupList.get(i))
{
subset[i].add(data[j]);
}
}
}
// menghitung mean tiap fitur tiap kelas
groupRataTengah = new double[subset.length][data[0].length];
for (int i = 0; i < groupRataTengah.length; i++)
{
for (int j = 0; j < groupRataTengah[i].length; j++)
{
groupRataTengah[i][j] = getGroupMean(j, subset[i]);
}
}
// menghitung global mean atau mean tiap fitur pada semua kelas
rataTengah = new double[data[0].length];
for (int i = 0; i < data[0].length; i++)
{
rataTengah[i] = getGlobalMean(i, data);
}
double[][] tempMatrix = new double[subset.length][data[0].length];
for (int i = 0; i < subset.length; i++)
{
for (int j = 0; j < data[0].length; j++)
{
tempMatrix[i][j] = groupRataTengah[i][j] - rataTengah[j];
}
}
double[][] SB = new double[data[0].length][data[0].length];
for (int k = 0; k < subset.length; k++)
{
int t = subset[k].size();
for (int i = 0; i < SB.length; i++)
{
for (int j = 0; j < SB[i].length; j++)
{
SB[i][j] = SB[i][j] + (t * (tempMatrix[k][i] * tempMatrix[k][j]))/data.length;
}
}
}
double[][] SW = new double[data[0].length][data[0].length];
for (int k = 0; k < groupList.size(); k++)
{
ArrayList _class = subset[k];
for (int l = 0; l < _class.size(); l++)
{
double [] _el = (double[])_class.get(l);
for (int i = 0; i < SB.length; i++)
{
for (int j = 0; j < SB[i].length; j++)
{
SW[i][j] = SW[i][j] + (groupRataTengah[k][i]-_el[i])*(groupRataTengah[k][j]-_el[j]);
}
}
}
}
RealMatrix rsw = MatrixUtils.createRealMatrix(SW);
RealMatrix rswInverse = new LUDecomposition(rsw).getSolver().getInverse();
RealMatrix rsb = MatrixUtils.createRealMatrix(SB);
RealMatrix r = rswInverse.multiply(rsb);
EigenDecomposition en = new EigenDecomposition(r);
double [] eg =en.getRealEigenvalues();
for(int i=0;i<eg.length;i++)
{
RealVector rv = en.getEigenvector(i);
System.out.println(eg[i] + ":");
System.out.println(rv.toString());
_map.put(eg[i], rv);
}
}
/**
* 得到前两个特征向量
* @return
*/
public RealVector[] getTop2Vector()
{
RealVector [] arrVec = new RealVector[2];
Iterator iter = this._map.entrySet().iterator();
ArrayList _list = new ArrayList();
while (iter.hasNext()) {
Map.Entry entry = (Map.Entry) iter.next();
Object key = entry.getKey();
//Object val = entry.getValue();
_list.add(key);
}
Collections.sort(_list);
int j=0;
for(int i=_list.size()-1;i>-1;i--)
{
System.out.println(_list.get(i));
arrVec[j] = (RealVector)this._map.get(_list.get(i));
j++;
if(j==2)
{
break;
}
}
this._top2vec = arrVec;
return arrVec;
}
/**
* 输入点向量,得到对应的二维平面坐标
* @param in
* @return
*/
public double[] getxydot(double[]in)
{
RealVector inv = MatrixUtils.createRealVector(in);
double[] out = new double[2];
for(int i=0;i<2;i++)
{
out[i] = this._top2vec[i].dotProduct(inv);
}
return out;
}
private double getGroupMean(int column, ArrayList<double[]> data)
{
double[] d = new double[data.size()];
for (int i = 0; i < data.size(); i++)
{
d[i] = data.get(i)[column];
}
return getMean(d);
}
private double getGlobalMean(int column, double data[][])
{
double[] d = new double[data.length];
for (int i = 0; i < data.length; i++)
{
d[i] = data[i][column];
}
return getMean(d);
}
// menghitung nilai fungsi discriminant untuk kelas yang berbeda
public double[] getDiscriminantFunctionValues(double[] values)
{
double[] function = new double[groupList.size()];
for (int i = 0; i < groupList.size(); i++)
{
double[] tmp = matrixMultiplication(groupRataTengah[i], kovarianGlobal);
function[i] = (matrixMultiplication(tmp, values))// fi=miu i*invers
// kovarian*data
// testing-1/2 miu
// i*invers
// kovarian*miu i
// trans+ln(pi)
- (.5d * matrixMultiplication(tmp, groupRataTengah[i]))
+ Math.log(probabilitas[i]);
}
return function;
}
// memprediksi masuk kelas mana
public int predict(double[] values)
{
int group = -1;
double max = Double.NEGATIVE_INFINITY;
double[] discr = this.getDiscriminantFunctionValues(values);
for (int i = 0; i < discr.length; i++)
{
if (discr[i] > max)
{
max = discr[i];
group = groupList.get(i);
}
}
return group;
}
// mengalikan dua matriks
private double[] matrixMultiplication(double[] matrixA, double[][] matrixB)
{
double c[] = new double[matrixA.length];
for (int i = 0; i < matrixA.length; i++)
{
c[i] = 0;
for (int j = 0; j < matrixB[i].length; j++)
{
c[i] += matrixA[i] * matrixB[i][j];
}
}
return c;
}
private double matrixMultiplication(double[] matrixA, double[] matrixB)
{
double c = 0d;
for (int i = 0; i < matrixA.length; i++)
{
c += matrixA[i] * matrixB[i];
}
return c;
}
public static double getMean(final double[] values)
{
if (values == null || values.length == 0)
return Double.NaN;
double mean = 0.0d;
for (int index = 0; index < values.length; index++)
mean += values[index];
return mean / (double) values.length;
}
public static void test(extraksi_fitur e, double a, double b, double c, double d)
{
extraksi_fitur ef = e;
int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
double[][] data = new double[7][2];// 15=jumlah data,4=fitur R,G,B,D
int count = 0;
data[0][0] = 2.95;
data[0][1] = 6.63;
data[0][0] = 2.53;
data[0][1] = 7.79;
data[0][0] = 3.57;
data[0][1] = 5.65;
data[0][0] = 3.16;
data[0][1] = 5.47;
data[0][0] = 2.58;
data[0][1] = 4.46;
data[0][0] = 2.16;
data[0][1] = 6.22;
data[0][0] = 3.27;
data[0][1] = 3.52;
LDA test = new LDA(data, group, true);
double[] testData = { a, b, c, d };
// test
double[] values = test.getDiscriminantFunctionValues(testData);
for (int i = 0; i < values.length; i++)
{
System.out.println("Discriminant function " + (i + 1) + ": " + values[i]);
}
System.out.println("Predicted group: " + test.predict(testData));
hasil = test.predict(testData);
f1 = values[0];
f2 = values[1];
f3 = values[2];
}
public static void main(String[] args)
{
double[][] data = new double[7][3];// 15=jumlah data,4=fitur R,G,B,D
data[0][0] = 2.95;
data[0][1] = 6.63;
data[0][2] = 2.34;
data[1][0] = 2.53;
data[1][1] = 7.79;
data[1][2] = 2.56;
data[2][0] = 3.57;
data[2][1] = 5.65;
data[2][2] = 2.76;
data[3][0] = 3.16;
data[3][1] = 5.47;
data[3][2] = 2.36;
data[4][0] = 2.58;
data[4][1] = 4.46;
data[4][2] = 5.2;
data[5][0] = 2.16;
data[5][1] = 6.22;
data[5][2] = 5.4;
data[6][0] = 3.27;
data[6][1] = 3.52;
data[6][2] = 6;
int[] group = { 1, 1, 1, 1, 2, 2, 2 };// 1=lemon,2=manis,3=nipis
LDA lda = new LDA(data, group, true);
lda.getTop2Vector();
double[] tt = {123.0,23.0,4};
double[] out = lda.getxydot(tt);
for(int i=0;i<out.length;i++)
{
System.out.println(out[i]);
}
}
}
试用效果良好,实验数据为200维矩阵。
相关文章推荐
- 线性判别分析(Linear Discriminant Analysis, LDA) 算法分析与代码
- [转]线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法初识
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 线性判别分析(Linear Discriminant Analysis, LDA)算法分析
- 【转】线性判别分析(Linear Discriminant Analysis, LDA)算法分析