MNIST手写数字识别(cnn)&pytorch的框架实现
2020-04-10 07:30
417 查看
MNIST手写数字识别(cnn)&pytorch的框架实现
一、程序主要模块:
- getData:首先要对数据进行处理,得到训练集和测试集;
- CNNnet:定义自己的卷积神经网络;
- train:正式训练的函数;
- test:测试集用的函数;
二、main函数:
大部分名词已经在代码得注释部分说明,所有程序代码在本文中都以呈现,只需要调用main()函数即可:
通常main函数的主框架可以不动。
def main(): args = Arg()#参数类,其中定义了程序需要的参数 #CPU设置种子用于生成随机数,以使得结果是确定的 torch.manual_seed(args.seed) # 加载数据 train_loader, test_loader = getData(args) # 得到卷积神经网络类 model = CNNnet() # 定义损失函数 loss_func = nn.CrossEntropyLoss() # 定义优化方式 opt = torch.optim.Adam(model.parameters(), lr=args.lr) # 正式进入训练和测试 # 其中epoch表示便利训练集的次数 for epoch in range(1,args.epoch+1): train(args, model, train_loader, opt, loss_func, epoch) test(args, model, test_loader, loss_func, epoch)
三、处理数据集
由于本文直接使用了自带的数据集,不需要细化处理。
在使用自定义的数据集时可以继承Dataset类,使用自己的数据集。
def getData(args:Arg): # 数据集的预处理:把所有数据变成tensor类型,自动归一化 # 当然还可以对数据集进行其他预处理 data_tf = torchvision.transforms.Compose( [ torchvision.transforms.ToTensor() # 自动归一化[0.0,1.0] ] ) # 如果本地没有数据集 要先执行该代码 获取数据集 # train_data = mnist.MNIST(Arg.data_path,train=True,transform=data_tf,download=True) train_data = mnist.MNIST(args.data_path, train=True, transform=data_tf, download=False) test_data = mnist.MNIST(args.data_path, train=False, transform=data_tf, download=False) # 获取迭代数据:data.DataLoader(), 把训练集和测试集依次放进去 # Dataloader返回所有的数据,分成了许多批次,一个批次有batch_size大小的数据 train_loader = data.DataLoader(train_data, batch_size=args.batchSize, shuffle=True) # shuffle:是否打乱数据 test_loader = data.DataLoader(test_data, batch_size=args.batchSize, shuffle=True) # shuffle:是否打乱数据 return train_loader, test_loader
四、定义CNN网络
可以再该函数中修改自己的CNN网络或者变成其他的网络,
本文主要是对框架的学习,不必对代码进行细究。
class CNNnet(torch.nn.Module): def __init__(self): # 复制并使用CNNNet的父类的初始化方法,即先运行nn.Module的初始化函数 super(CNNnet, self).__init__() self.conv1 = torch.nn.Sequential( torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1), torch.nn.BatchNorm2d(16), torch.nn.ReLU() ) self.conv2 = torch.nn.Sequential( torch.nn.Conv2d(16, 32, 3, 2, 1), torch.nn.BatchNorm2d(32), torch.nn.ReLU() ) self.conv3 = torch.nn.Sequential( torch.nn.Conv2d(32, 64, 3, 2, 1), torch.nn.BatchNorm2d(64), torch.nn.ReLU() ) self.conv4 = torch.nn.Sequential( torch.nn.Conv2d(64, 64, 2, 2, 0), torch.nn.BatchNorm2d(64), torch.nn.ReLU() ) self.mlp1 = torch.nn.Linear(2 * 2 * 64, 100) self.mlp2 = torch.nn.Linear(100, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) x = self.conv4(x) x = self.mlp1(x.view(x.size(0), -1)) x = self.mlp2(x) return x
五、正式训练:
每个train的执行相当于一次epoch,遍历了一次数据集,train函数框架也大体可以不变。
def train(args, model, train_loader, opt, loss_func, epoch): #主要是针对由于model在训练时和评价时 BatchNormalization和Dropout方法模式不同, #训练时要带上model.train(),预测时要带上model.test() model.train() for batch_id,(data, target) in enumerate(train_loader): data, target = Variable(data), Variable(target) # 初始输入,[128,1,28,28],[128] output = model(data) # 最终输出[128,10] loss = loss_func(output, target) # 计算损失 opt.zero_grad() # 清空参与更新参数值 loss.backward() # 反向传播 opt.step() # 参数更新 if batch_id%args.interval==0: print("Train Epoch:{}[{}/{}({:.0f}%)]\tLoss:{:.6f}".format(epoch, batch_id*len(data),len(train_loader.dataset),\ 100.*batch_id/len(train_loader), loss.item()))# 隔一段时间输出一下当前情况,可自定义
损失函数和优化器在train中的使用步骤: 1、获取损失: loss = loss_func(预测值, 真实值) #针对每个batch来说的 2、清空上一步参与更新参数:opt.zero_grad() 3、误差反向传播:loss.backward() 4、更新参数:opt.step()
六、对测试集进行测试
test的框架一般只需要自定义一下自己的评价指标即可
def test(args, model, test_loader, loss_func): model.eval() test_loss = 0 correct = 0 for data, target in test_loader:# 每批次每批次的输入 data, target = Variable(data), Variable(target) # 初始输入,[128,1,28,28],[128] output = model(data) # 最终输出[128,10] test_loss += loss_func(output, target).item()# 损失总和 isRight = torch.max(output, 1)[1].numpy() == target.numpy() correct+=np.sum(isRight!= 0) #所有的正确率 accuracy = correct/len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: ({:.1f}%)\n'.format(test_loss, accuracy*100))
七、参数类介绍
不包含全部的参数,CNN中有许多参数,因为方便适用于所有网络,本文没有将CNN的参数放进去。
class Arg: def __init__(self): self.batchSize = 64 # 批次大小 self.data_path = '../data/' # 路径自己设定 self.lr = 0.001 #学习率 self.epoch = 20 self.interval = 100 #多少批次的间隔后输出一下当前训练结果 # 其他 self.seed = 1 #随机种子,设置种子用于生成随机数,以使得结果是确定的
八、所需要的包
import torch from torch.utils import data # 获取迭代数据 from torch.autograd import Variable # 获取变量 import torchvision from torchvision.datasets import mnist # 获取数据集 import torch.nn as nn import numpy as np
九、拓展
- 可以使用GPU加速训练,当然代码要改一部分;
- 尝试自定义自己的数据集,数据预处理要做很多工作;
- 网站中学习率可以动态设置,百度下即可;
十、参考文章:
- https://blog.csdn.net/m0_37306360/article/details/79311501
- https://blog.csdn.net/qq_34714751/article/details/85610966
不足之处,请多指正!
- 点赞
- 收藏
- 分享
- 文章举报
相关文章推荐
- 深度学习-CNN卷积神经网络使用TensorFlow框架实现MNIST手写数字识别
- 深度学习-传统神经网络使用TensorFlow框架实现MNIST手写数字识别
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测—daidingdaiding
- 用Tensorflow搭建CNN卷积神经网络,实现MNIST手写数字识别
- TensorFlow之CNN实现MNIST手写数字识别
- Deep Learning-TensorFlow (1) CNN卷积神经网络_MNIST手写数字识别代码实现详解
- Pytorch实现的手写数字mnist识别功能完整示例
- Tensorflow学习笔记(二):利用CNN实现手写数字(mnist)识别
- Android+TensorFlow+CNN+MNIST实现手写数字识别
- CNN实现MNIST手写数字识别
- Deep Learning-TensorFlow (1) CNN卷积神经网络_MNIST手写数字识别代码实现
- 基于Tensorflow框架CNN实现手写数字识别(二):识别数字
- DL之LiR&DNN&CNN:利用LiR、DNN、CNN算法对MNIST手写数字图片(csv)识别数据集实现(10)分类预测
- PyTorch: CNN实战MNIST手写数字识别
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
- Pytorch_cnn_net实现mnist手写识别
- Tensorflow新手教程(1)----CNN实现mnist手写数字识别
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
- Tensoflow+CNN实现简单的mnist手写数字识别
- tensorflow 学习专栏(五):在mnist数据集上使用tensorflow实现临近算法(Nearest-Neighbor)进行手写数字识别