您的位置:首页 > 其它

An introduction to machine learning with scikit-learn

2015-01-23 13:36 543 查看
scikit-learn 是一个基于SciPy和Numpy的开源机器学习模块,包括分类、回归、聚类的一系列算法,而且有详细的文档,是边学边练的绝佳教材,本文将通过一个简单的例子向大家展示如何使用scikit-learn。这个例子是关于手写识别的,就是给了一个手写的数字,让机器来识别它是几。首先来介绍一下数据集,在这个例子中,所谓的数据集就是一张张手写数字的图片,每张图片有8*8个像素,在训练的时候会将每张图片的这64个像素点排列成一个特征向量,所以也可以认为是这一个个特征向量组成了数据集,同时数据集里还包含target
value,就是每张图片对应的真实数字。scikit-learn 为了方便我们学习已经把这个数据集准备好了, 我们只需载入一下即可:

from sklearn import datasets
digits = datasets.load_digits()
大家不妨打印出来看看:

digits.data[0]
digits.target[0]




前者对应的就是数据集中第一个图片的特征向量,后者就是这张图片对应的数字。接下来我们就可以选择一个模型,然后训练它,最后再用训练好的模型来识别新的手写数字。这里选择了SVM模型,如下:

from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100.)
训练模型也很简单,一个函数就可以搞定,fit,如下:

clf.fit(digits.data[:-1], digits.target[:-1])
这里用了数据集中除最后一个数据外的所有数据,这个函数执行完后我们的模型就训练好了,然后我们用这个模型来预测数据集中最后一个数据(特征向量)对应的真实数字,如下:

clf.predict(digits.data[-1])
预测出来的数字是8,那么到底对不对呢,我们看一下最后一个数字的手写图片:



大家就见仁见智啦
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: