Background: The Importance of Trainability in Neural Pruning
Neural network pruning, a method to enhance computational efficiency, has gained significant traction recently. The primary goal of pruning is to eliminate redundant parameters from a neural network without considerably degrading its performance. This process typically includes three phases: pre-training a dense model, pruning unnecessary connections to form a sparse model, and retraining the sparse model to regain performance. Two primary categories of pruning exist: unstructured pruning and structured pruning. The latter, structured pruning, is more aligned with modern network needs like ResNets, aiming for faster rather than smaller networks.
A notable phenomenon in neural network pruning is the crucial role of trainability. Unattended broken trainability, resulting from the pruning process, can lead to under-performance and affect the retraining learning rate, potentially causing biased results. The essence of trainability lies in its connection with the network's ability to learn effectively post-pruning.
Method: Introducing Trainability Preserving Pruning (TPP)
The innovation of Trainability Preserving Pruning (TPP) marks a significant advancement in this field. TPP is a novel filter pruning algorithm designed to maintain trainability through a regulated training process. The method focuses on decoupling the pruned (unimportant) filters from the retained (important) filters, effectively minimizing the dependencies that typically impede trainability post-pruning.
TPP leverages two main strategies:
- Regularizing the weight gram matrix to encourage zero correlation between pruned and kept filters. This approach avoids over-penalizing important weights, which could otherwise lead to optimization issues and suboptimal training.
- Incorporating a Batch Normalization (BN) regularizer. Given that BN parameters are part of the trainable network, their removal can significantly impact trainability. TPP addresses this by regularizing the learnable parameters in BN, thus mitigating potential detrimental effects on trainability
Regularizing the weight gram matrix to encourage zero correlation between pruned and kept filters, as implemented in the Trainability Preserving Pruning (TPP) method, is a sophisticated technique designed to maintain the trainability of neural networks during the pruning process. Let's break down this process in detail:
Understanding the Weight Gram Matrix
- What is a Weight Gram Matrix?
- A weight gram matrix is a matrix representation that captures the correlations between different sets of weights or filters in a neural network. In this context, it's used to understand how closely related different filters are within the network.
- Purpose of Regularizing the Weight Gram Matrix
- The goal of regularizing the gram matrix is to reduce the dependencies between filters that are going to be pruned (deemed unimportant) and those that will be kept (important). By reducing these dependencies, we can prune the network without significantly impairing its ability to learn, i.e., preserving its trainability.
Regularization Process
- Sorting Filters Based on Importance
- Filters in a layer are sorted based on their L1 norms. Filters with the smallest norms are considered less important and are candidates for pruning.
- Decorrelation Strategy
- The main idea is to adjust the gram matrix of weights such that the correlation entries between the pruned and retained filters approach zero. This means that the pruned filters have minimal influence on the kept filters, thus preserving the trainability of the network.
- Regularization Formula
- The regularization term used is given by:\(\) \[L_{L1} = \left \| W_l W_l^{\top} \odot (1 - mm^{\top}) \right \|_F^2\]
- where
- \(\) \(L_{L1}\) is the regularization term.
- \(\) \(W_l\) represents the weight matrix of the \(l\)-th layer.
- \(\) \(W_l^{\top}\) is the transpose of \(W_l\).
- \(\) \(\odot\) denotes the Hadamard (element-wise) product.
- \(\) \(m\) is a binary mask vector with 0s for pruned filters and 1s for retained filters.
- \(\) \(mm^{\top}\) is the outer product of m with itself, effectively creating a mask matrix.
- \(\) \(\| \cdot \|_F^2\) indicates the squared Frobenius norm, which is a measure of matrix size or 'weight'.
- Practical Implementation
-
- During the training process, this regularization term is added to the loss function. The effect is to penalize correlations between pruned and retained filters, encouraging the network to reduce these correlations as part of the training process.
-
- Avoiding Over-Penalization
-
- Unlike methods that strive for perfect orthogonality (where all non-diagonal elements of the gram matrix are zero), TPP seeks a weaker constraint. It aims for a benign state where gradients can flow effectively through the model without imposing strict orthogonality, which can be overly restrictive and lead to suboptimal training dynamics.
-
Experiment
Extensive experiments have been conducted to validate TPP's effectiveness. These studies spanned various networks and datasets, including MLP-7-Linear on MNIST, ResNet56 and VGG19 on CIFAR10/100, and ResNet34 and ResNet50 on ImageNet-1K. The results were promising, showcasing TPP's superior performance compared to other state-of-the-art (SOTA) pruning methods, especially in scenarios of high sparsity. This performance boost is attributed to TPP's ability to maintain trainability even under aggressive pruning conditions.
Specifically, on the ImageNet-1K benchmark, TPP consistently outperformed other methods across different speedup ratios. For instance, TPP showed better resilience to performance drop when compared to methods like Taylor-FO under varying speedup conditions.
Github Code:
MingSun-Tse/TPP: [ICLR'23] Trainability Preserving Neural Pruning (PyTorch) (github.com)
Comments