Datawhale 零基础入门CV赛事-Task03:字符识别模型
2020-06-02 05:32
477 查看
Datawhale 零基础入门CV赛事-Task03:字符识别模型
学习目标
学习CNN基础知识和原理
使用Pytorch框架构建CNN模型,并完成训练
字符识别模型
CNN介绍
CNN,它的全称是Convolutional Neural Networks,卷积神经网络,是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习中尤为重要的算法。
CNN发展
(1)LeNet
(2)AlexNet
(3)VGG
(4)GoogleNet
(5)ResNet
(6)GAN
不同的网络有不同的优点和缺点,适用的场景和任务也各不相同。在使用这些网络处理任务时,要针对任务的特点,结合网络的特点,选择合适的网络,并且进行合理的改进,从而使结果能够达到最好。
Pytorch构建CNN模型
在pytorch中进行CNN模型构建,以ResNet18为例
(1)导入库
import torch torch.manual_seed(0) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True import torchvision.models as models import torchvision.transforms as transforms import torchvision.datasets as datasets import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.autograd import Variable from torch.utils.data.dataset import Dataset
(2)定义模型
class SVHN_Model1(nn.Module): def __init__(self): super(SVHN_Model1, self).__init__() model_conv = models.resnet18(pretrained=True) model_conv.avgpool = nn.AdaptiveAvgPool2d(1) model_conv = nn.Sequential(*list(model_conv.children())[:-1]) self.cnn = model_conv self.fc1 = nn.Linear(512, 11) self.fc2 = nn.Linear(512, 11) self.fc3 = nn.Linear(512, 11) self.fc4 = nn.Linear(512, 11) self.fc5 = nn.Linear(512, 11) def forward(self, img): feat = self.cnn(img) # print(feat.shape) feat = feat.view(feat.shape[0], -1) c1 = self.fc1(feat) c2 = self.fc2(feat) c3 = self.fc3(feat) c4 = self.fc4(feat) c5 = self.fc5(feat) return c1, c2, c3, c4, c5
(3)随后对其进行训练
def train(train_loader, model, criterion, optimizer, epoch): # 切换模型为训练模式 model.train() train_loss = [] for i, (input, target) in enumerate(train_loader): if use_cuda: input = input.cuda() target = target.cuda() c0, c1, c2, c3, c4 = model(input) loss = criterion(c0, target[:, 0]) + \ criterion(c1, target[:, 1]) + \ criterion(c2, target[:, 2]) + \ criterion(c3, target[:, 3]) + \ criterion(c4, target[:, 4]) # loss /= 6 optimizer.zero_grad() loss.backward() optimizer.step() train_loss.append(loss.item()) return np.mean(train_loss)
相关文章推荐
- Datawhale 零基础入门CV赛事-Task5 模型集成
- Datawhale 零基础入门CV赛事-Task4 模型训练与验证
- 零基础入门CV之街道字符识别(三)
- 零基础入门CV之街道字符识别(二)
- Datawhale 零基础入门CV赛事-Task2 数据读取与数据扩增
- 零基础入门CV赛事-模型训练与验证
- Datawhale零基础入门cV-Task4模型训练与验证
- 零基础入门CV赛事-字符识别模型
- Datawhale 零基础⼊⻔CV-Task3 字符识别模型
- 零基础入门CV赛事- 数据读取与数据扩增
- 零基础入门CV赛事
- markdown入门基础——特殊字符
- JAVA基础——类与接口的实现(返回字符模型)
- 零基础入门CV之街道字符识别(四)
- (转)零基础入门--中文命名实体识别
- 天池-街景字符编码识别4-模型训练与验证
- 【广告算法工程师入门 37】模型特征-算法基础之模型构建
- 基于Qt + Vs2019 图像识别入门基础 增强图像对比度(矩阵的掩膜操作)
- OpenGL 入门基础教程 —— 模型的变换
- 模式识别零基础入门