Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Neural networks need plasticity to be adaptable and robust.
- The mechanisms behind plasticity loss are not well understood.
- This paper looks into plasticity loss to understand it better.
- Plasticity loss is connected to changes in the loss landscape.
- Saturated units and divergent gradient norms are not the cause of plasticity loss.
- Layer normalization can help preserve plasticity.
Paper Content
Introduction
- Neural networks trained to fit different learning objectives have reduced ability to solve new tasks
- When inputs and prediction targets change over time, networks must learn to ‘overwrite’ prior predictions
- Deep reinforcement learning agents are trained in a way that causes plasticity loss
- Understanding and mitigating plasticity loss is important for developing deep RL agents
- Existing methods act on potential mechanisms of plasticity loss
- Loss landscape curvature is a crucial factor in determining plasticity
Background
- Training a network on one task and then a second can reduce performance on the first task (catastrophic forgetting).
- Training a network on a series of tasks can result in worse performance on later tasks than a randomly initialized network.
Preliminaries
- TD learning is a form of reinforcement learning
- TD learning uses networks trained with sampled transitions from an agent’s interaction with the environment
- Loss landscape analysis studies the structure of the loss landscape traversed by an optimization algorithm
- The Hessian of a network is used to measure the sharpness of the loss landscape
- The gradient covariance is used to measure interference between inputs
Defining plasticity
- Plasticity has been studied in neuroscience for decades and is a recent topic of interest in deep learning.
- Classical notions of complexity evaluate whether a hypothesis class contains functions that capture arbitrary patterns, but don’t consider the ability of a search algorithm to find these functions.
- Plasticity is a problem-dependent property that captures the interaction between the network state, optimization process, and training data.
- Capacity is a fixed property of a network architecture.
- Plasticity is measured by the ability of a network to update its predictions in response to a wide array of possible learning signals.
- Plasticity is defined as the difference between the baseline and the expectation of the final loss obtained by an optimization process.
Two simple studies on plasticity
- Plasticity loss can occur in learning problems.
- Optimizer design can interact with nonstationarity to cause instabilities.
Optimizer instability and non-stationarity
- Deep learning methods have been widely adopted due to the robustness of existing optimizers.
- Adam optimizer can yield reasonable initial results, but can experience catastrophic divergence when assumptions on stationarity no longer hold.
- An example of this is shown in Figure 1, where a two-hidden-layer fully-connected neural network is trained to memorize random labels of MNIST images.
- Increasing and setting a more aggressive decay rate for the second-moment estimate can avoid catastrophic instability.
Loss landscape evolution under non-stationarity
- Prior work observes reductions in network plasticity
- Causes of this phenomenon are difficult to determine
- Experiment compares evolution of two coupled updating procedures
- Hessian eigenvalue distribution and gradient covariance structure are evaluated
- Gradient descent induces bias that pushes parameters towards regions of parameter space with less friendly loss landscape
- Plausible explanations of plasticity loss do not identify robust causal relationships
- Plasticity loss may arise due to changes in network’s loss landscape
Experimental setting
- Constructed a simple MDP analogue of image classification
- Constructed three variants of a block MDP with state space from CIFAR-10 or MNIST image dataset
- Reward and transition dynamics depend on action taken by agent
- Reward functions allow comparison of tasks aligned with network’s inductive bias
- Trained DQN agents on each environment-observation space combination
- Evaluated ability of each network to fit randomly generated set of target functions
- Updated target network every 1,000 steps
- Logged loss after 2,000 steps of optimization
- Considered two network architectures: MLP and CNN
Falsification of prior hypotheses
- Prior work has proposed explanations for why neural networks may exhibit reduced ability to fit new targets over time
- Explanatory power of these hypotheses has not been rigorously tested
- To test, 128 DQN agents trained under a range of tasks, observation spaces, optimizers, and seeds
- Logged several statistics of the parameters and activations, along with the plasticity of the parameters
- Scatterplots show relationship between plasticity and each statistic
- Correlation between plasticity loss and the quantity of interest is nonexistent or weak
Loss landscape evolution during training
- Plasticity loss is characterized by reduced ability to fit arbitrary new targets
- Learning curves convey ease or difficulty of navigating the loss landscape
- Early training checkpoints quickly attain low losses, but learning curves have increasing variance
- Increasing difficulty of navigating the loss landscape drives plasticity loss
- Scaling architecture reduces plasticity loss, but cannot completely eliminate it
Solutions
- Neural networks can lose plasticity when classifying MNIST digits
- Section 5.1 will evaluate if scaling can reduce plasticity loss
- Section 5.2 will evaluate the effect of interventions on plasticity
- Section 5.3 will test findings on larger scale tasks
The role of scaling on plasticity
- Plasticity loss is a challenge that may be solved by increasing the size of a network.
- Scaling a CNN to the limit of a single GPU’s memory is not enough to eliminate plasticity loss.
- Plasticity loss is unlikely to be the limiting factor for sufficiently large networks on simple tasks.
Interventions in toy problems
- Evaluated effect of interventions on plasticity loss
- Task used: 100 iterations of 1000 steps
- 4 architectures: MLP, CNN, ResNet-18, Vision Transformer
- Interventions: reset last layer, reset optimizer, layer normalization, Shrink and Perturb, spectral normalization, weight decay
- Smoothing out loss landscape most effective for preserving plasticity
- Two-hot encoding reduces plasticity loss but affects stability of policy
Application to larger benchmarks
- Layer normalization improves performance on large-scale benchmarks
- Double DQN, RMSProp, -greedy exploration, and frame stacking are used
- Layer normalization reduces plasticity loss and improves robustness of RL agents
- Layer normalization induces weaker gradient correlation in environments where it significantly improves performance
Related work
- Initialization and architecture design can affect gradients in neural networks
- ResNets bias each layer’s mapping towards the identity function to improve gradients
- Meanfield analysis, information propagation, and deep kernel shaping have been used to study trainability
- Loss landscape smoothness affects generalization and performance
- Early training periods can be chaotic and related to linear mode connectivity
- Loss of plasticity can limit deep reinforcement learning
- Resetting and distillation can affect performance
Conclusions
- Divide between curriculum learning and foundation models
- Plasticity loss is not a limiting factor in network performance for small environments
- Stabilizing the loss landscape is a crucial step towards promoting plasticity
- Smoother loss landscape is easier to optimize and has better generalization
- Optimizer instability studied on MNIST dataset
- Brownian motion studied on easy classification MDP with MNIST observation space
- MLP, CNN, ResNet, and Vision Transformer architectures studied
- Interventions studied: reset last layer, weight decay, spectral norm, use layernorm, shrink and perturb
- Categorical output representation found to have beneficial effects
- Interventions interfere with learning on primary task
- Learning curves of different networks studied
- Gradient covariance and Hessian eigenvalue density studied
- Positive and negative correlations between variables and plasticity found