您的位置:首页 > 其它

数据挖掘笔记-聚类-SpectralClustering-原理与简单实现

2014-09-02 15:42 549 查看


谱聚类(Spectral Clustering, SC)是一种基于图论的聚类方法——将带权无向图划分为两个或两个以上的最优子图,使子图内部尽量相似,而子图间距离尽量距离较远,以达到常见的聚类的目的。其中的最优是指最优目标函数不同,可以是Min
Cut、Nomarlized Cut、Ratio Cut等。谱聚类能够识别任意形状的样本空间且收敛于全局最优解,其基本思想是利用样本数据的相似矩阵(拉普拉斯矩阵)进行特征分解后得到的特征向量进行聚类。


Spectral Clustering 算法步骤:

1)根据数据构造一个Graph,Graph的每一个节点对应一个数据点,将相似的点连接起来,并且边的权重用于表示数据之间的相似度。把这个Graph用邻接矩阵的形式表示出来,记为 W。

2)把W的每一列元素活者行元素加起来得到N个数,把它们放在对角线上(其他地方都是零),组成一个N*N的度矩阵,记为D 。

3)根据度矩阵与邻接矩阵得出拉普拉斯矩阵 L = D - W 。

4)求出拉普拉斯矩阵L的前k个特征值(除非特殊说明,否则“前k个”指按照特征值的大小从小到大的顺序)以及对应的特征向量。

5)把这k个特征(列)向量排列在一起组成一个N*k的矩阵,将其中每一行看作k维空间中的一个向量,并使用 K-Means算法进行聚类。聚类的结果中每一行所属的类别就是原来Graph中的节点亦即最初的N个数据点分别所属的类别。

示例



Spectral Clustering 和传统的聚类方法(如 K-Means等)对比:
1)和 K-Medoids 类似,Spectral
Clustering 只需要数据之间的相似度矩阵就可以了,而不必像K-means那样要求数据必须是 N 维欧氏空间中的向量。Spectral
Clustering 所需要的所有信息都包含在 W 中。不过一般 W 并不总是等于最初的相似度矩阵——回忆一下,W 是我们构造出来的 Graph 的邻接矩阵表示,通常我们在构造 Graph 的时候为了方便进行聚类,更加强到“局部”的连通性,亦即主要考虑把相似的点连接在一起,比如:我们可以设置一个阈值,如果两个点的相似度小于这个阈值,就把他们看作是不连接的。另一种构造
Graph 邻接的方法是将 n 个与节点最相似的点与其连接起来。
2)由于抓住了主要矛盾,忽略了次要的东西,因此比传统的聚类算法更加健壮一些,对于不规则的误差数据不是那么敏感,而且性能也要好一些。许多实验都证明了这一点。事实上,在各种现代聚类算法的比较中,K-means 通常都是作为 baseline 而存在的。实际上
Spectral Clustering 是在用特征向量的元素来表示原来的数据,并在这种“更好的表示形式”上进行 K-Means 。实际上这种“更好的表示形式”是用 Laplacian Eig进行降维的后的结果。而降维的目的正是“抓住主要矛盾,忽略次要的东西”。
3)计算复杂度比 K-means 要小。这个在高维数据上表现尤为明显。例如文本数据,通常排列起来是维度非常高(比如几千或者几万)的稀疏矩阵,对稀疏矩阵求特征值和特征向量有很高效的办法,得到的结果是一些 k 维的向量(通常 k 不会很大),在这些低维的数据上做
K-Means 运算量非常小。但是对于原始数据直接做 K-Means 的话,虽然最初的数据是稀疏矩阵,但是 K-Means 中有一个求 Centroid 的运算,就是求一个平均值:许多稀疏的向量的平均值求出来并不一定还是稀疏向量,事实上,在文本数据里,很多情况下求出来的 Centroid 向量是非常稠密,这时再计算向量之间的距离的时候,运算量就变得非常大,直接导致普通的 K-Means 巨慢无比,而 Spectral Clustering 等工序更多的算法则迅速得多的结果。

Java简单实现代码如下:

