您的位置:首页 > 其它

scikit-learn入门到精通(五)Unsupervised learning: seeking representations of the data

2016-01-30 12:41 267 查看
#encoding=utf-8
'''
五监督学习:寻找数据的代表
'''
'''
KMeans聚类
'''
from sklearn import cluster ,datasets
iris = datasets.load_iris()
X_iris = iris.data
y_iris = iris.target

k_means = cluster.KMeans(n_clusters=3)
k_means.fit(X_iris)
print(k_means.labels_[::10])
print(y_iris[::10])
'''
应用案例:矢量量化(vector quantization)
'''
import scipy as sp
try:
lena = sp.lena()
except AttributeError:
from scipy import misc
lena = misc.lena()

lena.shape#(512, 512)
X = lena.reshape((-1,1))
X.shape#(262144, 1)
k_means = cluster.KMeans(n_clusters=5,n_init=1)
k_means.fit(X)
values = k_means.cluster_centers_.squeeze()
values#array([  52.38530826,  199.55751615,  128.24558006,  157.71465044,
#94.97697638])
labels = k_means.labels_
labels#array([3, 3, 3, ..., 4, 4, 4], dtype=int32)
import numpy as np
lena_compresse = np.choose(labels,values)
lena_compresse
#array([ 157.71465044,  157.71465044,  157.71465044, ...,   94.97697638,
#        94.97697638,   94.97697638])
lena_compresse.shape = lena.shape

lena_compresse.shape #(512, 512)

'''
层次聚类算法:Wald
'''
from sklearn.feature_extraction.image import grid_to_graph
from sklearn.cluster import AgglomerativeClustering
lena = sp.misc.lena()
lena.shape #(512, 512)
lena = lena[::2,::2]+lena[1::2,::2]+lena[::2,1::2]+lena[1::2,1::2]
lena.shape #(256, 256)
X = np.reshape(lena,(-1,1))
X.shape#(65536, 1)

###############################################################################
# Define the structure A of the data. Pixels connected to their neighbors.
connectivity = grid_to_graph(*lena.shape)

###############################################################################
# Compute clustering
print("Compute structured hierarchical clustering...")
import  time
st = time.time()
n_clusters = 15  # number of regions
ward = AgglomerativeClustering(n_clusters=n_clusters,
linkage='ward', connectivity=connectivity).fit(X)
label = np.reshape(ward.labels_, lena.shape)
print("Elapsed time: ", time.time() - st)
print("Number of pixels: ", label.size)
print("Number of clusters: ", np.unique(label).size
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: