交叉验证的机制
scikit-learn提供库:cross_validation
from sklearn import cross_validation
n指代样本数,n_folds指代将数据集分成多少份且做几次验证试验。
在初始化对象k_fold里已经包含了许多信息,它已经根据参数n和n_folds将n个样本分成n_folds份。每次验证过程选取其中1份作为测试集,剩下的n_folds-1份作为训练集,并且做n_folds次这样的验证。
>>> from sklearn import cross_validation
>>> k_fold = cross_validation.KFold(n=6, n_folds=3)
>>> for train_indices, test_indices in k_fold:
... print('Train: %s | test: %s' % (train_indices, test_indices))
Train: [2 3 4 5] | test: [0 1]
Train: [0 1 4 5] | test: [2 3]
Train: [0 1 2 3] | test: [4 5]
实践
源代码
>>> from sklearn import datasets, svm
>>> digits = datasets.load_digits()
>>> X_digits = digits.data
>>> y_digits = digits.target
>>> svc = svm.SVC(C=1, kernel='linear')
>>>
>>> from sklearn import cross_validation
>>> kfold = cross_validation.KFold(len(X_digits), n_folds=3)
>>> [svc.fit(X_digits[train], y_digits[train]).score(X_digits[test], y_digits[test]) for train, test in kfold]
>>> cross_validation.cross_val_score(svc, X_digits, y_digits, n_jobs=-1)
array([ 0.93489149, 0.95659432, 0.93989983])
收获
1、交叉验证如何用
2、在一条语句里使用for循环的技巧,
[ process(Xi) for Xi in dataset ]
进阶
1、参数自适应机制
from sklearn.grid_search import GridSearchCV
这种机制能通过cross_validation不断验证参数集,并最终评判出最佳参数
>>> from sklearn.grid_search import GridSearchCV
>>> Cs = np.logspace(-6, -1, 10)
>>> clf = GridSearchCV(estimator=svc, param_grid=dict(C=Cs),
... n_jobs=-1)
>>> clf.fit(X_digits[:1000], y_digits[:1000])
GridSearchCV(cv=None,...
>>> clf.best_score_
0.925...
>>> clf.best_estimator_.C
0.0077...
2、嵌入到估计器(estimator)中的参数自适应机制
这种情形下,估计器(estimator)中自带cross_validation功能,所以可以在估计器里设计验证机制使得估计器能自动测试出最佳参数,这里的参数指模型本身的参数,如L1正则中的alpha参数
>>> from sklearn import linear_model, datasets
>>> lasso = linear_model.LassoCV()
>>> diabetes = datasets.load_diabetes()
>>> X_diabetes = diabetes.data
>>> y_diabetes = diabetes.target
>>> lasso.fit(X_diabetes, y_diabetes)
LassoCV(alphas=None, copy_X=True, cv=None, eps=0.001, fit_intercept=True,
max_iter=1000, n_alphas=100, n_jobs=1, normalize=False, positive=False,
precompute='auto', random_state=None, selection='cyclic', tol=0.0001,
verbose=False)
>>>
>>> lasso.alpha_
0.01229...