Lamda新论文gcForest测试-手写数字测试识别详解(Kaggle数据集)
2017-10-03 11:47
351 查看
lamda实验室最近研发了新的学习算法gcForest,论文和代码在lamda网站上都有,给出的代码没有注释相当费解(大牛请无视),后又在github上找到民间大神实现的代码,较为简洁易懂,先贴上代码链接:https://github.com/pylablanche/gcForest
简要对gcForest算法处理数据的流程解释一下:
假设我们现在的数据就是kaggle平台上手写数字识别的数据格式,具体格式见https://www.kaggle.com/c/digit-recognizer/data
那么但数据规模是28*28,假设我们设置扫描窗口window为14*14:
先有多粒度扫描过程,这个过程实质上是将原数据特征的一种“放大”与分离处理:
由28-14+1=15,
可知每行数据会切出:15*15=225个窗口,每个窗口14*14的规模
那么原来数据集的每行都变成了225行,即225个小窗口,窗口为14*14的特征块。即每行有14*14列(sliced_X)
这时sliced_Y还是int,被重复了225次,即每行还是对应一个y,指这个数据的正确标签数字。
然后将sliced_X,sliced_Y送给随机森林和完全随机森林训练,然后再用这俩森林对sliced_X跑出结果概率(十维,表示每个手写数字的概率),然后把俩十维文件合并为20维的,然后又把概率矩阵规模重设成了原始数据的行数。
注意这时数据格式是森林预测到的概率! 原始特征已经“看不到”了,即以后训练级联森林用的不是原数据特征,而是原数据经常森林预测出的概率,以概率为输入特征训练接下来的级联森林。这一点非常容易混淆。
重复上面的,把所有window值都跑一遍,整合概率矩阵,最后MutiScanning返回的行数等于初始数据行数,每行都是预测的概率,列数极多。然后将它送入级联森林。
训练级联森林就简单多了,每一层接受上一层的数据,并检验性能,没有提升就停下来。
论文介绍该算法对序列数据和图像数据效果较好,并相对于深度神经网络有一些优势,详见论文。
我下载来理解之后简单进行了测试,采用了kaggle平台上的手写数字的数据集,在参数都使用默认情形下,依然得到了不错的识别率。
这里给出测试方法:(需要先在上面链接下载gcForest代码,与本测试代码放在同目录下)
由于gcForest代码运行时十分耗内存,博主16g的内存不够用,当把kaggle上四万数据全部用上时,扫描窗口window不能设为合适值,window大约14左右合适,但目前内存限制只能设到23以上,并不能完全发挥gcForest的能力,kaggle上提交准确率有98.2%
每次得到的结果可能不同,这是因为其中随机森林有一定随机性,但总体差别不大。
后期调参工作应该可以再提升识别率。
简要对gcForest算法处理数据的流程解释一下:
假设我们现在的数据就是kaggle平台上手写数字识别的数据格式,具体格式见https://www.kaggle.com/c/digit-recognizer/data
那么但数据规模是28*28,假设我们设置扫描窗口window为14*14:
先有多粒度扫描过程,这个过程实质上是将原数据特征的一种“放大”与分离处理:
由28-14+1=15,
可知每行数据会切出:15*15=225个窗口,每个窗口14*14的规模
那么原来数据集的每行都变成了225行,即225个小窗口,窗口为14*14的特征块。即每行有14*14列(sliced_X)
这时sliced_Y还是int,被重复了225次,即每行还是对应一个y,指这个数据的正确标签数字。
然后将sliced_X,sliced_Y送给随机森林和完全随机森林训练,然后再用这俩森林对sliced_X跑出结果概率(十维,表示每个手写数字的概率),然后把俩十维文件合并为20维的,然后又把概率矩阵规模重设成了原始数据的行数。
注意这时数据格式是森林预测到的概率! 原始特征已经“看不到”了,即以后训练级联森林用的不是原数据特征,而是原数据经常森林预测出的概率,以概率为输入特征训练接下来的级联森林。这一点非常容易混淆。
重复上面的,把所有window值都跑一遍,整合概率矩阵,最后MutiScanning返回的行数等于初始数据行数,每行都是预测的概率,列数极多。然后将它送入级联森林。
训练级联森林就简单多了,每一层接受上一层的数据,并检验性能,没有提升就停下来。
论文介绍该算法对序列数据和图像数据效果较好,并相对于深度神经网络有一些优势,详见论文。
我下载来理解之后简单进行了测试,采用了kaggle平台上的手写数字的数据集,在参数都使用默认情形下,依然得到了不错的识别率。
这里给出测试方法:(需要先在上面链接下载gcForest代码,与本测试代码放在同目录下)
# -*-coding:utf-8-*- import pandas as pd import numpy as np from sklearn.cross_validation import train_test_split import GCForest import pandas as pd import time import numpy as np from sklearn.model_selection import train_test_split from GCForest import gcForest from sklearn.metrics import accuracy_score data= pd.read_csv('train.csv') #print(data.shape) ddata=np.array(data) x=ddata[:1400,1:785].copy() #为快速看到测试结果,这里只试用1400条数据,可自行更改 y=ddata[:1400,:1].copy() y=y.flatten() #print(x) #print(y.shape) X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.25,random_state=9) #print(y_train) gcf = GCForest.gcForest(shape_1X=[28,28], window=[14,16], tolerance=0.0, min_samples_mgs=10, min_samples_cascade=7) gcf.fit(X_train, y_train) pred_X = gcf.predict(X_test) #print (pred_X) accuracy = accuracy_score(y_true=y_test, y_pred=pred_X) print ('gcForest accuracy:{}'.format(accuracy))
由于gcForest代码运行时十分耗内存,博主16g的内存不够用,当把kaggle上四万数据全部用上时,扫描窗口window不能设为合适值,window大约14左右合适,但目前内存限制只能设到23以上,并不能完全发挥gcForest的能力,kaggle上提交准确率有98.2%
每次得到的结果可能不同,这是因为其中随机森林有一定随机性,但总体差别不大。
后期调参工作应该可以再提升识别率。
相关文章推荐
- 在Kaggle手写数字数据集上使用Spark MLlib的RandomForest进行手写数字识别
- 在Kaggle手写数字数据集上使用Spark MLlib的朴素贝叶斯模型进行手写数字识别
- 机器学习Tensorflow基于MNIST数据集识别自己的手写数字(读取和测试自己的模型)
- [置顶] java实现基于Mnist数据集的手写数字识别
- caffe示例实现之4在MNIST手写数字数据集上训练与测试LeNet
- 使用tensorflow利用神经网络分类识别MNIST手写数字数据集,转自随心1993
- 用TensorFlow做Kaggle“手写识别”达到98%准确率-详解
- 第二课 深度学习的“hello world”——基于mnist数据集的手写数字识别
- PK/NN/*/SVM:实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu
- 【深度学习】笔记2_caffe自带的第一个例子,Mnist手写数字识别代码,过程,网络详解
- Keras_深度学习_MNIST数据集手写数字识别之各种调参
- Kaggle Digit Recognizer 基于sklearn实现的手写数字识别 for MNIST data
- mnist数据集在caffe(windows)上的训练与测试及对自己手写数字的分类
- 【Python | TensorBoard】用 PCA 可视化 MNIST 手写数字识别数据集
- 基于机器学习多种方法的kaggle竞赛入门之手写数字的图像识别预测
- 学习笔记——《机器学习实战》KNN算法实现 约会网站测试,手写数字识别,代码,注释,错误修改
- Sklearn-手写数字识别 & kaggle
- Kaggle竞赛排名Top 10% —— 手写数字识别Digit Recognizer
- 详解python实现识别手写MNIST数字集的程序
- Tensorflow系列之(二):详解CNN识别MNIST手写数字集