您的位置:首页 > 理论基础 > 计算机网络

【吴恩达课后编程作业】01 - 神经网络和深度学习 - 第四周 - PA1&2 - 一步步搭建多层神经网络以及应用

2018-10-06 15:53 1091 查看

【吴恩达课后编程作业】01 - 神经网络和深度学习 - 第四周 - PA1&2 - 一步步搭建多层神经网络以及应用

上一篇:【课程1 - 第四周测验】※※※※※ 【回到目录】※※※※※下一篇:【课程2 - 第一周测验】

声明

  本文参考Kulbear【Building your Deep Neural Network - Step by Step】【Deep Neural Network - Application】,以及念师【8. 多层神经网络代码实战】,我基于以上的文章加以自己的理解发表这篇博客,力求让大家以最轻松的姿态理解吴恩达的视频,如有不妥的地方欢迎大家指正。

本文所使用的资料已上传到百度网盘【点击下载】,请在开始之前下载好所需资料,或者在本文底部copy资料代码。

【博主使用的python版本:3.6.2】

开始之前

  在正式开始之前,我们先来了解一下我们要做什么。在本次教程中,我们要构建两个神经网络,一个是构建两层的神经网络,一个是构建多层的神经网络,多层神经网络的层数可以自己定义。本次的教程的难度有所提升,但是我会力求深入简出。在这里,我们简单的讲一下难点,本文会提到[LINEAR-> ACTIVATION]转发函数,比如我有一个多层的神经网络,结构是输入层->隐藏层->隐藏层->···->隐藏层->输出层,在每一层中,我会首先计算

Z = np.dot(W,A) + b
,这叫做【linear_forward】,然后再计算
A = relu(Z)
或者
A = sigmoid(Z)
,这叫做【linear_activation_forward】,合并起来就是这一层的计算方法,所以每一层的计算都有两个步骤,先是计算Z,再计算A,你也可以参照下图:

我们来说一下步骤:

  1. 初始化网络参数

  2. 前向传播

2.1 计算一层的中线性求和的部分

2.2 计算激活函数的部分(ReLU使用L-1次,Sigmod使用1次)

