您的位置:首页 > 编程语言 > Python开发

[caffe的python接口学习三]:生成solver文件

2017-09-24 16:25 423 查看
作者:JackGao16 CSDN

邮箱:gshuai16@mail.ustc.edu.cn

caffe在训练和测试的时候,不但需要对网络中的参数进行设置,同时也要对一些超参数进行设置,这些超参数和网络的本身无关,和数据集以及相应的计算资源存在关系。而我们就通过solver文件、来设置这些超参数。

通常我们见到的solver.prototxt文件:

base_lr: 0.001
display: 782
gamma: 0.1
lr_policy: “step”
max_iter: 78200
momentum: 0.9
snapshot: 7820
snapshot_prefix: “snapshot”
solver_mode: GPU
solver_type: SGD
stepsize: 26067
test_interval: 782
test_iter: 313
test_net: “/home/xxx/data/val.prototxt”
train_net: “/home/xxx/data/proto/train.prototxt”
weight_decay: 0.0005


这里的参数并非随意的设置,具体的设置要根据参数的含义和具体的算法和计算资源来决定。

具体的含义待后续整理完进行补充。

通过python生成solver文件

可以通过以下的代码实现:

# -*- coding: utf-8 -*-
"""
Created on Sun Jul 17 18:20:57 2016

@author: root
"""
path='/home/gaoshuai/caffe/examples/mnist/'
solver_file=path+'solver.prototxt'     #solver文件保存位置

sp={}
sp['train_net']=‘“’+path+'train.prototxt”'  # 训练配置文件
sp['test_net']=‘“’+path+'val.prototxt”'     # 测试配置文件
sp['test_iter']='313'                  # 测试迭代次数
sp['test_interval']='782'              # 测试间隔
sp['base_lr']='0.001'                  # 基础学习率
sp['display']='782'                    # 屏幕日志显示间隔
sp['max_iter']='78200'                 # 最大迭代次数
sp['lr_policy']='“step”'                 # 学习率变化规律
sp['gamma']='0.1'                      # 学习率变化指数
sp['momentum']='0.9'                   # 动量
sp['weight_decay']='0.0005'            # 权值衰减
sp['stepsize']='26067'                 # 学习率变化频率
sp['snapshot']='7820'                   # 保存model间隔
sp['snapshot_prefix']=‘"snapshot"’       # 保存的model前缀
sp['solver_mode']='GPU'                # 是否使用gpu
sp['solver_type']='SGD'                # 优化算法

def write_solver():
#写入文件
with open(solver_file, 'w') as f:
for key, value in sorted(sp.items()):
if not(type(value) is str):
raise TypeError('All solver parameters must be strings')
f.write('%s: %s\n' % (key, value))
if __name__ == '__main__':
write_solver()


说明:

上面的代码会在目录位置:/home/gaoshuai/caffe/examples/mnist/处生成一个solver.prototxt的文件。

当然,由于代码的书写方式不同,下面还有一种方式同样达到一样的效果:

# -*- coding: utf-8 -*-

from caffe.proto import caffe_pb2
s = caffe_pb2.SolverParameter()

path='/home/gaoshuai/caffe/examples/mnist/'
solver_file=path+'solver1.prototxt'

s.train_net = path+'train.prototxt'
s.test_net.append(path+'val.prototxt')
s.test_interval = 782
s.test_iter.append(313)
s.max_iter = 78200

s.base_lr = 0.001
s.momentum = 0.9
s.weight_decay = 5e-4
s.lr_policy = 'step'
s.stepsize=26067
s.gamma = 0.1
s.display = 782
s.snapshot = 7820
s.snapshot_prefix = 'shapshot'
s.type = “SGD”
s.solver_mode = caffe_pb2.SolverParameter.GPU

with open(solver_file, 'w') as f:
f.write(str(s))


可以参考链接:dnney的专栏博客:caffe的python接口学习(2)生成solver文件
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: