您的位置:首页 > 编程语言 > Python开发

深度学习中构造自己的图像数据集格式

2016-11-23 21:35 501 查看
构建数据集

读取数据集

可视化数据集

完整代码

在接触tensorflow之前我还从未接触过深度学习,所以刚开始写tensorflow就卡在了数据集上了,官方提供了mnist和cifar10两种数据集的训练代码,如果我们希望将自己的图像数据同样用在示例代码上,那么我们必须建立同样格式的数据集,我在初期就是这样做的,实际上在另一篇博文 windows下编译mnisten 我已经写了该如何制作mnist格式的数据集了。而现在我们可以构建自定义格式(我把它称为znyp格式)并输入模型进行训练,这将更加方便且利于理解。

构建数据集

构建数据集的核心代码如下:

from PIL import Image
import numpy as np

im = Image.open('images.jpeg')
im = (np.array(im))

r = im[:,:,0].flatten()
g = im[:,:,1].flatten()
b = im[:,:,2].flatten()
label = [1]

out = np.array(list(label) + list(r) + list(g) + list(b),np.uint8)
out.tofile("out.bin")


这是我在stack overflow中找到的,原理就是对于每一张图片,我们读取图片像素并转为numpy格式然后写入磁盘即可。事实上我尝试使用for循环来读取像素

data = []
for c in range(3):
for w in range(32):
for h in range(32):
data.append(pix[w,h][c])


两种代码的速度比起来,使用for循环会慢很多,所以我不建议使用for循环来读取图片像素,特别是在数据量比较大的情况。接下来我们对这段核心代码进行改造,使它能够真正制作数据集并具有一定的可拓展性。首先为了同时支持灰度图和RGB图像的读取,我们可以做一个判断

if channel == 1:
r = im.flatten()
g = []
b = []
elif channel == 3:
# get pixel from red channel, then green then blue
r = im[:, :, 0].flatten()
g = im[:, :, 1].flatten()
b = im[:, :, 2].flatten()
else:
raise Exception('The channel for the image should be 1 or 3')


接着我们所需做的就是将所有图片得到像素读入并拼接在一起,然后转换为numpy格式,代码如下

def _pickel_dataset(filename, write_to_file_root, category, channel, step_every_print=500):
data = []
labels = []

for idx, filename in enumerate(filename):
im = Image.open(filename)
im = (np.array(im))
H, W = im.shape[0], im.shape[1]

if channel == 1: r = im.flatten() g = [] b = [] elif channel == 3: # get pixel from red channel, then green then blue r = im[:, :, 0].flatten() g = im[:, :, 1].flatten() b = im[:, :, 2].flatten() else: raise Exception('The channel for the image should be 1 or 3')

# append the label
label = int(filename.split(os.sep)[-2])
labels.append(label)
# append the pixel
data += (list(r) + list(g) + list(b))

# convert the list to numpy
data = np.array(data, np.uint8)


在实际运行的时候主要有内存的限制,如果数据量过大会导致程序无法正常运行,具体看情况而定(我测试的10000张60 * 60的RGB图片是没有问题的)。最后我们需要将变量存储到磁盘,这里使用python来读取数据集其实是非常方便的,因为我们可以通过序列化将图像内容写入磁盘然后通过反序列化将图像内容读取到内存当中,所以我们使用python的pickle来实现,将字典保存到磁盘。代码如下

# write to pickle
datadict = {'data': data, 'labels': labels, 'height': H, 'width': W, 'channel': channel, 'num_images': num_images}
f = open(outname, 'wb')
pickle.dump(datadict, f, True)
f.close()


读取数据集

既然已经保存好了数据集,我们肯定要能够读取它,读取代码非常简单

def _unpickle_dataset(filename, category):
with open(filename, 'rb') as f:
datadict = pickle.load(f)
X = datadict['data']
Y = datadict['labels']
H = datadict['height']
W = datadict['width']
C = datadict['channel']
num_images = datadict['num_images']
X = X.reshape(num_images, C, H, W).transpose(0, 2, 3,
1).astype("float")
Y = np.array(Y)
print('X_%s: %s' % (category, X.shape))
print('Y_%s: %s' % (category, Y.shape))
return X, Y


可视化数据集

为了保证数据集的构建是正确的,我们可以将随机取加载的数据集中的几个图像进行可视化来判断,代码如下

def visualize_image(X_train, y_train):
W, H, C = X_train[0].shape
plt.rcParams['figure.figsize'] = (10.0, 8.0)
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# 这里控制输出的分类个数,列表中有几个元素就输出几个分类,分类从0开始
classes = ['A', 'B', 'C']
num_classes = len(classes)
# 这里控制每个分类的显示图片个数
samples_per_class = 4
for y, cls in enumerate(classes):
idxs = np.flatnonzero(y_train == y)
idxs = np.random.choice(idxs, samples_per_class, replace=False)
for i, idx in enumerate(idxs):
plt_idx = i * num_classes + y + 1

plt.subplot(samples_per_class, num_classes, plt_idx)
if C == 1:
X_show = X_train[idx].reshape((W, H))
elif C == 3:
X_show = X_train[idx]
plt.imshow(X_show.astype('uint8'))
plt.axis('off')
if i == 0:
plt.title(cls)
plt.show()


效果如下(图片都是随便找的):



完整代码

整个数据集的核心代码基本就是上面这些,当然比如对输入图像进行随机排列,同时制作测试集等功能这里就不详细说明了,完整的代码见github/znyp_format
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