How does Batch Normalization help?

Notes for “How Does Batch Normalization Help Optimization?”

These notes provide some context and background for the paper we are going to discuss at the next reading group meetup: How Does Batch Normalization Help Optimization?. This is not a substitute for reading the paper. It’s just meant to give a starting point for those who are not familiar with the topic. In fact, I spend more time talking about the paper that precedes this one.

What is batch normalization?

In 2015, a very important paper appeared which significantly improved a lot of deep learning training: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. It introduced the concept of batch normalization (BN) which is now a part of every machine learner’s standard toolkit. The paper itself has been cited over 7,700 times.
In the paper, they show that BN stabilizes training, avoids the problem of exploding and vanishing gradients, allows for faster learning rates, makes the choice of initial weights less delicate, and acts as a regularizer. If you don’t know what any of that means, let me summarize it this way: BN makes all the hard things about training a deep learning model a bit (or a lot) easier.
Using BN, the authors were able to train an ImageNet classifier 14x faster than usual, and beat the state of the art for accuracy. No wonder everyone paid attention.

Where does the idea come from?

People have known since the stone age that it is helpful to normalize training data as a preprocessing step before training a model. For images, this might be as simple as subtracting the mean and dividing by the standard deviation, to even out the brightness and contrast between samples.
BN is more like another preprocessing strategy called “feature scaling” where you put each of the features onto a similar scale, so the model doesn’t have to learn how to do that itself. The general idea is that without feature scaling, the loss that you are trying to optimize has elliptical contours and you have to zigzag around a lot to reach the minimum. With feature scaling, the contours are circles and you can go straight to the bullseye.
Feature scaling makes the loss surface easier to navigate, so training goes faster.
The idea behind BN is, if normalization is good for the input layer, maybe it’s a good idea to do it for all the other layers too? Turns out that it is.
However, it’s not practical to actually normalize all the features of all the input points to every layer. Instead, the normalization is done one minibatch at a time. (Hence the name.)

How do the authors explain why BN works?

The BN authors argue that without BN, the statistics of the inputs to each layer (after the input layer) vary as training proceeds. This is because the inputs to a layer depend on the weights in all previous layers. Those weights start out random and their values change during training. (That, in fact, is the point of training.) Consequently, the statistics change.
They dub this effect “internal covariate shift” (ICS) which has become a key concept in the folklore of machine learning (and is a great phrase to impress your friends).
BN mitigates ICS, in theory, by forcing the mean and standard deviation to be (0, 1), thus imposing some consistency on the statistics during training.
Their evidence for ICS being the root cause of the problem they are solving is rather indirect: BN trains more successfully. But they don’t present any direct evidence. In software terms, this is sort of like doing a system test instead of a unit test.
Everyone agrees that BN is a wonderful thing. Not everyone is satisfied that we understand why it works. If you are Jeremy Howard (the guru), that’s just fine, because the fastest way forward is to just “ignore all the math”.
But this is not good enough for others. In his famous (infamous?) “machine learning is alchemy” talk at NIPS 2017, Ali Rahimi uses BN as an example of how theory lags experiment in machine learning. Paraphrasing from his presentation:

Here is what we know about batch norm as a field:

  • “It works because it reduces internal covariate shift”
  • Wouldn’t you like to know why reducing internal covariate shift speeds up gradient descent?
  • Wouldn’t you like to see a theorem or an experiment?
  • Wouldn’t you like to see evidence that batch norm reduces internal covariate shift?
  • Wouldn’t you like to know what internal covariate shift is?
  • Wouldn’t you like to see a definition of it?

Batch Norm has become a foundational tool in how we build deep nets and yet as a field we know almost nothing about it.


How does BN help optimization?

The authors of the paper we will discuss at the reading group asked themselves many of the same questions, and came up with some surprising answers.
  1. ICS is actually not a very big effect
  2. BN (maybe) doesn’t reduce ICS significantly
  3. What BN does do is smooth out the loss surface. This is the reason for the training improvements.
The following picture illustrates the smoothness idea. It doesn’t come from this paper, but from Visualizing the Loss Landscape of Neural Nets, which we should definitely read at a future meetup.
The effect of BN on the loss surface. Left: without BN, Right: with BN (artist’s conception)

The authors make this notion of smoothness mathematically precise and prove that BN increase smoothness. This seems like a good foundation for exploring ways to make BN even better, but they don’t offer any suggestions for how to modify the technique.

Also, I don’t think they fully connect the dots between having a smoother loss surface and training working better. But I’m still reading the paper…