tensorflow图像分类实战解析(下)
2016-07-18 11:21
441 查看
global_step = tf.Variable(0, name='global_step', trainable=False) train_op = optimizer.minimize(loss, global_step=global_step) return train_op
设定最小化目标以及最大步长
def evaluation(logits, labels): correct = tf.nn.in_top_k(logits, labels, 1) return tf.reduce_sum(tf.cast(correct, tf.int32))
衡量标签得到的数值和ground-truth之间的关系
def placeholder_inputs(batch_size): images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,IMAGE_PIXELS))
这里的batchsize定义了placeholder一次性读入图片的数目,所以理论上应该是有能力不需要将所有图片全部读入内存中再进行处理的。这个我还得继续研究
labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size)) return images_placeholder, labels_placeholder def fill_feed_dict(images_feed,labels_feed, images_pl, labels_pl): feed_dict = { images_pl: images_feed, labels_pl: labels_feed, } return feed_dict
很重要的feed-dict 用来灌数据,其实就是在建立一个字典
def do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set): # And run one epoch of eval.
评价现在参数训练效果的函数
true_count = 0 # Counts the number of correct predictions. steps_per_epoch = 4 // FLAGS.batch_size num_examples = steps_per_epoch * FLAGS.batch_size for step in xrange(steps_per_epoch): feed_dict = fill_feed_dict(train_images,train_labels, images_placeholder, labels_placeholder) true_count += sess.run(eval_correct, feed_dict=feed_dict)
有必要再加深了解一下这个feed_dict是什么
precision = true_count / num_examples print(' Num examples: %d Num correct: %d Precision @ 1: %0.04f' % (num_examples, true_count, precision)) # Get the sets of images and labels for training, validation, and train_images = [] for filename in ['a.jpg', 'b.jpg', 'c.jpg', 'd.jpg']: image = Image.open(FLAGS.train_dir+'/'+filename) image = image.resize((IMAGE_SIZE,IMAGE_SIZE)) train_images.append(np.array(image))
读取图像训练列表到内存中
train_images = np.array(train_images)
转换为nummpy的array
train_images = train_images.reshape(4,IMAGE_PIXELS)
将图像序列化
label = [0,1,1,1] train_labels = np.array(label)
图像的ground-truth
def run_training(): # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Generate placeholders for the images and labels. images_placeholder, labels_placeholder = placeholder_inputs(4) # Build a Graph that computes predictions from the inference model. logits = inference(images_placeholder, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the Ops for loss calculation. loss = cal_loss(logits, labels_placeholder) # Add to the Graph the Ops that calculate and apply gradients. train_op = training(loss, FLAGS.learning_rate) # Add the Op to compare the logits to the labels during evaluation. eval_correct = evaluation(logits, labels_placeholder) # Create a saver for writing training checkpoints. saver = tf.train.Saver() # Create a session for running Ops on the Graph. sess = tf.Session()
我们注意到,这里有两种session可以选择Session和InteractiveSession.唯一区别在于InteractiveSession的eval和run方法会默认在interactivesession中进行调用,用法样例如下:
InteractiveSession:
sess = tf.InteractiveSession() a = tf.constant(5.0) b = tf.constant(6.0) c = a * b # We can just use 'c.eval()' without passing 'sess' print(c.eval()) sess.close()
Session:
a = tf.constant(5.0) b = tf.constant(6.0) c = a * b with tf.Session(): # We can also use 'c.eval()' here. print(c.eval())
如果不是很明白,还可以看看这个例子:
x = tf.linspace(-3.0, 3.0, n_values) # %% Construct a tf.Session to execute the graph. sess = tf.Session() result = sess.run(x) # %% Alternatively pass a session to the eval fn: x.eval(session=sess) # x.eval() does not work, as it requires a session! # %% We can setup an interactive session if we don't # want to keep passing the session around: sess.close() sess = tf.InteractiveSession() # %% Now this will work! x.eval()
# Run the Op to initialize the variables. init = tf.initialize_all_variables() sess.run(init)
初始化图中的所有变量
# And then after everything is built, start the training loop. for step in xrange(FLAGS.max_steps): start_time = time.time() feed_dict = fill_feed_dict(train_images,train_labels, images_placeholder, labels_placeholder) _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
进行训练,每次run只会驱动一次优化过程
duration = time.time() - start_time if step % 100 == 0: # Print status to stdout. print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps: saver.save(sess, FLAGS.train_dir, global_step=step)
训练完成,将结果存入train_dir路径中
print('Training Data Eval:') do_eval(sess, eval_correct, images_placeholder, labels_placeholder, train_images)
相关文章推荐
- 构造函数语义学
- 自动档车高级驾驶技术完整攻略
- 在日期数据上加一天
- linux服务器在运行210天左右宕机
- XML编码utf-8有中文无法解析或乱码 C#
- C语言_初学结构体_plusC14.2
- tensorflow 图像分类实战解析(上)
- iOS 下关于 MD5 的那个坑
- wheelView
- java学习笔记2
- (转载)C# 编程 使用可空类型
- Mock 模拟测试简介及 Mockito 使用入门
- JAVA list+for循环实现分页
- Window attributes属性详解
- usaco 2006 nov poj3255 严格次短路
- 计蒜客-程序设计竞赛入门
- 建造者模式
- Android学习一(windows安装Git)
- 蓝牙模块
- 第四章 流程控制与数组