public class SpectralClusteringBuilder {

public static int DIMENSION = 30;

public static double THRESHOLD = 0.01;

public Data getInitData() {
Data data = new Data();
try {
String path = SpectralClustering.class.getClassLoader()
.getResource("测试").toURI().getPath();
DocumentSet documentSet = DocumentLoader.loadDocumentSet(path);
List<Document> documents = documentSet.getDocuments();
DocumentUtils.calculateTFIDF_0(documents);
DocumentUtils.calculateSimilarity(documents, new CosineDistance());
Map<String, Map<String, Double>> nmap = new HashMap<String, Map<String, Double>>();
Map<String, String> cmap = new HashMap<String, String>();
for (Document document : documents) {
String name = document.getName();
cmap.put(name, document.getCategory());
Map<String, Double> similarities = nmap.get(name);
if (null == similarities) {
similarities = new HashMap<String, Double>();
nmap.put(name, similarities);
}
for (DocumentSimilarity similarity : document.getSimilarities()) {
if (similarity.getDoc2().getName().equalsIgnoreCase(similarity.getDoc1().getName())) {
similarities.put(similarity.getDoc2().getName(), 0.0);
} else {
similarities.put(similarity.getDoc2().getName(), similarity.getDistance());
}
}
}
String[] docnames = nmap.keySet().toArray(new String[0]);
data.setRow(docnames);
data.setColumn(docnames);
data.setDocnames(docnames);
int len = docnames.length;
double[][] original = new double[len][len];
for (int i = 0; i < len; i++) {
Map<String, Double> similarities = nmap.get(docnames[i]);
for (int j = 0; j < len; j++) {
double distance = similarities.get(docnames[j]);
original[i][j] = distance;
}
}
data.setOriginal(original);
data.setCmap(cmap);
data.setNmap(nmap);
} catch (Exception e) {
e.printStackTrace();
}
return data;
}

/**
* 获取距离阀值在一定范围内的点
* @param data
* @return
*/
public double[][] getWByDistance(Data data) {
Map<String, Map<String, Double>> nmap = data.getNmap();
String[] docnames = data.getDocnames();
int len = docnames.length;
double[][] w = new double[len][len];
for (int i = 0; i < len; i++) {
Map<String, Double> similarities = nmap.get(docnames[i]);
for (int j = 0; j < len; j++) {
double distance = similarities.get(docnames[j]);
w[i][j] = distance < THRESHOLD ? 1 : 0;
}
}
return w;
}

/**
* 获取距离最近的K个点
* @param data
* @return
*/
public double[][] getWByKNearestNeighbors(Data data) {
Map<String, Map<String, Double>> nmap = data.getNmap();
String[] docnames = data.getDocnames();
int len = docnames.length;
double[][] w = new double[len][len];
for (int i = 0; i < len; i++) {
List<Map.Entry<String, Double>> similarities =
new ArrayList<Map.Entry<String, Double>>(nmap.get(docnames[i]).entrySet());
sortSimilarities(similarities, DIMENSION);
for (int j = 0; j < len; j++) {
String name = docnames[j];
boolean flag = false;
for (Map.Entry<String, Double> entry : similarities) {
if (name.equalsIgnoreCase(entry.getKey())) {
flag = true;
break;
}
}
w[i][j] = flag ? 1 : 0;
}
}
return w;
}

/**
* 垂直求和
* @param W
* @return
*/
public double[][] getVerticalD(double[][] W) {
int row = W.length;
int column = W[0].length;
double[][] d = new double[row][column];
for (int j = 0; j < column; j++) {
double sum = 0;
for (int i = 0; i < row; i++) {
sum += W[i][j];
}
d[j][j] = sum;
}
return d;
}

/**
* 水平求和
* @param W
* @return
*/
public double[][] getHorizontalD(double[][] W) {
int row = W.length;
int column = W[0].length;
double[][] d = new double[row][column];
for (int i = 0; i < row; i++) {
double sum = 0;
for (int j = 0; j < column; j++) {
sum += W[i][j];
}
d[i][i] = sum;
}
return d;
}

/**
* 相似度排序,并取前K个,倒叙
* @param similarities
* @param k
*/
public void sortSimilarities(List<Map.Entry<String, Double>> similarities, int k) {
Collections.sort(similarities, new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Entry<String, Double> o1,
Entry<String, Double> o2) {
return o2.getValue().compareTo(o1.getValue());
}
});
while (similarities.size() > k) {
similarities.remove(similarities.size() - 1);
}
}

