您的位置:首页 > 编程语言 > PHP开发

tensorflow基础学习 非线性回归实现,matplotlib可视化结果

2018-11-25 00:14 447 查看

开发平台:win10+py3.6 64bit

使用的工具以及库:

  • pycharm(全宇宙唯一一款专门用做python开发的工具)功能强大
  • tensorflow-gpu 1.9 (win10下配置gpu环境过程多坑,下次补配置教程)
  • tensorflow cpu版也是可以的
  • numpy
  • matplotlib

下面进入正题(代码模块有详细注释):

话不多说,先上个效果图

1. show code

# coding: utf-8
# 导入对应的库
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# 使用numpy生成在 [-0.5, 0.5] 内 **均匀分布**的 200个随机点
# linspace():返回带有数据的一个一维的ndarry
# [:, np.newaxis], 将这一维数组转化为一个 [200, 1]的ndarry, 也就是200行,1列的数组
x_data = np.linspace(-0.5, 0.5, num=200)[:, np.newaxis]

# 生成随机的噪声,不然就是一个很漂亮的曲线了
noise = np.random.normal(0, 0.02, x_data.shape)
y_data = np.square(x_data) + noise

如图, (x, y)

下面是运算时的网络结构,只写了一层隐藏层,由于数据很简单也没必要多写,大家可以自己尝试去加一下

# 定义两个placeholder,用于后面训练的时候实时提供数据
x = tf.placeholder(tf.float32, [None, 1])
y = tf.placeholder(tf.float32, [None, 1])

# 定义神经网络中间层
# 设置神经元的个数[1, 10], 1个特征值输入, 10 个神经元运算
# 一个样本输出10个值
Weights_L1 = tf.Variable(tf.random_normal([1, 10]))
biases_L1 = tf.Variable(tf.zeros([1, 10]))
Wx_plus_b_L1 = tf.matmul(x, Weights_L1) + biases_L1
# Wx_plus_b_L1: [None, 10] None个样本,10个输出值
# biases_L1:[1,10], 给每一个样本计算的值加上偏置
# 激活函数: 使用 tanh
L1 = tf.nn.tanh(Wx_plus_b_L1)

# 定义神经网络输出层
# 因为输出只有一个预测值, 上一层的输出[none, 10], 要求输出[none,]
# 也就是一个样本输出一个结果,none个就有none行

Weights_L2 = tf.Variable(tf.random_normal([10, 1]))
biases_L2 = tf.Variable(tf.zeros([1, 1]))
Wx_plus_b_L2 = tf.matmul(L1, Weights_L2) + biases_L2
prediction = tf.nn.tanh(Wx_plus_b_L2)

# 损失函数, 使用方差计算,以后会用到交叉熵损失计算等
loss = tf.reduce_mean(tf.square(y - prediction))

# 使用梯度下降法训练, 梯度下降就是找损失函数剪小最快的方向去修改权重和偏置
# 0.1是学习率, 可以简单的理解为,每次修改权重和偏置这些参数修改的幅度大小
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)

开启会话训练,并且画图

with tf.Session() as sess:
# 变量初始化
sess.run(tf.global_variables_initializer())
# 训练2000步, 传入之前生成的数据
for _ in range(2000):
sess.run(train_step, feed_dict={x: x_data, y: y_data})

# 获得预测值
prediction_value = sess.run(prediction, feed_dict={x: x_data})
# 画图
plt.figure()
plt.scatter(x_data, y_data)
plt.plot(x_data, prediction_value, 'r-', lw=5)
plt.show()

小结

这是我的第一篇博客,感觉来得有点晚,毕竟学了也有一段时间了,以后会接着更新python爬虫,数据分析,数据挖掘,验证码识别这一类的学习经历,当然还会有对应的环境搭配等,欢迎大家多多指点,大家相互交流学习。喜欢的话可以点一下关注,不会迷路。

阅读更多
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: