caffe2 学习笔记03-从图片如何到mdb数据集
2017-08-28 15:34
441 查看
caffe2 学习笔记03-从图片如何到mdb数据集
caffe2 学习笔记03-从图片如何到mdb数据集前言
import库文件
准备
write函数读入图片文件与标签并转换为mdb文件
read函数读取mdb文件并校验此步不是必须的
执行
可能遇见的报错
CHW和HWC的问题
channels不匹配问题
1. 前言
本文以caffe2训练识别汉字模型为例;2. import库文件
输出为Required modules imported.")即导入成功,若提示缺少某个库文件,请谷歌一下;
# -*- coding: UTF-8 -*- %matplotlib inline import os import skimage import skimage.io as io import skimage.transform import sys import numpy as np import math from matplotlib import pyplot import matplotlib.image as mpimg from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import numpy as np import lmdb from caffe2.proto import caffe2_pb2 from caffe2.python import workspace, model_helper print("Required modules imported.")
3. 准备
设置路径,设置标签对应表,限制最大输出文件大小path = "/home/hw/H/00_dataOfPlate/15_hanzi/01_new_chn/train" #数据路径 sep = os.path.sep #当前系统(linux)路径分隔符 chn = ["beijing", "tianjin", "hebei", "shanxi", "neimenggu", "liaoning", "jilin", "heilongjiang", "shanghai", "jiangsu", "zhejiang", "anhui", "fujian", "jiangxi", "shandong", "henan", "hubei", "hunan", "guangdong", "guangxi", "hainan", "sichuan", "guizhou", "yunnan", "chongqin", "xizang", "shengxi", "gansu", "qinghai", "ningxia", "xinjiang"] # No. 31 # 标签对应表 LMDB_MAP_SIZE = 1099511627776 #max output file < 1TB print("prepared") 最大输出文件大小
4. write函数:读入图片文件与标签,并转换为mdb文件
文件结构,以train文件夹为例,train下包含26个字母,标签label以图片所在文件夹为准;- train
- A
- 0001.bmp
- 0002.bmp
- …
- 4000.bmp
- B
- 0001.bmp
- …
- 4100.bmp
- …
- …
一级目录 | 二级目录 | 图片 |
---|---|---|
train | A | 1022.bmp |
train | A | … |
train | A | 4032.bmp |
train | B | 1022.bmp |
def write_db_with_caffe2(output_file): print(">>> Write database ...") LMDB_MAP_SIZE = 1099511627776 env = lmdb.open(output_file, map_size = LMDB_MAP_SIZE) checksum = 0 checksumm = 0 j = 0 with env.begin(write = True) as txn: for dirs in os.listdir(path): # print dirs new_path = path + sep + dirs label = chn.index(dirs) for pics in os.listdir(new_path): #print pics # print(len(os.listdir(new_path))) pic_path = new_path + sep + pics #print pic_path img_data = skimage.img_as_float(skimage.io.imread(pic_path)).astype(np.float) print("before: {}".format(img_data.shape)) img_data = img_data[:,:,:1] #3通道转换为1通道 img_data = img_data.swapaxes(1, 2).swapaxes(0, 1) #HWC 转换为 CHW print("after: {}".format(img_data.shape)) # print np.prod(img_data.shape) tensor_protos = caffe2_pb2.TensorProtos() img_tensor = tensor_protos.protos.add() img_tensor.dims.extend(img_data.shape) img_tensor.data_typ 4000 e = 1 flatten_img = img_data.reshape(np.prod(img_data.shape)) print("after: {}".format(flatten_img.shape)) img_tensor.float_data.extend(flatten_img.flat) label_tensor = tensor_protos.protos.add() label_tensor.data_type = 2 label_tensor.int32_data.append(label) txn.put('{}'.format(j).encode('ascii'),tensor_protos.SerializeToString()) # print(np.sum(img_data)) # print(label) checksum += np.sum(img_data) * label checksumm += np.sum(img_data) if(j % 5 == 0): pass # print("Inserted {} rows".format(j)) j+=1 # print(j) print("Checksum/write: {}".format(int(checksum))) print("Checksumm/write: {}".format(int(checksumm)))
5. read函数:读取mdb文件,并校验(此步不是必须的)
输入数据所在文件夹:read_db_with_caffe2(db_file, expected_checksum)
db_file: 数据文件所在路径
expected_checksum:期望的输出校验值,应该与write_db_with_caffe2中的值对应
def read_db_with_caffe2(db_file, expected_checksum): print(">>> Read database...") model = model_helper.ModelHelper(name="lmdbtest") batch_size = 744000 #共计多少个文件,一定要写正确,否则会造成校验失败("Read/write checksums dont match") data, label = model.TensorProtosDBInput( [], ["data", "label"], batch_size=batch_size, db=db_file, db_type="lmdb") checksum = 0 workspace.RunNetOnce(model.param_init_net) workspace.CreateNet(model.net) for _ in range(0, 1): workspace.RunNet(model.net.Proto().name) img_datas = workspace.FetchBlob("data") labels = workspace.FetchBlob("label") # print("batch_size: {}".format(batch_size)) # print(img_data.shape) for j in range(batch_size): # print(img_datas[j, 2]) checksum += np.sum(img_datas[j, :]) * labels[j] checksumm += np.sum(img_datas[j, :]) # print(np.sum(img_datas[j,:])) # print(labels[j]) print("Checksum/read: {}".format(int(checksum))) print("minus of read and write: {}".format(np.abs(expected_checksum - checksum ))) assert np.abs(expected_checksum - checksum < 0.1), \ "Read/write checksums dont match"
6. 执行
执行时间较长,请耐心等待,读取744000个大小为20*20的灰度图像时,时间约为二十分钟,读取db数据进行测试,电脑卡死了,o(╯□╰)o;write_db_with_caffe2("./chn_db") read_db_with_caffe2("./chn_db", 640020532) #640020532为校验值,应该等于write中输出的checksum大小
7. 可能遇见的报错
1. CHW和HWC的问题:
input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1
原因:默认读取的图片为shape为HWC(height/width/channels),而caffe2默认图片数据格式为CHW,所以需要进行转换,不转换则报错如下:
RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1 Error from operator: input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"
解决方式:在将图片转换为mdb文件时,加入
img_data = img_data.swapaxes(1, 2).swapaxes(0, 1)(上面的程序中已经加入了)
2. channels不匹配问题:
input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1
原因:默认读取的图片,不论是否为灰度图,都会以三通道的形式读取,经过上面1. 中的HWC–>> CHW的转换后,通道为3,与MNIST示例LENET中的单通道不匹配,所以报错如下:
RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1 Error from operator: input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"
解决方式:在将图片转换为mdb文件时,加入:
img_data = img_data[:,:,:1](上面程序已经加入了)
相关文章推荐
- 【openCV学习笔记】【1】如何载入一张图片
- aswing学习笔记3-在JPanel中,如何将.png格式的图片设置为背景?
- JavaScript学习笔记(03)之JavaScript如何输出显示
- HTML学习笔记——如何让图片随鼠标移动
- iOS学习笔记-104.多线程03——线程间通信图片下载与时间计算
- HTML学习笔记03(按钮使用图片)
- [Android新手学习笔记03]-如何创建Menu菜单
- 深度学习caffe应用笔记--如何将图片转换为lmdb格式
- Sharepoint学习笔记---如何在Sharepoint2010网站中整合Crystal Report水晶报表(显示图片)
- 【学习笔记】cocos2d 如何绘制一张图片
- 2015.05.18,外语,学习笔记-《Word Power Made Easy》 03 “如何谈论不同从业者”
- iOS学习笔记--如何使九宫格布局图片的显示自适应大小
- [知了堂学习笔记]_mybatis_03如何快速搭建mybatis框架之二
- 【学习笔记】cocos2d-x 如何创建一个按钮(文本按钮、图片按钮)
- HTML学习笔记4:如何给网页添加图片和超链接
- java(j2se)学习笔记----如何实现四舍五入?
- J2ME学习笔记(7)-- 如何在WKT中设置自己的工程目录
- 孙鑫VC学习笔记:第十一讲 (三) 如何把元文件保存到文件当中
- 孙鑫VC学习笔记:第十一讲 (五) 如何使窗口具有滚动条
- (源码实例)通过层DIV实现,当鼠标放在链接上面,显示图片及文字 - 流星絮语 JAVA学习笔记 - CSDNBlog