Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Presents Recurrent Interface Network (RIN), a neural net architecture that allocates computation adaptively to the input
- Hidden units of RINs are partitioned into the interface and latents
- RIN block selectively reads from the interface into latents for high-capacity processing
- Stacking multiple blocks enables effective routing across local and global levels
- Latent self-conditioning technique “warm-starts” the latents at each iteration of the generation process
- RINs yield state-of-the-art image and video generation without cascades or guidance
- Up to 10$\times$ more efficient compared to specialized 2D and 3D U-Nets
Paper Content
Introduction
- Design of effective neural network architectures is important for deep learning
- Convolutional neural networks and Transformers are examples of architectures
- Computation is usually allocated in a fixed, uniform manner
- It is important to allocate computation in an adaptive manner to improve scalability
- Prior work has explored dynamic and input-decoupled computation
- Generating images and videos with high-dimensional data requires adaptive computation
- Recurrent Interface Networks (RINs) is a new architecture that allocates computation more effectively
- RINs outperform U-Net architectures for image and video generation
- Latent self-conditioning is proposed to reduce the cost of routing
- RINs lead to significant performance and efficiency gains in diffusion models
Method
Overview
- RINs use tokenization to connect the interface to the input space and learnable embeddings to initialize the latents.
- RINs route information between the interface and latents.
- The interface grows linearly with input size, while the number of latents is much smaller.
- RINs are more efficient than U-Nets, Transformers, and other decoupled architectures.
- RINs are especially useful in recurrent settings.
Iterative generation with diffusion models
- Diffusion models learn a series of state transitions to map noise from a known prior distribution to data
- A forward transition from x 0 to x t is defined
- A neural net is learned to predict from x t and then estimate x t−∆ from the estimated ˜ and x t
- Samples are generated by iteratively applying the denoising function
- The network takes as input a noisy image x t , a time step t, and an optional conditioning variable
- The interface is initialized from an input x, such as an image or video
- Latents are initialized as learned embeddings
- The RIN block routes information by reading from X into Z, processing Z, and writing updates back to X
- MLP and MHA are used to process information
- Readout layer is applied to the corresponding interface tokens to predict local outputs
- Local outputs are combined to form the desired output
Latent self-conditioning
- RINs use routing information to allocate compute to parts of the input
- Latents are built by reading interface information
- Without context, there is a “cold-start” problem
- Humans face a similar “cold-start” problem
- RINs can amortize the “warm-up” cost in sequential computation settings
- Propose to “warm-start” latents using latents from a previous step
- Latent self-conditioning conditions on the latent activations of the neural network
Experiments
- RINs improve performance on image generation and video prediction
- RINs do not require guidance
- RINs are more efficient than other methods
Implementation details
- Noise schedule is based on cosine and sigmoid functions
- Sigmoid temperature is set to 0.9 by default
- Images are tokenized by extracting non-overlapping patches
- Videos are tokenized using 2x4x4 patches
Experimental setup
- Use ImageNet dataset for image generation
- Use CIFAR-10 to show model can be trained with small datasets
- Use FID and Inception Score as metrics for evaluation
- Use Kinetics-600 dataset for video prediction
- Use FVD and Inception Score as metrics for evaluation
Comparison to sota
- Image generation works well with small datasets such as CIFAR-10
- Obtained 1.81 FID without using improved sampling procedure
- Model has 31M parameters and trains in 3 hours on 8 TPUv3 chips
- Video generation works without using guidance
- Latent self-conditioning is important for enhanced routing
- Stacking blocks enhances global and local processing
- Model can handle a wide range of patch sizes
- Sigmoid schedule with appropriate temperature is better during training than cosine schedule
- Noise schedule has less impact during sampling
- Visualizing read attention reveals which parts of the image are most attended to
- RINs bear resemblance to architectures that leverage auxiliary memory
- Latent self-conditioning allows RINs to leverage global context
- RINs are closely related to recurrent models with input attention
- Pixel diffusion models are predominant architecture for image and video generation
- Self-conditioning for diffusion models was originally proposed in (Chen et al., 2022c)
- RINs outperform U-Nets widely used in recent state-of-the-art image and video diffusion models