tensorflow之从文件中读取数据(适用场景:大规模数据集,亲测有效~)
2018-03-22 12:45
656 查看
网上从文件中读取样本和标签的资料很多,但大多讲的不全面,或只讲原理,或只有变为.tfrecords部分,或没有调用的栗子。寄几and男票一起捣鼓了两天,终于有了目前这个完整版的代码,希望对看到的朋友有所帮助。
图1
图2
我们用ray14_train.py进行train,这个.py文件和train_y.csv不在同一目录下。所以,在标签文件train_y.csv中,我们需要将图片名称这一列变为相对路径,如图4所示,这个新csv我们存为y_train.csv,测试集也这么处理。
图3
图4import numpy as np
import pandas as pd
import cv2
import csv
from os import path as osp
import os
n = 86524, my_number = 25596, my_catelogy = 2, batch = 4):
for e in range(nb_epoch):
n_batch = 0
for my_batch_train in range(int(n/batch)):
Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
Xu_batch = transform_batch_images(Xu_batch)
Yu_batch = np_utils.to_categorical(Yu_batch, 2)
# print('train label',Yu_batch)
feed_dict = { self.x: Xu_batch, self.y_: Yu_batch ,self.istrain:True}
cost, Ft_loss = self.sess.run([cost, Ft_loss], feed_dict=feed_dict)
n_batch += 1
#every 1000 minibatch print loss
if n_batch % 1000==0:
print("Epoch %d total_loss %f Ft_loss %f" % (e + 1, cost,Ft_loss))其中,从文件读取部分代码是:Xu_batch, Yu_batch = self.sess.run([source_train, y_train])9.测试的代码就不写了,类似8。
参考资料:1.https://zhuanlan.zhihu.com/p/272386302.https://www.cnblogs.com/wktwj/p/7257526.html
1. 准备样本和标签
样本图示如图1,标签文件train_y.csv如图2,这是个2分类问题。图1
图2
2.生成记录样本的记录文件
我们的图片存储路径如图3红框所示,标签文件train_y.csv存储路径如图3绿框所示。我们用ray14_train.py进行train,这个.py文件和train_y.csv不在同一目录下。所以,在标签文件train_y.csv中,我们需要将图片名称这一列变为相对路径,如图4所示,这个新csv我们存为y_train.csv,测试集也这么处理。
图3
图4import numpy as np
import pandas as pd
import cv2
import csv
from os import path as osp
import os
base_path = os.path.join('images','images224') train_y_path = os.path.join(base_path,'train_y.csv') train_y = np.loadtxt(train_y_path, delimiter=",", skiprows=0, usecols=(0,1), dtype=str) train_y_pd = pd.DataFrame(train_y) for i in range(train_y.shape[0]): train_y_pd.iloc[i,0] = os.path.join(base_path,train_y[i,0]) train_y_pd.to_csv(os.path.join(base_path, 'y_train.csv'),header=None,index=None)
先将2运行,得到y_train.csv和y_test.csv,从3开始要正式读取了。
3.读取csv存于数组中,将图片路径和标签存于数组中
def load_file(example_list_file): lines = np.genfromtxt(example_list_file,delimiter=",",dtype=[('col1', 'S120'), ('col2', 'i8')]) examples = [] labels = [] for example,label in lines: examples.append(example) labels.append(label) #convert to numpy array return np.asarray(examples),np.asarray(labels),len(lines)
4.使用cv2读取图片
def extract_image(filename,height,width): # print(filename) image = cv2.imread(filename) # image = cv2.resize(image,(height,width)) b,g,r = cv2.split(image) rgb_image = cv2.merge([r,g,b]) return rgb_image
5.将图片和标签转化为tfrecords文件
def trans2tfRecord(train_file,name,output_dir,height,width): if not os.path.exists(output_dir) or os.path.isfile(output_dir): os.makedirs(output_dir) _examples,_labels,examples_num = load_file(train_file) filename = name + '.tfrecords' writer = tf.python_io.TFRecordWriter(filename) for i,[example,label] in enumerate(zip(_examples,_labels)): # print("NO{}".format(i)) #need to convert the example(bytes) to utf-8 example = example.decode("UTF-8") image = extract_image(example,height,width) image_raw = image.tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'image_raw':_bytes_feature(image_raw), 'height':_int64_feature(image.shape[0]), 'width': _int64_feature(32), 'depth': _int64_feature(32), 'label': _int64_feature(label) })) b6c0 writer.write(example.SerializeToString()) writer.close() return filename
def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
6.从tfrecords文件中读取训练数据
def read_tfRecord(file_tfRecord,shuffle=False): # 这个函数需要传入一个文件名,系统会自动将它转为一个文件名队列,这个队列存的是训练或测试过程用到的数据 # tf.train.string_input_producer有两个重要的参数,一个是num_epochs,这个设成默认none就行,none表示无限次 # 它表示将全部样本入队次数,一般程序迭代几次就入队几次。程序运行开始,数据就开始出队,为了保证队列一直不空, # 我们设为none,使全部样本入队无数次(无限循环)。 # 另外一个就是shuffle,shuffle是指在一个epoch内文件的顺序是否被打乱(但是我测试时发现无论是True还是False,其实都打乱了)。 queue = tf.train.string_input_producer([file_tfRecord], shuffle=shuffle) reader = tf.TFRecordReader() _,serialized_example = reader.read(queue) features = tf.parse_single_example( serialized_example, features={ 'image_raw': tf.FixedLenFeature([], tf.string), 'height': tf.FixedLenFeature([], tf.int64), 'width':tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64) } ) image = tf.decode_raw(features['image_raw'],tf.uint8) #height = tf.cast(features['height'], tf.int64) #width = tf.cast(features['width'], tf.int64) image = tf.reshape(image,[224,224,3]) image = tf.cast(image, tf.float32) image = tf.image.per_image_standardization(image) label = tf.cast(features['label'], tf.int64) print(image,label) return image,label
7.调用3-6,开始训练
with tf.Session() as sess: # 训练过程 base_path = os.path.join('images','images224') data_train_path = os.path.join(base_path,'y_train.csv') data_test_path = os.path.join(base_path,'y_test.csv') # 首次执行程序需要运行一旦生成之后就可以注释掉了:利用csv生成y_train.tfrecords和y_test.tfrecords文件,这俩文件是训练集和测试集的样本与标签, filename = trans2tfRecord(data_train_path, 'y_train', base_path, 224, 224) filename2 = trans2tfRecord(data_train_path, 'y_test', base_path, 224, 224) img_batch, path_batch = read_tfRecord(filename, shuffle=True) img_batch2, path_batch2 = read_tfRecord(filename2, shuffle=False) image_batches, label_batches = tf.train.batch([img_batch, path_batch], batch_size=batch, capacity=4096) image_batches2, label_batches2 = tf.train.batch([img_batch2, path_batch2], batch_size=batch, capacity=4096) tf.local_variables_initializer().run() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess,coord=coord) # 定义一个模型 model=ATDA(sess=sess) model.create_model() # 训练模型:(image_batches,label_batches)是训练集,(image_batches2,label_batches2)是测试集, model.fit_ATDA(source_train=image_batches, y_train=label_batches, target_val=image_batches2, y_val=label_batches2, # n是训练集总数,my_number是测试集总数,my_catelogy是标签种类,batch是迭代次数 nb_epoch=epochs, n = 86524, my_number = 25596, my_catelogy = 2,batch = 16) coord.request_stop() # 请求线程结束 coord.join() # 等待线程结束8.model.fit_ATDA(),这部分是训练模型。def fit_ATDA(source_train, y_train, target_val, y_val, nb_epoch=30,
n = 86524, my_number = 25596, my_catelogy = 2, batch = 4):
for e in range(nb_epoch):
n_batch = 0
for my_batch_train in range(int(n/batch)):
Xu_batch, Yu_batch = self.sess.run([source_train, y_train])
Xu_batch = transform_batch_images(Xu_batch)
Yu_batch = np_utils.to_categorical(Yu_batch, 2)
# print('train label',Yu_batch)
feed_dict = { self.x: Xu_batch, self.y_: Yu_batch ,self.istrain:True}
cost, Ft_loss = self.sess.run([cost, Ft_loss], feed_dict=feed_dict)
n_batch += 1
#every 1000 minibatch print loss
if n_batch % 1000==0:
print("Epoch %d total_loss %f Ft_loss %f" % (e + 1, cost,Ft_loss))其中,从文件读取部分代码是:Xu_batch, Yu_batch = self.sess.run([source_train, y_train])9.测试的代码就不写了,类似8。
参考资料:1.https://zhuanlan.zhihu.com/p/272386302.https://www.cnblogs.com/wktwj/p/7257526.html
相关文章推荐
- Tensorflow从文件读取数据
- Ubuntu下Tensorflow加载MNIST数据集(数据下载和读取)
- MATLAB怎样有效读取excel文件中的数据?
- vbs读取文件内的信息将非有效数据移动到指定路径
- python读取文件中的一行有效数据
- Tensorflow从文件读取数据
- Tensorflow 从bin文件中读取数据并
- tensorflow数据集制作/文件队列读取方式
- C++ 简单读写文本文件、统计文件的行数、读取文件数据到数组
- 读取某个文件夹下的所有文件并读取文件中的文本数据
- C++ 简单读写文本文件、统计文件的行数、读取文件数据到数组
- c语言读取obj文件转换数据
- read.table()读取数据文件
- 从plist文件中读取数据
- Java如何读取数据文件,如txt文件或者.dat文件 中的内容
- 【python图像处理】txt文件数据的读取与写入
- Oracle DBA的神器: PRM恢复工具,可脱离Oracle软件运行,直接读取Oracle数据文件中的数据
- tensorflow学习笔记三:实例数据下载与读取
- js读取json文件中的json数据
- matlab 批量读取数据文件.mat .dat