您的位置:首页 > 理论基础 > 计算机网络

MXNet官方文档教程(2):基于卷积神经网络的手写数字识别示例

2016-11-23 22:07 911 查看
原本打算开始翻译计算图的部分,结果上一篇刚发完,MXNet就升级了教程文档(伤不起啊),更新了上一篇中手写数字识别示例的详细教程。那这一篇就与时俱进,来将刚更新的这篇教程翻译过来把。由于目前图片无法上传到博客中,相关图片可在原网站查看:Handwritten
Digit Recognition。

本教程引导你完成一个有关计算机视觉分类的应用示例:使用人工神经网络识别手写数字

 

加载数据

我们首先需要获取MNIST 数据,该数据集是手写数字识别常用的数据集。数据集中的每一幅图像都被缩放为28*28像素大小的灰度图(灰度值介于0到254之间)。以下代码下载并加载图像和与图像对应的标签到numpy。

import numpy as np
import os
impor turllib
import gzip
import struct
def download_data(url, force_download=True):
    fname = url.split("/")[-1]
    if force_downloadornot
os.path.exists(fname):
        urllib.urlretrieve(url,fname)
    return fname
 
def read_data(label_url, image_url):
    with gzip.open(download_data(label_url))as
flbl:
        magic, num = struct.unpack(">II",
flbl.read(8))
        label = np.fromstring(flbl.read(),
dtype=np.int8)
    with gzip.open(download_data(image_url),'rb')as
fimg:
        magic, num, rows, cols
= struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(),
dtype=np.uint8).reshape(len(label),
rows, cols)
    return (label, image)
 
path='http://yann.lecun.com/exdb/mnist/'
(train_lbl, train_img)= read_data(
    path+'train-labels-idx1-ubyte.gz', path+'train-images-idx3-ubyte.gz')
(val_lbl, val_img) = read_data(
    path+'t10k-labels-idx1-ubyte.gz', path+'t10k-images-idx3-ubyte.gz')

我们输出了前10幅图像和他们所对应的标签:

%matplotlib inline

import matplotlib.pyplot as plt

for i inrange(10):

    plt.subplot(1,10,i+1)

    plt.imshow(train_img[i], cmap='Greys_r')

    plt.axis('off')

plt.show()

print('label: %s'% (train_lbl[0:10],))


label: [5 0 4 1 9 2 1 3 1 4]


之后我们为MXNet创建数据迭代器。和迭代器一样,数据迭代器在每次调用next()函数时返回一批数据,包括多幅图片和其对应的标签。这些图像保存在一个大小为(batch_size, num_channels, width, height)的4维矩阵中。对于MNIST数据集来说,图像仅有一个色彩通道且高宽均为28。此外,我们经常洗乱用于训练的图像,以加快训练的速度。

import mxnet as mx

 

defto4d(img):

    return img.reshape(img.shape[0],1,28,28).astype(np.float32)/255

 

batch_size=100

train_iter= mx.io.NDArrayIter(to4d(train_img), train_lbl, batch_size, shuffle=True)

val_iter= mx.io.NDArrayIter(to4d(val_img), val_lbl, batch_size)


多层感知机

一个多层感知机包含多个全连接层。对于全连接层来说,假设输入矩阵X的大小为n*m,输出矩阵Y的大小为n*k,其中k通常被称为隐藏大小。这个层有两个参数,m*n的权重矩阵W和m*1的偏移向量b。则输出由下式得出:

Y =WX + b
全连接层的输出通常输入到一个卷积层,进行逐像素操作(elemental-wise operations)。其中一个很著名的函数就是Sigmoid函数:f(x)= 1/(1+e^(-x))。而如今人们也使用一个更简单的叫做relu的函数:f(x) = max(0,x)。

最后一个全连接层通常拥有和数据集中的类别个数一样的隐藏大小。最后我们压入一个softmax层,它可以将输入映射到表示可能性的分值。同样假设输入X大小为n*m,x_i为第i行。则第i行的输出为:

定义多层感知机在MXNet中是很简单的,如下所示。

