This guide is the third in a series that will explore the challenges and solutions involved in training Stable Diffusion models at scale. Stable Diffusion models are a class of generative models that have shown promising results in generating high-quality images. However, training these models requires large amounts of data and computational resources, which can be challenging to manage. This guide will focus on the model training aspect of training Stable Diffusion models, particularly the challenges involved in running model training at scale.
In this guide, we will learn how to:
💻 Train a Stable Diffusion model using Ray Train + PyTorch Lightning
💡 Understand the strategies for optimizing the training process
🚀 Scale the training process to handle extensive datasets and computational demands
🔍 Identify common challenges faced during large-scale training and how to address them
The code and algorithms behind Stable Diffusion are publicly available, allowing for model enhancement through training on high-quality datasets. Such training can improve image quality and address commercial use rights issues. However, training a diffusion model like Stable Diffusion from scratch requires a robust system for distributed computing. Without this, the training process could be lengthy and inefficient, leading to wasted time and resources.
Shown below is the end-to-end architecture diagram of the Stable diffusion model training. For reference, we already covered the Transformation and Encoding stages in the first guide, we aim to go over the training stage of the Stable Diffusion U-Net in this guide.
Note that at any point in the above diagram we can choose to persist the data stream to storage and load it to complete the remainder stages. What we did in the first guide was to store the output of the “Encoding” stage to S3 so we have Stable Diffusion U-Net inputs ready.
Our approach is practical and code-oriented rather than deeply mathematical. We aim to equip you with the intuitive understanding necessary to adapt our code for your projects. The guide specifically focuses on implementing the model training stage for the v2-base model of Stable Diffusion.
You don't need access to a distributed computing cluster to follow and run the code in this guide, as the same code is compatible on both single machines and clusters, thanks to Ray’s versatility. If you are looking to scale your code, try the reference implementation on the Anyscale platform. What you will get is a VSCode remote session running within a minimally sized cluster.
On a high level, the progression of this guide mimics the following steps needed to train a Stable Diffusion model, which we cover in much more detail later on:
Loading the preprocessed data: We start by loading the preprocessed data we prepared in our first guide.
Defining a Stable Diffusion model: Our Stable Diffusion model uses a network architecture called a U-Net. During training we start with preprocessed images (referred to as "image latents") and encoded text and iteratively add random noise to the image data ("noisy latents"). The U-Net will then learn to predict the noise we added. At inference time, we can then, for instance, start with a text description of an image we want to generate, and our Stable Diffusion model will invert the prior process (called a diffusion process), by iteratively subtracting noise from a completely random starting image, until the resulting image "looks like" text input we prompted the model with.
Defining a scalable training procedure: We scale the base model from the last step with Ray Train to enable running on a GPU compute cluster. We will go over how to set up and execute a Ray Train job, and how to integrate Ray Train with Ray Data to directly stream data onto the GPU training workers.
Running end-to-end training: Lastly, we show you how to carry out the training procedure itself on an Anyscale-managed Ray Cluster. We go over the specifics of the training process and discuss features like fault tolerance which Ray Train provides out of the box.
Let's start by loading the preprocessed data we prepared for training a Stable Diffusion model in the first guide.
We load our preprocessed dataset by implementing a simple load_precomputed_dataset
function
1def convert_precision_to_fp16(batch):
2 for k, v in batch.items():
3 batch[k] = v.astype(np.float16)
4 return batch
5
6def load_precomputed_dataset(
7 data_uri: str, num_data_loading_workers: int, resolution: int = 256
8) -> ray.data.Dataset:
9 ds = ray.data.read_parquet(
10 data_uri,
11 columns=[f"image_latents_{resolution}", "caption_embeddings"],
12 concurrency=num_data_loading_workers,
13 )
14
15 return ds.map_batches(
16 convert_precision_to_fp16,
17 batch_size=None,
18 concurrency=num_data_loading_workers,
19 )
load_precomputed_dataset
will perform two steps:
Use ray.data.read_parquet
to read the preprocessed image latents and caption embeddings
Ray Data adopts lazy execution and therefore this will simply return a Ray Dataset which represents our parquet dataset.
Apply convert_precision_to_fp16
to convert the data to a 16-bit floating point precision.
We chose to perform the precision halving on the data loading workers to reduce memory consumption on the training workers.
This is all we need to have our features ready for model training! As we will see later, Ray Train integrates seamlessly with Ray Data. Therefore no additional work is required in terms of constructing data loading functionality.
Let's discuss the training process of a Stable Diffusion model in a bit more detail. Consider the following diagram as the basis for our discussion:
The diagram contains a quick reminder of the preprocessing we applied: After cropping and normalizing images, we encoded them to a latent space using a VAE encoder. Tokenized text data is encoded with a CLIP model to generate text embeddings.
The core component of the training process is a U-Net, which we will define as a Pytorch Lighting module. The central idea of Stable Diffusion model training is to add random noise in image latent space at several timesteps t with a schedule, and apply it to get noisy latents. The purpose of U-Net is to learn to predict the noise. We compute the mean squared error (MSE) between the actual and the predicted noise as the loss function for training. Here's a visual representation of the full forward diffusion process as presented in the paper titled Denoising Diffusion Probabilistic Models.
Figure 3. Forward diffusion process. Picture taken from Denoising Diffusion Probabilistic Models.
As the U-Net gets better and better at predicting noise from noisy images, given an encoded text description, we can then use this to our advantage during inference. When provided with random noise and text input, we can instruct the U-Net to predict the noise, subtract it, and feed the result of it back into the U-Net, rinse and repeat. After a fixed number of steps we've removed "all the noise" and what's left is an image that fits the description. Here's a visual representation of the reverse diffusion process that leads to image generation:
Figure 4. Backward (reverse) diffusion process. Picture taken from Denoising Diffusion Probabilistic Models.
We're leaving out all mathematical details of this so-called diffusion process, but it is remarkable that learning to predictably destroy an image into pure noise means that you can predictably create art from pure noise, one iteration at a time. The idea of "just" chiseling away the random noise has a touch to it that is resembling the art of sculpting:
“The sculpture is already complete within the marble block, before I start my work. It is already there, I just have to chisel away the superfluous material.” Michelangelo
If you want to dive deep into the technical details of the diffusion process, we recommend you read the paper titled Denoising Diffusion Probabilistic Models.
Lets' define Stable Diffusion as a PyTorch Lightning module next. The definition itself is neither very difficult to understand, nor is it necessarily complex. But as the StableDiffusion
class contains many methods, it ends up being a bit long. To prepare you for that, let's first inspect the interface to walk you through the process.
Overall, here is the interface of the StableDiffusion
class:
1class StableDiffusion(pl.LightningModule):
2 def __init__(self, args):
3 ...
4
5 def on_fit_start(self):
6 ...
7
8 def forward(self, batch):
9 ...
10
11 def training_step(self, batch, batch_idx):
12 ...
13
14 def validation_step(self, batch, batch_idx):
15 ...
16
17 def configure_optimizers(self):
18 ...
Let's break down the most important methods of StableDiffusion
down, one by one.
Initialization:
We initialize a U-Net model and a noise scheduler in the __init__
method. A few things to note:
The U-Net model config is fetched from the Hugginface Model Hub given a model name.
The noise scheduler we chose to use is a Diffusion Denoising Probabilistic Model (DDPM) scheduler.
The scheduler is responsible for sampling noise at different timesteps during training and we have loaded a pre-trained scheduler from the Huggingface Model Hub.
The loss function is a simple mean square error loss function.
1class StableDiffusion(pl.LightningModule):
2
3 def __init__(self, model_name: str):
4 model_config = PretrainedConfig.get_config_dict(model_name, subfolder="unet")
5 self.unet = UNet2DConditionModel(**model_config[0])
6 self.noise_scheduler = DDPMScheduler.from_pretrained(
7 model_name, subfolder="scheduler"
8 )
9 self.loss_fn = F.mse_loss
10
11 ...
Forward pass:
The forward pass of the model is defined in the forward
method. Here, we perform the following:
Take the image latents and caption embeddings/latents as input
Sample noise at different timesteps
Add the noise to the image latents to get noisy latents
Pass the noisy latents through the U-Net model along with the encoded caption text to get the predicted noise
1class StableDiffusion(pl.LightningModule):
2 ...
3
4 def forward(self, batch):
5 image_latents = batch[f"image_latents_{self.resolution}"]
6 caption_latents = batch["caption_embeddings"]
7
8 timesteps = self._sample_timesteps(image_latents)
9 noise = torch.randn_like(image_latents)
10 noised_image_latents = self.noise_scheduler.add_noise(
11 image_latents, noise, timesteps
12 )
13
14 outputs = self.unet(noised_image_latents, timesteps, caption_latents)["sample"]
15 return outputs, noise
Training step:
The training_step
method is where we compute the loss for the model. We calculate the mean squared error between the predicted noise and the actual noise. Pytorch Lightning will automatically perform the backward pass and the optimizer step with the returned loss.
1class StableDiffusion(pl.LightningModule):
2 ...
3
4 def training_step(self, batch, batch_idx):
5 outputs, targets = self.forward(batch)
6 loss = self.loss_fn(outputs, targets)
7 return loss
Configuring Optimizers
The configure_optimizers
method is used to define the optimizer and learning rate scheduler for the model.
In this method, we define the optimizer and learning rate scheduler for the model. We use the get_linear_schedule_with_warmup
function from the transformers
library to create a linear learning rate scheduler with warmup. We also define the optimizer using the AdamW
optimizer from the torch.optim
module.
1class StableDiffusion(pl.LightningModule):
2 ...
3
4 def configure_optimizers(self) -> OptimizerLRScheduler:
5 optimizer = torch.optim.AdamW(
6 self.trainer.model.parameters(),
7 lr=self.lr,
8 weight_decay=self.weight_decay,
9 )
10 scheduler = get_linear_schedule_with_warmup(
11 optimizer,
12 num_warmup_steps=self.num_warmup_steps,
13 num_training_steps=self.num_training_steps,
14 )
15 return {
16 "optimizer": optimizer,
17 "lr_scheduler": {
18 "scheduler": scheduler,
19 "interval": "step",
20 "frequency": 1,
21 },
22 }
get_linear_schedule_with_warmup
creates a learning rate schedule with a warm up period during which it increases linearly from 0 to the initial lr set in the optimizer, and then decreases linearly from the initial lr set in the optimizer to 0 afterwards - see the figure below to visualize this.
Note in our implementation, we increased num_training_steps
to a very large number to maintain the initial_lr at an almost constant value after the warm up period. To view the full code implementing the StableDiffusion
module, instantly provision a Ray cluster and start running your own download, check out the reference implementation on training Stable Diffusion on Anyscale.
Now that we have defined the StableDiffusion
model, we can define a PyTorch Lightning training script to train the model. A sample training code will look like the following:
1def lightning_training_loop(
2 init_from_pretrained: bool = False,
3 model_name: str = "stabilityai/stable-diffusion-2-base",
4 resolution: int = 256,
5 lr: float = 1e-4,
6 num_warmup_steps: int = 10_000,
7 weight_decay: float = 1e-2,
8 train_loader: DataLoader,
9 val_loader: DataLoader,
10) -> None:
11 model = StableDiffusion(
12 init_from_pretrained=init_from_pretrained,
13 model_name=model_name,
14 resolution=resolution,
15 lr=lr,
16 num_warmup_steps=num_warmup_steps,
17 weight_decay=weight_decay,
18 )
19
20 trainer = pl.Trainer(
21 accelerator="gpu",
22 devices="auto",
23 precision="bf16-mixed",
24 strategy=FSDPStrategy(...),
25 ...,
26 )
27 trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
PyTorch Lightning provides a high-level interface for training PyTorch models. The Trainer
class is used to configure and run the training loop. The fit
method is used to train the model on the training data and validate it on the validation data. The training loop above requires data loaders to be prepared which we will not go into detail here as we will focus on producing data loaders from Ray Data when we scale the training process.
To scale our training with Ray Train, our resulting code will look like something like this:
1from ray.train import TorchTrainer
2
3def train_func(config):
4 ...
5
6trainer = TorchTrainer(
7 train_func,
8 scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=True),
9 run_config=RunConfig(
10 name=experiment_name,
11 storage_path=storage_path,
12 failure_config=FailureConfig(max_failures=max_failures),
13 ),
14 datasets=ray_datasets,
15)
16trainer.fit()
We perform two main steps in the code above:
Instantiate a trainer
by using the ray.train.TorchTrainer
class which ties together the following components:
train_func:
A Python function that contains our model training logic. train_func
will be a slightly modified version of the lightning_training_loop
we defined earlier as we will see below.
scaling_config:
An object specifying the number of workers and whether to use GPUs. For more complex resource configurations, see the docs page here
run_config:
A configuration of the experiment, storage settings, and the maximum number of failures allowed.
datasets:
A Python dictionary of training and validation Ray Dataset objects that will be used for training the model.
Call trainer.fit
to launch workers on the Ray cluster, distribute the training function to the workers, and launch a process to manage the training job. Here's a high-level overview of the process:
Ray Train integrates with Ray Data to offer a performant and scalable streaming solution for loading and preprocessing large datasets. By default, Ray Data shards all datasets across workers as shown in the diagram below:
To access the ray_datasets
we passed to our TorchTrainer from within a worker's training function, we make use of the method get_dataset_shard
from Ray Train. To generate data loaders from the Ray datasets, we use the iter_torch_batches
method. Here's how we our train_func
will handle data loading and preprocessing:
1def train_func(config):
2 train_ds = train.get_dataset_shard("train")
3 val_ds = train.get_dataset_shard("validation")
4
5 train_dataloader = train_ds.iter_torch_batches(
6 batch_size=config["batch_size"]
7 )
8 val_dataloader = val_ds.iter_torch_batches(
9 batch_size=config["batch_size"]
10 )
11
12 ...
13
14 trainer = pl.Trainer(
15 ...
16 )
17
18 trainer.fit(
19 model,
20 train_dataloaders=train_dataloader,
21 val_dataloaders=val_dataloader
22 )
In case you need to perform some last-mile preprocessing on the training worker, you can use the collate_fn
argument in iter_torch_batches
. This function will be called on each worker to preprocess the data before it is fed into the model. Here's how we will use the collate_fn
argument to move the data to the GPU:
1def move_to_device_collate_fn(batch, device):
2 for k, v in batch.items():
3 batch[k] = torch.tensor(v).to(device)
4 return batch
5
6def train_func(config):
7 ...
8
9 collate_fn = partial(
10 move_to_device_collate_fn,
11 device=ray.train.torch.get_device()
12 )
13 train_dataloader = train_ds.iter_torch_batches(
14 batch_size=config["batch_size"],
15 collate_fn=move_to_device_collate_fn,
16 prefetch_batches=config["num_batches_to_prefetch"],
17 ...
18 )
19 val_dataloader = val_ds.iter_torch_batches(
20 batch_size=config["batch_size"],
21 collate_fn=move_to_device_collate_fn,
22 prefetch_batches=config["num_batches_to_prefetch"],
23 ...
24 )
Note given the collate_fn
operation happens locally in the training workers we avoid adding any heavy transformation in this function as it may become the bottleneck. Additionally, we can further reduce bottlenecks by passing iter_torch_batches
a prefetch_batches argument
. If prefetch_batches
is set to N, Ray Train will launch background threads to fetch and process the next N batches.
Having shown how to load the data within a training function, let's now define the main training function that we will use to train the Stable Diffusion model using Ray Train. The training function will perform the following steps:
Initialize random seeds for reproducibility
Load the training and validation datasets
Initialize the Stable Diffusion model
Initialize the PyTorch Lightning Trainer
Train the model
1def train_func(config: dict):
2 pl.seed_everything(config["seed"])
3 ...
4 train_dataloader = train_ds.iter_torch_batches(
5 batch_size=config["batch_size_per_worker"],
6 collate_fn=collate_fn,
7 ...
8 )
9 validation_dataloader = validation_ds.iter_torch_batches(
10 batch_size=config["batch_size_per_worker"],
11 collate_fn=collate_fn,
12 ...
13 )
14 model = StableDiffusion(
15 init_from_pretrained=config["init_from_pretrained"],
16 model_name=config["model_name"],
17 resolution=config["resolution"],
18 lr=config["lr"],
19 ...
20 )
21 ...
22 trainer = pl.Trainer(
23 accelerator="gpu",
24 precision="bf16-mixed",
25 strategy=RayFSDPStrategy() if config["fsdp"] else RayDDPStrategy(),
26 plugins=[RayLightningEnvironment()],
27 callbacks=callbacks,
28 ...
29 )
30
31 trainer.fit(
32 model,
33 train_dataloaders=train_dataloader,
34 val_dataloaders=validation_dataloader,
35 )
You might have noticed that we specified a distribution strategy in our above training function as either RayFSDPStrategy
or RayDDPStrategy
. Fully Sharded Data Parallelism (FSDP) and Distributed Data Parallelism (DDP) are two strategies used in distributed machine learning training to handle large models and datasets by distributing the workload across multiple machines. While both aim to improve training efficiency and scalability, they employ different approaches to achieve these goals.
In general, the choice between Distributed Data Parallelism (DDP) and Fully Sharded Data Parallelism (FSDP) in distributed training typically depends on the specific requirements of your training job, such as model size, resource constraints, and scalability needs. Both methods have their advantages and scenarios where they perform best. Here’s a breakdown to help you choose the right strategy:
Figure 8: Distributed Data Parallelism (DDP) in distributed training. Picture taken from the facebook engineering blog post on FSDP
How it Works:
DDP is one of the most straightforward and commonly used strategies for distributed training. It replicates the entire model on each GPU or node in the cluster, with each replica working on a different subset of the data.
The core idea is to parallelize the training by distributing data batches across different nodes, where each node computes gradients independently. After each backward pass, DDP synchronizes gradients across all nodes to update the model weights consistently. This approach is relatively simple to implement and can significantly speed up training for large datasets.
Advantages:
Simplicity: It is easier to implement and integrate with existing models because each process handles a complete copy of the model.
Performance: Typically, DDP offers good speedup and performance, especially when the communication overhead is managed (e.g., with efficient collective communication primitives).
Disadvantages:
Memory Utilization: Each GPU needs memory for a full model copy, which can be inefficient for very large models.
Scaling Limit: As model sizes grow, the memory requirement can become a bottleneck.
Figure 9: Fully Sharded Data Parallelism (FSDP) in distributed training in FULL_SHARD mode. Picture taken from the facebook engineering blog post on FSDP
How it Works: FSDP takes a different approach to distributed training by sharding (splitting) both the data and any of the following across all available GPUs:
Optimizer states
Gradients
Model parameters
The extent of “model-based sharding” is configurable in lightning. The available sharding configurations include:
FULL_SHARD
: Shards model parameters, gradients, and optimizer states (default).
SHARD_GRAD_OP
: Shards gradients and optimizer states only. Model parameters get replicated.
NO_SHARD
: No sharding (similar to regular DDP).
HYBRID_SHARD
: Shards model parameters, gradients, and optimizer states within a single machine, but replicates across machines.
Advantages:
Memory Efficiency: Greatly reduces the memory footprint on each GPU for large models, which allows training much larger models or increasing the batch size.
Disadvantages:
Development Complexity: More complex to configure and debug, especially when using other optimizations together with it.
Model Size: If your model fits comfortably in the memory of a single GPU, DDP is typically sufficient and easier to manage. For very large models that exceed GPU memory limits, FSDP is likely necessary.
Resource Availability: If you have access to a large number of GPUs and need to optimize memory usage, FSDP can be more advantageous.
Development Complexity: DDP is simpler to configure and debug. If your team is smaller or less experienced with distributed systems, starting with DDP might be preferable.
Training Time and Efficiency: Consider the training time and scalability needs. FSDP can sometimes be slower per iteration but allows for training larger models or using larger batch sizes, which can be a critical factor.
In summary, no one strategy is always clearly better than the other. The choice largely depends on the specific constraints and requirements of your training workload. In many practical scenarios, teams might start with DDP due to its simplicity and lower barrier to entry, and move to FSDP as model sizes and scaling needs grow.
In our specific case, we found that FSDP in SHARD_GRAD_OP
mode was the best choice for training Stable Diffusion models striking a balance between a more efficient use of GPU memory and increased communication overhead, see our blog post for more details.
Finally, we can define the train
script as follows:
1...
2
3@app.command
4def train(
5 storage_path: str,
6 train_data_uri: str,
7 validation_data_uri: str,
8 experiment_name: str = "exp-base",
9 model_name: str = "stabilityai/stable-diffusion-2-base",
10 init_from_pretrained: bool = False,
11 resolution: int = 256,
12 lr: float = 1e-4,
13 weight_decay: float = 1e-2,
14 max_steps: int = -1,
15 num_warmup_steps: int = 10000,
16 batch_size_per_worker: int = 32,
17 num_data_loading_workers: int = 16,
18 num_training_workers: int = 16,
19 seed: int = 420,
20 max_failures: float = 0,
21):
22
23 ray_datasets = {
24 "train": load_precomputed_dataset(
25 data_uri=train_data_uri,
26 resolution=resolution,
27 num_data_loading_workers=num_data_loading_workers,
28 ),
29 "validation": load_precomputed_dataset(
30 data_uri=validation_data_uri,
31 resolution=resolution,
32 num_data_loading_workers=num_data_loading_workers,
33 )
34 }
35
36 trainer = TorchTrainer(
37 train_func,
38 train_loop_config={
39 "model_name": model_name,
40 "resolution": resolution,
41 "lr": lr,
42 "weight_decay": weight_decay,
43 "init_from_pretrained": init_from_pretrained,
44 "seed": seed,
45 "batch_size_per_worker": batch_size_per_worker,
46 "max_steps": max_steps,
47 "num_warmup_steps": num_warmup_steps,
48 "project_name": experiment_name,
49 },
50 scaling_config=ScalingConfig(
51 num_workers=num_training_workers, use_gpu=True
52 ),
53 run_config=RunConfig(
54 name=experiment_name,
55 storage_path=storage_path,
56 failure_config=FailureConfig(max_failures=max_failures),
57 ),
58 datasets=ray_datasets,
59 )
60
61 trainer.fit()
The train
script is a command-line interface that allows you to run the training process on a Ray cluster. The script takes in various arguments such as the storage path, training and validation data URIs, experiment name, model name, and other hyperparameters. The script then uses the TorchTrainer
class to combine all the components we discussed earlier and calls trainer.fit
to run the training process on a Ray cluster.
To view the full code implementing the train
script, instantly provision a Ray cluster and start running your own download, check out the reference implementation on training stable diffusion on Anyscale.
Here's an illustration of how experiment restoration works in Ray Train:
Ray Train has built-in fault tolerance to recover from worker failures (i.e. RayActorErrors). When a failure is detected, the workers will be shut down and new workers will be added in. The training function will be restarted, but progress from the previous execution can be resumed through checkpointing.
In order to retain progress when recovery, the training function must implement logic for both saving and loading checkpoints. Each instance of recovery from a worker failure is considered a retry. The number of retries is configurable through the max_failures
attribute of the FailureConfig
argument set in the RunConfig
passed to the Trainer
.
Here is how we would update out train_func
to save and recover from a checkpoint
1def train_func():
2# Recover from checkpoint
3 if ray.train.get_checkpoint():
4 ...
5
6 for i in range(num_epochs):
7 ...
8 # Save a checkpoint
9 ray.train.report(metrics, checkpoint)
10
11trainer = TorchTrainer(
12 ...
13 failure_config=FailureConfig(
14 max_failures=-1
15)
16)
To manually restore a training job, you can use the restore
method of the Trainer
class. This method will restore the training job from the last checkpoint and continue training from that point.
So far our implementation can be summarized with the following diagram:
We performed the preprocessing and stored the dataset on S3. That means to use this processed data for Stable Diffusion training, we then have to pick it up again and feed it into the Unet-model that's central to all Stable Diffusion models. This approach is also known as “Offline preprocessing”.
The alternative to this approach is to preprocess the raw LAION data online - i.e. "on the fly" before feeding it into an Stable Diffusion model for training. This way, we can "merge" preprocessing and training into a single end-to-end pipeline in a Ray cluster. Schematically, online preprocessing looks like this:
You might want to perform offline preprocessing for these reasons:
Cost Efficiency: Preprocessing data in advance allows for reuse in multiple training runs, reducing the need for repeated pre-processing and potentially lowering overall costs.
Note: This assumes preprocessing is deterministic, which may not apply if you want to perform transformations like random image cropping in your custom training of Stable Diffusion.
Simplified Optimization: Enables full focus on enhancing the training process itself, such as maximizing GPU utilization, without the overhead of coupling to pre-processing.
Challenges include higher storage costs due to the need to save processed data, and less flexibility since adjustments require rerunning the entire preprocessing.
Consider online preprocessing for:
Ideal for Dynamic Preprocessing: A natural choice when dynamic, on-the-fly adjustments to preprocessing steps are necessary.
Immediate Training Start: Training can begin as soon as raw data is available, eliminating preprocessing wait times.
However, this method may increase the complexity of managing resources and can lead to higher operational costs if data needs frequent reprocessing for each training session.
Read our blog post to better quantify the above trade-offs between offline and online preprocessing and choose the right approach for your use case.
If we want to implement online preprocessing, the change is simple enough with Ray Data. Instead of making use of the load_precomputed_dataset
to fetch our training and validation data. We will make use of the get_laion_streaming_dataset
method that we implemented in our first guide to which implemented the core logic for preprocessing.
Therefore here is what will change in our train
script:
1...
2
3@app.command
4def train(
5 storage_path: str,
6 online_preprocessing: bool,
7 processed_train_data_uri: str | None,
8 processed_validation_data_uri: str | None,
9 resolution: int = 256,
10 batch_size_per_worker: int = 32,
11 num_enconders: int = 16,
12 num_data_loading_workers: int = 16,
13 ...
14):
15 if online_preprocessing:
16 ray_datasets = {
17 "train": get_laion_streaming_dataset(
18 input_uri=raw_train_data_uri,
19 batch_size=batch_size_per_worker,
20 resolution=resolution,
21 num_encoders=num_enconders,
22 ),
23 "validation": get_laion_streaming_dataset(
24 input_uri=raw_validation_data_uri,
25 batch_size=batch_size_per_worker,
26 resolution=resolution,
27 num_encoders=num_enconders,
28 ),
29 }
30 else:
31 ray_datasets = {
32 "train": load_precomputed_dataset(
33 data_uri=processed_train_data_uri,
34 resolution=resolution,
35 num_data_loading_workers=num_data_loading_workers,
36 ),
37 "validation": load_precomputed_dataset(
38 data_uri=pocessed_validation_data_uri,
39 resolution=resolution,
40 num_data_loading_workers=num_data_loading_workers,
41 ),
42 }
43 ...
In this guide, we learned the following:
How to train a Stable Diffusion model using Ray Train + PyTorch Lightning
How to assess online vs. offline preprocessing and integrate Ray Data for our training pipeline
How to scale the training process to handle extensive datasets and computational demands
In future guides, we aim to show the implementation in Ray for training more recent versions of Stable Diffusion models such as Stable Diffusion V3 which entail even more challenging resource data processing requirements.
Training Stable Diffusion models at scale can be a challenging task due to the large amounts of data and computational resources required. Ray Train provides a powerful framework for distributed training that can help you efficiently train Stable Diffusion models on large datasets. By following the steps outlined in this guide, you can train Stable Diffusion models at scale and take advantage of the benefits of distributed training.
To view the full code implementing the end-to-end training process, instantly provision a Ray cluster and start running your own download, check out the reference implementation on training stable diffusion on Anyscale.