基于MATLAB的神经网络进行手写体数字识别(含鼠绘GUI / 数据集:MNIST)
基本介绍
- 软件:Matlab R2018b
- 数据集:MNIST手写体数字数据集
- 网络:自建简单网络
数据准备
MNIST数据集还挺有名的,这里就不过多介绍了。数据集本身读取格式官网有给,怎么转换成图片格式网上也有很多,这里不再赘述。
官网:http://yann.lecun.com/exdb/mnist/
训练集包含60000个示例,测试集包含10000个示例。
测试集的前5000个示例来自原始的NIST训练集。 最后的5000个来自原始的NIST测试集。 前5000个比后5000个更干净点,识别起来更容易。
当然为了方便使用MATLAB,这里给出程序缺省的数据集:
链接:https://pan.baidu.com/s/1VItI8MdUa-oBhWjKUUB72w
提取码:tgv9
CSDN地址:https://download.csdn.net/download/garker/12413315
每一个数字都包含1000张图片,每张图片大小均为28×28×1,1代表单通道,即灰度图。
神经网络组建
因为数据集本身特征并不多,所以不需要动用常用的神经网络,这里给出一个官方的结构形式。一共有15层。
这里可以看出,三层卷积,三层归一化,是相当简单的CNN网络结构了,可以当作CNN结构的入门学习好好钻研学习。
在MATLAB中的建构代码如下:
layers = [ imageInputLayer([28 28 1]) convolution2dLayer(3,8,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,16,'Padding','same') batchNormalizationLayer reluLayer maxPooling2dLayer(2,'Stride',2) convolution2dLayer(3,32,'Padding','same') batchNormalizationLayer reluLayer fullyConnectedLayer(10) softmaxLayer classificationLayer];
这其中,各层的参数如下:
convolution2dLayer
参数 | 值 | 含义 |
---|---|---|
FilterSize | 3,3 | 卷积核尺寸 |
NumFilter | 8 | 卷积核数量 |
Padding | ‘same’ | new_height = new_width = W / S (结果向上取整) |
(W×W的输入矩阵,F×F的卷积核,步长为S=1)
BatchNormalizationLayer
归一化层采用默认数据
maxPooling2dLayer
参数 | 值 | 含义 |
---|---|---|
PoolSize | 2,2 | 池化尺寸 |
Stride | 2,2 | 步长 |
fullyConnectedLayer
全连接层输出为10(0-9共10个数字)
训练神经网络
imds = imageDatastore('train_dataset', ... 'IncludeSubfolders',true,'LabelSource','foldernames'); %导入数据 [imdsTrain,imdsValidation] = splitEachLabel(imds,0.8,'randomize'); %分割数据集与测试集 options = trainingOptions('sgdm', ... 'InitialLearnRate',0.01, ... 'MaxEpochs',5, ... 'Shuffle','every-epoch', ... 'ValidationData',imdsValidation, ... 'ValidationFrequency',30, ... 'Verbose',false, ... 'Plots','training-progress'); %设置训练参数 net = trainNetwork(imdsTrain,layers,options); %训练神经网络
这里可以看出来基本上第三个世代就已经训练差不多了,最后的accuracy也能达到99.80%。
测试数据集
YPred = classify(net,imdsValidation); YValidation = imdsValidation.Labels; accuracy = sum(YPred == YValidation)/numel(YValidation); figure; perm = randperm(10000,20); for i = 1:20 subplot(4,5,i); s = classify(net,imread(imds.Files{perm(i)})); imshow(imds.Files{perm(i)});title(string(s)); end
随机挑出来20个看看效果,没什么大问题:
鼠绘输入识别的GUI
GUI的代码编写不算难,直接回调函数里面编写也比较方便。这里着重讲一下鼠绘的问题,网上查了很多资料也踩了不少坑,这里按处理顺序把比较坑的细节都放一下:
鼠绘区域
红色区域里面只有axes1是有实际作用的,为了美观我把X、Y轴颜色改成了背景的灰色以达到隐藏的效果。此外,还需要把X、Y轴的XLimMode、YLimMode设置为manual,其主要作用是锁住它们,不然在鼠绘的时候每一笔都会飘。
此外,对该区域的鼠绘效果显示代码如下:
figure1_WindowButtonDownFcn
unction figure1_WindowButtonDownFcn(hObject, eventdata, handles) % hObject handle to figure1 (see GCBO) % eventdata reserved - to be defined in a future version of MATLAB % handles structure with handles and user data (see GUIDATA) global draw_enable; global x; global y; draw_enable=1; if draw_enable position=get(gca,'currentpoint'); x(1)=position(1); y(1)=position(3); end
figure1_WindowButtonMotionFcn
function figure1_WindowButtonMotionFcn(hObject, eventdata, handles) % hObject handle to figure1 (see GCBO) % eventdata reserved - to be defined in a future version of MATLAB % handles structure with handles and user data (see GUIDATA) global draw_enable; global x; global y; if draw_enable position=get(gca,'currentpoint'); x(2)=position(1); y(2)=position(3); h1 = line(x,y,'EraseMode','xor','LineWidth',5,'color','black'); x(1)=x(2); y(1)=y(2); end
figure1_WindowButtonUpFcn
function figure1_WindowButtonUpFcn(hObject, eventdata, handles) % hObject handle to figure1 (see GCBO) % eventdata reserved - to be defined in a future version of MATLAB % handles structure with handles and user data (see GUIDATA) global draw_enable draw_enable=0;
特别特别需要注意的是,这三个回调函数都是在整个GUI默认的整体面板上来的,也就是figure1。具体找到这个回调函数的如下图所示:
没错,就是点击GUI编辑面板空白区域!
识别
识别按钮的回调函数很简单这里就不赘述了,需要特别提醒的是:
从绘制区域直接得到的并不是可直接使用图像数据,这里直接保存到默认目录一份正好也做备份用;
再者,保存好的图像的手写数据部分是深色的,背景部分是浅色的,这与我们之前的训练数据是不符的,直接用来识别肯定不会出现正确的答案,所以把这个数据读取之后再取反色,部分代码如下:
h=getframe(handles.axes1); imwrite(h.cdata,'output.jpg','jpg'); img = imread('output.jpg'); img = imresize(img,[28,28]); img = rgb2gray(img); img = 255 - img; %取反色
结论
“0-9”这十个数字逐一写了一遍感觉问题不大,但是千万别因为鼠绘区域大懒省事儿把数字写的很小这会影响到识别结果,如果实在感觉控制不好,可以在GUI编辑界面把整个界面改成按比例,这样实际使用的时候可以等比例把界面拉小,鼠绘更方便一些。
- 利用tensorflow一步一步实现基于MNIST 数据集进行手写数字识别的神经网络,逻辑回归
- 深度学习笔记——TensorFlow学习笔记(三)使用TensorFlow实现的神经网络进行MNIST手写体数字识别
- 使用tensorflow利用神经网络分类识别MNIST手写数字数据集,转自随心1993
- 使用逻辑回归方法(softmax regression)识别MNIST手写体数字、使用CNN神经网络识别MNIST手写体数字、使用tensorboard可视化训练过程数据
- 【深度学习·笔记一】基于Matlab的已训练神经网络Alexnet进行图像识别
- 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST
- 神经网络与深度学习 使用Python实现基于梯度下降算法的神经网络和自制仿MNIST数据集的手写数字分类可视化程序 web版本
- tensorflow 全连接神经网络 MNIST手写体数字识别
- TensorFlow学习笔记(1):使用softmax对手写体数字(MNIST数据集)进行识别
- 经典神经网络进行MNIST手写数字识别系列(一):ALEXNET
- 全连神经网络的经典实战--MNIST手写体数字识别
- 基于Tensorflow框架的卷集神经网络手写体数字识别
- 神经网络——实现MNIST数据集的手写数字识别
- 神经网络与深度学习 1.6 使用Python实现基于梯度下降算法的神经网络和MNIST数据集的手写数字分类程序
- 构建一个简单的神经网络(又名多层感知器)来对MNIST数字数据集进行分类--学习笔记
- 基于MNIST数据集的手写数字识别应用开发实践
- DL之DNN:自定义2层神经网络TwoLayerNet模型(计算梯度两种方法)算法对MNIST数据集进行训练、预测
- 用简单的CNN网络实现MNIST数据集的识别:实现模型保存与调用模型进行测试
- 用matlab实现神经网络识别数字
- tensorflow 学习专栏(五):在mnist数据集上使用tensorflow实现临近算法(Nearest-Neighbor)进行手写数字识别