# Create a place holder variable for the input data

data= mx.sym.Variable('data')

# Flatten the data from 4-D shape (batch_size, num_channel, width, height)

# into 2-D (batch_size, num_channel*width*height)

data= mx.sym.Flatten(data=data)

 

# The first fully-connected layer

fc1  = mx.sym.FullyConnected(data=data, name='fc1', num_hidden=128)

# Apply relu to the output of the first fully-connnected layer

act1= mx.sym.Activation(data=fc1, name='relu1', act_type="relu")

 

# The second fully-connected layer and the according activation function

fc2  = mx.sym.FullyConnected(data=act1, name='fc2', num_hidden =64)

act2= mx.sym.Activation(data=fc2, name='relu2', act_type="relu")

 

# The thrid fully-connected layer, note that the hidden size should be 10, which is the number of unique digits

fc3  = mx.sym.FullyConnected(data=act2, name='fc3', num_hidden=10)

# The softmax and loss layer

mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

 

# We visualize the network structure with output size (the batch_size is ignored.)

shape= {"data" : (batch_size, 1,28,28)}

mx.viz.plot_network(symbol=mlp, shape=shape)


现在神经网络定义和数据迭代器都已经准备好了。我们可以开始训练了:

import logging

logging.getLogger().setLevel(logging.DEBUG)

 

model= mx.model.FeedForward(

    symbol = mlp,       # network structure

    num_epoch =10,     # number of data passes for training

    learning_rate =0.1# learning rate of SGD

)

model.fit(

    X=train_iter,       # training data

    eval_data=val_iter,# validation data

    batch_end_callback = mx.callback.Speedometer(batch_size,200)# output progress for each 200 data batches

)

INFO:root:Start training with [cpu(0)]

INFO:root:Epoch[0] Batch [200]  Speed: 26279.17 samples/sec Train-accuracy=0.111550

INFO:root:Epoch[0] Batch [400]  Speed: 27424.98 samples/sec Train-accuracy=0.111000

INFO:root:Epoch[0] Batch [600]  Speed: 27094.87 samples/sec Train-accuracy=0.133200

INFO:root:Epoch[0] Resetting Data Iterator

INFO:root:Epoch[0] Time cost=2.320

INFO:root:Epoch[0] Validation-accuracy=0.276800

INFO:root:Epoch[1] Batch [200]  Speed: 17739.48 samples/sec Train-accuracy=0.412650

INFO:root:Epoch[1] Batch [400]  Speed: 18869.69 samples/sec Train-accuracy=0.753500

INFO:root:Epoch[1] Batch [600]  Speed: 25618.04 samples/sec Train-accuracy=0.828750

INFO:root:Epoch[1] Resetting Data Iterator

INFO:root:Epoch[1] Time cost=2.988

INFO:root:Epoch[1] Validation-accuracy=0.854400

INFO:root:Epoch[2] Batch [200]  Speed: 21532.09 samples/sec Train-accuracy=0.859750

INFO:root:Epoch[2] Batch [400]  Speed: 27919.08 samples/sec Train-accuracy=0.888700

INFO:root:Epoch[2] Batch [600]  Speed: 26810.95 samples/sec Train-accuracy=0.905550

INFO:root:Epoch[2] Resetting Data Iterator

INFO:root:Epoch[2] Time cost=2.408

INFO:root:Epoch[2] Validation-accuracy=0.916300

INFO:root:Epoch[3] Batch [200]  Speed: 28097.98 samples/sec Train-accuracy=0.917300

INFO:root:Epoch[3] Batch [400]  Speed: 27490.20 samples/sec Train-accuracy=0.925850

INFO:root:Epoch[3] Batch [600]  Speed: 27937.45 samples/sec Train-accuracy=0.934900

INFO:root:Epoch[3] Resetting Data Iterator

INFO:root:Epoch[3] Time cost=2.167

INFO:root:Epoch[3] Validation-accuracy=0.938400

INFO:root:Epoch[4] Batch [200]  Speed: 26948.04 samples/sec Train-accuracy=0.942450

