您的位置:首页 > 其它

tensorflow图像分类实战解析(下)

2016-07-18 11:21 441 查看
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)
return train_op


设定最小化目标以及最大步长

def evaluation(logits, labels):
correct = tf.nn.in_top_k(logits, labels, 1)
return tf.reduce_sum(tf.cast(correct, tf.int32))


衡量标签得到的数值和ground-truth之间的关系

def placeholder_inputs(batch_size):
images_placeholder = tf.placeholder(tf.float32, shape=(batch_size,IMAGE_PIXELS))


这里的batchsize定义了placeholder一次性读入图片的数目,所以理论上应该是有能力不需要将所有图片全部读入内存中再进行处理的。这个我还得继续研究

labels_placeholder = tf.placeholder(tf.int32, shape=(batch_size))
return images_placeholder, labels_placeholder

def fill_feed_dict(images_feed,labels_feed, images_pl, labels_pl):
feed_dict = {
images_pl: images_feed,
labels_pl: labels_feed,
}
return feed_dict


很重要的feed-dict 用来灌数据,其实就是在建立一个字典

def do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
data_set):
# And run one epoch of eval.


评价现在参数训练效果的函数

true_count = 0  # Counts the number of correct predictions.
steps_per_epoch = 4 // FLAGS.batch_size
num_examples = steps_per_epoch * FLAGS.batch_size
for step in xrange(steps_per_epoch):
feed_dict = fill_feed_dict(train_images,train_labels,
images_placeholder,
labels_placeholder)
true_count += sess.run(eval_correct, feed_dict=feed_dict)


有必要再加深了解一下这个feed_dict是什么

precision = true_count / num_examples
print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %
(num_examples, true_count, precision))

# Get the sets of images and labels for training, validation, and
train_images = []
for filename in ['a.jpg', 'b.jpg', 'c.jpg', 'd.jpg']:
image = Image.open(FLAGS.train_dir+'/'+filename)
image = image.resize((IMAGE_SIZE,IMAGE_SIZE))
train_images.append(np.array(image))


读取图像训练列表到内存中

train_images = np.array(train_images)


转换为nummpy的array

train_images = train_images.reshape(4,IMAGE_PIXELS)


将图像序列化

label = [0,1,1,1]
train_labels = np.array(label)


图像的ground-truth

def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Generate placeholders for the images and labels.
images_placeholder, labels_placeholder = placeholder_inputs(4)

# Build a Graph that computes predictions from the inference model.
logits = inference(images_placeholder,
FLAGS.hidden1,
FLAGS.hidden2)

# Add to the Graph the Ops for loss calculation.
loss = cal_loss(logits, labels_placeholder)

# Add to the Graph the Ops that calculate and apply gradients.
train_op = training(loss, FLAGS.learning_rate)

# Add the Op to compare the logits to the labels during evaluation.
eval_correct = evaluation(logits, labels_placeholder)

# Create a saver for writing training checkpoints.
saver = tf.train.Saver()

# Create a session for running Ops on the Graph.
sess = tf.Session()


我们注意到,这里有两种session可以选择Session和InteractiveSession.唯一区别在于InteractiveSession的eval和run方法会默认在interactivesession中进行调用,用法样例如下:

InteractiveSession:

sess = tf.InteractiveSession()
a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
# We can just use 'c.eval()' without passing 'sess'
print(c.eval())
sess.close()


Session:

a = tf.constant(5.0)
b = tf.constant(6.0)
c = a * b
with tf.Session():
# We can also use 'c.eval()' here.
print(c.eval())


如果不是很明白,还可以看看这个例子:

x = tf.linspace(-3.0, 3.0, n_values)

# %% Construct a tf.Session to execute the graph.
sess = tf.Session()
result = sess.run(x)

# %% Alternatively pass a session to the eval fn:
x.eval(session=sess)
# x.eval() does not work, as it requires a session!

# %% We can setup an interactive session if we don't
# want to keep passing the session around:
sess.close()
sess = tf.InteractiveSession()

# %% Now this will work!
x.eval()


# Run the Op to initialize the variables.
init = tf.initialize_all_variables()
sess.run(init)


初始化图中的所有变量

# And then after everything is built, start the training loop.
for step in xrange(FLAGS.max_steps):
start_time = time.time()
feed_dict = fill_feed_dict(train_images,train_labels,
images_placeholder,
labels_placeholder)
_, loss_value = sess.run([train_op, loss],
feed_dict=feed_dict)


进行训练,每次run只会驱动一次优化过程

duration = time.time() - start_time
if step % 100 == 0:
# Print status to stdout.
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
saver.save(sess, FLAGS.train_dir, global_step=step)


训练完成,将结果存入train_dir路径中

print('Training Data Eval:')
do_eval(sess,
eval_correct,
images_placeholder,
labels_placeholder,
train_images)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: