您的位置:首页 > 其它

sklearn GMM BIC 模型选择

2016-08-13 21:25 525 查看
BIC为似然函数与参数及样本量的组合,选择该值最小的模型。

np.infty: inf

对GMM模型直接调用bic就可以得到其值

itertools.cycle: 实例化圆形迭代器,zip具有压缩取短的性质。

这里还使用了凸组合:bic.min() * 0.97 + 0.03 * bic.max()

下面是一个利用BIC选取GMM的例子:
import itertools
import numpy as np
from scipy import linalg
import matplotlib.pyplot as plt
import matplotlib as mpl
from sklearn import mixture

n_samples = 500
np.random.seed(0)
C = np.array([[0, -0.1], [1.7, 0.4]])
X = np.r_[np.dot(np.random.randn(n_samples, 2), C),
   0.7 * np.random.randn(n_samples, 2) + np.array([-6, 3])]

lowest_bic = np.infty
bic = []
n_components_range = range(1, 7)
cv_types = ['spherical', 'tied', 'diag', 'full']
for cv_type in cv_types:
 for n_components in n_components_range:
  gmm = mixture.GMM(n_components = n_components, covariance_type = cv_type)
  gmm.fit(X)
  bic.append(gmm.bic(X))
  if bic[-1] < lowest_bic:
   lowest_bic = bic[-1]
   best_gmm = gmm

bic = np.array(bic)
color_iter = itertools.cycle(['k', 'r', 'g', 'b', 'c', 'm', 'y'])
clf = best_gmm
bars = []

spl = plt.subplot(2, 1, 1)
for i, (cv_type, color) in enumerate(zip(cv_types, color_iter)):
 xpos = np.array(n_components_range) + 0.2 * (i - 2)
 bars.append(plt.bar(xpos, bic[i*len(n_components_range): (i + 1) * len(n_components_range)], width = .2, color = color))

plt.xticks(n_components_range)
plt.ylim([bic.min() * 1.01 - .01 *bic.max(), bic.max()])
plt.title('BIC score per model')
xpos = np.mod(bic.argmin(), len(n_components_range)) + .65 + .2 * np.floor(bic.argmin() / len(n_components_range))
plt.text(xpos, bic.min() * 0.97 + 0.03 * bic.max(), "*", fontsize = 14)
spl.set_xlabel("Number of components")
spl.legend([b[0] for b in bars], cv_types)

splot = plt.subplot(2, 1, 2)
Y_ = clf.predict(X)
for i, (mean, covar, color) in enumerate(zip(clf.means_, clf.covars_, color_iter)):
 v, w = linalg.eigh(covar)
 if not np.any(Y_ == i):
  continue
 plt.scatter(X[Y_ == i, 0], X[Y_ == i, 1], .8, color = color)

 angle = np.arctan2(w[0][1], w[0][0])
 angle = 180 * angle / np.pi
 v *= 4
 ell = mpl.patches.Ellipse(mean, v[0], v[1], 180 + angle, color = color)
 ell.set_clip_box(splot.bbox)
 ell.set_alpha(.5)
 splot.add_artist(ell)

plt.xlim(-10, 10)
plt.ylim(-3, 6)
plt.xticks(())
plt.yticks(())
plt.title("Selected GMM: full model, 2 components")
plt.subplots_adjust(hspace = .35, bottom = .02)
plt.show()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息