INFO:root:Epoch[4] Batch [400]  Speed: 24250.66 samples/sec Train-accuracy=0.943200

INFO:root:Epoch[4] Batch [600]  Speed: 22772.67 samples/sec Train-accuracy=0.951550

INFO:root:Epoch[4] Resetting Data Iterator

INFO:root:Epoch[4] Time cost=2.456

INFO:root:Epoch[4] Validation-accuracy=0.951500

INFO:root:Epoch[5] Batch [200]  Speed: 27313.59 samples/sec Train-accuracy=0.955500

INFO:root:Epoch[5] Batch [400]  Speed: 28061.48 samples/sec Train-accuracy=0.955100

INFO:root:Epoch[5] Batch [600]  Speed: 26730.32 samples/sec Train-accuracy=0.960500

INFO:root:Epoch[5] Resetting Data Iterator

INFO:root:Epoch[5] Time cost=2.206

INFO:root:Epoch[5] Validation-accuracy=0.956300

INFO:root:Epoch[6] Batch [200]  Speed: 28440.23 samples/sec Train-accuracy=0.962700

INFO:root:Epoch[6] Batch [400]  Speed: 28832.82 samples/sec Train-accuracy=0.962700

INFO:root:Epoch[6] Batch [600]  Speed: 27814.78 samples/sec Train-accuracy=0.967150

INFO:root:Epoch[6] Resetting Data Iterator

INFO:root:Epoch[6] Time cost=2.131

INFO:root:Epoch[6] Validation-accuracy=0.960300

INFO:root:Epoch[7] Batch [200]  Speed: 20942.23 samples/sec Train-accuracy=0.967550

INFO:root:Epoch[7] Batch [400]  Speed: 22264.85 samples/sec Train-accuracy=0.967750

INFO:root:Epoch[7] Batch [600]  Speed: 21294.69 samples/sec Train-accuracy=0.971500

INFO:root:Epoch[7] Resetting Data Iterator

INFO:root:Epoch[7] Time cost=2.805

INFO:root:Epoch[7] Validation-accuracy=0.961400

INFO:root:Epoch[8] Batch [200]  Speed: 17870.55 samples/sec Train-accuracy=0.972550

INFO:root:Epoch[8] Batch [400]  Speed: 11526.75 samples/sec Train-accuracy=0.971600

INFO:root:Epoch[8] Batch [600]  Speed: 15082.47 samples/sec Train-accuracy=0.974500

INFO:root:Epoch[8] Resetting Data Iterator

INFO:root:Epoch[8] Time cost=4.197

INFO:root:Epoch[8] Validation-accuracy=0.963000

INFO:root:Epoch[9] Batch [200]  Speed: 10139.52 samples/sec Train-accuracy=0.976000

INFO:root:Epoch[9] Batch [400]  Speed: 10321.69 samples/sec Train-accuracy=0.975550

INFO:root:Epoch[9] Batch [600]  Speed: 10820.23 samples/sec Train-accuracy=0.977750

INFO:root:Epoch[9] Resetting Data Iterator

INFO:root:Epoch[9] Time cost=5.777

INFO:root:Epoch[9] Validation-accuracy=0.964100


完成训练后,我们对单幅图片进行测试。

plt.imshow(val_img[0], cmap='Greys_r')

plt.axis('off')

plt.show()

prob= model.predict(val_img[0:1].astype(np.float32)/255)[0]

print'Classified as %d with probability %f'% (prob.argmax(),max(prob))


Classified as 7 with probability 0.999781


我们也可以通过给予一个数据迭代器来计算正确率。

print'Validation accuracy: %f%%'% (model.score(val_iter)*100,)


 

Validation accuracy: 96.410000%


甚至,我们可以识别写在框中的数字。

from IPython.display import HTML

import cv2

import numpy as np

from mnist_demo import html, script

def classify(img):

    img = img[len('data:image/png;base64,'):].decode('base64')

    img = cv2.imdecode(np.fromstring(img, np.uint8),-1)

    img = cv2.resize(img[:,:,3], (28,28))

    img = img.astype(np.float32).reshape((1,1,28,28))/255.0

   return model.predict(img)[0].argmax()

 

