The beginning of the downfall for Batch Normalization? DeepMind released a new family of state-of-the-art networks for Image Classification that has surpassed the previous best – EfficientNet – by quite a margin. The main idea behind the new architecture is the use of Normalizer Free Neural Nets or NF-Nets to train networks instead of batch normalization, which is still widely used in almost every other neural net because of its rich set of advantages. The NF-Net-F1 matches the accuracy of EfficientNet-B7 while being 8.7x faster to train! Given below is the Accuracy vs Latency graph of the most recent state-of-the-art architectures for image classification on the ImageNet dataset.
Almost all neural networks, especially in the Computer Vision-Image Classification domain, heavily rely on Batch Normalization for training deep networks. May it be the smoothing of loss, reducing covariate shift, or the regularization effect, Batch Normalization has always given an unprecedented advantage in training a neural-net, till now. That’s about to change! The authors of NF-Net showed that we can train Deep Residual Networks without Batch Normalization by replacing it with a few other techniques that will lead to faster training and better accuracy than Batch Normalization-enabled networks. NF-Nets were faster to train and could have a batch-size of up-to 4096! But before diving into the details of the new alternatives and techniques, let’s first set the context by having a small walk-through of Batch Normalization.
Just as a quick recap, Batch Normalization is a method to train very deep networks like ResNet by standardizing the inputs for each mini-batch. There was a problem observed called internal covariate shift, which is the change of distribution of inputs between the layers, and it looked like the network was training towards a moving target. This is because neural nets are quite sensitive to the initial weights and the algorithm. Batch normalization solved this in the initial stages to a significant extent by shifting the inputs with a calculated mean and variance parameters that smoothen the loss landscape and stabilize the training.
The below image shows the various steps and formulae in the Batch Normalization procedure for each mini-batch.
Pros of Batch Normalization
- It is hugely successful in down-scaling the hidden activations when training deep ResNets (ResNet, ResNeXt, etc) consisting of 1000s of layers. How? Batch normalization when used on the residual branch was effective in scaling down the activation outputs. On the other hand, when Batch normalization was applied to the skip branch, it induced a slight bias towards the residual branch. This ensured that the initial training was stable and led to efficient optimizations.
- It eliminates the mean shift effect when non-zero mean activations are outputted by functions like ReLU & GELU. The main issue was, as the network went deeper, the mean-activation got larger and more positive, which led to the network predicting the same labels for all training samples and the training unstable. Batch normalization solves this by making sure that the mean-activation is zero across all layers.
- It has a good regularizing effect and it can replace Dropout as a regularizing factor while training neural nets. The networks trained with Batch Normalization do not overfit easily and generalize well for any given data sample. This is mainly due to normalizing the noisy batch stats. Also, the validation accuracy can be improved by tuning the batch size.
- It can train neural nets with bigger batch sizes and larger learning rates. Since the loss landscape is smooth, larger and stable learning rates can be used for training. This might not be efficient for smaller batch sizes. Also, it achieves the same accuracy in fewer steps compared to a non-batch normalized neural net, thereby improving the training speed.
Cons of Batch Normalization
- It can be computationally intensive in a few networks due to the calculation of mean and scaling parameters and storing them for the back-prop step.
- There can be discrepancies between the behavior of the network during training and testing times. While training, the network might have learned and trained to certain batches, which makes the network dependent on that batch-wise setup. So it might not perform well when a single example is provided at inference.
- Batch normalization breaks the independence between examples within a batch. This means the examples selected in a batch are significant and lead us to two more prospects – Batch Size matters and inefficient Distributed Training that will lead to the network cheating the loss function.
- Batch size matters because in a small batch, the mean approximate is going to be noisy whereas in a bigger batch the approximation may be worth considering. It has been observed that larger batches lead to stable and efficient training. Also, the performance of Batch Normalized networks can degrade if the stats have a large variance while training.
- Coming to the second point of cheating loss functions, difficult distributed training, and replication of results in another hardware, when we distribute the training in parallel streams, each stream will receive a portion or shard of a batch and a forward pass is applied. The discrepancy comes when there is no communication between the Batch Normalization layers i.e. all the streams will calculate the mean and variance parameters independently. This is wrong because the parameters are not holding true for the entire batch but rather for each shard, which leads to cheating the loss function.
Background for NF-Nets
From the above discussion on Pros and Cons, we understand that although Batch Normalization has been instrumental in training deep networks, there are major disadvantages, and the authors put forward the idea of heading towards a new direction of Neural Nets free of batch normalization. The way to achieve this was to replace batch norm by suppressing the hidden activations on the residual branch. The authors of this [Paper] implemented a normalizer-free ResNet by suppressing the residual branch at initialization and using Scaled Weight Standardization. Weight Standardization controls the first and second moments of the weights of each output channel individually in convolutional layers. Weight Standardization standardizes the weights in a differentiable way which aims to normalize gradients during back-propagation.
These nets (NF-ResNet) were able to match the accuracy of Batch Normalized ResNets but struggled with larger batch sizes and failed to match the current state-of-the-art EfficientNet. NF-Nets are built on top of this research work.
Main Contributions of the Research
- The authors propose a new method, Adaptive Gradient Clipping, which can be used to clip the gradient based on the unit-wise ratio of gradient norms and parameters norms allowing training of NF-Nets with larger batches and stronger data augmentations.
- It introduced a family of Normalizer-free ResNets, NF-Nets which surpass the results of the previous state-of-the-art architecture, EfficientNets. The largest NF-Net model achieved a top-1 accuracy of 86.5% (new state-of-the-art) without the use of extra data!
- It shows that the NF-Nets outperform the Batch Normalized networks in terms of validation accuracies when fine-tuned on ImageNet. The top-1 accuracy post fine-tuning is 89.2%
Let us understand the main highlight of the paper i.e. Adaptive Gradient Clipping
Adaptive Gradient Clipping
Before going into Adaptive Gradient Clipping or ADC in short, what is gradient clipping? Gradient Clipping is a method to limit a huge change in gradient values either positively or negatively. To put it in simple terms, we don’t want the gradient to take big jumps while finding the global minima. We simply clip off the gradient value when it is too much. But, we also have to accommodate the scenario where the gradient has to be large enough to come out from the local minima or correct its course while traversing through the loss landscape. If the path of the resultant is good enough, we will get it again for sure, but on the other hand if it is a bad gradient, we want to limit its impact.
The authors hypothesize that gradient clipping will enable the training of NF-Nets with a large batch size and larger learning rate. This is because of previous works in Language Modelling [Paper] to stabilize the training and it also allows the usage of larger learning rates, thereby accelerating the training. The standard gradient clipping [Paper] is given by,
where G is the gradient vector, G = ∂L/∂θ, where L denotes the loss and θ denotes a vector with all model parameters. This clip is done before updating the θ. The lambda (λ) is the clipping threshold hyperparameter that must be tuned. The authors observed that this hyperparameter is very highly influential for the training (training stability was extremely sensitive to the clipping threshold value chosen) and hence required very fine-grained tuning. To counter this, they introduced Adaptive Gradient Clipping or ADC.
The gist of ADC is that it clips the gradient w.r.t this particular ratio – norm of Gradients G^l to the norm of Weights W^l for the layer l (How large the gradient is / How large the weight is that the gradient acts upon). This will give a measure of the change in the gradient step that will change the original weights. Specifically, each unit i of the gradient of the l-th layer - G(l)(i) given by,
where epsilon = 10^-3 to prevent the zero-initialized parameters to be clipped back to 0. With the usage of AGC, it enabled training neural networks stably with batch sizes of 4096, which is quite a massive number for a batch. λ, which is the clipping parameter, is set depending upon the optimizer, learning rate, and batch size. The below image depicts the scaling of NF-Nets for larger batchsizes.
Empirically, it was found that a lower clipping threshold holds good for larger batch sizes. Consider a large batch size, where the batch stats are not that noisy. Here the clipping must be set low, else it will collapse. Likewise, if the batch size is very small, there might be more noise in the batch stats and we can get away with a higher clipping threshold.
Coming to the architecture, NF-Net is a modified version of SE-ResNeXt-D [Paper]. The model has an initial “stem” comprised of a 3×3 stride-2 convolution with 16 channels, two 3×3 stride-1 convolutions with 32 channels and 64 channels respectively, and a final 3×3 stride-2 convolution with 128 channels. The activation function used here is GELU [Paper] which performs as good as ReLU & SiLU.
The Residual stages consist of two types of blocks, starting with a transitional block and followed by the non-transitional block. All blocks employ the pre-activation ResNe(X)t bottleneck pattern with an added 3×3 grouped convolution inside the bottleneck. Following the last 1×1 convolution is a Squeeze & Excite layer which globally average pools the activation, applies two linear layers with an interleaved scaled non-linearity to the pooled activation, applies a sigmoid, then rescales the tensor channel-wise by twice the value of this sigmoid.
After all of the residual stages, we apply a 1×1 expansion convolution that doubles the channel count. This layer is primarily helpful when using very thin networks, as it is typically desirable to have the dimensionality of the final activation vectors (which the classifier layer receives) be greater than or equal to the number of classes. Coming to the final layer, it outputs a 1000-way class vector with learnable biases. This layer has its weights initialized to 0.01 and not 0.
The important feature of NF-Nets is “Normalizer Free”. The input to the main path of the residual block is multiplied by 1/β, where β is the predicted value of the variance at that block at initialization, and the output of the block is multiplied by a scalar hyperparameter α, typically set to a small value like α = 0.2. These scalars α and β are very instrumental in achieving the Normalizer Free implementation. The formulae are given below,
The model and its configurations are quite complex and going through the paper will help if you’re replicating/implementing the NF-Nets. The experiments also are a tad lengthy and to keep it short and crisp in this blog, I have not included the experiments and ablation studies.
Benchmarks and Results
As we see, NF-Nets outperform their Batch Normalized counterparts. The NF-Net-F5 model achieved a top-1 validation accuracy of 86%, improving over the previous state-of-the-art results. And behold, we have the NF-Net as the new state-of-the-art network, whereas the NF-Net-F1 model matches the EfficientNet-B7’s score of 84.7% (all depicted in the below table).
Problems which may not yet be solved
- The training examples are still implicitly dependent on the batch if we observe how the implementation is done: Clipping after the averaging operations for that batch.
- The observation of different behaviors in train and test time was quoted as one of the problems of Batch Normalization. But here, in the implementation, they have used Dropout, which also has different behavior in train and test times.
To conclude, there were a lot of things going on in this paper, but few were very significant for the successful development of Normalizer Free Neural Nets. The introduced NF-Nets surpassed the performance of the latest state-of-the-art for Image Classification (without using extra data) and was also faster to train. It was also shown that the family of NF-Nets is better-suited for fine-tuning on large datasets than the batch normalized variants.
- NF-Net: https://arxiv.org/pdf/2102.06171v1.pdf
- EfficientNet: https://arxiv.org/pdf/1905.11946.pdf
- Batch Normalization: https://arxiv.org/pdf/1502.03167.pdf
- ImageNet: http://www.image-net.org/
- ResNet: https://arxiv.org/pdf/1512.03385.pdf
- ReLU: https://arxiv.org/pdf/1811.03378.pdf
- GELU: https://arxiv.org/pdf/1606.08415.pdf
- Unnormalized ResNet: https://arxiv.org/pdf/2101.08692.pdf
- Regularizing and Optimizing LSTM Language Models: https://arxiv.org/pdf/1708.02182.pdf
- Gradient Clipping & Regularization: https://arxiv.org/pdf/1211.5063.pdf
- SE-ResNeXt-D (Squeeze & Excitation Networks): https://arxiv.org/pdf/2102.06171v1.pdf