Stable Diffusion, one of the most popular open source models, is known for its ability to generate highly detailed and creative images based on text prompts, making it a pivotal tool in AI-driven art and design. However, without solid training infrastructure and expertise, pre-training would take a prohibitively long time with unnecessarily large costs.
In this blog post, we introduce an advanced pre-training solution for Stable Diffusion v2 models, leveraging the power of the Ray and Anyscale Platform to enhance scalability and cost efficiency.
In summary,
We pre-trained the Stable Diffusion v2 model on a massive dataset of ~2 billion images for under $40,000.
We eliminated preprocessing bottlenecks with Ray Data, boosting training throughput by 30%.
We further apply system and algorithm optimizations to reduce training costs by 3x compared to baseline methods.
This is the first blog in the three-pieces series. Check the Definitive Guides with Ray series on:
Processing 2 Billion Images for Stable Diffusion Model Training
Pre-Training Stable Diffusion Models on 2 billion Images Without Breaking the Bank
Stable Diffusion is a conditional generation model that generates high-quality images from textual prompts. Figure 1 illustrates its training pipeline:
A pre-trained VAE and a text encoder(OpenCLIP-ViT/H) encodes the input images and text prompts.
A trainable U-Net model learns the diffusion process with the image latents and text embeddings.
The loss is calculated based on the input noise and the noise predicted by the U-Net.
The model is pre-trained on 256x256 images(Phase 1) and then on 512x512 images(Phase 2).
Despite the impressive generation quality, baseline pre-trained Stable Diffusion models [1] might not always be suitable for commercial use. One concern involves the pre-training datasets, which could introduce potential biases[2] or contain illegal or copyrighted contents[3, 4]. To guarantee model consistency, fairness, and avoid ethical and legal issues, many organizations choose to pre-train their own models with carefully curated datasets.
Baseline pre-training requires over 200,000 A100 GPU hours on billions of images [5]. This highlights the inherent challenges of pre-training due to its large-scale and computationally intensive nature. There are 3 main factors can severely hinder training efficiency:
Bottlenecks on Data preprocessing. Complex preprocessing logic for Stable Diffusion, especially image and text encoder inference, is computationally expensive and competes with the U-Net model for precious GPU resources.
Failures in Large-Scale Training: Hardware and application failures in large-scale long-running jobs are common, but notoriously hard to diagnose and fix. Failing to recover promptly from failures can lead to significant wasted training progress and compute resources.
Inefficient Training Infrastructure: Unoptimized training infrastructure (e.g. training strategy, network configuration) may lead to ineffective use of hardware. However, identifying and resolving the bottlenecks in a distributed cluster is challenging.
In the next section, we discuss how to address these technical challenges with Ray Data and Ray Train.
Traditionally, the entire Stable Diffusion model (including Encoders, U-Net) is placed on a single GPU. Training workers load the input images and prompts from storage (e.g. using PyTorch DataLoader), then feed them to the model. This is possible when using GPUs like the A100 with a large enough GRAM. It can still lead to severe underutilization issues.
Based on the memory profiling results in Figure 3, we can see that the "Encoder Forward" step may be the bottleneck. Here, 0.44 seconds (~39%) of the iteration time is spent in the encoder, while the GRAM utilization is only ~25%. Using A100 for encoder inference is a bit overkill. If we can move the encoders out of A100s and only train U-Net on it, the GPU utilization should improve accordingly.
One straightforward idea is to precompute encodings before training, such that we can decouple encoders and U-Net. As shown in Figure 3, we split the whole training workload into 2 separate jobs.
An offline preprocessing job that consists of the following steps:
Load and Transform: Load the input data from the cloud storage, and do image transformation and text tokenization.
Model Inference: Feed the transformed data to the encoders, and generate latent vectors and embeddings.
Save Results: Save the latents and embeddings back to the cloud storage.
A model training job that streams the pre-computed latents from S3 to training workers, allowing A100 to be fully utilized for U-Net training.
For the offline preprocessing job, the "Load and Transform" and "Save Results" stages only require CPUs, whereas the "Model Inference" stage primarily requires GPUs. Running them sequentially may cause CPU-intensive tasks to become a bottleneck for GPU tasks, resulting in low GPU utilization.
Thanks to Ray's native support for heterogeneous resources, Ray Data can schedule CPU and GPU tasks independently across a cluster. It provides fine-grained control over concurrency and batch sizes for each stage, ensuring that the GPUs are always fully utilized.
In addition, Ray Data natively ingests and processes data in a streaming manner. Streaming execution eliminates the need to load all data into memory before model inference, thus significantly reducing memory requirements.
For more details on the advantages of using Ray Data over other solutions for batch inference jobs, please refer to this blog post.
During model training, Ray Data streams precomputed latent variables from S3, and then feeds them into U-Net for training. For comparison, we also built a baseline method using Torch DataLoader, which loads raw images and text from S3 and encodes them on the fly with image and text encoders.
Offline Preprocessing | Training Throughput @ Resolution 256x256 (images/s) | Training Throughput @ Resolution 512x512 (images/s) | |
---|---|---|---|
Torch DataLoader | ✗ | 2805 | 812 |
Ray Data | ✓ | 4068 (1.45x) | 1029 (1.26x) |
Table 1: Training throughput on 32 x A100-80G with Torch Dataloader and Ray Data offline preprocessing. Micro batch sizes are 128 and 32 for resolution 256 and 512.
Table 1 shows the throughput gains from decoupling Encoders and U-Net. Ray Data offline preprocessing resulted in 1.45x and 1.57x improvements over the Torch DataLoader baseline for images at resolutions 256 and 512.
While offline preprocessing boosts training throughput and eliminates repeated encoding computations across epochs, it's not a one-size-fits-all solution.
When the preprocessing logic changes, we need to take tens of hours to reprocess the entire dataset before training can start. Furthermore, when preprocessing involves some dynamic logic (such as random cropping), offline preprocessing needs to precompute multiple copies for different epochs, which incurs excessive time and storage costs.
Therefore, we propose an end-to-end training pipeline, which combines data ingestion, preprocessing, and model training into a single job. This approach overcomes the limitations of offline preprocessing without sacrificing the throughput gains of decoupled encoders.
Figure 5 shows the overall architecture, including the following key features:
Leveraging heterogeneous instance types for Data and Training. Ray Data allows launching preprocessing workers and training workers on different instance types. This is a unique advantage over existing solutions (such as torch dataloader), which requires data preprocessing and model training to be collocated on the same node.
Automatic data streaming and sharding. Ray Data can stream data all the way from cloud storage through pre-processing workers, and finally the training workers. Data batches are automatically sharded evenly. All the data transfer is via Ray in-memory object storage, so no intermediate storage is required.
Figure 6 shows the scalability of the online preprocessing. In this experiment, we fixed the number of A100s for distributed training, while scaling up the number of A10G GPUs for online preprocessing. Training throughput increases linearly until the training workload on the A100 becomes the bottleneck. Eventually, it converges to the throughput achieved with offline preprocessing.
Resources | Training Throughput (images/s) | Hourly Cost | Training Time on Anyscale Platform | Total Cost Per Epoch | |
---|---|---|---|---|---|
Online Processing (Torch Dataloader) | 4 x p4de.24xlarge | 2811 | $163.44 | 111.3h | $18,192 |
Offline Processing (Ray Data) | 4 x p4de.24xlarge | 4068 | $163.44 | 76.9h (-30.9%) | *$14,753 (-18.9%) |
Online Processing (Ray Data) | 4 x p4de.24xlarge 40 x g5.2xlarge | 4075 | $221.92 | 76.8h (-31.0%) | $16,275 (-10.5%) |
Table 2: Cost analysis for Stable Diffusion Pre-training on 1,126,400,000 images at resolution 256x256. The cost is estimated with the on-demand instance price on AWS (40.96$/h for p4de.24xlarge and 1.212$/h for g5.2xlarge). *Includes an extra $2,183 for one-time offline preprocessing.
Table 2 shows the cost analysis for the first phase of pre-training. Both online and offline preprocess with Ray Data saved 30% of the training time. Offline preprocessing also saves the overall training cost by 18%.
Check Definitive Guide with Ray on Processing 2 Billion Images for Stable Diffusion Model Training - Series to learn about the implementation in detail.
Fault tolerance is always a major issue in large-scale distributed training, where various hardware and software failures can occur. To avoid losing training progress, we need to periodically checkpoint the model states and be able to recover from them.
Ray Train provides an out-of-the-box solution for fault-tolerant training. During training, each worker independently syncs its checkpoint to cloud storage (e.g. S3, GCS). When an inevitable hardware or software failure occurs, Ray Train will automatically rescale the cluster, restore the latest checkpoint from cloud storage, and continue training.
For more information about how to enable fault tolerant training with Ray Train, please refer to our user guide[6].
Apart from resolving encoding bottlenecks, we also adopted a series of optimizations to accelerate U-Net training. Users can easily integrate these optimizations with Ray Train and Anyscale.
Elastic Fabric Adapter (EFA) [7] provides lower and more consistent latency and higher throughput than the TCP transport traditionally used in cloud-based HPC systems. With the optimized NCCL plugins, it significantly reduces the communication overhead and speeds up distributed training. This feature has been automatically supported on Anyscale Platform.
Figure 7: Fully sharded data parallel training - algorithm overview (adapted from FB engineering blog).
FSDP [8] is designed to reduce communication overhead by sharding model state across multiple devices. We use SHARD_GRAD_OP
mode in our experiments, which partition the gradient and optimizer states among all workers. During training, each worker aggregates only a portion of the gradients, and updates the corresponding weights with the sharded optimizer states. Compared to DDP, it reduces communication overhead of full gradient synchronization and also reduces peak GRAM usage, allowing for larger batch sizes and higher throughput.
Torch.compile [9] is a PyTorch feature that optimizes model execution by compiling PyTorch code into more efficient representations. This Just-In-Time (JIT) compilation process analyzes the computation graph of a model and applies optimizations, such as kernel fusion, dead code elimination to reduce execution overhead and accelerates training.
Baseline (DDP) | + EFA | + FSDP | + torch.compile | |
---|---|---|---|---|
Resolution @ 256x256 | 1075 | 1269 (1.18x) | 1925 (1.79x) | 2910 (2.71x) |
Resolution @ 512x512 | 264 | 474 (1.86x) | 667 (2.52x) | 805 (3.05x) |
Table 3.a: Training throughput (images/s) and speedup on 16 x A100-40G. The baseline method is using DDP training and TCP network transport.
Baseline (DDP) | + EFA | + FSDP | + torch.compile | |
---|---|---|---|---|
Resolution @ 256x256 | 1573 | 4068 (2.59x) | 5014 (3.18x) | 5908 (3.75x) |
Resolution @ 512x512 | 389 | 1029 (2.64x) | 1168 (3.00x) | 1349 (3.46x) |
Table 3.b: Training throughput(images/s) and speedup on 32 x A100-80G. The baseline method is using DDP training and TCP network transport.
After applying all the above acceleration methods, we improved the training throughput by ~3x. Tables 3.a and 3.b above show the breakdown of throughput gains for each optimization:
EFA: The EFA shows better throughput improvement with 32 workers compared to 16. This enhancement is due to increased communication overhead as the training scales.
FSDP: Fully Sharded Data Parallel achieves a greater speedup on the A100-40G model compared to the A100-80G, since training on A100-40Gs is more constrained by GPU memory limits.
Torch.compile: This optimization contributes additional speedup in both scenarios by minimizing PyTorch overhead through just-in-time (JIT) compilation.
In summary, in this blog post we introduced a scalable and cost-efficient solution for Stable Diffusion pre-training built with Ray Data and Ray Train. Compared with vanilla PyTorch solutions, Ray takes full advantage of heterogeneous resources and advanced scheduling capabilities to significantly reduce pre-training costs to less than $40,000.
Cloud Provider | AWS (us-west-2) |
GPU Type | A100-80G |
Cluster Setting | 4 x p4de.24xlarge |
Global Batch Size | 4096 |
Training Procedure | Phase 1: 1,126,400,000 samples at resolution 256x256 Phase 2: 1,740,800,000 samples at resolution 512x512 |
Total A100 Hours | 13,165 |
Total Training Cost | $39,511 (1-yr reservation instances) |
Table 4: Summary of pretraining configurations and costs. Compared with the original results reported by Stability AI (200,000 A100-40G hours)[5], we reduced the training cost ~10x.
If you are interested in learning the implementation in the greater detail, check out the Definitive Guides with Ray series on:
Processing 2 Billion Images for Stable Diffusion Model Training
Pre-Training Stable Diffusion Models on 2 billion Images Without Breaking the Bank
If you are impatient and want to run our reference implementation right away, check out this pre-packaged solution with all the code.