'''

To see the model in action, run the demo notebook at

https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb.

'''

HTML(html+ script)


卷积神经网络

注意之前的全连接层在训练时只是将图像转换为向量,而忽略了像素在水平和垂直维度上的空间信息。卷积层的作用就是通过使用一个更结构化的权重W来克服这一缺点。它使用2维卷积来代替简单的矩阵乘法来得到输出。

我们也可以使用多个特征图(每一个都拥有一个不同的权重矩阵)来提取不同的特征。

 

除了卷积层外,另一个卷积神经网络主要的变化就是加入了池化层(pooling layers)。池化层将一个n*m(通常我们称其为核大小)的图像转化为一个单独的值来降低人工神经网络对于空间位置的敏感程度(译者注:为了避免过拟合。)

data= mx.symbol.Variable('data')

# first conv layer

conv1= mx.sym.Convolution(data=data, kernel=(5,5), num_filter=20)

tanh1= mx.sym.Activation(data=conv1, act_type="tanh")

pool1= mx.sym.Pooling(data=tanh1, pool_type="max", kernel=(2,2), stride=(2,2))

# second conv layer

conv2= mx.sym.Convolution(data=pool1, kernel=(5,5), num_filter=50)

tanh2= mx.sym.Activation(data=conv2, act_type="tanh")

pool2= mx.sym.Pooling(data=tanh2, pool_type="max", kernel=(2,2), stride=(2,2))

# first fullc layer

flatten= mx.sym.Flatten(data=pool2)

fc1= mx.symbol.FullyConnected(data=flatten, num_hidden=500)

tanh3= mx.sym.Activation(data=fc1, act_type="tanh")

# second fullc

fc2= mx.sym.FullyConnected(data=tanh3, num_hidden=10)

# softmax loss

lenet= mx.sym.SoftmaxOutput(data=fc2, name='softmax')


注意上面的LeNet模型比多层感知机更加复杂,所以我们使用GPU代替CPU来进行训练。

model= mx.model.FeedForward(

    ctx = mx.gpu(0),     # use GPU 0 for training, others are same as before

   symbol = lenet,      

    num_epoch =10,    

    learning_rate =0.1)

model.fit(

    X=train_iter, 

    eval_data=val_iter,

    batch_end_callback = mx.callback.Speedometer(batch_size,200)

)


 

INFO:root:Start training with [gpu(0)]

INFO:root:Epoch[0] Batch [200]  Speed: 14804.86 samples/sec Train-accuracy=0.111500

INFO:root:Epoch[0] Batch [400]  Speed: 14294.26 samples/sec Train-accuracy=0.111000

INFO:root:Epoch[0] Batch [600]  Speed: 14273.05 samples/sec Train-accuracy=0.113600

INFO:root:Epoch[0] Resetting Data Iterator

INFO:root:Epoch[0] Time cost=4.446

INFO:root:Epoch[0] Validation-accuracy=0.113500

INFO:root:Epoch[1] Batch [200]  Speed: 14332.64 samples/sec Train-accuracy=0.141350

INFO:root:Epoch[1] Batch [400]  Speed: 14785.42 samples/sec Train-accuracy=0.777650

INFO:root:Epoch[1] Batch [600]  Speed: 14796.36 samples/sec Train-accuracy=0.914550

INFO:root:Epoch[1] Resetting Data Iterator

INFO:root:Epoch[1] Time cost=4.105

INFO:root:Epoch[1] Validation-accuracy=0.937700

INFO:root:Epoch[2] Batch [200]  Speed: 14877.08 samples/sec Train-accuracy=0.941850

INFO:root:Epoch[2] Batch [400]  Speed: 14806.53 samples/sec Train-accuracy=0.955900

INFO:root:Epoch[2] Batch [600]  Speed: 14844.79 samples/sec Train-accuracy=0.965200

INFO:root:Epoch[2] Resetting Data Iterator

