您的位置:首页 > 其它

Datawhale 零基础入门CV赛事-Task03:字符识别模型

2020-06-02 05:32 477 查看

Datawhale 零基础入门CV赛事-Task03:字符识别模型

学习目标

学习CNN基础知识和原理
使用Pytorch框架构建CNN模型,并完成训练

字符识别模型

CNN介绍

CNN,它的全称是Convolutional Neural Networks,卷积神经网络,是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习中尤为重要的算法。

CNN发展

(1)LeNet
(2)AlexNet
(3)VGG
(4)GoogleNet
(5)ResNet
(6)GAN
不同的网络有不同的优点和缺点,适用的场景和任务也各不相同。在使用这些网络处理任务时,要针对任务的特点,结合网络的特点,选择合适的网络,并且进行合理的改进,从而使结果能够达到最好。

Pytorch构建CNN模型

在pytorch中进行CNN模型构建,以ResNet18为例
(1)导入库

import torch
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True

import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset

(2)定义模型

class SVHN_Model1(nn.Module):
def __init__(self):
super(SVHN_Model1, self).__init__()

model_conv = models.resnet18(pretrained=True)
model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
model_conv = nn.Sequential(*list(model_conv.children())[:-1])
self.cnn = model_conv

self.fc1 = nn.Linear(512, 11)
self.fc2 = nn.Linear(512, 11)
self.fc3 = nn.Linear(512, 11)
self.fc4 = nn.Linear(512, 11)
self.fc5 = nn.Linear(512, 11)

def forward(self, img):
feat = self.cnn(img)
# print(feat.shape)
feat = feat.view(feat.shape[0], -1)
c1 = self.fc1(feat)
c2 = self.fc2(feat)
c3 = self.fc3(feat)
c4 = self.fc4(feat)
c5 = self.fc5(feat)
return c1, c2, c3, c4, c5

(3)随后对其进行训练

def train(train_loader, model, criterion, optimizer, epoch):
# 切换模型为训练模式
model.train()
train_loss = []

for i, (input, target) in enumerate(train_loader):
if use_cuda:
input = input.cuda()
target = target.cuda()

c0, c1, c2, c3, c4 = model(input)
loss = criterion(c0, target[:, 0]) + \
criterion(c1, target[:, 1]) + \
criterion(c2, target[:, 2]) + \
criterion(c3, target[:, 3]) + \
criterion(c4, target[:, 4])

# loss /= 6
optimizer.zero_grad()
loss.backward()
optimizer.step()

train_loss.append(loss.item())
return np.mean(train_loss)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: