sklearn模型调优(判断是否过过拟合及选择参数)
2018-02-04 22:19
441 查看
sklearn模型调优(判断是否过过拟合及选择参数)
这篇博客主要介绍两个方面的东西,其实就是两个函数:1. learning_curve():这个函数主要是用来判断(可视化)模型是否过拟合的,关于过拟合,就不多说了,具体可以看以前的博客:模型选择和改进
2. validation_curve():这个函数主要是用来查看在参数不同的取值下模型的性能
下面通过代码例子来看下这两个函数:
一、learning_curve()
这个函数的官方API为:官方API。部分参数含义为:
参数 | 含义 |
---|---|
estimator | 训练的模型 |
X | 数据集样本(不包括label) |
y | 样本label |
train_sizes | 用于产生learning_curve的样本数量,比如[0.1,0.25,0.5,0.75,1]就是当样本是总样本数量的10%,25%,…100%时产生learning_curve,其实就是对应折线图上那几个点的横坐标(见下图),因为样本数量很多,因此都设置比例,当然你也可以直接设置样本数量,默认是np.linspace(0.1, 1.0, 5)。 |
cv | 交叉验证的折数 |
scoring | 模型性能的评价指标,如(‘accuracy’、‘f1’、”mean_squared_error”等) |
直接看个代码吧:
from sklearn import datasets from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import learning_curve import numpy as np import matplotlib.pyplot as plt (X,y) = datasets.load_digits(return_X_y=True) # print(X[:2,:]) train_sizes,train_score,test_score = learning_curve(RandomForestClassifier(),X,y,train_sizes=[0.1,0.2,0.4,0.6,0.8,1],cv=10,scoring='accuracy') train_error = 1- np.mean(train_score,axis=1) test_error = 1- np.mean(test_score,axis=1) plt.plot(train_sizes,train_error,'o-',color = 'r',label = 'training') plt.plot(train_sizes,test_error,'o-',color = 'g',label = 'testing') plt.legend(loc='best') plt.xlabel('traing examples') plt.ylabel('error') plt.show()
运行结果:
二、validation_curve()
官方的API为:validation_curve(),这个函数的部分重要的参数为:
参数 | 含义 |
---|---|
param_name | 要改变的参数的名字,如果当model为SVC时,改变gamma的值,求最好的那个gamma值 |
param_range | 给定的参数范围 |
代码示例:
from sklearn import datasets from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import validation_curve import numpy as np import matplotlib.pyplot as plt (X,y) = datasets.load_digits(return_X_y=True) # print(X[:2,:]) param_range = [10,20,40,80,160,250] train_score,test_score = validation_curve(RandomForestClassifier(),X,y,param_name='n_estimators',param_range=param_range,cv=10,scoring='accuracy') train_score = np.mean(train_score,axis=1) test_score = np.mean(test_score,axis=1) plt.plot(param_range,train_score,'o-',color = 'r',label = 'training') plt.plot(param_range,test_score,'o-',color = 'g',label = 'testing') plt.legend(loc='best') plt.xlabel('number of tree') plt.ylabel('accuracy') plt.show()
运行结果:
可以看到当树的数量为80-90左右的时候,model的性能最好,因此我们可以把n_estimators设置85,这样model的性能会相对好些。
以上就是learning_curve()和validation_curve()的简介。
相关文章推荐
- Spark机器学习——模型选择与参数调优之交叉验证
- 从a站点跳转到b站点,通过url的参数判断是否让该用户选择身份
- Spark2.0机器学习系列之1:基于Pipeline、交叉验证、ParamMap的模型选择和超参数调优
- 从a站点跳转到b站点,通过url的参数判断是否让该用户选择身份
- 如何判断LSTM模型中的过拟合和欠拟合 By 机器之心2017年10月02日 11:09 判断长短期记忆模型在序列预测问题上是否表现良好可能是一件困难的事。也许你会得到一个不错的模型技术得分,但了解
- Spark2.0机器学习系列之2:基于Pipeline、交叉验证、ParamMap的模型选择和超参数调优
- [Spark2.0]ML 调优:模型选择和超参数调优
- 【Scikit-Learn 中文文档】模型选择:选择估计量及其参数 - 关于科学数据处理的统计学习教程 - scikit-learn 教程 | ApacheCN
- scikit-learn进行模型参数的选择
- 调用接口并且判断是否写日志(用一个参数来控制)
- GridView RadioButton 解决办法(二) -- 判断是否有选择
- (三)、利用命令行参数输入多个参数,判断该数组是否为回文数组
- js replace 全局替换 以表单的方式提交参数 判断是否为ie浏览器 将jquery.qqFace.js表情转换成微信的字符码 手机端省市区联动 新字体引用本地运行可以获得,放到服务器上报404 C#提取html中的汉字 MVC几种找不到资源的解决方式 使用Windows服务定时去执行一个方法的三种方式
- jQuery.isEmptyObject()函数用于判断指定参数是否是一个空对象。
- Shell脚本中判断输入变量或者参数是否为空的方法
- 【mybatis】mybatis中判断数组参数的下标是否为最后一个
- javascript判断函数参数是否传递[比较运算符中的两个等号与三个等号差别]
- 机器学习---过拟合和模型选择
- JavaScript基础 isNaN() 判断是否参数是否为数字
- 判断三个参数是否能构成一个三角形