从零搭建Pytorch模型教程(四)编写训练过程--参数解析
前言 训练过程主要是指编写train.py文件,其中包括参数的解析、训练日志的配置、设置随机数种子、classdataset的初始化、网络的初始化、学习率的设置、损失函数的设置、优化方式的设置、tensorboard的配置、训练过程的搭建等。
由于篇幅问题,这些内容将分成多篇文章来写。本文介绍参数解析的两种方式。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
一个模型中包含众多的训练参数,如文件保存目录、数据集目录、学习率、epoch数量、模块中的参数等。
参数解析常用的有两种方式。
一种是将所有参数都放在yaml文件中,通过读取yaml文件来配置参数。这种一般用于比较复杂的项目,例如有多个模型,对应多组参数。这样就可以每个模型配置一个yaml文件,里面对应的是每个模型的对应的参数。
另一种是直接在train.py文件中通过argparser解析器来配置。这种一般用于仅一个模型或比较简单的项目中。每次只需要改一两个参数的。
yaml文件解析
yaml语法规则
大小写敏感 使用缩进表示层级关系 缩进时不允许使用Tab键,只允许使用空格。(可以将你的ide的tab按键输出替换成4个空格) 缩进的空格数目不重要,只要相同层级的元素左侧对齐即可 #表示注释
yaml文件示例
TRAIN: RESUME_PATH: "/path/to/your/net.pth" DATASET: ucf24 # `ava`, `ucf24` or `jhmdb21` BATCH_SIZE: 10 TOTAL_BATCH_SIZE: 128 SOLVER: MOMENTUM: 0.9 WEIGHT_DECAY: 5e-4 LR_DECAY_RATE: 0.5 NOOBJECT_SCALE: 1 DATA: TRAIN_JITTER_SCALES: [256, 320] TRAIN_CROP_SIZE: 224 TEST_CROP_SIZE: 224 MEAN: [0.4345, 0.4051, 0.3775] STD: [0.2768, 0.2713, 0.2737] MODEL: NUM_CLASSES: 24 BACKBONE: darknet WEIGHTS: BACKBONE: "weights/yolo.weights" FREEZE_BACKBONE_2D: False LISTDATA: BASE_PTH: "datasets/ucf24" TRAIN_FILE: "path/to/your/classdataset/trainlist.txt" TEST_FILE: "path/to/your/classdataset/testlist.txt"
yaml的解析
这里介绍两种方法,一种比较复杂的,像上面这个有两级。解析比较麻烦,代码如下:
import yaml import argparser from fvcore.common.config import CfgNode cfg = CfgNode() cfg.TRAIN= CfgNode() #每一级都要这样新建一个节点 cfg.TRAIN.RESUME_PATH = "Train" cfg.TRAIN.DATASET = "ucf24" # `ava`, `ucf24` or `jhmdb21` cfg.TRAIN.BATCH_SIZE=10 cfg.TRAIN.TOTAL_BATCH_SIZE=128 ... cfg.SOLVER= CfgNode() #每一级都要这样新建一个节点 cfg.SOLVER.MOMENTUM=0.9 cfg.SOLVER.WEIGHT_DECAY=5e-4 ... yaml_path = "yaml_test.yaml" cfg.merge_from_file(yaml_path) #访问方法 print(cfg.TRAIN.RESUME_PATH)
它的麻烦在于需要将所有的元素都初始化一遍,然后通过cfg.merge_from_file(yaml_path)来根据yaml文件更新这些元素。
另一种是比较简单的解析二级的方法。
import yaml import argparser with open(yaml_path,'r') as file: opt = argparse.Namespace(**yaml.load(file.read(),Loader=yaml.FullLoader)) #访问方法 print(opt.TRAIN["RESUME_PATH"])
即第一级是在argparse的Namespace中,可通过点号来访问,第二级仍然是字典的形式。但它简单太多了。如果只有一级的话,直接通过点号就可以了。如果不使用argparse.Namespace,则两级都是字典,直接通过访问字典的形式也可以。
argparser解析
argparser解析的形式一般放在train.py文件的最前面,适用于参数相对比较少,每次只需要改一两个参数的情况。(我本人习惯将它放在其它文件中,例如单独搞一个parser.py或直接放在util.py中,只因为如果放在train前每次都要滑动很长才能到train的部分,相当麻烦)
先来个标准示例
import argparser def get_args(): parser = argparse.ArgumentParser(description='Training') parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training') parser.add_argument('--batchsize', default=8, type=int, help='batchsize') parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2') return parser.parse_args() #使用方法 python train.py --color_jitter \ --batchsize=16 \ --gpu_ids='0' #访问元素 opt = get_args() print(opt.batchsize)
这里列举了三种形式,一种是action的,当action='store_true'时,默认是false,在设置参数时直接--color_jitter即可变成True,另外两种如上所示。
下一篇将介绍编写训练过程的训练日志的配置、设置随机数种子、classdataset的初始化、网络的初始化、学习率的设置、损失函数的设置、优化方式的设置等内容。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
CV技术指南创建了一个免费的知识星球。关注公众号添加编辑的微信号可邀请加入。
征稿通知:欢迎可以写以下内容的朋友联系我(微信号:“FewDesire”)。
- TVM入门到实践的教程
- TensorRT入门到实践的教程
- MNN入门到实践的教程
- 数字图像处理与Opencv入门到实践的教程
- OpenVINO入门到实践的教程
- libtorch入门到实践的教程
- Oneflow入门到实践的教程
- Detectron入门到实践的教程
- CUDA入门到实践的教程
- caffe源码阅读
- pytorch源码阅读
- 深度学习从入门到精通(从卷积神经网络开始讲起)
- 最新顶会的解读。例如最近的CVPR2022论文。
- 各个方向的系统性综述、主要模型发展演变、各个模型的创新思路和优缺点、代码解析等。
- 若自己有想写的且这上面没提到的,可以跟我联系。
声明:有一定报酬,具体请联系详谈。若有想法写但觉得自己能力不够,也可以先联系本人(微信号:FewDesire)了解。添加前请先备注“投稿”。
其它文章
招聘 | 迁移科技招聘深度学习、视觉、3D视觉、机器人算法工程师等多个职位
Attention Mechanism in Computer Vision
从零搭建Pytorch模型教程(三)搭建Transformer网络
- 防止训练模型时信息丢失,用于TensorFlow、Keras和PyTorch的检查点教程
- 【深度学习】读取和存储训练好的模型参数(pyTorch)
- 解决了PyTorch 使用torch.nn.DataParallel 进行多GPU训练的一个BUG:模型(参数)和数据不在相同设备上
- pytorch模型加载跑测试集和训练过程中跑测试集结果不一致的问题
- 实用:使用caffe训练模型时solver.prototxt中的参数设置解析
- 记一次pytorch训练模型及搭建(pytorch图片数据载入,模型训练)
- Pytorch保存与加载训练模型,同时可以保存中间训练过程中的训练模型
- YOLO模型训练可视化训练过程中的中间参数
- (更新视频教程)Tensorflow object detection API 搭建属于自己的物体识别模型(2)——训练并使用自己的模型
- YOLO模型训练可视化训练过程中的中间参数
- Pytorch加载部分预训练模型的参数实例
- 使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
- 从零搭建Pytorch模型教程(一)数据读取
- 写一个Linux下搭建Discuz论坛的全过程教程。
- 使用谷歌Colab(Colaboratory)免费GPU训练自己的模型及谷歌网盘无限容量(Google drive)申请教程
- SpringMVC中HandlerMethod的请求参数解析过程
- 华为机试在线训练-牛客网(34)参数解析
- Struts 2.3.24源码解析+Struts2拦截参数,处理请求,返回到前台过程详析
- 搭建一个解析接口教程(自己的解析接口,可自定义广告)
- mapreduce中map处理过程?参数如何解析传递给map方法?