在python中编写caffe的prototxt文件
2016-05-16 20:06
507 查看
利用python可以创建caffe的网络定义的prototxt文件,利用这种方法的一个好处就是:可以保证training,testing和deploy网络的一致性!
下面是利用python编写train和test的prototxt文件的一个小事例:
下面是利用python编写train和test的prototxt文件的一个小事例:
import caffe from caffe import layers as L from caffe import params as P # Function: Set lenet net def init_net_lenet(netName,netType,batch_size,prototxt_root,data_root): # -------- # set the type and name of the net, e.g., train_lenet, or test_lenet NetTypeName = netType + '_'+netName # -------- n = caffe.NetSpec() # -------- # set the input layer n.data, n.label = L.ImageData( image_data_param={"batch_size": batch_size, "is_color" : False,"shuffle": True}, source= data_root+netType+".data", transform_param=dict(scale=1./255), ntop=2) # ------ # set other layers n.conv1 = L.Convolution(n.data, kernel_size=5, num_output=6, weight_filler=dict(type='xavier')) n.tanh1 = L.TanH(n.conv1, in_place=True) n.pool1 = L.Pooling(n.conv1, kernel_size=2, stride=2, pool=P.Pooling.MAX) n.conv2 = L.Convolution(n.pool1, kernel_size=5, num_output=16, weight_filler=dict(type='xavier')) n.tanh2 = L.TanH(n.conv2, in_place=True) n.pool2 = L.Pooling(n.conv2, kernel_size=2, stride=2, pool=P.Pooling.MAX) n.ip3 = L.InnerProduct(n.pool2, num_output=120, weight_filler=dict(type='xavier')) n.tanh3 = L.TanH(n.ip3, in_place=True) n.ip4 = L.InnerProduct(n.ip3, num_output=84, weight_filler=dict(type='xavier')) n.tanh4 = L.TanH(n.ip4, in_place=True) n.ip5 = L.InnerProduct(n.ip4, num_output=2, weight_filler=dict(type='xavier')) n.loss = L.SoftmaxWithLoss(n.ip5, n.label) # -------- # write the prototxt file print('Writing net to %s' % prototxt_root+NetTypeName+'.prototxt') with open(prototxt_root+ NetTypeName+'.prototxt', 'w') as f: f.write(str(n.to_proto())) print 'done...' # -------- # return the name of the output layer (used for predicting) return 'ip5'
相关文章推荐
- Python - 两圆相交求交点坐标
- 《python程序设计》第二章基本程序设计笔记
- 《python语言程序设计》第一章python概述 笔记
- Python With
- 远程访问jupyter notebook
- python遍历文件夹
- Python使用lxml模块和Requests模块抓取HTML页面的教程
- 登录知乎的爬虫
- Ubuntu 14.04下OpenCV 3.0+Python 2.7安装测试
- Python中数字类型
- 我的Python成长之路---第八天---Python基础(24)---2016年3月5日(晴)
- 我的Python成长之路---第八天---Python基础(23)---2016年3月5日(晴)
- Python之路3【第一篇】Python简介入门
- Python使用xslt提取网页数据
- Python实现协程的生产者与消费者
- 机器学习实战--k近邻算法
- python实现欧拉计划24题
- 灰帽子Python 学习记录 6
- Python-Jenkins API使用 —— 在后端代码中操控Jenkins
- windows下用Python把pdf文件转化为图片(png格式)