您的位置:首页 > 其它

将自己的数据集制作成tf格式,文件批量重命名

2018-01-20 10:39 330 查看
主要参考博客:data_to_tf

在进行深度学习时,如何将自己的数据制作成tf格式是要关注的第一步。结合看到的多篇博客论文,并整理成python code格式进行记录及分享。

import tensorflow as tf
import os
from PIL import Image
import numpy as np

IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_CHANNEL = 128, 128, 1

# 生成整数型的属性
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# 生成字符串类型的属性
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 制作TFRecord格式
def createTFRecord(filename, mapfile):
'''
:param filename: output path
:param mapfile:
:return:
'''
class_map = {}
data_dir = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf'
classes = {'/ori', '/new'}
# 输出TFRecord文件的地址
writer = tf.python_io.TFRecordWriter(filename)
for index, name in enumerate(classes):
class_path = data_dir + name + '/'
class_map[index] = name
for img_name in os.listdir(class_path):
img_path = class_path + img_name # 每个图片的地址
img = Image.open(img_path)
# 获得当前原始图片的shape,当前的shape不能大于resize之后的大小
# print(np.shape(img))
img = img.resize((IMAGE_HEIGHT, IMAGE_WIDTH)) # 进行resize
img_raw = img.tobytes() # 将图片转化成二进制格式
example = tf.train.Example(features=tf.train.Features(feature={
'label': _int64_feature(index),
'image_raw': _bytes_feature(img_raw)
}))
# print('example', example)
writer.write(example.SerializeToString())
writer.close()
txtfile = open(mapfile, 'w+')
for key in class_map.keys():
txtfile.writelines(str(key) + ":" + class_map[key] + "\n")
txtfile.close()

# 读取生成的tfrecord,并进行resize
def read_and_decode(filename):
# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队列来维护输入文件列表
filename_queue = tf.train.string_input_producer([filename], shuffle=False, num_epochs=1)
# 从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
_, serialized_example = reader.read(filename_queue)
# print _,serialized_example

# 解析读入的一个样例,如果需要解析多个,可以用parse_example
features = tf.parse_single_example(
serialized_example,
features={'label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string), })
# 将字符串解析成图像对应的像素数组
img = tf.decode_raw(features['image_raw'], tf.uint8)
img = tf.reshape(img, [IMAGE_HEIGHT, IMAGE_HEIGHT, IMAGE_CHANNEL]) # reshape为128*128*3通道图片
img = tf.image.per_image_standardization(img)
labels = tf.cast(features['label'], tf.int32)
return img, labels

# 生成batch
def createBatch(filename, batchsize):
images, labels = read_and_decode(filename)

min_after_dequeue = 10
capacity = min_after_dequeue + 3 * batchsize

image_batch, label_batch = tf.train.shuffle_batch([images, labels],
batch_size=batchsize,
capacity=capacity,
min_after_dequeue=min_after_dequeue
)

label_batch = tf.one_hot(label_batch, depth=2)
return image_batch, label_batch

if __name__ == "__main__":
# 训练图片两张为一个batch,进行训练,测试图片一起进行测试
mapfile = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/classmap.txt'
train_filename = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/train.tfrecords'
createTFRecord(train_filename, mapfile)
test_filename = '/home/sxf/MyProject_Python/normal_code/data_make/my_data_to_tf/test.tfrecords'
createTFRecord(test_filename, mapfile)
image_batch, label_batch = createBatch(filename=train_filename, batchsize=2)
test_images, test_labels = createBatch(filename=test_filename, batchsize=2)
with tf.Session() as sess:
initop = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(initop)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

try:
step = 0
while 1:
_image_batch, _label_batch = sess.run([image_batch, label_batch])
step += 1
print(step)
print(_label_batch)
except tf.errors.OutOfRangeError:
print(" trainData done!")

try:
step = 0
while 1:
_test_images, _test_labels = sess.run([test_images, test_labels])
step += 1
print(step)
# print _image_batch.shape
print(_test_labels)
except tf.errors.OutOfRangeError:
print(" TEST done!")

coord.request_stop()
coord.join(threads)
存在的注意点:
原始的数据中图片的格式要要保证通道数的一致,不然resize要出错。

另一个要解决的点就是进行图片重命名的问题。

直接上代码:

# 将原始路径下的图片进行重命名,并复制保存到新的路径下,最终返回新的路径下的文件名列表。
import os
import shutil

def rename(path_ori, newpath,flage=True):
newname_front = input("please input the new name style:")
print('new name is the format like %s_1.....'%(newname_front))
newname_front = newname_front.strip()
file_list = os.listdir(path_ori)
i = 0
for file in file_list:
i += 1
olddir = os.path.join(path_ori, file)
if os.path.isdir(olddir):
continue
filename = os.path.splitext(file)[0]
filetype = os.path.splitext(file)[1]
newname = newname_front + '_' + str(i)
rename_dir = os.path.join(path_ori, newname + filetype)
rename_new_dir = os.path.join(newpath, newname + filetype)
os.rename(olddir, rename_dir)
# savedatapath = os.path.join(strangedatafile, filename)
if flage:
shutil.copyfile(rename_dir, rename_new_dir)
newfile_list = os.listdir(newpath)
return newfile_list

###########################//
#test
###########################//
path = '/home/sxf/MyProject_Python/ori_image/new'
newpath = '/home/sxf/MyProject_Python/ori_image/ori'
list = rename(path, newpath,flage=False)
print(list)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: