您的位置:首页 > 编程语言 > Python开发

Lamda新论文gcForest测试-手写数字测试识别详解(Kaggle数据集)

2017-10-03 11:47 351 查看
lamda实验室最近研发了新的学习算法gcForest,论文和代码在lamda网站上都有,给出的代码没有注释相当费解(大牛请无视),后又在github上找到民间大神实现的代码,较为简洁易懂,先贴上代码链接:https://github.com/pylablanche/gcForest

简要对gcForest算法处理数据的流程解释一下:

假设我们现在的数据就是kaggle平台上手写数字识别的数据格式,具体格式见https://www.kaggle.com/c/digit-recognizer/data

那么但数据规模是28*28,假设我们设置扫描窗口window为14*14:

先有多粒度扫描过程,这个过程实质上是将原数据特征的一种“放大”与分离处理:

由28-14+1=15,

可知每行数据会切出:15*15=225个窗口,每个窗口14*14的规模

那么原来数据集的每行都变成了225行,即225个小窗口,窗口为14*14的特征块。即每行有14*14列(sliced_X)

这时sliced_Y还是int,被重复了225次,即每行还是对应一个y,指这个数据的正确标签数字。

然后将sliced_X,sliced_Y送给随机森林和完全随机森林训练,然后再用这俩森林对sliced_X跑出结果概率(十维,表示每个手写数字的概率),然后把俩十维文件合并为20维的,然后又把概率矩阵规模重设成了原始数据的行数。

注意这时数据格式是森林预测到的概率! 原始特征已经“看不到”了,即以后训练级联森林用的不是原数据特征,而是原数据经常森林预测出的概率,以概率为输入特征训练接下来的级联森林。这一点非常容易混淆。

重复上面的,把所有window值都跑一遍,整合概率矩阵,最后MutiScanning返回的行数等于初始数据行数,每行都是预测的概率,列数极多。然后将它送入级联森林。

训练级联森林就简单多了,每一层接受上一层的数据,并检验性能,没有提升就停下来。

论文介绍该算法对序列数据和图像数据效果较好,并相对于深度神经网络有一些优势,详见论文。

我下载来理解之后简单进行了测试,采用了kaggle平台上的手写数字的数据集,在参数都使用默认情形下,依然得到了不错的识别率。

这里给出测试方法:(需要先在上面链接下载gcForest代码,与本测试代码放在同目录下)

# -*-coding:utf-8-*-
import pandas as pd
import numpy as np
from sklearn.cross_validation import train_test_split
import GCForest
import pandas as pd
import time
import numpy as np
from sklearn.model_selection import train_test_split
from GCForest import gcForest
from sklearn.metrics import accuracy_score
data= pd.read_csv('train.csv')
#print(data.shape)
ddata=np.array(data)
x=ddata[:1400,1:785].copy()  #为快速看到测试结果,这里只试用1400条数据,可自行更改
y=ddata[:1400,:1].copy()
y=y.flatten()
#print(x)
#print(y.shape)
X_train,X_test,y_train,y_test=train_test_split(x,y,test_size=0.25,random_state=9)
#print(y_train)
gcf = GCForest.gcForest(shape_1X=[28,28], window=[14,16], tolerance=0.0, min_samples_mgs=10, min_samples_cascade=7)
gcf.fit(X_train, y_train)
pred_X = gcf.predict(X_test)
#print (pred_X)
accuracy = accuracy_score(y_true=y_test, y_pred=pred_X)
print ('gcForest accuracy:{}'.format(accuracy))


由于gcForest代码运行时十分耗内存,博主16g的内存不够用,当把kaggle上四万数据全部用上时,扫描窗口window不能设为合适值,window大约14左右合适,但目前内存限制只能设到23以上,并不能完全发挥gcForest的能力,kaggle上提交准确率有98.2%

每次得到的结果可能不同,这是因为其中随机森林有一定随机性,但总体差别不大。

后期调参工作应该可以再提升识别率。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