您的位置:首页 > 理论基础 > 计算机网络

tensorflow之安装及简单神经网络搭建

2017-09-02 11:05 603 查看
一.linux/bantu系统在python2.7安装tensorflor

1.安装pip:

sudo apt-get install python-pippython-dev


2.安装链接:

sudo pip install –upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl[/code] 
3.安装检验:

python
import tensorflow


4000
二. 搭建神经网络

1.简单的线性拟合

#encoding=utf-8
import tensorflow as tf
import numpy as np
#创建数据
x_real = np.random.rand(100).astype(np.float32)#产生随机数据x,类型为float32
y_real = x_real*0.1+0.3 #真实值y,线性函数生成
#创建tensorflow的架构
w_pre= tf.Variable(tf.random_uniform([1],-1.0,1.0))#产生随机一维变量,范围[-1.0,1.0],为需要学习的权值,实际0.1
b_pre = tf.Variable(tf.zeros([1])) #一维变量,初值为0,为需要学习的偏置,实际训练接近0.3
y_pre = w_pre*x_real + b_pre #线性模拟真实值,得到预测值
loss = tf.reduce_mean(tf.square(y_pre - y_real)) #计算预测值与真实值的差值,最小平方差均值
optimizer = tf.train.GradientDescentOptimizer(0.5)#神经网络的优化器,学习效率0.5
train = optimizer.minimize(loss) #优化器使误差最小化
init = tf.initialize_all_variables() #初始化变量
#初始化
sess = tf.Session()
sess.run(init) #激活神经网络,非常重要,完成初始化激活
#开始训练
for step in range(201):
sess.run(train) #sess指针指向run,并开始train
if step % 20 == 0:
print(step,sess.run(w_pre),sess.run(b_pre))


2.session的两种打开方式:Session()

#encoding=utf-8
import tensorflow as tf
#产生数据
matrix1 = tf.constant([[1,2]])     #一行两列常量矩阵
matrix2 = tf.constant([[3],[4]])    #两行一列常量矩阵
result = tf.matmul(matrix1,matrix2)     #矩阵乘法


#方法一
sess = tf.Session()
result1 = sess.run(result)
print('第一种方法的结果:')
print(result1)
sess.close()#关闭session


#方法2
with tf.Session() as sess:
result2 = sess.run(result)
print('第二种方法的结果:')
print(result2)


3.定义变量:tf.Variable

#encoding=utf-8
import tensorflow as tf

S_name = tf.Variable(0,name = 'counter')#变量名为counter
print(S_name.name)#打印变量,结果为counter:0
one = tf.constant(1)  #产生常量1
value = tf.add(S_name,one)  #原始数据+1
update = tf.assign(S_name,value)#把value加载到S_name
init = tf.initialize_all_variables() #初始化所有变量
with tf.Session() as sess:
sess.run(init)#激活方法,必须run之后才算激活成功,将所有变量激活,定义变量必须激活
for _ in range(3):
sess.run(update)
print(sess.run(S_name))


4.传入值:placeholder

函数格式:tf.placeholder(dtype,shape=None,name=None)

dtype:函数类型,常用tf.float32,还要tf.float64

shape:数据形状,默认None代表一维数值,也可以是多维,如[2,4]表示2行4列,也可以不定[None,3],表示不定行,3列

name:表示名称

如下为一个简单的应用:

#encoding=utf-8
import tensorflow as tf
#定义数据
in_1 = tf.placeholder(tf.float32) #传入值,为float32类型,placeholder获取值
in_2 = tf.placeholder(tf.float32)
output = tf.mul(in_1,in_2)#数值乘法运算
with tf.Session() as sess:
print(sess.run(output,feed_dict={in_1:[7.],in_2:[2.]}))#feed_dict传入值,传进去值为字典类型


5.激励函数

所谓激励寒素,就是将计算结果放到一个小范围,用于判断标签,范围一般为[-1,1]或[0,1]

https://www.tensorflow.org/api_docs/python/nn.html

这个网址里面有说到tensorflow里面常见的几种激励函数,可以作为了解

6.导入画图工具matplotlib

python可视化库有两种显示模式,一种是block阻塞模式,一种是interactive交互模式,在python的consol命令中,默认交互模式,在python脚本中,默认阻塞模式。区别在于:

(1)在交互模式下,plt.plot(x)或plt.imshow(x)是直接出图像,不需要plt.show()。一般如果在脚本中开启了plt.ion()模式,那么必须要在plt.show()之前通过plt.ioff()关闭交互模式,不然图像会一闪而过,当然,也可以通过plt.pause(10),其中10表示10秒

(2)在阻塞模式下,打开一个窗口必须关掉以后才能打开一个新的窗口,且plt.plot(x)或plt.imshow(x)是直接出图像,需要plt.show()进行显示。

#encoding=utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
#create data
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data*0.1+0.3
init = tf.initialize_all_variables() #初始化变量
#初始化
sess = tf.Session()
sess.run(init) #激活神经网络,非常重要
#画图,交互模式
plt.ion()
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data,y_data)
lines = ax.plot(x_data,y_data,'r',lw=2)
plt.ioff()
plt.show()


7.优化器,可参考一些其他博客的详述:http://blog.csdn.net/u012759136/article/details/52302426

(1)梯度下降方法GradientDescent

所谓梯度下降方法就是利用负梯度方向来觉得每次迭代的新的搜索方向,使得每次迭代能使待优化的目标函数逐步减小。梯度下降法是2范数下的最速下降法,梯度则是函数的偏导数。常用的最速下降是X(k+1)=X(k)-a*g(k),其中a为学习速率。

(2)随机梯度算法(stochastic gradient descent)SGD

所谓随机梯度算法每次选取训练样本中的一个样本进行学习,优点在于学习快速,可进行在线更新,缺点在于每次更新的方向不一定按照正确的方向进行,可能引起扰动。

个人觉得这边博客对随机梯度下降讲解比较详细:http://www.sohu.com/a/131923387_473283

(3)Adagrad对学习速率进行了一个约束。Adadelta是Adgrad的扩展,不仅在学习速率进行约束,还对计算进行简化。RMSprop是Adadelta的一个特例。Adam是带有动量项的RMSprop,利用梯度的一阶矩估计和二阶矩估计来动态调整每个参数的学习率。

(4)Momentum模拟物理动力,积累之前的动量来代替真正的梯度

(5)FTL算法:http://blog.csdn.net/wenzishou/article/details/73558017
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: