【学习记录】day3 Task3 字符识别模型 (Datawhale 零基础⼊⻔CV)
2020-06-02 05:33
274 查看
代码我有点晕 ,主要是因为卷积已经被我忘了差不多了,然后涉及CNN就(嗯,这货说得是个啥)
让我先转一篇将卷积的文章
https://baijiahao.baidu.com/s?id=1653145909866150049&wfr=spider&for=pc
说实话的,高数上我只是知道怎么算,但是不知为什么。
讲CNN的好像这个讲得更详细一点
传送门
好像理解了一点点吧
import torch torch.manual_seed(0) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True import torchvision.models as models import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable from torch.utils.data.dataset import Dataset # 定义模型 class SVHN_Model1(nn.Module): def __init__(self): super(SVHN_Model1, self).__init__() # CNN提取特征模块 self.cnn = nn.Sequential( nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2)), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2)), nn.ReLU(), nn.MaxPool2d(2), ) # self.fc1 = nn.Linear(32*3*7, 11) self.fc2 = nn.Linear(32*3*7, 11) self.fc3 = nn.Linear(32*3*7, 11) self.fc4 = nn.Linear(32*3*7, 11) self.fc5 = nn.Linear(32*3*7, 11) self.fc6 = nn.Linear(32*3*7, 11) def forward(self, img): feat = self.cnn(img) feat = feat.view(feat.shape[0], -1) c1 = self.fc1(feat) c2 = self.fc2(feat) c3 = self.fc3(feat) c4 = self.fc4(feat) c5 = self.fc5(feat) c6 = self.fc6(feat) return c1, c2, c3, c4, c5, c6 model = SVHN_Model1()
# 损失函数 criterion = nn.CrossEntropyLoss() # 优化器 optimizer = torch.optim.Adam(model.parameters(), 0.005) loss_plot, c0_plot = [], [] # 迭代10个Epoch for epoch in range(10): for data in train_loader: c0, c1, c2, c3, c4, c5 = model(data[0]) loss = criterion(c0, data[1][:, 0]) + \ criterion(c1, data[1][:, 1]) + \ criterion(c2, data[1][:, 2]) + \ criterion(c3, data[1][:, 3]) + \ criterion(c4, data[1][:, 4]) + \ criterion(c5, data[1][:, 5]) loss /= 6 optimizer.zero_grad() loss.backward() optimizer.step() loss_plot.append(loss.item()) c0_plot.append((c0.argmax(1) == data[1][:, 0]).sum().item()*1.0 / c0.shape[0]) print(epoch)
相关文章推荐
- 【学习记录】day4 Task4 模型训练与验证 (Datawhale 零基础⼊⻔CV)
- Datawhale 零基础入门CV赛事-Task3 字符识别模型
- Datawhale 零基础入门CV - Task 03 字符识别模型
- Datawhale 零基础⼊⻔CV-Task3 字符识别模型
- Datawhale 零基础入门CV赛事-Task3 字符识别模型
- Datawhale 零基础入门CV赛事-Task4 模型训练与验证
- 从零开始实现Unity光照模型_01_标准光照模型与漫反射_技术美术基础学习记录
- 零基础入门CV赛事-字符识别模型
- 零基础⼊⻔CV-Task3 字符识别模型
- Datawhale 零基础入门CV赛事-Task03:字符识别模型
- 3D图形学编程基础-基于Direct3D11-学习记录(二)光照模型的实现
- Datawhale零基础入门cV-Task4模型训练与验证
- Datawhale 零基础入门CV赛事-Task4 模型训练与验证
- Datawhale 零基础入门CV赛事-Task5 模型集成
- Datawhale 零基础入门CV赛事-Task4 模型训练与验证
- 重新整理后的Oracle OAF学习笔记——3.应用构建基础之实现模型
- Javaweb Servlet基础学习记录(4)—重定向与请求转发(请求转发)
- Linux设备驱动模型学习之基础篇--Kobject.txt翻译
- django模型manager学习记录
- C++基础学习之2 - 内存对象模型