jmlr batchnorm - hassony2/inria-research-wiki GitHub Wiki
Batchnorm
Journal of Machine Learning Research 2014
[jmlr-batchnorm] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift [PDF] [notes]
Sergey Ioffe, Christian Szegedy
Synthesis
Why do we need it ?
If intermediate outputs of layers change in distribution, it messes up the following layers, as the inputs to those layers change in distribution, and therefore the layer has troubles learning a relevant representation of the activations.
We therefore need a way to normalize the intermediary distributions to improve the learning possibilities of later layers.
One solution could be a normalization step that fixes the means and variances of the layer inputs
But this solution has a shortcoming : normalizations would need to be done over all the training set and then back propagated, which is time consuming !
Simplifications : each scalar feature es independently whitened
This has the following advantages
- it allows for higher learning rates
- it Reduces need for dropout
We also need to restore the feature's representation power, and make sure that the transformation inserted can represent the identity transform. (the will to be able to represent the identity transform comes from the fact that we want the network to be able to choose the identity transform, and thus no modifications, if this is optimal)
Therefore, a pair of parameters is introduced to shift and scale the normalized value. (affine transform)
This pair of parameters is learned.
This step can represent the identity transform by setting the scaling to the std and the shift to the mean. If it were the optimal thing to do, the shift-scale output can be the same as the input.
The scaled and shifted values are passed to the next network layers.
Using the whole dataset to perform normalization would be impractical, therefore we estimate the mean and variance of the activations for each mini-batch.
Why should it help ?
Intuition the presence of the normalized activation can be seen as the input to the affine transform + the following layers on top. This composes a subnetwork that receives inputs in which each scalar value has normalized values, and therefore a distribution that is more stable over time (less covariate shift). This should speed-up the training of the subnetwork, and therefore increase the speed of the training of the entire network as well.
What about inference ?
During the inference stage, the normalization of the activations should not depend on the mini-batch's statistics. (the output should be deterministically determined by the input) Therefore, we use the normalization using the entire population (we therefore scale and shift the activations to normalize according to the activations for the entire training set).
This normalization shift + scaling can be combined with the learned shift + scaling to produce a single linear transform which replaces the batch-norm layer at inference time.
The variance is estimated using the unbiased variance estimate Var[x] = E[sigma^2]m/(m-1) where sigma is the variance of the activations in the mini-batch and m is the size of the mini-batch. (The global variance is estimated by taking the unbiased expectation of the variances of several mini-batches). The mean is also estimated from the expectation of the mean of several mini-batches.
Where to put your BN layer
The Batchnorm transform is placed before the non-linearity because adding it after a non-linearity constrains the first and second moment but for a non-linear layer the next order statistics can still produce covariate shift. On the other hand, linear activations are more likely to have symmetric non-sparse distributions (roughly "more gaussian"), normalizing it is more likely to produce activations with stable distributions.
Special consideration for convolutional layers
For convolutional layers we want different features of the same feature map to be normalized the same way accross locations. To achieve this, we jointly normalize all the activations in a minibatch over these features at various spatial locations.
For 3d convolutional layers, the same observation applies but with regards to spatio-temporal features.
Stabilization of parameter growth during Gradient Descent
BatchNorm is largely unaffected by the scale of the parameters of the previous linear layer (multiplication of all the linear parameters by a scale factor a). (as there is a normalization step, scaling doesn't affect the forward pass, nor the gradient of the BN output with regard to the initial input). The gradient of the BN output with regard to the scaled weights of the linear transform is 1/a the initial gradient, so larger weights produce smaller gradients ==> stabilization of parameter growth
In practice
-
Batchnorm should be used with a mini-batch size > 1 (otherwise you are not normalizing anything as your batch is your sample !) (except for convolutional layers, which are normalized accross space dimensions)
-
Batchnorm is replaced by a linear layer at test time
-
Batchnorm is inserted before the non linearity
-
Batchnorm allows for higher learning rates
-
Batchnorm reduces the need for dropout (removing dropout speeds up training with little impact on generalization power)
-
Batchnorm reduces the need for weight regularization (typically l2 regularization can be reduced, reduction of a factor 5 is beneficial in the paper)
-
Batchnorm allows for faster learning rate decay (because less steps are needed to achieve same level of training)
-
Bonus : avoiding samples to be present twice in same mini-batch improved validation accuracy (and, gentle reminder, don't forget to shuffle you inputs during training, it will already harm training without batch norm, but it will be even worse with it, you don't want the samples to always be sampled in the same mini-batch)
Results
-
Allows higher learning rates as small changes in params do not amplify as much because of the normalization
-
Minimize necessity for high dropout rates, maybe because having the outputs being dependant on the batch size (non-determinism) helps the network to generalize.
Note on gradient descent
There are several modes of gradient descent which rely on different (N):number of iterations and (B):batch-size:
-
Batch mode: N=B, one epoch is same as one iteration (one step takes into account the gradients of all samples)
- Pro: finds the optimal step at each iteration
- Con: often prohibitively slow !
-
Mini-batch mode: 1<B<N, one epoch consists of N/B iterations (this is the common mode used for deep learning)
- Pro: Allows for more frequent steps, and therefore often produces faster convergence
- Con: the steps are suboptimal, as the mini-batch gradient is only an approximation of the full gradient
-
Stochastic mode: B=1, one epoch takes N iterations.