您的位置:首页 > 其它

[MXNet]Lecture04批量归一化

2018-01-18 11:31 169 查看
从零开始实现:

from mxnet import ndarray as nd

def pure_batch_norm(X,gemma,beta,eps=1e-5):
assert len(X.shape) in (2,4)
if len(X.shape)==2:
mean=X.mean(axis=0)
variance=((X-mean)**2).mean(axis=0)

else:
mean=X.mean(axis=(0,2,3),keepdims=True)
variance=((X-mean)**2).mean(axis=(0,2,3),keepdims=True)

print("mean",mean)
print("gemma",gemma)
print("gemma.reshape",gemma.reshape(mean.shape))

x_hat=(X-mean)/nd.sqrt(variance+eps)
return gemma.reshape(mean.shape)*x_hat+beta.reshape(mean.shape)

X=nd.arange(6).reshape((3,2))
y=pure_batch_norm(X,gemma=nd.array([1,1]),beta=nd.array([0,0]))
print(y)
X2=nd.arange(36).reshape((1,4,3,3))
y2=pure_batch_norm(X2,gemma=nd.array([1,1,1,1]),beta=nd.array([0,0,0,0]))
print(y2)


对于全联接层来说,输入是二维,要针对每一个特征进行归一化,所以有几列(axis=1方向),gemma就有几个数。
对于卷积层来说,输入是四维,针对每一个channel进行归一化,所以有几个channel(axis=1方向),gemma就有几个数。

如果把测试时的情况考虑进来,可以用移动平均的方法:

from mxnet import ndarray as nd
def batch_norm(x,gemma,beta,istraining,moving_mean,moving_variance,moving_momentum,eps=1e-5):
assert len(x.shape) in (2,4)
if len(x.shape)==2:
mean=x.mean(axis=0)
variance=((x-mean)**2).mean(axis=0)
else:
mean=x.mean(axis=(0,2,3),keepdims=True)
variacne=((x-mean)**2).mean(axis=(0,2,3),keepdims=True)
moving_mean=moving_mean.reshape(mean.shape)
moving_variance=moving_variance.reshape(variance.shape)
print(x)

if istraining:
x_hat=(x-mean)/nd.sqrt(variance+eps)
moving_mean[:]=moving_momentum*moving_mean+(1.-moving_momentum)*mean
moving_variance[:]=moving_momentum*moving_variance+(1.-moving_momentum)*variance

else:
x_hat=(x-moving_mean)/nd.sqrt(moving_variance+eps)
print('moving_mean',moving_mean)
print('moving_variacne',moving_variance)
print(x_hat)
return x_hat*gemma.reshape(mean.shape)+beta.reshape(mean.shape)
moving_mean=nd.zeros(2)
moving_variance=nd.zeros(2)

gluon实现
内容来自用户分享和网络整理,不保证内容的准确性,如有侵权内容,可联系管理员处理 点击这里给我发消息
标签: