您的位置:首页 > 编程语言

机器学习实践指南(五)—— GD/SGD/MSGD 伪代码演示

2016-05-18 22:22 411 查看

GD:梯度下降

while True:
loss = f(params)
d_loss_wrt_params = ...
params -= eta * d_loss_wrt_params
if <stopping condition met>:
return params


SGD:随机梯度下降

逐样本训练:

for x_i, y_i in training_data:
loss = f(params, x_i, y_i)
d_loss_wrt_params = ...
params -= eta * d_loss_wrt_params
if <stopping condition met>:
return params


更进一步,如果外层还有一个 epochs:

for j in range(epochs):
random.shuffle(training_data)
for x_i, y_i in training_data:
...


MSGD(Minibatch SGD):块随机梯度下降

n = len(training_data)
mini_batch_size = ...
mini_batches = [training_data[k:k+mini_batch_size] for k in range(0, n, mini_batch_size)]
for mini_batch in mini_batches:
loss = f(params, mini_batch)
d_loss_wrt_params = ...
params -= eta * d_loss_wrt_params
if <stopping condition met>:
return params
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: