使用PyTorch对cifar-10图片分类
2018-03-25 17:36
2715 查看
使用PyTorch对cifar-10图片分类
前言
最近刚学习了PyTorch,主要是在PyTorch主页教程里面学习。不过这个教程是英文的,学习起来比较费劲。因此我自己对PyTorch对cifar-10图片分类这一部分进行了总结,因为光对着代码看很容易乱,所以将整个过程的流程整理出来,方便理解。
程序流程
一、数据预处理
图片转化为Tensor将数据归一化
为训练集、测试集分别创建可迭代的数据集,每次迭代可得到batch_size个输入图片数据。
import torchvision as tv import torchvision.transforms as transforms from torchvision.transforms import ToPILImage show = ToPILImage() # 可以把Tensor转成Image,方便可视化 # 第一次运行程序torchvision会自动下载CIFAR-10数据集, # 大约100M,需花费一定的时间, # 如果已经下载有CIFAR-10,可通过root参数指定 # 定义对数据的预处理 transform = transforms.Compose([ transforms.ToTensor(), # 转为Tensor transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # 归一化 ]) # 训练集 trainset = tv.datasets.CIFAR10( root='/home/cy/tmp/data/', train=True, download=True, transform=transform) trainloader = t.utils.data.DataLoader( trainset, batch_size=4, shuffle=True, num_workers=2) # 测试集 testset = tv.datasets.CIFAR10( '/home/cy/tmp/data/', train=False, download=True, transform=transform) testloader = t.utils.data.DataLoader( testset, batch_size=4, shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
二、定义网络
仿照LetNet网络,创建继承nn.Module的子类,并实现init、forward方法。import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16*5*5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(x.size()[0], -1) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = Net() print(net)
三、定义损失函数和优化器
from torch import optim criterion = nn.CrossEntropyLoss() # 交叉熵损失函数 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #定义优化器
四、训练网络
定义epoch训练次数循环读取数据迭代器的训练数据
获得输入的inputs、标签labels
梯度清零
inputs带入网络得到outputs,将outputs与labels比较得带loss
对loss执行反向传播,自动求所有参数的梯度
用优化器把获得的梯度来更新参数
打印loss信息
t.set_num_threads(8) for epoch in range(2): running_loss = 0.0 for i, data in enumerate(trainloader, 0): # 输入数据 inputs, labels = data inputs, labels = Variable(inputs), Variable(labels) # 梯度清零 optimizer.zero_grad() # forward + backward outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() # 更新参数 optimizer.step() # 打印log信息 running_loss += loss.data[0] if i % 2000 == 1999: # 每2000个batch打印一下训练状态 print('[%d, %5d] loss: %.3f' \ % (epoch+1, i+1, running_loss / 2000)) running_loss = 0.0 print('Finished Training')
五、获得准确率
dataiter = iter(testloader) images, labels = dataiter.next() # 一个batch返回4张图片 print('实际的label: ', ' '.join(\ '%08s'%classes[labels[j]] for j in range(4))) show(tv.utils.make_grid(images / 2 - 0.5)).resize((400,100)) # 计算图片在每个类别上的分数 outputs = net(Variable(images)) # 得分最高的那个类 _, predicted = t.max(outputs.data, 1) print('预测结果: ', ' '.join('%5s'\ % classes[predicted[j]] for j in range(4))) correct = 0 # 预测正确的图片数 total = 0 # 总共的图片数 for data in testloader: images, labels = data outputs = net(Variable(images)) _, predicted = t.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum() print('10000张测试集中的准确率为: %d %%' % (100 * correct / total))
相关文章推荐
- 用KNN算法分类CIFAR-10图片数据
- caffe (10) 使用python测试多张图片统计分类结果
- Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练
- Java关键字static、final使用小结Z… 分类: Android开发 2014-05-30 10:58 66人阅读 评论(0) 收藏
- Android中使用JavaMail发送邮件ZZ 分类: Android资源 2014-05-30 10:58 93人阅读 评论(0) 收藏
- Deep Learning-TensorFlow (4) CNN卷积神经网络_CIFAR-10进阶图像分类模型(上)
- 快速导出PDF文件中所有图片(使用Adobe Acrobat 10 )
- 使用caffe中的imagenet对自己的图片进行分类训练(超级详细版)
- 使用CNN神经网络进行图片识别分类
- cifar-10 cnn 分类
- 利用运行时,给UIImageView写一个分类,交换里面的setImage的方法,可以重绘图片,提高内存的利用率(要是没有重绘图片,直接使用系统提供的setImag就会造成占用大量的内存问题)
- 从CIFAR-10手工分类中学到的经验教训Lessons learned from manually classifying CIFAR-10
- Android创建和使用数据库详… 分类: Android数据存储 2014-05-30 10:58 71人阅读 评论(0) 收藏
- 判断图片 & 判断URL , 使用分类实现
- 利用pytorch对CIFAR-10数据集的分类
- TensorFlow学习笔记(七) 使用训练好的inception_v3模型预测分类图片
- 【图片压缩】使用canvas,html5进行图片压缩 分类: canvas 图片压缩 压缩 HTML5 fileReader 2015-03-20 17:14 118人阅读 评论(0) 收藏
- Android创建和使用数据库详… 分类: Android数据存储 2014-05-30 10:58 82人阅读 评论(0) 收藏
- 使用现有的基于caffe训练好的imagenet model进行图片分类
- !!使用Caffe对图片进行训练并分类的简单流程