您的位置:首页 > 其它

懒阳的深度学习日记tf2.0线性回归

2020-04-02 18:30 931 查看

tf-线性回归模型

TensorFlow2.0——简单线性回归
关于初次使用tensorflow遇到的坑
tf2.0 实现线性回归1.0版本的好多API函数都不能用了太难受了,自己还是入门,一步一步探索。本次主要是预测w,以及参数b。预测和自己规定的数据相差不多,效果还可以。自己也会慢慢努力学习,加油。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
#随机生成1000个点,围绕在 y=0.1x + 0.3
num_point =1000
vector_set=[]
for i in range(num_point):
x1 = np.random.normal(0.0,0.55)
y1 = x1 * 0.1 +0.3 + np.random.normal(0.0,0.03)
vector_set.append([x1,y1])
x_data = [v[0] for v in vector_set]
y_data = [v[1] for v in vector_set]

plt.scatter(x_data,y_data,c='r')
plt.show()
#print(len(x_data))
#print(y_data)
w=tf.Variable(tf.random.uniform((1,), -1.0, 1.0))
print(w)
b=tf.Variable(tf.zeros([1]),name='b')
#y=w*x_data +b
#print(y)
losses = []
#以预估值y和真实值y_data之间的均方误差作为损失
#loss = tf.reduce_mean(tf.square(y-y_data),name='loss')
opt = tf.keras.optimizers.SGD(1e-1)
for i in range(1000):
loss = lambda: tf.losses.MeanSquaredError()(w*x_data+b, y_data)
#采用梯度下降法来优化参数
opt.minimize(loss, var_list=[w,b])
losses.append(loss().numpy())
print(w)
print(b)
  • 点赞
  • 收藏
  • 分享
  • 文章举报
逐原之鹿_懒阳 发布了2 篇原创文章 · 获赞 0 · 访问量 24 私信 关注
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: