pytorch目标检测ssd七__训练代码与loss组成解析
2020-07-14 06:22
639 查看
本篇博客是我学习(https://blog.csdn.net/weixin_44791964)博主写的pytorch的ssd的博客后写的,大家可以直接去看这位博主的博客(https://blog.csdn.net/weixin_44791964/article/details/104981486)。这位博主在b站还有配套视频,传送门:(https://www.bilibili.com/video/BV1A7411976Z)。这位博主的在GitHub的源代码(https://github.com/bubbliiiing/ssd-pytorch)。 侵删
这篇博客主要是理清楚ssd目标检测算法的训练思路
下面就是训练文件的代码了,注释都在代码里面
from nets.ssd import get_ssd from nets.ssd_training import Generator,MultiBoxLoss from utils.config import Config #from torchsummary import summary from torch.autograd import Variable import torch.backends.cudnn as cudnn import time import torch import numpy as np import torch.nn as nn import torch.optim as optim import torch.nn.init as init def adjust_learning_rate(optimizer, lr, gamma, step): lr = lr * (gamma ** (step)) for param_group in optimizer.param_groups: param_group['lr'] = lr return lr if __name__ == "__main__": Batch_size = 4 lr = 1e-5 Epoch = 50 Cuda = False Start_iter = 0 # 需要使用device来指定网络在GPU还是CPU运行 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') #获得ssd目标检测算法的模型 model = get_ssd("train",Config["num_classes"]) #载入我们与训练好的预训练的模型,类似于迁移学习的思想,但是嗷,这里用的是gpu训练出来的参数,我电脑没有gpu,实验室电脑连不上,真的dmn了嗷 print('Loading weights into state dict...') model_dict = model.state_dict() #pretrained_dict = torch.load("model_data/ssd_weights.pth") #pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)} #model_dict.update(pretrained_dict) #model.load_state_dict(model_dict) print('Finished!') #设置了模型的cuda参数 net = model if Cuda: net = torch.nn.DataParallel(model) cudnn.benchmark = True net = net.cuda() """ 2007_train.txt这个其实是我们执行voc_annotation.py之后生成的文件, 这个文件里面存放了图片的路径和他所对应的目标 """ annotation_path = '2007_train.txt' with open(annotation_path) as f: lines = f.readlines() np.random.seed(10101) #打开文件之后进行一个shuffle的打乱 np.random.shuffle(lines) np.random.seed(None) num_train = len(lines) """ 使用Generator来对我们的图片进行一次预处理, Generator会利用2007_train.txt文件去生成图片和对应的标签 """ gen = Generator(Batch_size, lines, (Config["min_dim"], Config["min_dim"]), Config["num_classes"]).generate() #设置优化器 optimizer = optim.Adam(net.parameters(), lr=lr) #MultiBoxLoss是ssd使用的loss函数 criterion = MultiBoxLoss(Config['num_classes'], 0.5, True, 0, True, 3, 0.5, False, Cuda) net.train() epoch_size = num_train // Batch_size for epoch in range(Start_iter,Epoch): if epoch%10==0: adjust_learning_rate(optimizer,lr,0.95,epoch) loc_loss = 0 conf_loss = 0 #首先取出一个batch来进行训练 for iteration in range(epoch_size): images, targets = next(gen) with torch.no_grad(): if Cuda: #将图片和target变成变量的形式 images = Variable(torch.from_numpy(images).cuda().type(torch.FloatTensor)) targets = [Variable(torch.from_numpy(ann).cuda().type(torch.FloatTensor)) for ann in targets] else: images = Variable(torch.from_numpy(images).type(torch.FloatTensor)) targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets] # 前向传播 out = net(images) # 清零梯度 optimizer.zero_grad() # 计算loss loss_l, loss_c = criterion(out, targets) loss = loss_l + loss_c # 反向传播 loss.backward() optimizer.step() # 加上 loc_loss += loss_l.item() conf_loss += loss_c.item() print('\nEpoch:'+ str(epoch+1) + '/' + str(Epoch)) print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Loc_Loss: %.4f || Conf_Loss: %.4f ||' % (loc_loss/(iteration+1),conf_loss/(iteration+1)), end=' ') #每一个batch进行一次权重的保存 print('Saving state, iter:', str(epoch+1)) torch.save(model.state_dict(), 'logs/Epoch%d-Loc%.4f-Conf%.4f.pth'%((epoch+1),loc_loss/(iteration+1),conf_loss/(iteration+1)))
首先就是读取文件,然后利用Generator获得图片及其对应的标签,然后就是基本的训练了
相关文章推荐
- 目标检测之SSD(single shot multibox detector)的pytorch代码阅读总结
- 目标检测ssd复现pytorch代码以及更换自己的数据集
- 目标检测:SSD的multibox_loss_layer和MineHardExamples的理解
- SSD目标检测算法改进:DSOD(不需要预训练的目标检测算法)
- 目标检测:SSD目标检测中PriorBox代码解读
- 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)
- 目标检测算法SSD在window环境下GPU配置训练自己的数据集
- 利用SSD和自己训练好的模型进行目标检测
- 目标检测SSD:训练自己的数据集
- 逐字理解目标检测simple-faster-rcnn-pytorch-master代码(一)
- 移动端实时目标检测网络Mobilenet_v2-ssdlite及其keras实现(附代码地址)
- Linux 编译SSD Caffe目标检测代码
- 目标检测:RFCN的Python代码训练自己的模型
- 逐字理解目标检测simple-faster-rcnn-pytorch-master代码(二)
- 目标检测:SSD目标检测中PriorBox代码解读
- pytorch下实现ssd目标检测算法运行时遇到的错误
- SSD-目标检测代码解读
- mac做目标检测 google colaboratory训练(ssd,vgg模型,直接上手,顺便教你使用免费服务器训练)
- 目标检测算法SSD之训练自己的数据集
- 目标检测:caffe-ssd编译、训练和测试全过程记录