您的位置:首页 > 理论基础 > 计算机网络

神经网络入门 Python 十行核心代码

2017-03-26 22:01 316 查看
#!/usr/bin/env python
# --*-- coding:utf-8 --*--
import time
import datetime
import numpy as np
import matplotlib as mpl
import sys
import matplotlib.pyplot as plt

X = np.array([ [0,0,1],[0,1,1],[1,0,1],[1,1,1] ])
y = np.array([[0,1,1,0]]).T
syn0 = 2*np.random.random((3,4)) - 1
syn1 = 2*np.random.random((4,1)) - 1
l1=0
l2=0
for j in range(1000):
l1 = 1/(1+np.exp(-(np.dot(X,syn0))))
l2 = 1/(1+np.exp(-(np.dot(l1,syn1))))
l2_delta = (y - l2)*(l2*(1-l2))
l1_delta = l2_delta.dot(syn1.T) * (l1 * (1-l1))
syn1 += l1.T.dot(l2_delta)
syn0 += X.T.dot(l1_delta)
t = np.arange(len(l2))
mpl.rcParams['font.sans-serif'] = [u'simHei']
mpl.rcParams['axes.unicode_minus'] = False
plt.plot(t, l2, 'g-', linewidth=2, label=u'预测数据')
plt.plot(t, y, 'r-', linewidth=2, label=u'真实数据')
plt.axis([-0.5,3.5,-0.5,1.5])
plt.title(u'神经网络', fontsize=18)
plt.legend(loc='upper right')
plt.grid()
plt.show()
print(l2)
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签:  python 神经网络