2.3 结合线性求和与激活函数

  • 计算误差

  • 反向传播

  • 4.1 线性部分的反向传播公式

    4.2 激活函数部分的反向传播公式

    4.3 结合线性部分与激活函数的反向传播公式

  • 更新参数

  •   请注意,对于每个前向函数,都有一个相应的后向函数。 这就是为什么在我们的转发模块的每一步都会在cache中存储一些值,cache的值对计算梯度很有用, 在反向传播模块中,我们将使用cache来计算梯度。 现在我们正式开始分别构建两层神经网络和多层神经网络。

    准备软件包

    在开始我们需要准备一些软件包:

    import numpy as np
    import h5py
    import matplotlib.pyplot as plt
    import testCases #参见资料包,或者在文章底部copy
    from dnn_utils import sigmoid, sigmoid_backward, relu, relu_backward #参见资料包
    import lr_utils #参见资料包,或者在文章底部copy
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    [/code]

    软件包准备好了,我们开始构建初始化参数的函数。

    为了和我的数据匹配,你需要指定随机种子

    np.random.seed(1)
    • 1
    [/code]

    初始化参数

    对于一个两层的神经网络结构而言,模型结构是线性->ReLU->线性->sigmod函数。

    初始化函数如下:

    def initialize_parameters(n_x,n_h,n_y):
    """
    此函数是为了初始化两层网络参数而使用的函数。
    参数:
    n_x - 输入层节点数量
    n_h - 隐藏层节点数量
    n_y - 输出层节点数量
    
    返回:
    parameters - 包含你的参数的python字典:
    W1 - 权重矩阵,维度为(n_h,n_x)
    b1 - 偏向量,维度为(n_h,1)
    W2 - 权重矩阵,维度为(n_y,n_h)
    b2 - 偏向量,维度为(n_y,1)
    
    """
    W1 = np.random.randn(n_h, n_x) * 0.01
    b1 = np.zeros((n_h, 1))
    W2 = np.random.randn(n_y, n_h) * 0.01
    b2 = np.zeros((n_y, 1))
    
    #使用断言确保我的数据格式是正确的
    assert(W1.shape == (n_h, n_x))
    assert(b1.shape == (n_h, 1))
    assert(W2.shape == (n_y, n_h))
    assert(b2.shape == (n_y, 1))
    
    parameters = {"W1": W1,
    "b1": b1,
    "W2": W2,
    "b2": b2}
    
    return parameters
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    • 33
    [/code]

    初始化完成我们来测试一下:

    print("==============测试initialize_parameters==============")
    parameters = initialize_parameters(3,2,1)
    print("W1 = " + str(parameters["W1"]))
    print("b1 = " + str(parameters["b1"]))
    print("W2 = " + str(parameters["W2"]))
    print("b2 = " + str(parameters["b2"]))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    [/code]

    测试结果:

    ==============测试initialize_parameters==============
    W1 = [[ 0.01624345 -0.00611756 -0.00528172]
    [-0.01072969  0.00865408 -0.02301539]]
    b1 = [[ 0.]
    [ 0.]]
    W2 = [[ 0.01744812 -0.00761207]]
    b2 = [[ 0.]]
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    [/code]

    两层的神经网络测试已经完毕了,那么对于一个L层的神经网络而言呢?初始化会是什么样的?

    假设X(输入数据)的维度为(12288,209):

    <tbody><tr>
    <td>  </td>
    <td> W的维度 </td>
    <td> b的维度  </td>
    <td> 激活值的计算</td>
    <td> 激活值的维度</td>
    </tr><tr>
    
    </tr><tr>
    <td> 第 1 层 </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-9-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><mn>12288</mn><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-1" style="width: 6.193em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.122em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1005em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-2"><span class="mo" id="MathJax-Span-3" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-4"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-5" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-6"><span class="mrow" id="MathJax-Span-7"><span class="mo" id="MathJax-Span-8" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-9" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-10" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-11" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-12" style="font-family: MathJax_Main; padding-left: 0.182em;">12288</span><span class="mo" id="MathJax-Span-13" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><mn>12288</mn><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-9">(n^{[1]},12288)</script> </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-10-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><mn>1</mn><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-14" style="width: 3.812em; display: inline-block;"><span style="display: inline-block; position: relative; width: 3.158em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1003.04em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-15"><span class="mo" id="MathJax-Span-16" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-17"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-18" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span cl
    1aa6f
    ass="texatom" id="MathJax-Span-19"><span class="mrow" id="MathJax-Span-20"><span class="mo" id="MathJax-Span-21" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-22" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-23" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-24" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-25" style="font-family: MathJax_Main; padding-left: 0.182em;">1</span><span class="mo" id="MathJax-Span-26" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><mn>1</mn><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-10">(n^{[1]},1)</script> </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-11-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><msup><mi>Z</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>=</mo><msup><mi>W</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mi>X</mi><mo>+</mo><msup><mi>b</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-27" style="width: 6.967em; display: inline-block;"><span style="display: inline-block; position: relative; width: 5.777em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(2.92em, 1005.78em, 5.598em, -999.997em); top: -3.985em; left: 0em;"><span class="mrow" id="MathJax-Span-28"><span style="display: inline-block; position: relative; width: 5.777em; height: 0px;"><span style="position: absolute; clip: rect(2.92em, 1005.78em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="msubsup" id="MathJax-Span-29"><span style="display: inline-block; position: relative; width: 1.61em; height: 0px;"><span style="position: absolute; clip: rect(3.158em, 1000.72em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-30" style="font-family: MathJax_Math-italic;">Z<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.063em;"></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.777em;"><span class="texatom" id="MathJax-Span-31"><span class="mrow" id="MathJax-Span-32"><span class="mo" id="MathJax-Span-33" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-34" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-35" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-36" style="font-family: MathJax_Main; padding-left: 0.301em;">=</span><span class="msubsup" id="MathJax-Span-37" style="padding-left: 0.301em;"><span style="display: inline-block; position: relative; width: 1.967em; height: 0px;"><span style="position: absolute; clip: rect(3.158em, 1001.07em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-38" style="font-family: MathJax_Math-italic;">W<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.122em;"></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 1.134em;"><span class="texatom" id="MathJax-Span-39"><span class="mrow" id="MathJax-Span-40"><span class="mo" id="MathJax-Span-41" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-42" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-43" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mi" id="MathJax-Span-44" style="font-family: MathJax_Math-italic;">X<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.003em;"></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; clip: rect(2.92em, 1002.26em, 4.229em, -999.997em); top: -2.616em; left: 0em;"><span class="mo" id="MathJax-Span-45" style="font-family: MathJax_Main;">+</span><span class="msubsup" id="MathJax-Span-46" style="padding-left: 0.241em;"><span style="display: inline-block; position: relative; width: 1.253em; height: 0px;"><span style="position: absolute; clip: rect(3.098em, 1000.42em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-47" style="font-family: MathJax_Math-italic;">b</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.42em;"><span class="texatom" id="MathJax-Span-48"><span class="mrow" id="MathJax-Span-49"><span class="mo" id="MathJax-Span-50" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-51" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-52" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -1.782em; border-left: 0px solid; width: 0px; height: 2.932em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><msup><mi>Z</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo>=</mo><msup><mi>W</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mi>X</mi><mo>+</mo><msup><mi>b</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup></math></span></span><script type="math/tex" id="MathJax-Element-11">Z^{[1]} = W^{[1]}  X + b^{[1]} </script> </td>
    
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-12-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><mn>209</mn><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-53" style="width: 5.003em; display: inline-block;"><span style="display: inline-block; position: relative; width: 4.17em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1004.05em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-54"><span class="mo" id="MathJax-Span-55" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-56"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-57" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-58"><span class="mrow" id="MathJax-Span-59"><span class="mo" id="MathJax-Span-60" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-61" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-62" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-63" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-64" style="font-family: MathJax_Main; padding-left: 0.182em;">209</span><span class="mo" id="MathJax-Span-65" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><mn>209</mn><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-12">(n^{[1]},209)</script> </td>
    </tr><tr>
    
    </tr><tr>
    <td> 第 2 层 </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-13-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-66" style="width: 4.884em; display: inline-block;"><span style="display: inline-block; position: relative; width: 4.051em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1003.93em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-67"><span class="mo" id="MathJax-Span-68" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-69"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-70" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-71"><span class="mrow" id="MathJax-Span-72"><span class="mo" id="MathJax-Span-73" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-74" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-75" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-76" style="font-family: MathJax_Main;">,</span><span class="msubsup" id="MathJax-Span-77" style="padding-left: 0.182em;"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-78" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-79"><span class="mrow" id="MathJax-Span-80"><span class="mo" id="MathJax-Span-81" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-82" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-83" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-84" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-13">(n^{[2]}, n^{[1]})</script>  </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-14-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><mn>1</mn><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-85" style="width: 3.812em; display: inline-block;"><span style="display: inline-block; position: relative; width: 3.158em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1003.04em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-86"><span class="mo" id="MathJax-Span-87" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-88"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-89" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-90"><span class="mrow" id="MathJax-Span-91"><span class="mo" id="MathJax-Span-92" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-93" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-94" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-95" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-96" style="font-family: MathJax_Main; padding-left: 0.182em;">1</span><span class="mo" id="MathJax-Span-97" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><mn>1</mn><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-14">(n^{[2]},1)</script> </td>
    <td><span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-15-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><msup><mi>Z</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>=</mo><msup><mi>W</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><msup><mi>A</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>1</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>+</mo><msup><mi>b</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-98" style="width: 7.801em; display: inline-block;"><span style="display: inline-block; position: relative; width: 6.491em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(2.92em, 1006.49em, 5.598em, -999.997em); top: -3.985em; left: 0em;"><span class="mrow" id="MathJax-Span-99"><span style="display: inline-block; position: relative; width: 6.491em; height: 0px;"><span style="position: absolute; clip: rect(2.92em, 1006.49em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="msubsup" id="MathJax-Span-100"><span style="display: inline-block; position: relative; width: 1.61em; height: 0px;"><span style="position: absolute; clip: rect(3.158em, 1000.72em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-101" style="font-family: MathJax_Math-italic;">Z<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.063em;"></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.777em;"><span class="texatom" id="MathJax-Span-102"><span class="mrow" id="MathJax-Span-103"><span class="mo" id="MathJax-Span-104" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-105" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-106" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-107" style="font-family: MathJax_Main; padding-left: 0.301em;">=</span><span class="msubsup" id="MathJax-Span-108" style="padding-left: 0.301em;"><span style="display: inline-block; position: relative; width: 1.967em; height: 0px;"><span style="position: absolute; clip: rect(3.158em, 1001.07em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-109" style="font-family: MathJax_Math-italic;">W<span style="display: inline-block; overflow: hidden; height: 1px; width: 0.122em;"></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 1.134em;"><span class="texatom" id="MathJax-Span-110"><span class="mrow" id="MathJax-Span-111"><span class="mo" id="MathJax-Span-112" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-113" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-114" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="msubsup" id="MathJax-Span-115"><span style="display: inline-block; position: relative; width: 1.551em; height: 0px;"><span style="position: absolute; clip: rect(3.098em, 1000.72em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-116" style="font-family: MathJax_Math-italic;">A</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.777em;"><span class="texatom" id="MathJax-Span-117"><span class="mrow" id="MathJax-Span-118"><span class="mo" id="MathJax-Span-119" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-120" style="font-size: 70.7%; font-family: MathJax_Main;">1</span><span class="mo" id="MathJax-Span-121" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; clip: rect(2.92em, 1002.26em, 4.229em, -999.997em); top: -2.616em; left: 0em;"><span class="mo" id="MathJax-Span-122" style="font-family: MathJax_Main;">+</span><span class="msubsup" id="MathJax-Span-123" style="padding-left: 0.241em;"><span style="display: inline-block; position: relative; width: 1.253em; height: 0px;"><span style="position: absolute; clip: rect(3.098em, 1000.42em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-124" style="font-family: MathJax_Math-italic;">b</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.42em;"><span class="texatom" id="MathJax-Span-125"><span class="mrow" id="MathJax-Span-126"><span class="mo" id="MathJax-Span-127" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-128" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-129" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -1.782em; border-left: 0px solid; width: 0px; height: 2.932em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><msup><mi>Z</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup><mo>=</mo><msup><mi>W</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup><msup><mi>A</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>1</mn><mo stretchy="false">]</mo></mrow></msup><mo>+</mo><msup><mi>b</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup></math></span></span><script type="math/tex" id="MathJax-Element-15">Z^{[2]} = W^{[2]} A^{[1]} + b^{[2]}</script> </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-16-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo stretchy=&quot;false&quot;>(</mo><msup><mi>n</mi><mrow class=&quot;MJX-TeXAtom-ORD&quot;><mo stretchy=&quot;false&quot;>[</mo><mn>2</mn><mo stretchy=&quot;false&quot;>]</mo></mrow></msup><mo>,</mo><mn>209</mn><mo stretchy=&quot;false&quot;>)</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-130" style="width: 5.003em; display: inline-block;"><span style="display: inline-block; position: relative; width: 4.17em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(1.134em, 1004.05em, 2.622em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-131"><span class="mo" id="MathJax-Span-132" style="font-family: MathJax_Main;">(</span><span class="msubsup" id="MathJax-Span-133"><span style="display: inline-block; position: relative; width: 1.432em; height: 0px;"><span style="position: absolute; clip: rect(3.396em, 1000.6em, 4.17em, -999.997em); top: -3.985em; left: 0em;"><span class="mi" id="MathJax-Span-134" style="font-family: MathJax_Math-italic;">n</span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span><span style="position: absolute; top: -4.342em; left: 0.598em;"><span class="texatom" id="MathJax-Span-135"><span class="mrow" id="MathJax-Span-136"><span class="mo" id="MathJax-Span-137" style="font-size: 70.7%; font-family: MathJax_Main;">[</span><span class="mn" id="MathJax-Span-138" style="font-size: 70.7%; font-family: MathJax_Main;">2</span><span class="mo" id="MathJax-Span-139" style="font-size: 70.7%; font-family: MathJax_Main;">]</span></span></span><span style="display: inline-block; width: 0px; height: 3.991em;"></span></span></span></span><span class="mo" id="MathJax-Span-140" style="font-family: MathJax_Main;">,</span><span class="mn" id="MathJax-Span-141" style="font-family: MathJax_Main; padding-left: 0.182em;">209</span><span class="mo" id="MathJax-Span-142" style="font-family: MathJax_Main;">)</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.354em; border-left: 0px solid; width: 0px; height: 1.504em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo stretchy="false">(</mo><msup><mi>n</mi><mrow class="MJX-TeXAtom-ORD"><mo stretchy="false">[</mo><mn>2</mn><mo stretchy="false">]</mo></mrow></msup><mo>,</mo><mn>209</mn><mo stretchy="false">)</mo></math></span></span><script type="math/tex" id="MathJax-Element-16">(n^{[2]}, 209)</script> </td>
    </tr><tr>
    
    </tr><tr>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-17-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo>&amp;#x22EE;</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-143" style="width: 0.36em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.301em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(0.717em, 1000.24em, 2.443em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-144"><span class="mo" id="MathJax-Span-145" style="font-family: MathJax_Main;">⋮</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.139em; border-left: 0px solid; width: 0px; height: 1.718em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo>⋮</mo></math></span></span><script type="math/tex" id="MathJax-Element-17">\vdots</script> </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-18-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo>&amp;#x22EE;</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-146" style="width: 0.36em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.301em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(0.717em, 1000.24em, 2.443em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-147"><span class="mo" id="MathJax-Span-148" style="font-family: MathJax_Main;">⋮</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.139em; border-left: 0px solid; width: 0px; height: 1.718em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo>⋮</mo></math></span></span><script type="math/tex" id="MathJax-Element-18">\vdots</script>  </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-19-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo>&amp;#x22EE;</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-149" style="width: 0.36em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.301em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(0.717em, 1000.24em, 2.443em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-150"><span class="mo" id="MathJax-Span-151" style="font-family: MathJax_Main;">⋮</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.139em; border-left: 0px solid; width: 0px; height: 1.718em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo>⋮</mo></math></span></span><script type="math/tex" id="MathJax-Element-19">\vdots</script>  </td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-20-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo>&amp;#x22EE;</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-152" style="width: 0.36em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.301em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(0.717em, 1000.24em, 2.443em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-153"><span class="mo" id="MathJax-Span-154" style="font-family: MathJax_Main;">⋮</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.139em; border-left: 0px solid; width: 0px; height: 1.718em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo>⋮</mo></math></span></span><script type="math/tex" id="MathJax-Element-20">\vdots</script></td>
    <td> <span class="MathJax_Preview" style="color: inherit; display: none;"></span><span class="MathJax" id="MathJax-Element-21-Frame" tabindex="0" style="position: relative;" data-mathml="<math xmlns=&quot;http://www.w3.org/1998/Math/MathML&quot;><mo>&amp;#x22EE;</mo></math>" role="presentation"><nobr aria-hidden="true"><span class="math" id="MathJax-Span-155" style="width: 0.36em; display: inline-block;"><span style="display: inline-block; position: relative; width: 0.301em; height: 0px; font-size: 120%;"><span style="position: absolute; clip: rect(0.717em, 1000.24em, 2.443em, -999.997em); top: -2.199em; left: 0em;"><span class="mrow" id="MathJax-Span-156"><span class="mo" id="MathJax-Span-157" style="font-family: MathJax_Main;">⋮</span></span><span style="display: inline-block; width: 0px; height: 2.205em;"></span></span></span><span style="display: inline-block; overflow: hidden; vertical-align: -0.139em; border-left: 0px solid; width: 0px; height: 1.718em;"></span></span></nobr><span class="MJX_Assistive_MathML" role="presentation"><math xmlns="http://www.w3.org/1998/Math/MathML"><mo>⋮</mo></math></span></span><script type="math/tex" id="MathJax-Element-21">\vdots</script>  </td>
    </tr><tr>
    第 L-1 层 (n[L−1],n[L−2])(n[L−1],n[L−2]) 第 L 层 (n[L],n[L−1])(n[L],n[L−1])

    当然,矩阵的计算方法还是要说一下的:

    W=⎡⎣⎢jmpknqlor⎤⎦⎥X=⎡⎣⎢adgbehcfi⎤⎦⎥b=⎡⎣⎢stu⎤⎦⎥(1)(1)W=[jklmnopqr]X=[abcdefghi]b=[stu]

    如果要计算 WX+bWX+b 的话,计算方法是这样的:

    WX+b=⎡⎣⎢(ja+kd+lg)+s(ma+nd+og)+t(pa+qd+rg)+u(jb+ke+lh)+s(mb+ne+oh)+t(pb+qe+rh)+u(jc+kf+li)+s(mc+nf+oi)+t(pc+qf+ri)+u⎤⎦⎥(2)(2)WX+b=[(ja+kd+lg)+s(jb+ke+lh)+s(jc+kf+li)+s(ma+nd+og)+t(mb+ne+oh)+t(mc+nf+oi)+t(pa+qd+rg)+u(pb+qe+rh)+u(pc+qf+ri)+u]

    在实际中,也不需要你去做这么复杂的运算,我们来看一下它是怎样计算的吧:

    def initialize_parameters_deep(layers_dims):
    """
    此函数是为了初始化多层网络参数而使用的函数。
    参数:
    layers_dims - 包含我们网络中每个图层的节点数量的列表
    
    返回:
    parameters - 包含参数“W1”,“b1”,...,“WL”,“bL”的字典:
    W1 - 权重矩阵,维度为(layers_dims [1],layers_dims [1-1])
    bl - 偏向量,维度为(layers_dims [1],1)
    """
    np.random.seed(3)
    parameters = {}
    L = len(layers_dims)
    
    for l in range(1,L):
    parameters["W" + str(l)] = np.random.randn(layers_dims[l], layers_dims[l - 1]) / np.sqrt(layers_dims[l - 1])
    parameters["b" + str(l)] = np.zeros((layers_dims[l], 1))
    
    #确保我要的数据的格式是正确的
    assert(parameters["W" + str(l)].shape == (layers_dims[l], layers_dims[l-1]))
    assert(parameters["b" + str(l)].shape == (layers_dims[l], 1))
    
    return parameters
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    [/code]

    测试一下:

    #测试initialize_parameters_deep
    print("==============测试initialize_parameters_deep==============")
    layers_dims = [5,4,3]
    parameters = initialize_parameters_deep(layers_dims)
    print("W1 = " + str(parameters["W1"]))
    print("b1 = " + str(parameters["b1"]))
    print("W2 = " + str(parameters["W2"]))
    print("b2 = " + str(parameters["b2"]))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    [/code]

    测试结果:

    ==============测试initialize_parameters_deep==============
    W1 = [[ 0.01788628  0.0043651   0.00096497 -0.01863493 -0.00277388]
    [-0.00354759 -0.00082741 -0.00627001 -0.00043818 -0.00477218]
    [-0.01313865  0.00884622  0.00881318  0.01709573  0.00050034]
    [-0.00404677 -0.0054536  -0.01546477  0.00982367 -0.01101068]]
    b1 = [[ 0.]
    [ 0.]
    [ 0.]
    [ 0.]]
    W2 = [[-0.01185047 -0.0020565   0.01486148  0.00236716]
    [-0.01023785 -0.00712993  0.00625245 -0.00160513]
    [-0.00768836 -0.00230031  0.00745056  0.01976111]]
    b2 = [[ 0.]
    [ 0.]
    [ 0.]]
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    [/code]

    我们分别构建了两层和多层神经网络的初始化参数的函数,现在我们开始构建前向传播函数。

    前向传播函数

    前向传播有以下三个步骤

    • LINEAR
    • LINEAR - >ACTIVATION,其中激活函数将会使用ReLU或Sigmoid。
    • [LINEAR - > RELU] ×(L-1) - > LINEAR - > SIGMOID(整个模型)

    线性正向传播模块(向量化所有示例)使用公式(3)进行计算:

    Z[l]=W[l]A[l−1]+b[l](3)(3)Z[l]=W[l]A[l−1]+b[l]

    线性部分【LINEAR】

    前向传播中,线性部分计算如下:

    def linear_forward(A,W,b):
    """
    实现前向传播的线性部分。
    
    参数:
    A - 来自上一层(或输入数据)的激活,维度为(上一层的节点数量,示例的数量)
    W - 权重矩阵,numpy数组,维度为(当前图层的节点数量,前一图层的节点数量)
    b - 偏向量,numpy向量,维度为(当前图层节点数量,1)
    
    返回:
    Z - 激活功能的输入,也称为预激活参数
    cache - 一个包含“A”,“W”和“b”的字典,存储这些变量以有效地计算后向传递
    """
    Z = np.dot(W,A) + b
    assert(Z.shape == (W.shape[0],A.shape[1]))
    cache = (A,W,b)
    
    return Z,cache
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    [/code]

    测试一下线性部分:

    #测试linear_forward
    print("==============测试linear_forward==============")
    A,W,b = testCases.linear_forward_test_case()
    Z,linear_cache = linear_forward(A,W,b)
    print("Z = " + str(Z))
    • 1
    • 2
    • 3
    • 4
    • 5
    [/code]

    测试结果:

    ==============测试linear_forward==============
    Z = [[ 3.26295337 -1.23429987]]
    • 1
    • 2
    [/code]

    我们前向传播的单层计算完成了一半啦!我们来开始构建后半部分,如果你不知道我在说啥,请往上翻到【开始之前】仔细看看吧~

    线性激活部分【LINEAR - >ACTIVATION】

      为了更方便,我们将把两个功能(线性和激活)分组为一个功能(LINEAR-> ACTIVATION)。 因此,我们将实现一个执行LINEAR前进步骤,然后执行ACTIVATION前进步骤的功能。我们来看看这激活函数的数学实现吧~

    • Sigmoid: σ(Z)=σ(WA+b)=11+e−(WA+b)σ(Z)=σ(WA+b)=11+e−(WA+b)

      我们为了实现LINEAR->ACTIVATION这个步骤, 使用的公式是:A[l]=g(Z[l])=g(W[l]A[l−1]+b[l])A[l]=g(Z[l])=g(W[l]A[l−1]+b[l]),其中,函数g会是sigmoid() 或者是 relu(),当然,sigmoid()只在输出层使用,现在我们正式构建前向线性激活部分。

    def linear_activation_forward(A_prev,W,b,activation):
    """
    实现LINEAR-> ACTIVATION 这一层的前向传播
    
    参数:
    A_prev - 来自上一层(或输入层)的激活,维度为(上一层的节点数量,示例数)
    W - 权重矩阵,numpy数组,维度为(当前层的节点数量,前一层的大小)
    b - 偏向量,numpy阵列,维度为(当前层的节点数量,1)
    activation - 选择在此层中使用的激活函数名,字符串类型,【"sigmoid" | "relu"】
    
    返回:
    A - 激活函数的输出,也称为激活后的值
    cache - 一个包含“linear_cache”和“activation_cache”的字典,我们需要存储它以有效地计算后向传递
    """
    
    if activation == "sigmoid":
    Z, linear_cache = linear_forward(A_prev, W, b)
    A, activation_cache = sigmoid(Z)
    elif activation == "relu":
    Z, linear_cache = linear_forward(A_prev, W, b)
    A, activation_cache = relu(Z)
    
    assert(A.shape == (W.shape[0],A_prev.shape[1]))
    cache = (linear_cache,activation_cache)
    
    return A,cache
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    [/code]

    测试一下:

    #测试linear_activation_forward
    print("==============测试linear_activation_forward==============")
    A_prev, W,b = testCases.linear_activation_forward_test_case()
    
    A, linear_activation_cache = linear_activation_forward(A_prev, W, b, activation = "sigmoid")
    print("sigmoid,A = " + str(A))
    
    A, linear_activation_cache = linear_activation_forward(A_prev, W, b, activation = "relu")
    print("ReLU,A = " + str(A))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    [/code]

    测试结果:

    ==============测试linear_activation_forward==============
    sigmoid,A = [[ 0.96890023  0.11013289]]
    ReLU,A = [[ 3.43896131  0.        ]]
    • 1
    • 2
    • 3
    [/code]

      我们把两层模型需要的前向传播函数做完了,那多层网络模型的前向传播是怎样的呢?我们调用上面的那两个函数来实现它,为了在实现L层神经网络时更加方便,我们需要一个函数来复制前一个函数(带有RELU的linear_activation_forward)L-1次,然后用一个带有SIGMOID的linear_activation_forward跟踪它,我们来看一下它的结构是怎样的:
    RELU] ×× (L-1) -> LINEAR -> SIGMOID model" title="">

    Figure 2 : [LINEAR -> RELU] ×× (L-1) -> LINEAR -> SIGMOID model

    在下面的代码中,

    AL
    表示A[L]=σ(Z[L])=σ(W[L]A[L−1]+b[L])A[L]=σ(Z[L])=σ(W[L]A[L−1]+b[L]).)

    多层模型的前向传播计算模型代码如下:

    def L_model_forward(X,parameters):
    """
    实现[LINEAR-> RELU] *(L-1) - > LINEAR-> SIGMOID计算前向传播,也就是多层网络的前向传播,为后面每一层都执行LINEAR和ACTIVATION
    
    参数:
    X - 数据,numpy数组,维度为(输入节点数量,示例数)
    parameters - initialize_parameters_deep()的输出
    
    返回:
    AL - 最后的激活值
    caches - 包含以下内容的缓存列表:
    linear_relu_forward()的每个cache(有L-1个,索引为从0到L-2)
    linear_sigmoid_forward()的cache(只有一个,索引为L-1)
    """
    caches = []
    A = X
    L = len(parameters) // 2
    for l in range(1,L):
    A_prev = A
    A, cache = linear_activation_forward(A_prev, parameters['W' + str(l)], parameters['b' + str(l)], "relu")
    caches.append(cache)
    
    AL, cache = linear_activation_forward(A, parameters['W' + str(L)], parameters['b' + str(L)], "sigmoid")
    caches.append(cache)
    
    assert(AL.shape == (1,X.shape[1]))
    
    return AL,caches
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    [/code]

    测试一下:

    #测试L_model_forward
    print("==============测试L_model_forward==============")
    X,parameters = testCases.L_model_forward_test_case()
    AL,caches = L_model_forward(X,parameters)
    print("AL = " + str(AL))
    print("caches 的长度为 = " + str(len(caches)))
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    [/code]

    测试结果:

    ==============测试L_model_forward==============
    AL = [[ 0.17007265  0.2524272 ]]
    caches 的长度为 = 2
    • 1
    • 2
    • 3
    [/code]

    计算成本

    我们已经把这两个模型的前向传播部分完成了,我们需要计算成本(误差),以确定它到底有没有在学习,成本的计算公式如下:

    −1m∑i=1m(y(i)log(a[L](i))+(1−y(i))log(1−a[L](i)))(4)(4)−1m∑i=1m(y(i)log⁡(a[L](i))+(1−y(i))log⁡(1−a[L](i)))

    def compute_cost(AL,Y):
    """
    实施等式(4)定义的成本函数。
    
    参数:
    AL - 与标签预测相对应的概率向量,维度为(1,示例数量)
    Y - 标签向量(例如:如果不是猫,则为0,如果是猫则为1),维度为(1,数量)
    
    返回:
    cost - 交叉熵成本
    """
    m = Y.shape[1]
    cost = -np.sum(np.multiply(np.log(AL),Y) + np.multiply(np.log(1 - AL), 1 - Y)) / m
    
    cost = np.squeeze(cost)
    assert(cost.shape == ())
    
    return cost
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    [/code]

    测试一下:

    #测试compute_cost
    print("==============测试compute_cost==============")
    Y,AL = testCases.compute_cost_test_case()
    print("cost = " + str(compute_cost(AL, Y)))
    • 1
    • 2
    • 3
    • 4
    [/code]

    测试结果:

    ==============测试compute_cost==============
    cost = 0.414931599615
    • 1
    • 2
    [/code]

    我们已经把误差值计算出来了,现在开始进行反向传播

    反向传播

    反向传播用于计算相对于参数的损失函数的梯度,我们来看看向前和向后传播的流程图:

    流程图有了,我们再来看一看对于线性的部分的公式:

    我们需要使用dZ[l]dZ[l]

    <

    阅读更多
    内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
    标签: 
    相关文章推荐