您的位置:首页 > 其它

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

2019-01-14 15:59 393 查看

[MXNet逐梦之旅]练习一·使用MXNet拟合直线手动实现

  • code
[code]#%%
from matplotlib import pyplot as plt
from mxnet import autograd, nd
import random

#%%
num_inputs = 1
num_examples = 100
true_w = 1.56
true_b = 1.24
features = nd.arange(0,10,0.1).reshape((-1, 1))
labels = true_w * features + true_b
labels += nd.random.normal(scale=0.2, shape=labels.shape)

features[0], labels[0]

#%%
# 本函数已保存在d2lzh包中方便以后使用
def data_iter(batch_size, features, labels):
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)  # 样本的读取顺序是随机的
for i in range(0, num_examples, batch_size):
j = nd.array(indices[i: min(i + batch_size, num_examples)])
yield features.take(j), labels.take(j)  # take函数根据索引返回对应元素

#%%
batch_size = 10

for X, y in data_iter(batch_size, features, labels):
print(X, y)
break

#%%
w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))
b = nd.zeros(shape=(1,))

#%%

w.attach_grad()
b.attach_grad()

#%%
def linreg(X, w, b):  # 本函数已保存在d2lzh包中方便以后使用
return nd.dot(X, w) + b

#%%

def squared_loss(y_hat, y):  # 本函数已保存在d2lzh包中方便以后使用
return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

#%%

def sgd(params, lr, batch_size):  # 本函数已保存在d2lzh包中方便以后使用
for param in params:
param[:] = param - lr * param.grad / batch_size

#%%

lr = 0.05
num_epochs = 20
net = linreg
loss = squared_loss

for epoch in range(num_epochs):  # 训练模型一共需要num_epochs个迭代周期
# 在每一个迭代周期中,会使用训练数据集中所有样本一次(假设样本数能够被批量大小整除)。X
# 和y分别是小批量样本的特征和标签
for X, y in data_iter(batch_size, features, labels):
with autograd.record():
l = loss(net(X, w, b), y)  # l是有关小批量X和y的损失
l.backward()  # 小批量的损失对模型参数求梯度
sgd([w, b], lr, batch_size)  # 使用小批量随机梯度下降迭代模型参数
train_l = loss(net(features, w, b), labels)
print('epoch %d, loss %f' % (epoch + 1, train_l.mean().asnumpy()))

#%%
true_w, w

#%%
true_b, b

#%%
plt.scatter(features.asnumpy(), labels.asnumpy(), 1)

labels1 = linreg(features,w,b)
plt.scatter(features.asnumpy(), labels1.asnumpy(), 1)
plt.show()
  • out

黄色是原始数据

绿色为拟合数据

 

内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: