您的位置:首页 > 其它

Tenosorflow基础学习---------Tensorflow训练自己的数据集

2018-09-25 20:55 513 查看

一般我们获得的数据集并非是提前处理好的二进制的格式文件,所以我们需要将数据集进行处理,当然我们这里说的数据集类似于猫狗大战那样的,并不是MNIST和CIFAR-10那样拿来就可以直接用的,而且提前分类和标签的数据集,只不过给的是大量的图片,一般都是比赛提供方给的数据集,而对于这样的数据集当然不可能整张输入和读取,这样不仅的数据不仅数据量大,需要大量的内存消耗,而且时间也是相当的慢,于是在tensorflow中提供了一种专门用于tensorflow的数据集的格式转换。

目录

1.第一步:数据集的加工

2.第二步:图片数据集转化为Tensorflow专用格式

2.1附加读取数据

1.第一步:数据集的加工

数据集中的数据并不是按照规格大小处理,对于不同的的图片,其规格尺寸都不尽相同,因此在数据集提交之前需要对数据集进行处理。对于图像的处理我这里用的是opencv,opencv的功能真的方便,强力推荐大家!!

最简单的方式就是把数据裁剪成规定大小。比如输入模型中的图片大小为[227,227],因此我们这里将图片裁剪成[227,227]的尺寸,给出代码示例:

[code]import cv2
import os
def rebulid(dir):
#walk(top, topdown=True, onerror=None, followlinks=False)
#top-是你所要遍历的目录的地址,返回的是一个三元组(root,dirs,files)
#root所指的是当前正在遍历的这个文件夹的本身的地址
#dirs是一个list,内容是该文件夹中所有的目录的名字(不包括子目录)
#files同样是list,内容是该文件夹中所有的文件(不包含子目录)
#topdown--可选,为True,则优先遍历top目录,否则优先遍历top的子目录(默认为开启)。
# 如果topdown参数为True,walk会遍历top文件夹中每一个子目录
#onerror--可选,需要一个callable对象,当walk需要异常时,会调用
#followlinks--可选,如果为True,则会遍历目录下的快捷方式实际所指的目录
for root,dirs,files in os.walk(dir):
for file in files:
filepath = os.path.join(root,file)
try:
image = cv2.imread(filepath)
dim = (227,227)
resized = cv2.resize(image,dim)
path = "C:\\cat_and_dog\\dog-r\\"+file
cv2.imwrite(path,resized)
except:
print(filepath)
os.remove(filepath)
cv2.waitKey(0)   #退出

在这里导入的是图片集的根目录,os对数据集所在的文件夹进行读取,之后的一个for循环重建了图片数据所在的路径(filepath),在图片被重构后重新写入给定的位置(path)。

2.第二步:图片数据集转化为Tensorflow专用格式

对于数据集来说,最好的办法就是将其转换成Tensorflow专用的数据格式,即TFRecord格式。

将裁剪后的图片的位置进行读取,之后根据文件名称的不同将处于不同文件夹中的图片标签设置为0或者1,如果有更多分类的话可以依据这个格式设置更多的标签类,之后使用创建的数组对所读取的文件位置和标签进行保存,而Numpy对数组的调整重构了存储有对应文件位置和文件标签的矩阵,并返回。

[code]def get_file(file_dir):
images=[]
temp = []
for root,sub_folders,files in os.walk(file_dir):
#图片目录
for name in files:
images.append(os.path.join(root,name))
#get 10 sub-folders:
for name in sub_folders:
temp.append(os.path.join(root,name))
print(files)
#根据文件夹名分配多个标签
labels = []
for one_folder in temp:
n_img = len(os.listdir(one_folder))
#split('\\')[-1]以\\分割字符串,保留最后一段
letter = one_folder.split('\\')[-1]
if letter == 'cat':
#n_img*[0]和np.zeros[n_img]一样
labels = np.append(labels,n_img*[0])
else:
#n_img*[1]和np.ones[n_img]一样
labels = np.append(labels,n_img*[1])
temp = np.array([images,labels])
#转置
temp = temp.transpose()
#shuffle 随机排列
np.random.shuffle(temp)

image_list = list(temp[:,0])
label_list = list(temp[:,1])
label_list = [int(float(i) for i in label_list)]

return image_list,label_list

在获取图片数据文件位置和图片标签之后,即可通过相应的程序对其进行读取,并生成专门用的TFRecord格式的数据集

首先是转换格式的定义,这里需要将数据转换为相应的格式。

[code]def int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

def convert_to_tfrecord(images_list,lables_list,save_dir,name):
filename = os.path.join(save_dir,name+'.tfrecords')
n_samples = len(lables_list)
writer = tf.python_io.TFRecordWriter(filename)
print('\nTransform start.......')
for i in np.arange(0,n_samples):
try:
image = io.imread(images_list[i])
image_raw = image.tostring()
label = int(label[i])
example = tf.train.Example(features=tf.train.Feature(feature={
'label':int64_feature(label),
'image_raw':bytes_feature(image_raw)}))
except IOError as e:
print('Could not read:',images[i])
writer.close()
print('Transform done!')

convert_to_tfrecord(images_list,labels_list,save_dir,name)函数中需要4个参数,其中images_list和labels_list是上一段代码段获取的图片位置和对应标签的列表。save_dir是存储路径,如果希望生成的TFRecord文件存储在当前目录下,直接使用空的双引号""即可。最后是生成的文件名,这里只需填写名称就会自动生成以".tfrecords"格式结尾的数据集。

2.1附加读取数据

当生成完数据集后,在神经网络使用数据集进行训练时,需要一个方法将数据从数据急中取出,下面代码段完成了数据读取功能。

[code]def read_and_decode(tfrecords_file,batch_size):
filename_queue = tf.train.string_input_producer([tfrecords_file])
reader = tf.TFRecordReader()
_,serialized_example = reader.read(filename_queue)
img_features = tf.parse_single_example(
serialized_example,
features={
'label':tf.FixedLenFeature([],tf.int64),
'image_raw':tf.FixedLenFeature([],tf.string),
}
)
image = tf.decode_raw(img_features['image_raw'],tf.uint8)
image = tf.reshape(image,[227,227,3])
lable = tf.cast(img_features['label'],tf.int32)
image_batch,lable_batch = tf.train.shuffle_batch([image,label],
batch_size=batch_size,
min_after_dequeue=100,
num_threads=64,
capacity=200)
return image_batch,tf.reshape(lable_batch,[batch_size])

 

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: 
相关文章推荐