TensorFlow——训练自己的数据(三)模型训练
2017-07-11 16:45
543 查看
参考:Tensorflow教程-猫狗大战数据集
文件training.py
导入文件
变量声明
获取批次batch
操作定义
进行batch的训练
文件training.py
导入文件
import os import numpy as np import tensorflow as tf import input_data import model
变量声明
N_CLASSES = 2 #猫和狗 IMG_W = 208 # resize图像,太大的话训练时间久 IMG_H = 208 BATCH_SIZE = 16 CAPACITY = 2000 MAX_STEP = 10000 # 一般大于10K learning_rate = 0.0001 # 一般小于0.0001
获取批次batch
train_dir = '/home/kevin/tensorflow/cats_vs_dogs/data/train/' logs_train_dir = '/home/kevin/tensorflow/cats_vs_dogs/logs/train/' train, train_label = input_data.get_files(train_dir) train_batch,train_label_batch=input_data.get_batch(train, train_label, IMG_W, IMG_H, BATCH_SIZE, CAPACITY)
操作定义
train_logits = model.inference(train_batch, BATCH_SIZE, N_CLASSES) train_loss = model.losses(train_logits, train_label_batch) train_op = model.trainning(train_loss, learning_rate) train__acc = model.evaluation(train_logits, train_label_batch) summary_op = tf.summary.merge_all() #这个是log汇总记录 #产生一个会话 sess = tf.Session() #产生一个writer来写log文件 train_writer = tf.summary.FileWriter(logs_train_dir, sess.graph) #产生一个saver来存储训练好的模型 saver = tf.train.Saver() #所有节点初始化 sess.run(tf.global_variables_initializer()) #队列监控 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord)
进行batch的训练
try: #执行MAX_STEP步的训练,一步一个batch for step in np.arange(MAX_STEP): if coord.should_stop(): break #启动以下操作节点,有个疑问,为什么train_logits在这里没有开启? _, tra_loss, tra_acc = sess.run([train_op, train_loss, train__acc]) #每隔50步打印一次当前的loss以及acc,同时记录log,写入writer if step % 50 == 0: print('Step %d, train loss = %.2f, train accuracy = %.2f%%' %(step, tra_loss, tra_acc*100.0)) summary_str = sess.run(summary_op) train_writer.add_summary(summary_str, step) #每隔2000步,保存一次训练好的模型 if step % 2000 == 0 or (step + 1) == MAX_STEP: checkpoint_path = os.path.join(logs_train_dir, 'model.ckpt') saver.save(sess, checkpoint_path, global_step=step) except tf.errors.OutOfRangeError: print('Done training -- epoch limit reached') finally: coord.request_stop()
相关文章推荐
- Tensorflow学习笔记:用minst数据集训练卷积神经网络并用训练后的模型测试自己的BMP图片
- TensorFlow——训练自己的数据(二)模型设计
- TensorFlow——训练自己的数据(四)模型测试
- 利用tensorflow训练自己的图片数据(3)——建立网络模型
- 用tensorflow训练自己的数据_3、训练模型
- TensorFlow——训练自己的数据(五)模型评估
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-2生成图像库的均值文件
- Caffe使用step by step:使用自己数据对已经训练好的模型进行finetuning
- SSD Faster-RCNN使用自己的数据fine-tune训练模型
- Matconvnet 训练自己的数据(使用现有模型)
- mxnet 使用自己的图片数据训练CNN模型
- Caffe使用step by step:使用自己数据对已经训练好的模型进行finetuning
- caffe——cifar10模型训练自己的数据
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-4应用生成模型进行预测
- Caffe windows 用自己的数据训练模型
- 深度学习-CAFFE利用CIFAR10网络模型训练自己的图像数据获得模型-1.制作自己的数据集
- Tensorflow 训练自己的cnn模型 行人识别
- 【深度学习】笔记6:使用caffe中的CIFAR10网络模型和自己的图片数据训练自己的模型(步骤详解)
- tensorflow 分布式 数据并行 异步训练 between-graph 自己写的实例 RNN
- Tensorflow 训练模型数据freeze固话保存在Graph中