sklearn GMM

本文展示了如何使用sklearn库进行分层交叉验证,并应用GMM(高斯混合模型)进行数据建模。通过StratifiedKFold进行数据划分,使用GMM进行训练和预测,同时绘制了每个类别的高斯椭圆分布,以可视化模型的协方差结构。最终计算并显示了训练和测试的准确性。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

sklearn.cross_validation.StratifiedKFold:
分层交叉验证,使得交叉验证抽到的样本符合原始样本的比例。

类GMM模型,_get_covars 应当返回每一个混合成分的协方差矩阵。

np.linalg.eigh 返回特征值特征向量二元组。
np.arctan2 返回的是两个序列比的弧度值,可以考虑手动转为角度值。

mpl.patches.Ellipse:
有参数: 中心, 宽, 高, 逆时针旋转角度。
由于是椭圆 代码中 180+ 的旋转并没有用。

下面是一个例子:
import matplotlib.pyplot as plt 
import matplotlib as mpl 
import numpy as np 

from sklearn import datasets 
from sklearn.cross_validation import StratifiedKFold 
from sklearn.externals.six.moves import xrange 
from sklearn.mixture import GMM 

def make_ellipses(gmm, ax):
 for n, color in enumerate('rgb'):
  v, w = np.linalg.eigh(gmm._get_covars()[n][:2,:2])
  u = w[0] / np.linalg.norm(w[0])
  angle = np.arctan2(u[1], u[0])
  angle = 180 * angle / np.pi 
  v *= 9
  ell = mpl.patches.Ellipse(gmm.means_[n,:2], v[0], v[1], 180 + angle, color = color)
  ell.set_clip_box(ax.bbox)
  ell.set_alpha(0.5)
  ax.add_artist(ell)

iris = datasets.load_iris()

skf = StratifiedKFold(iris.target, n_folds = 4)
train_index, test_index = next(iter(skf))

X_train = iris.data[train_index]
y_train = iris.target[train_index]
X_test = iris.data[test_index]
y_test = iris.target[test_index]

n_classes = len(np.unique(y_train))

classifiers = dict((covar_type, GMM(n_components = n_classes, covariance_type = covar_type, init_params = 'wc', n_iter = 20)) for covar_type in ['spherical', 'diag', 'tied', 'full'])

n_classifiers = len(classifiers)

plt.figure(figsize = (3 * n_classifiers / 2, 6))
plt.subplots_adjust(bottom = .01, top = 0.95, hspace = .15, wspace = .05, left = .01, right = .99)

for index, (name, classifier) in enumerate(classifiers.items()):
 classifier.means_ = np.array([X_train[y_train == i].mean(axis = 0) for i in xrange(n_classes)])
 classifier.fit(X_train)
 h = plt.subplot(2, n_classifiers / 2, index + 1)
 make_ellipses(classifier, h)

 for n, color in enumerate('rgb'):
  data = iris.data[iris.target == n]
  plt.scatter(data[:,0], data[:,1], 0.8, color = color, label = iris.target_names[n])

 for n, color in enumerate('rgb'):
  data = X_test[y_test == n]
  plt.plot(data[:,0], data[:,1], 'x', color = color)

 y_train_pred = classifier.predict(X_train)
 train_accuracy = np.mean(y_train_pred.ravel() == y_train.ravel()) * 100
 plt.text(0.05, 0.9, 'Train accuracy: %.1f' % train_accuracy, transform = h.transAxes)

 y_test_pred = classifier.predict(X_test)
 test_accuracy = np.mean(y_test_pred.ravel() == y_test.ravel()) * 100
 plt.text(0.05, 0.8, "Test accuracy: %.1f" % test_accuracy, transform = h.transAxes)

 plt.xticks(())
 plt.yticks(())
 plt.title(name)

plt.legend(loc = 'lower right', prop = dict(size = 12))
plt.show()
 



评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值