在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别
2016-05-12 20:36
579 查看
昨天我使用Spark MLlib的朴素贝叶斯进行手写数字识别,准确率在0.83左右,今天使用了
首先来说说
numTrees:随机森林中树的数目。增大这个数值可以减小预测的方差,提高预测试验的准确性,训练时间会线性地随之增长。
maxDepth:随机森林中每棵树的深度。增加这个值可以是模型更具表征性和更强大,然而训练也更耗时,更容易过拟合。
在这次的训练过程中,我就是反复调整上面两个参数来提升预测的准确性。首先来设定一下一些参数的初始值。
第一次我将树的数目设定为3,每棵树深度为4。下面开始训练模型:
与使用朴素贝叶斯时评估准确率方式一样,我使用训练数据来计算准确率:
下面是每次对上面所说到的两个参数进行调整后得到的准确率:
可以发现,准确率在
把训练出来的结果上传到Kaggle上,得到的准确率为
RandomForest来训练模型,并进行了参数调优。
首先来说说
RandomForest训练分类器时使用到的一些参数:
numTrees:随机森林中树的数目。增大这个数值可以减小预测的方差,提高预测试验的准确性,训练时间会线性地随之增长。
maxDepth:随机森林中每棵树的深度。增加这个值可以是模型更具表征性和更强大,然而训练也更耗时,更容易过拟合。
在这次的训练过程中,我就是反复调整上面两个参数来提升预测的准确性。首先来设定一下一些参数的初始值。
val numClasses = 10 val categoricalFeaturesInfo = Map[Int, Int]() val numTrees = 3 val featureSubsetStrategy = "auto" val impurity = "gini" val maxDepth = 4 val maxBins = 32
第一次我将树的数目设定为3,每棵树深度为4。下面开始训练模型:
val randomForestModel = RandomForest.trainClassifier(data, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
与使用朴素贝叶斯时评估准确率方式一样,我使用训练数据来计算准确率:
val nbTotalCorrect = data.map { point => if (randomForestModel.predict(point.features) == point.label) 1 else 0 }.sum val numData = data.count() println(numData) //42000 val nbAccuracy = nbTotalCorrect / numData
下面是每次对上面所说到的两个参数进行调整后得到的准确率:
//numTree=3,maxDepth=4,准确率:0.5507619047619048 //numTree=4,maxDepth=5,准确率:0.7023095238095238 //numTree=5,maxDepth=6,准确率:0.693595238095238 //numTree=6,maxDepth=7,准确率:0.8426428571428571 //numTree=7,maxDepth=8,准确率:0.879452380952381 //numTree=8,maxDepth=9,准确率:0.9105714285714286 //numTree=9,maxDepth=10,准确率:0.9446428571428571 //numTree=10,maxDepth=11,准确率:0.9611428571428572 //numTree=11,maxDepth=12,准确率:0.9765952380952381 //numTree=12,maxDepth=13,准确率:0.9859523809523809 //numTree=13,maxDepth=14,准确率:0.9928333333333333 //numTree=14,maxDepth=15,准确率:0.9955 //numTree=15,maxDepth=16,准确率:0.9972857142857143 //numTree=16,maxDepth=17,准确率:0.9979285714285714 //numTree=17,maxDepth=18,准确率:0.9983809523809524 //numTree=18,maxDepth=19,准确率:0.9989285714285714 //numTree=19,maxDepth=20,准确率:0.9989523809523809 //numTree=20,maxDepth=21,准确率:0.999 //numTree=21,maxDepth=22,准确率:0.9994761904761905 //numTree=22,maxDepth=23,准确率:0.9994761904761905 //numTree=23,maxDepth=24,准确率:0.9997619047619047 //numTree=24,maxDepth=25,准确率:0.9997857142857143 //numTree=25,maxDepth=26,准确率:0.9998333333333334 //numTree=29,maxDepth=30,准确率:0.9999523809523809
可以发现,准确率在
numTree=11,maxDepth=12附近开始收敛到0.999。这次得到的准确率要比上次使用朴素贝叶斯训练得出的准确率(0.826)要高出许多。现在开始对测试数据进行预测,使用的参数是
numTree=29,maxDepth=30:
val predictions = randomForestModel.predict(features).map { p => p.toInt }
把训练出来的结果上传到Kaggle上,得到的准确率为
0.95929,经过我的四次参数调整,得到的最高的准确率是
0.96586,设置的参数是:
numTree=55,maxDepth=30,当我将参数改为
numTree=70,maxDepth=30时,准确率有所下降,为
0.96271,看来这个时候出现过拟合了。不过准确率能从昨天的0.83提高到0.96还是挺兴奋的,我还会继续尝试使用其他方式进行手写数字识别,不知何时能达到1.
相关文章推荐
- 面试题 1
- 实验三进程调度
- 第06篇 MEF部件的生命周期(PartCreationPolicy)
- linux相关
- Linux基础之:curl工具的使用
- delphi概念性学习(三)
- python datetime模块详解
- 简单的cp程序
- C++ 类的静态成员详细讲解
- Max Sum
- IO流的读写操作
- 多重背包问题II
- JBoss调用Webservice出现org.jboss.ws.core.jaxws.spi.ProviderImple not found错误
- Linux 设备驱动框架
- 会计眼中的“借”“贷”
- LDA
- scripts/mysql_install_db报错原因分析和解决
- java 中的==,equals(),hashCode()
- iOS App 跳转到 AppStore
- Linux监控cpu以及内存使用情况之top命令