您的位置:首页 > 理论基础 > 计算机网络

TensorFlow 深度学习框架 (2)-- 反向传播优化神经网络

2018-03-16 09:32 731 查看
训练神经网络的过程就是设置神经网络参数的过程,只有经过有效训练的神经网络模型才可以真正的解决分类问题或回归问题。使用监督学习的方式设置神经网络参数需要有一个标注好的训练数据集。监督学习的最重要的思想是,在已知答案的标注数据集上,模型给出的预测结果要尽可能接近真实的答案。
在神经网络优化算法中,最常用的就是反向传播算法。反向传播算法的具体工作流程如图



如图所示是训练的流程图,那么在训练之前,还有一个问题有待解决,就是如何评估模型与标注数据的差距?答案是通过优化损失函数,损失函数就是代表训练数据与标注数据差异的一个指标。针对分类问题,其中最常用的损失函数之一就是交叉熵。关于交叉熵放到后续再仔细讨论,这里知道这么一个概念就可以了。
针对第(1)讲述的那个神经网络结构,训练代码示例和注释 如下import tensorflow as tf
from numpy.random import RandomState

#以下就是第(1)节中的前向传播过程
batch_size = 8
w1 = tf.Variable(tf.random_normal([2,3], stddev = 1,seed = 1))
w2 = tf.Variable(tf.random_normal([3,1], stddev = 1,seed = 1))
x = tf.placeholder(tf.float32, shape = (None,2),name = "x-input")) # shape = (None,2) 代表一个batch 的训练数据
y_ = tf.placeholder(tf.float32,shape = (None,1),name = "y-input")) # 一个batch 的标注数据

#前向传播过程
a = tf.matmul(x,w1)
y = tf.matmul(a,w2)

#定义损失函数(交叉熵)和反向传播算法
cross_entropy = -tf.reduce_mean(
    y_ * tf.log(tf.clip_by_value(y,1e-10,1.0)))
train_step = tf.train.AdamOptimizer(0.001).minimize(cross_entropy)

#通过随机数生成一个模拟训练数据集及其标注
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size,2)
#定义规则来给出样本的标签,在这里所有 x1 + x2 < 1 的样例被认为是正样本
Y = [[int(x1 + x2 < 1)] for (x1,x2) in X]

#创建会话来运行 TensorFlow 程序
with tf.Session() as sess:
    init_op = tf.initialize_all_variables()
    sess.run(init_op) #变量初始化
    
    #设定训练的轮数
    STEPS = 5000
    for i in range(STEPS):
        #每次选取 batch_size 个样本进行训练
        start = (i * batch_size) % dataset_size
        end = min(start + batch_size,dataset_size)

        #通过选取的样本训练神经网络并更新参数
        sess.run(train_step,feed_dict = {x:X[start:end],y_:Y[start:end]})
        #每 1000 轮观测总体交叉熵的结果
        if i % 1000 == 0:
            total_cross_entropy = sess.run(cross_entropy,
                                            feed_dict = {x:X,y_:Y})
            print("After %d training steps,cross entropy on all data is %g" %(i,total_cross_entropy))
            """
            在运行的过程中,交叉熵越小表明训练随着轮数的增加,训练的结果与真实的差距越来越小
            """
    
    #最后观测训练出来的参数值
    print sess.run(w1)
    print sess.run(w2)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  tensorflow 反向传播