您的位置:首页 > 编程语言

TensorFlow 机器学实战指南示例代码之 TensorFlow 实现随机训练和批量训练

2018-02-07 14:26 495 查看
"""
批量训练
"""
import os
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

s = tf.Session()

# 声明批量训练的数据量的大小
batch_size = 20

# 声明模型的数据、占位符和变量
x_vals = np.random.normal(1, 0.1, 100)
y_vals = np.repeat(10., 100)
x_data = tf.placeholder(shape=[None, 1], dtype=tf.float32)  # 可显式地设置维度为 20
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)  # 也可设置为 None
A = tf.Variable(tf.random_normal(shape=[1, 1]))
B = tf.Variable(tf.random_normal(shape=[1, 1]))

# 初始化变量
init = tf.global_variables_initializer()
s.run(init)

# 在计算图中增加矩阵乘法操作
my_output = tf.matmul(x_data, A)

# 批量训练时,损失函数是每个数据点 L2 损失的平均值
loss = tf.reduce_mean(tf.square(my_output - y_target))

# 声明优化器
my_opt = tf.train.GradientDescentOptimizer(0.02)
train_step = my_opt.minimize(loss)

# 在训练过程中,通过循环迭代优化模型算法

loss_batch = []  # 初始化一个列表,每隔 5 次迭代保存损失函数

for i in range(100):
rand_index = np.random.choice(100, size=batch_size)
rand_x = np.transpose([x_vals[rand_index]])
rand_y = np.transpose([y_vals[rand_index]])
s.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

if (i+1) % 5 == 0:
print('Step # ' + str(i + 1) + ' A = ' + str(s.run(A)))
temp_loss = s.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
print('Loss = ' + str(temp_loss))
loss_batch.append(temp_loss)

# 为防止上一节代码中变量 A 的值的影响, 在进行随机训练时,需要将变量 A 重新初始化
init1 = tf.global_variables_initializer()
s.run(init1)

# 随机损失代码
loss_stochastic = []
for j in range(100):
rand_index = np.random.choice(100)
rand_x = [[x_vals[rand_index]]]
rand_y = [[y_vals[rand_index]]]
s.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

if (j + 1) % 5 == 0:
print('Step # ' + str(j + 1) + ' A = ' + str(s.run(A)))
temp_loss = s.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
print('Loss = ' + str(temp_loss))
loss_stochastic.append(temp_loss)

# 绘制回归算法的随机训练损失和批量训练损失
plt.plot(range(0, 100, 5), loss_stochastic, 'b-', label='Stochastic Loss')
plt.plot(range(0, 100, 5), loss_batch, 'r--', label='Batch Loss, size = 20')
plt.legend(loc='upper right', prop={'size': 11})
plt.show()
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
相关文章推荐