Motivation and Overview
While implementing a few kernels from scratch for several transformer-based architectures, the LayerNorm operator is one you would come across in each of the residual multi-headed attention blocks. I’ve always conceptually understood it but took its implementation for granted. I figured I’d do a slightly deeper dive on it here since this is the first log entry and I’m testing the formatting for these logs as well.
Layer Normalization was brought forth to address the issues of batch normalization primarily for variable length inputs. This is especially useful for LSTM-based and Transformer-based model where length varies greatly. Furthermore, runtime statistics no longer need to be stored and it is invariant to batch size.
Batch Normalization
Before we can understand why Layer Normalization came about we’ll need to briefly recap on what Batch Normalization is and why we use it.
We’ll have to take a step back in time before Batch Normalization even existed. This was during the era of CNNs circa 2016.
The Internal Co-Variate Shift
Layer Normalization
Implementation
TODO:
- Section On Batch Normalization, include where it’s used in conv nets and how it’s calculated
- Section on why batch norm doesn’t work well on LSTM / Transformers.
- Section On Layer Normalization and why its better.
- Computational / Implementation details.