Link to paper
The full paper is available here.
You can also find the paper on PapersWithCode here.
Abstract
- Diffusion models generate realistic objects from complex data distributions.
- Diffusion models are computationally inefficient.
- Branched diffusion models offer improvements in efficient generation from multiple classes.
- Branched diffusion models offer advantages such as ease of extension and interpretability.
Paper Content
Introduction
- Diffusion models are popular for generating complex data distributions
- Diffusion models have been successful in generating images, videos, graphs, and tabular data
- Generating samples from diffusion models is computationally costly
- Training diffusion models can be a limitation in continual learning
- Branched diffusion models offer major improvements in computational efficiency for multi-class sample generation
- Branched diffusion models can generate new data classes efficiently
- Branched diffusion models offer interpretability into the diffusion and generation processes
Related work
- Song et al. (2021) proposed a method of conditional generation by class label
- Ho et al. (2021) proposed an alternative method of classifier-free conditional generation
- Classifier-free method can be applied to both continuous- and discrete-time diffusion models
- Several diverse methods have been proposed to improve training or sampling efficiency
Hierarchically branched diffusion models
- Proposed hierarchically branched diffusion structure for multi-class conditional generation
- Diffusion process shared between closely related classes at later time intervals
- Branches constrained such that each class and time can be assigned to one branch
- Branches form a rooted tree from t=T to t=0
- Optimal branch definitions can be computed from data or prior domain knowledge
- Implemented as one multi-task neural network
- Training follows same procedure as standard linear diffusion model
- Demonstrated on two datasets of different modalities
- Visualized random sample of digits generated from branched diffusion model
- Branched model generated letters that are realistic and true to training data
- Conditional generation of distinct class requires no more computation than unconditional diffusion model
- Label-guided diffusion models trained on same data
- Branched diffusion models achieved similar generative performance compared to label-guided models
- Branched diffusion models also able to perform in discrete-time diffusion settings
Efficient sampling from branched diffusion models
- Branched diffusion models offer improved computational efficiency when sampling from multiple classes.
- Sampling from a diffusion model is computationally expensive due to its iterative nature.
- Branched diffusion models share much of the diffusion process for multiple classes.
- Branched models take the same amount of computation to generate samples of a single class compared to a linear model.
- Branched models enjoy significant savings in computational efficiency when sampling multiple classes.
- Speedup factor for multi-class sample generation depends on the branching structure between the classes.
Extending branched diffusion models to novel classes
- Branched diffusion models have a hierarchical structure that allows new classes to be added easily.
- A small branched diffusion model was trained on three MNIST classes and then a new digit class (7) was added.
- The new branch was fine-tuned only on 7s, and the model was able to generate high-quality 7s without affecting the other digits.
- A label-guided (linear) diffusion model was also trained on 0s, 4s, and 9s, and when a new digit class (7) was added, the model suffered from catastrophic forgetting.
- Retraining the linear model on all data would take much longer than training the branched diffusion model, and the linear model still experienced inappropriate influence from the new task.
Interpretability of branched diffusion models
- Branched diffusion models are efficient and extendable.
- Branched diffusion models can reveal insight into common features between classes.
- Branched diffusion models can generate analogous objects between different classes.
Hybrids at branch points
- Branch points act as minimal times of indistinguishability
- Classes split off at branch points and start reverse diffusing
- Two branches meet at a branch point when classes become noisy
- Hybrid objects are generated from reverse diffusion
- Hybrids show shared characteristics between classes
- Hybrids act as a smooth transition between two endpoints
Transmutation between classes
- Diffusion models can traverse the diffusion process both forward and in reverse.
- Branched diffusion models can be used to start with an object from one class and generate the analogous object in a different class.
- On the MNIST branched diffusion model, 4s and 9s were transmuted based on the slantedness of the digit.
- On the tabular branched diffusion model, letters were transmuted between V and Y and all features showed a positive correlation.
Discussion
- Branched diffusion models represent a tradeoff between training and sampling complexity
- Branched models require more parameters and more training time than linear counterparts, but have significant savings in computational efficiency during sampling
- Branched diffusion models require branch points to be times of indistinguishability between classes, which can be challenging for certain image datasets
- Branched diffusion can easily accommodate other modalities such as tabular data, graphs, and text
Conclusion
- Proposed a novel form of diffusion models which introduces branch points to encode hierarchical relationship between data classes
- Branched diffusion models offer an alternative framework of conditional generation for discrete classes
- Flexibly applied to many traditional diffusion-model paradigms
- More efficient to sample multiple classes from a branched diffusion model than a linear one
- Easily extendable to new classes through a short and efficient finetuning step
- Can offer insights into interpretability
- Reverse-diffusion intermediates at branch points are hybrids which encode shared or interpolated characteristics of multiple data classes
- Can transmute objects from one class into the analogous object in another class
- Trained models and performed analyses on a single Nvidia Quadro P6000
- Used MNIST and tabular letter-recognition datasets
- Used variance-preserving stochastic differential equation for continuous-time diffusion models
- Used discrete-time Gaussian noising process for discrete-time diffusion model
- Computed branch definitions using an algorithm
- Used UNet and dense architectures for MNIST and tabular letters models respectively
- Used group normalization after every layer
- Used time and label embeddings
- Trained with a batch size of 128 examples
- Used learning rate of 0.001
- Trained for different epochs for different models
- Used predictor-corrector algorithm for continuous-time diffusion models
- Used sampling algorithm for discrete-time diffusion model
- Compared quality of samples generated from branched and linear diffusion models using Fréchet Inception Distance
- Leveraged branching structure to generate samples