[MXNet]Lecture04批量归一化
2018-01-18 11:31
169 查看
从零开始实现:
对于全联接层来说,输入是二维,要针对每一个特征进行归一化,所以有几列(axis=1方向),gemma就有几个数。
对于卷积层来说,输入是四维,针对每一个channel进行归一化,所以有几个channel(axis=1方向),gemma就有几个数。
如果把测试时的情况考虑进来,可以用移动平均的方法:
from mxnet import ndarray as nd
def batch_norm(x,gemma,beta,istraining,moving_mean,moving_variance,moving_momentum,eps=1e-5):
assert len(x.shape) in (2,4)
if len(x.shape)==2:
mean=x.mean(axis=0)
variance=((x-mean)**2).mean(axis=0)
else:
mean=x.mean(axis=(0,2,3),keepdims=True)
variacne=((x-mean)**2).mean(axis=(0,2,3),keepdims=True)
moving_mean=moving_mean.reshape(mean.shape)
moving_variance=moving_variance.reshape(variance.shape)
print(x)
if istraining:
x_hat=(x-mean)/nd.sqrt(variance+eps)
moving_mean[:]=moving_momentum*moving_mean+(1.-moving_momentum)*mean
moving_variance[:]=moving_momentum*moving_variance+(1.-moving_momentum)*variance
else:
x_hat=(x-moving_mean)/nd.sqrt(moving_variance+eps)
print('moving_mean',moving_mean)
print('moving_variacne',moving_variance)
print(x_hat)
return x_hat*gemma.reshape(mean.shape)+beta.reshape(mean.shape)
moving_mean=nd.zeros(2)
moving_variance=nd.zeros(2)
gluon实现
from mxnet import ndarray as nd def pure_batch_norm(X,gemma,beta,eps=1e-5): assert len(X.shape) in (2,4) if len(X.shape)==2: mean=X.mean(axis=0) variance=((X-mean)**2).mean(axis=0) else: mean=X.mean(axis=(0,2,3),keepdims=True) variance=((X-mean)**2).mean(axis=(0,2,3),keepdims=True) print("mean",mean) print("gemma",gemma) print("gemma.reshape",gemma.reshape(mean.shape)) x_hat=(X-mean)/nd.sqrt(variance+eps) return gemma.reshape(mean.shape)*x_hat+beta.reshape(mean.shape) X=nd.arange(6).reshape((3,2)) y=pure_batch_norm(X,gemma=nd.array([1,1]),beta=nd.array([0,0])) print(y) X2=nd.arange(36).reshape((1,4,3,3)) y2=pure_batch_norm(X2,gemma=nd.array([1,1,1,1]),beta=nd.array([0,0,0,0])) print(y2)
对于全联接层来说,输入是二维,要针对每一个特征进行归一化,所以有几列(axis=1方向),gemma就有几个数。
对于卷积层来说,输入是四维,针对每一个channel进行归一化,所以有几个channel(axis=1方向),gemma就有几个数。
如果把测试时的情况考虑进来,可以用移动平均的方法:
from mxnet import ndarray as nd
def batch_norm(x,gemma,beta,istraining,moving_mean,moving_variance,moving_momentum,eps=1e-5):
assert len(x.shape) in (2,4)
if len(x.shape)==2:
mean=x.mean(axis=0)
variance=((x-mean)**2).mean(axis=0)
else:
mean=x.mean(axis=(0,2,3),keepdims=True)
variacne=((x-mean)**2).mean(axis=(0,2,3),keepdims=True)
moving_mean=moving_mean.reshape(mean.shape)
moving_variance=moving_variance.reshape(variance.shape)
print(x)
if istraining:
x_hat=(x-mean)/nd.sqrt(variance+eps)
moving_mean[:]=moving_momentum*moving_mean+(1.-moving_momentum)*mean
moving_variance[:]=moving_momentum*moving_variance+(1.-moving_momentum)*variance
else:
x_hat=(x-moving_mean)/nd.sqrt(moving_variance+eps)
print('moving_mean',moving_mean)
print('moving_variacne',moving_variance)
print(x_hat)
return x_hat*gemma.reshape(mean.shape)+beta.reshape(mean.shape)
moving_mean=nd.zeros(2)
moving_variance=nd.zeros(2)
gluon实现
相关文章推荐
- mxnet-梯度,反馈与标准化(归一化)
- Caffe、TensorFlow、MXnet三个开源库对比
- CRU+MXnet︱CRU-Net - Collective Residual Networks
- [MXNet Gluon]基于斯坦福狗的品种分类数据集训练SSD检测模型
- mxnet-读取示例数据
- mxnet系列之-mshadow
- MXNet学习8——自己写operator实现Logistic Regression
- Operators in MXNet-Convolution
- ubuntu14.04 + mxnet + python2.7 安装指南
- 用MXnet入门实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
- 使用windows上 mxnet 预编译版本
- 如何选择深度学习框架 TensorFlow/Torch/Mxnet/Theano
- 五大主流深度学习框架比较分析:MXNET是最好选择
- 详解mxnet.random.seed。即随机数生成种子。
- 为mxnet点赞!So many other frameworks exist, why MXNet?
- MxNet系列——how_to——multi_devices
- MXNet的预训练:fine-tune.py源码详解
- Single Shot MultiBox Detector(MXNet)源码阅读笔记(2)
- MXNet官方文档教程(6):神经网络图
- mxnet CUDNN_STATUS_ALLOC_FAILED 错误 cudnn出错