您的位置:首页 > Web前端

caffe2 学习笔记03-从图片如何到mdb数据集

2017-08-28 15:34 441 查看

caffe2 学习笔记03-从图片如何到mdb数据集

caffe2 学习笔记03-从图片如何到mdb数据集
前言

import库文件

准备

write函数读入图片文件与标签并转换为mdb文件

read函数读取mdb文件并校验此步不是必须的

执行

可能遇见的报错
CHW和HWC的问题

channels不匹配问题

1. 前言

本文以caffe2训练识别汉字模型为例;

2. import库文件

输出为
Required modules imported.")
即导入成功,若提示缺少某个库文件,请谷歌一下;

# -*- coding: UTF-8 -*-
%matplotlib inline
import os
import skimage
import skimage.io as io
import skimage.transform
import sys
import numpy as np
import math
from matplotlib import pyplot
import matplotlib.image as mpimg

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import numpy as np

import lmdb
from caffe2.proto import caffe2_pb2
from caffe2.python import workspace, model_helper
print("Required modules imported.")


3. 准备

设置路径,设置标签对应表,限制最大输出文件大小

path = "/home/hw/H/00_dataOfPlate/15_hanzi/01_new_chn/train" #数据路径
sep = os.path.sep #当前系统(linux)路径分隔符

chn = ["beijing", "tianjin", "hebei", "shanxi", "neimenggu", "liaoning", "jilin", "heilongjiang", "shanghai", "jiangsu", "zhejiang", "anhui", "fujian", "jiangxi", "shandong", "henan", "hubei", "hunan", "guangdong", "guangxi", "hainan", "sichuan", "guizhou", "yunnan", "chongqin", "xizang", "shengxi", "gansu", "qinghai", "ningxia", "xinjiang"] # No. 31 # 标签对应表

LMDB_MAP_SIZE =  1099511627776  #max output file < 1TB
print("prepared") 最大输出文件大小


4. write函数:读入图片文件与标签,并转换为mdb文件

文件结构,以train文件夹为例,train下包含26个字母,标签label以图片所在文件夹为准;

- train

- A

- 0001.bmp

- 0002.bmp

- …

- 4000.bmp

- B

- 0001.bmp

- …

- 4100.bmp

- …

- …

一级目录二级目录图片
trainA1022.bmp
trainA
trainA4032.bmp
trainB1022.bmp
def write_db_with_caffe2(output_file):
print(">>> Write database ...")
LMDB_MAP_SIZE = 1099511627776
env = lmdb.open(output_file, map_size = LMDB_MAP_SIZE)
checksum = 0
checksumm = 0
j = 0

with env.begin(write = True) as txn:
for dirs in os.listdir(path):
#     print dirs
new_path = path + sep + dirs
label = chn.index(dirs)
for pics in os.listdir(new_path):
#print pics
#             print(len(os.listdir(new_path)))
pic_path = new_path + sep + pics
#print pic_path
img_data = skimage.img_as_float(skimage.io.imread(pic_path)).astype(np.float)
print("before: {}".format(img_data.shape))

img_data = img_data[:,:,:1] #3通道转换为1通道
img_data = img_data.swapaxes(1, 2).swapaxes(0, 1) #HWC 转换为 CHW
print("after: {}".format(img_data.shape))
#         print np.prod(img_data.shape)
tensor_protos = caffe2_pb2.TensorProtos()
img_tensor = tensor_protos.protos.add()
img_tensor.dims.extend(img_data.shape)
img_tensor.data_typ
4000
e = 1

flatten_img = img_data.reshape(np.prod(img_data.shape))
print("after: {}".format(flatten_img.shape))
img_tensor.float_data.extend(flatten_img.flat)

label_tensor = tensor_protos.protos.add()
label_tensor.data_type = 2
label_tensor.int32_data.append(label)
txn.put('{}'.format(j).encode('ascii'),tensor_protos.SerializeToString())

#             print(np.sum(img_data))
#             print(label)
checksum += np.sum(img_data) * label
checksumm += np.sum(img_data)
if(j % 5 == 0):
pass
#                 print("Inserted {} rows".format(j))
j+=1
#     print(j)
print("Checksum/write: {}".format(int(checksum)))
print("Checksumm/write: {}".format(int(checksumm)))


5. read函数:读取mdb文件,并校验(此步不是必须的)

输入数据所在文件夹:
read_db_with_caffe2(db_file, expected_checksum)


db_file: 数据文件所在路径

expected_checksum:期望的输出校验值,应该与write_db_with_caffe2中的值对应

def read_db_with_caffe2(db_file, expected_checksum):

print(">>> Read database...")
model = model_helper.ModelHelper(name="lmdbtest")
batch_size = 744000 #共计多少个文件,一定要写正确,否则会造成校验失败("Read/write checksums dont match")
data, label = model.TensorProtosDBInput(
[], ["data", "label"], batch_size=batch_size,
db=db_file, db_type="lmdb")

checksum = 0
workspace.RunNetOnce(model.param_init_net)
workspace.CreateNet(model.net)

for _ in range(0, 1):
workspace.RunNet(model.net.Proto().name)

img_datas = workspace.FetchBlob("data")
labels = workspace.FetchBlob("label")
#         print("batch_size: {}".format(batch_size))
#         print(img_data.shape)
for j in range(batch_size):
#             print(img_datas[j, 2])
checksum += np.sum(img_datas[j, :]) * labels[j]
checksumm += np.sum(img_datas[j, :])
#             print(np.sum(img_datas[j,:]))
#             print(labels[j])
print("Checksum/read: {}".format(int(checksum)))
print("minus of read and write: {}".format(np.abs(expected_checksum - checksum )))
assert np.abs(expected_checksum - checksum < 0.1), \
"Read/write checksums dont match"


6. 执行

执行时间较长,请耐心等待,读取744000个大小为20*20的灰度图像时,时间约为二十分钟,读取db数据进行测试,电脑卡死了,o(╯□╰)o;

write_db_with_caffe2("./chn_db")
read_db_with_caffe2("./chn_db", 640020532) #640020532为校验值,应该等于write中输出的checksum大小


7. 可能遇见的报错

1. CHW和HWC的问题:

input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1


原因:默认读取的图片为shape为HWC(height/width/channels),而caffe2默认图片数据格式为CHW,所以需要进行转换,不转换则报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 20 is not equal to kernel channels * group:1*1 Error from operator:
input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"


解决方式:在将图片转换为mdb文件时,加入
img_data = img_data.swapaxes(1, 2).swapaxes(0, 1)
(上面的程序中已经加入了)

2. channels不匹配问题:

input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1


原因:默认读取的图片,不论是否为灰度图,都会以三通道的形式读取,经过上面1. 中的HWC–>> CHW的转换后,通道为3,与MNIST示例LENET中的单通道不匹配,所以报错如下:

RuntimeError: [enforce fail at conv_op_impl.h:30] C == filter.dim32(1) * group_. Convolution op: input channels does not match: # of input channels 3 is not equal to kernel channels * group:1*1 Error from operator:

input: "data" input: "conv1_w" input: "conv1_b" output: "conv1" name: "" type: "Conv" arg { name: "kernel" i: 5 } arg { name: "exhaustive_search" i: 0 } arg { name: "order" s: "NCHW" } engine: "CUDNN"


解决方式:在将图片转换为mdb文件时,加入:
img_data = img_data[:,:,:1]
(上面程序已经加入了)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  caffe2