public void print(double[][] values) {
for (int i = 0, il = values.length; i < il; i++) {
for (int j = 0, jl = values[0].length; j < jl; j++) {
System.out.print(values[i][j] + "  ");
}
System.out.println("\n");
}
}

// 随机生成中心点,并生成初始的K个聚类
public List<DataPointCluster> genInitCluster(List<DataPoint> points, int k) {
List<DataPointCluster> clusters = new ArrayList<DataPointCluster>();
Random random = new Random();
Set<String> categories = new HashSet<String>();
while (clusters.size() < k) {
DataPoint center = points.get(random.nextInt(points.size()));
String category = center.getCategory();
if (categories.contains(category))
continue;
categories.add(category);
DataPointCluster cluster = new DataPointCluster();
cluster.setCenter(center);
cluster.getDataPoints().add(center);
clusters.add(cluster);
}
return clusters;
}

// 将点归入到聚类中
public void handleCluster(List<DataPoint> points,
List<DataPointCluster> clusters, int iterNum) {
for (DataPoint point : points) {
DataPointCluster maxCluster = null;
double maxDistance = Integer.MIN_VALUE;
for (DataPointCluster cluster : clusters) {
DataPoint center = cluster.getCenter();
double distance = DistanceUtils.cosine(point.getValues(),
center.getValues());
if (distance > maxDistance) {
maxDistance = distance;
maxCluster = cluster;
}
}
if (null != maxCluster) {
maxCluster.getDataPoints().add(point);
}
}
// 终止条件定义为原中心点与新中心点距离小于一定阀值
// 当然也可以定义为原中心点等于新中心点
boolean flag = true;
for (DataPointCluster cluster : clusters) {
DataPoint center = cluster.getCenter();
DataPoint newCenter = cluster.computeMediodsCenter();
double distance = DistanceUtils.cosine(newCenter.getValues(),
center.getValues());
if (distance > 0.5) {
flag = false;
cluster.setCenter(newCenter);
}
}
if (!flag && iterNum < 25) {
for (DataPointCluster cluster : clusters) {
cluster.getDataPoints().clear();
}
handleCluster(points, clusters, ++iterNum);
}
}

/**
* KMeans方法
* @param dataPoints
*/
public void kmeans(List<DataPoint> dataPoints) {
List<DataPointCluster> clusters = genInitCluster(dataPoints, 4);
handleCluster(dataPoints, clusters, 0);
int success = 0, failure = 0;
for (DataPointCluster cluster : clusters) {
String category = cluster.getCenter().getCategory();
for (DataPoint dataPoint : cluster.getDataPoints()) {
String dpCategory = dataPoint.getCategory();
if (category.equals(dpCategory)) {
success++;
} else {
failure++;
}
}
}
System.out.println("total: " + (success + failure) + " success: "
+ success + " failure: " + failure);
}

public void build() {
Data data = getInitData();
double[][] w = getWByKNearestNeighbors(data);
double[][] d = getHorizontalD(w);
Matrix W = new Matrix(w);
Matrix D = new Matrix(d);
Matrix L = D.minus(W);
EigenvalueDecomposition eig = L.eig();
double[][] v = eig.getV().getArray();
double[][] vs = new double[v.length][DIMENSION];
for (int i = 0, li = v.length; i < li; i++) {
for (int j = 1, lj = DIMENSION; j <= lj; j++) {
vs[i][j-1] = v[i][j];
}
}
Matrix V = new Matrix(vs);
Matrix O = new Matrix(data.getOriginal());
double[][] t = O.times(V).getArray();
List<DataPoint> dataPoints = new ArrayList<DataPoint>();
for (int i = 0; i < t.length; i++) {
DataPoint dataPoint = new DataPoint();
dataPoint.setCategory(data.getCmap().get(data.getColumn()[i]));
dataPoint.setValues(t[i]);
dataPoints.add(dataPoint);
}
for (int n = 0; n < 10; n++) {
kmeans(dataPoints);
}
}

public static void main(String[] args) {
new SpectralClusteringBuilder().build();
}
}

代码托管:https://github.com/fighting-one-piece/repository-datamining.git
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息