This is a PyTorch implementation of the DeepNorm from the paper DeepNet: Scaling Transformers to 1,000 Layers.
The paper proposes a method to stabilize extremely deep transformers through a new normalizing function to replace LayerNorm and a weight initialization scheme. This combines the performance of Post-LayerNorm and the stability of Pre-LayerNorm. Transformers with DeepNorms are supposed to be stable even without a learning rate warm-up.
The paper first shows that the changes to layer outputs (for the same input) change gradually during stable training; when unstable it changes rapidly during the initial training steps. This happens with initializing weights to small values, and learning rate warm-ups where the training is stable. They use the idea of keeping the changes to layer outputs small to derive the new normalization and weight initialization mechanism.
Weight Initializations
Usually, the weights are initialized with Xavier or Kaiming initializations. This paper scales (sets the gain) the weights by a constant depending on the size of the transformer.
DeepNorm suggests scaling the weights of the two linear transforms in the Feed-Forward Network, the value projection transform, and the output projection transform of the attention layer. Weights of these transforms are scaled by (has a gain equal to) .
The scaling is implemented in the
Normalization Function
where is a constant that depends on the depth of the transformer, is Layer Normalization, and is the function of the -th transformer sub-layer (FFN or attention).
This function is used to replace Post-LayerNorm.
and constants
Where is the number of layers in the encoder and is the number of layers in the decoder.
Refer to the paper for derivation.
Here is an experiment implementation that uses DeepNorm.