INFO:root:Epoch[2] Time cost=4.048

INFO:root:Epoch[2] Validation-accuracy=0.971200

INFO:root:Epoch[3] Batch [200]  Speed: 14873.95 samples/sec Train-accuracy=0.971150

INFO:root:Epoch[3] Batch [400]  Speed: 14793.99 samples/sec Train-accuracy=0.972400

INFO:root:Epoch[3] Batch [600]  Speed: 14806.52 samples/sec Train-accuracy=0.976600

INFO:root:Epoch[3] Resetting Data Iterator

INFO:root:Epoch[3] Time cost=4.052

INFO:root:Epoch[3] Validation-accuracy=0.980600

INFO:root:Epoch[4] Batch [200]  Speed: 14428.12 samples/sec Train-accuracy=0.979100

INFO:root:Epoch[4] Batch [400]  Speed: 14298.85 samples/sec Train-accuracy=0.979550

INFO:root:Epoch[4] Batch [600]  Speed: 14618.55 samples/sec Train-accuracy=0.982400

INFO:root:Epoch[4] Resetting Data Iterator

INFO:root:Epoch[4] Time cost=4.158

INFO:root:Epoch[4] Validation-accuracy=0.983300

INFO:root:Epoch[5] Batch [200]  Speed: 14919.47 samples/sec Train-accuracy=0.983700

INFO:root:Epoch[5] Batch [400]  Speed: 14809.71 samples/sec Train-accuracy=0.984050

INFO:root:Epoch[5] Batch [600]  Speed: 14550.25 samples/sec Train-accuracy=0.986250

INFO:root:Epoch[5] Resetting Data Iterator

INFO:root:Epoch[5] Time cost=4.071

INFO:root:Epoch[5] Validation-accuracy=0.985100

INFO:root:Epoch[6] Batch [200]  Speed: 14363.59 samples/sec Train-accuracy=0.986500

INFO:root:Epoch[6] Batch [400]  Speed: 14629.87 samples/sec Train-accuracy=0.986950

INFO:root:Epoch[6] Batch [600]  Speed: 14842.83 samples/sec Train-accuracy=0.988700

INFO:root:Epoch[6] Resetting Data Iterator

INFO:root:Epoch[6] Time cost=4.113

INFO:root:Epoch[6] Validation-accuracy=0.985300

INFO:root:Epoch[7] Batch [200]  Speed: 14863.48 samples/sec Train-accuracy=0.988950

INFO:root:Epoch[7] Batch [400]  Speed: 14824.65 samples/sec Train-accuracy=0.988800

INFO:root:Epoch[7] Batch [600]  Speed: 14278.57 samples/sec Train-accuracy=0.990350

INFO:root:Epoch[7] Resetting Data Iterator

INFO:root:Epoch[7] Time cost=4.102

INFO:root:Epoch[7] Validation-accuracy=0.986400

INFO:root:Epoch[8] Batch [200]  Speed: 14875.69 samples/sec Train-accuracy=0.990300

INFO:root:Epoch[8] Batch [400]  Speed: 14833.44 samples/sec Train-accuracy=0.990750

INFO:root:Epoch[8] Batch [600]  Speed: 14804.53 samples/sec Train-accuracy=0.992250

INFO:root:Epoch[8] Resetting Data Iterator

INFO:root:Epoch[8] Time cost=4.049

INFO:root:Epoch[8] Validation-accuracy=0.987200

INFO:root:Epoch[9] Batch [200]  Speed: 14864.23 samples/sec Train-accuracy=0.992000

INFO:root:Epoch[9] Batch [400]  Speed: 14699.46 samples/sec Train-accuracy=0.991650

INFO:root:Epoch[9] Batch [600]  Speed: 14853.07 samples/sec Train-accuracy=0.992800

INFO:root:Epoch[9] Resetting Data Iterator

INFO:root:Epoch[9] Time cost=4.058

INFO:root:Epoch[9] Validation-accuracy=0.987800


注意到对于同样超参数,LeNet模型达到了98.7%的精度,高于多层感知机的96.6%。
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