scikit-learn是一个开源的Python语言机器学习工具包。它涵盖了几乎所有主流机器学习算法的实现,并且提供了一致的调用接口。它基于Numpy和SciPy等Python数值计算库,提供了高效的算法实现。总结起来,scikit-learn工具包有以下几个优点:
- 文档齐全:官方文档齐全,更新及时。
- 接口易用:针对所有的算法提供了一致的接口调用规则,不管是KNN、K-Means还是PCA。
- 算法全面:涵盖主流机器学习任务的算法,包括回归算法、分类算法、聚类分析、数据降维处理等。
当然,scikit-learn不支持分布式计算,不适合用来处理超大型数据,但这并不影响scikit-learn作为一个优秀的机器学习工具库这个事实。许多知名的公司,包括Evemote和Spotify都使用scikit-learn作为他们的机器学习应用。
1.scikit-learn示例
回顾前面章节介绍的机器学习应用开发的典型步骤,我们使用scikit-learn来完成一个手写数字识别的例子。这是一个有监督的学习,数据是标记过的手写数字的图片。即通过采集足够多的手写数字样本数据,选择合适的模型,并使用采集到的数据进行模型训练,最后验证手写识别程序的正确性。
(1)数据采集和标记
如果我们从头实现一个手写数字识别程序,需要先采集数据,即让尽量多不同手写习惯的用户,写出从0~9的所有数字,然后把用户写出来的数字进行标记,即用户每写出一个数字,就标记他写出的是哪个数字。
为什么要采集尽量多不同书写习惯的用户写的数字呢?因为只有这样,采集到的数据才有代表性,才能保证最终训练出来的模型的准确性。极端的例子,我们采集的都是习惯写出瘦高形数字的人,那么针对习惯写出矮胖形数字的人写出来的数字,模型的识别成功率就会很低。
所幸我们不需要从头开始这项工作。scikit-learn自带了一些数据集,其中一个是数字识别图片的数据。使用以下代码来加载数据。
from sklearn import dataset
digits = datasets.load_digits()
可以在ipython notebook环境下把数据所表示的图片用matplotlib显示出来:
from sklearn import datasets
from matplotlib import pyplot as plt
digits = datasets.load_digits()
images_and_labels = list(zip(digits.images,digits.target))
plt.figure(figsize=(8,6),dpi=200)
for index,(image,label)in enumerate(images_and_labels[:8]):
plt.subplot(2,4,index+1)
plt.axis('of