0503-autograd实战之线性回归
2021-04-24 17:39
120 查看
0503-autograd实战之线性回归
[TOC]
pytorch完整教程目录:https://www.cnblogs.com/nickchen121/p/14662511.html
一、用 variable 实现线性回归(autograd 实战)
import torch as t from torch.autograd import Variable as V # 不是 jupyter 运行请注释掉下面一行,为了 jupyter 显示图片 %matplotlib inline from matplotlib import pyplot as plt from IPython import display t.manual_seed(1000) # 随机数种子 def get_fake_data(batch_size=8): """产生随机数据:y = x * 2 + 3,同时加上了一些噪声""" x = t.rand(batch_size, 1) * 20 y = x * 2 + (1 + t.randn(batch_size, 1)) * 3 # 噪声为 |3-((1 + t.randn(batch_size, 1)) * 3)| return x, y # 查看 x,y 的分布情况 x, y = get_fake_data() plt.scatter(x.squeeze().numpy(), y.squeeze().numpy()) plt.show()
# 随机初始化参数 w = V(t.rand(1, 1), requires_grad=True) b = V(t.zeros(1, 1), requires_grad=True) lr = 0.001 # 学习率 for i in range(8000): x, y = get_fake_data() x, y = V(x), V(y) # forwad:计算 loss y_pred = x.mm(w) + b.expand_as(y) loss = 0.5 * (y_pred - y)**2 loss = loss.sum() # backward:自动计算梯度 loss.backward() # 更新参数 w.data.sub_(lr * w.grad.data) b.data.sub_(lr * b.grad.data) # 梯度清零,不清零则会进行叠加,影响下一次梯度计算 w.grad.data.zero_() b.grad.data.zero_() if i % 1000 == 0: # 画图 display.clear_output(wait=True) x = t.arange(0, 20, dtype=t.float).view(-1, 1) y = x.mm(w.data) + b.data.expand_as(x) plt.plot(x.numpy(), y.numpy(), color='red') # 预测效果 x2, y2 = get_fake_data(batch_size=20) plt.scatter(x2.numpy(), y2.numpy(), color='blue') # 真实数据 plt.xlim(0, 20) plt.ylim(0, 41) plt.show() plt.pause(0.5) break # 注销这一行,可以看到动态效果
# y = x * 2 + 3 w.data.squeeze(), b.data.squeeze() # 打印训练好的 w 和 b
(tensor(2.3009), tensor(0.1634))
二、第五章总结
本章介绍了 torch 的一个核心——autograd,其中 autograd 中的 variable 和 Tensor 都属于 torch 中的基础数据结构,variable 封装了 Tensor ,拥有着和 Tensor 几乎一样的接口,并且提供了自动求导技术。autograd 是 torch 的自动微分引擎,采用动态计算图的技术,可以更高效的计算导数。
这篇文章说实话是有点偏难的,可以多看几遍,尤其是对于还没有写过实际项目的小白,不过相信有前面几个小项目练手,以及最后一个线性回归的小 demo,相信你差也不差的能看懂,但这都不要紧,在未来的项目实战中,你将会对 autograd 的体会越来越深刻。
相关文章推荐
- 机器学习实战线性回归局部加权线性回归笔记
- 深度学习入门实战(二)-用TensorFlow训练线性回归
- Python线性回归实战分析
- SparkML实战之一:线性回归
- 通俗得说线性回归算法(二)线性回归实战
- 线性回归实战
- 机器学习实战--8.预测数值型数据:线性回归
- 机器学习MatLab实战整理--线性回归
- PyTorch:线性回归和逻辑回归实战
- 机器学习实战之线性回归+局部加权线性回归
- machine learning 线性回归实战
- 线性回归与逻辑回归实战
- 机器学习入门+实战初级(一)—— 线性回归
- 学习笔记【机器学习重点与实战】——1 线性回归
- Python机器学习实战--线性回归
- MXnet实战之线性回归
- 14. 实战:多元线性回归程序示例
- MLlib线性回归实战
- TensorFlow学习笔记(4):线性回归,TensorFlow实战
- 通俗得说线性回归算法(二)线性回归实战