Faster stable diffusion fine-tuning with Ray AIR

By Kai Fricke   

This is part 3 of our generative AI blog series that dives into a concrete example of how you can use Ray to scale the training of generative AI models. To learn more using Ray to productionize generative model workloads, see part 1. To learn about how Ray empowers LLM frameworks such as Alpa, see part 2.

In this blog post, we explore how to use Ray AIR to scale and accelerate the fine-tuning process of a stable diffusion model.

When Stable Diffusion was released last year, it took the internet by storm: This magic machine learning model is able to convert a textual description into a realistic image by taking a textual prompt and converting into its relevant image. For example, below are a couple of prompts to Stability AI’s DreamStudio, with resulting respective images.

Figure-1
Figure1. With a prompt: My cat on the moon
Figure-2
Figure 2. With a prompt: My cat is an astronaut chasing a rat on the planet mars.

Although this is an exciting new capability, particularly for creative content creators, it may be hard to control exactly what the output will look like and personalize it. For instance, if you want to generate a picture of your cat on the moon, you will want to fine-tune the stable diffusion model to adjust it to the task at hand.

To fine-tune a model, to accommodate your task at hand, you take an existing model and train it on some of your own data to get an individualized version. For instance, you can start with a pre-trained diffuser model and train it to recognize your cat. It can then use your cat in subsequent prompts for image generation. 

This all sounds easy enough - but in practice, behind the scenes with respect to code and infra, there are some challenges to overcome.

LinkChallenges when fine-tuning diffusion models

Fine-tuning a stable diffusion model can take a long time, and you may want to distribute your training process in order to speed things up. 

However, there are three main challenges when it comes to scaling the fine-tuning diffusion models:

  1. Converting your script to do distributed training: Figuring out how to convert your training script to use distributed data parallel training for multi GPUs and multiple nodes can add a lot of complexity to your training script.

  2. Distributed data loading: Data loading, especially in distributed settings, is painful. Existing PyTorch native solutions can be especially hard to scale if you have a larger dataset and need to efficiently read from cloud storage.

  3. Distributed orchestration: Configuring and setting up a distributed training cluster is one of the biggest headaches for a machine learning (ML) practitioner. You have to set up Kubernetes or manually manage the machines, make sure they can communicate with each other, and that every worker has the correct environment variables set.

To address and mitigate the above challenges, you can use Ray AIR, allowing you to distribute your training data, converting your script into a distributed application, and taking advantage of accelerators on your multi-node cluster, which the subsequent sections describe what Ray AIR is and how to use it. 

LinkWhat is Ray and Ray AIR?

Ray is a popular open-source distributed Python framework that makes it easy to scale AI and Python workloads.

Ray AIR (AI Runtime) is a native set of scalable machine libraries built on top of Ray. In particular, two AIR libraries are built specifically for scalable model training -- Ray Train, which simplifies distributed training for PyTorch & other common ML frameworks, and Ray Data, which simplifies data loading and data ingestion from the cloud. If you're new to Ray AIR, check out this Ray Summit 2022 talk.

With Ray, we can address the above three challenges:

  1. Scaling across multiple nodes: Ray Train provides a unified interface for you to take an existing training script and enable distributed multi-GPU multi-node training.

  2. Distributed Data Loading: Ray Data has a simple interface for reading files from cloud storage and efficiently loading and sharding data into your training GPUs.

  3. Distributed orchestration: Ray’s open-source cluster launcher allows you to create Ray clusters with a single line of code on AWS, Azure, Google Cloud.

LinkFine-tuning a diffusion model with Ray AIR

Let's explore what the training code looks like for our fine-tuning example.

The full code can be found here on GitHub. It includes instructions and the data loading and preprocessing part. In the rest of the post, we will only focus on the training code.

Central to our training code is the training function. This function accepts a configuration dict that contains the hyperparameters. It then defines a regular PyTorch training loop.

There are only a few locations in our training code where we interact with the Ray AIR API. These are preceded by Ray AIR comments in the code below.

Remember that we want to do data-parallel training for all our models.

  • We load the data shard for each worker with session.get_dataset_shard("train")

  • We iterate over the dataset with train_dataset.iter_torch_batches()

  • We report results to Ray AIR with session.report(results)

Figure-3
Figure 3. Ray data shards a slice of data to each worker for distributed training

The code snippet below is compacted for brevity. The full code is more thoroughly annotated.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils import clip_grad_norm_
from ray.air import session

def train_fn(config):
    cuda = get_cuda_devices()

    text_encoder, noise_scheduler, vae, unet = load_models(config, cuda)
    text_encoder = DistributedDataParallel(
        text_encoder, device_ids=[cuda[1]], output_device=cuda[1]
    )
    unet = DistributedDataParallel(unet, device_ids=[cuda[0]], output_device=cuda[0])
    optimizer = torch.optim.AdamW(
        itertools.chain(text_encoder.parameters(), unet.parameters()),
        lr=config["lr"],
    )
    # Ray AIR code
    train_dataset = session.get_dataset_shard("train")
    num_train_epochs = config["num_epochs"]

    global_step = 0
    for epoch in range(num_train_epochs):
        for step, batch in enumerate(
            # Ray AIR code
            train_dataset.iter_torch_batches(
                batch_size=config["train_batch_size"], device=cuda[1]
            )
        ):
            batch = collate(batch, cuda[1], torch.bfloat16)
            optimizer.zero_grad()
            latents = vae.encode(batch["images"]).latent_dist.sample() * 0.18215
            noise = torch.randn_like(latents)
            bsz = latents.shape[0]
            timesteps = torch.randint(
                0,
                noise_scheduler.config.num_train_timesteps,
                (bsz,),
                device=latents.device,
            )
            timesteps = timesteps.long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            encoder_hidden_states = text_encoder(batch["prompt_ids"])[0]
            model_pred = unet(
                noisy_latents.to(cuda[0]),
                timesteps.to(cuda[0]),
                encoder_hidden_states.to(cuda[0]),
            ).sample
            target = get_target(noise_scheduler, noise, latents, timesteps).to(cuda[0])
            loss = prior_preserving_loss(
                model_pred, target, config["prior_loss_weight"]
            )
            loss.backward()
            clip_grad_norm_(
                itertools.chain(text_encoder.parameters(), unet.parameters()),
                config["max_grad_norm"],
            )
            optimizer.step()  # Step all optimizers.
            global_step += 1
            results = {
                "step": global_step,
                "loss": loss.detach().item(),
            }
            # Ray AIR code
            session.report(results)

We can then run this training loop with Ray AIR's TorchTrainer:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
args = train_arguments().parse_args()

# Build training dataset.
train_dataset = get_train_dataset(args)

print(f"Loaded training dataset (size: {train_dataset.count()})")

# Train with Ray AIR TorchTrainer.
trainer = TorchTrainer(
    train_fn,
    train_loop_config=vars(args),
    scaling_config=ScalingConfig(
        use_gpu=True,
        num_workers=args.num_workers,
        resources_per_worker={
            "GPU": 2,
        },
    ),
    datasets={
        "train": train_dataset,
    },
)
result = trainer.fit()


In the TorchTrainer, we can easily configure our scale. The above example runs training on 2 workers with 2 GPUs each - i.e., on 4 GPUs. To run the example on 8 GPUs, just simply set the number of workers to 4!

Figure-4
Figure 4. Training times for number of GPUs


The training time decreases linearly with the number of workers. The scaling is not perfect: In an ideal world, doubling the number of workers should cut the training time in half. The communication of large model weights incurs some overhead.

This can likely be fixed by using a larger batch size, and hence by optimizing GPU memory usage with libraries such as DeepSpeed - which we'll explore in another blog post.

LinkLaunching a Ray cluster on the AWS, GCP, or Kubernetes

Ray ships with built-in support for launching AWS and GCP clusters and also has community-maintained integrations for Azure and Aliyun.

To keep this blog post short, we refer you to the Ray Cluster Launcher documentation

If you want to run on Kubernetes, you can check out the KubeRay documentation as well.

And if you’re interested in a managed offering for Ray in general, feel free to sign up for Anyscale.

LinkPutting it all together

Our example comes with a few scripts that can be easily run from the command line.
You can always find the latest version of this code here!

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# Get the Ray repo for the example code
git clone https://github.com/ray-project/ray.git
cd doc/source/templates/05_dreambooth_finetuning/dreambooth
pip install -Ur requirements.txt

# Set some environment variables
export DATA_PREFIX="./"
export ORIG_MODEL_NAME="CompVis/stable-diffusion-v1-4"
export ORIG_MODEL_HASH="249dd2d739844dea6a0bc7fc27b3c1d014720b28"
export ORIG_MODEL_DIR="$DATA_PREFIX/model-orig"
export ORIG_MODEL_PATH="$ORIG_MODEL_DIR/models--${ORIG_MODEL_NAME/\//--}/snapshots/$ORIG_MODEL_HASH"
export TUNED_MODEL_DIR="$DATA_PREFIX/model-tuned"
export IMAGES_REG_DIR="$DATA_PREFIX/images-reg"
export IMAGES_OWN_DIR="$DATA_PREFIX/images-own"
export IMAGES_NEW_DIR="$DATA_PREFIX/images-new"

export CLASS_NAME="cat"

mkdir -p $ORIG_MODEL_DIR $TUNED_MODEL_DIR $IMAGES_REG_DIR $IMAGES_OWN_DIR $IMAGES_NEW_DIR

# AT THIS POINT YOU SHOULD COPY YOUR OWN IMAGES INTO
# $IMAGES_OWN_DIR

# Download pre-trained model
python cache_model.py --model_dir=$ORIG_MODEL_DIR --model_name=$ORIG_MODEL_NAME --revision=$ORIG_MODEL_HASH

# Generate regularization images
python run_model.py \
  --model_dir=$ORIG_MODEL_PATH \
  --output_dir=$IMAGES_REG_DIR \
  --prompts="photo of a $CLASS_NAME" \
  --num_samples_per_prompt=200

# Train our model
python train.py \
  --model_dir=$ORIG_MODEL_PATH \
  --output_dir=$TUNED_MODEL_DIR \
  --instance_images_dir=$IMAGES_OWN_DIR \
  --instance_prompt="a photo of unqtkn $CLASS_NAME" \
  --class_images_dir=$IMAGES_REG_DIR \
  --class_prompt="a photo of a $CLASS_NAME"

# Generate our images
python run_model.py \
  --model_dir=$TUNED_MODEL_DIR \
  --output_dir=$IMAGES_NEW_DIR \
  --prompts="photo of a unqtkn $CLASS_NAME" \
  --num_samples_per_prompt=20

And, finally, my cat on the moon, looks like:

Figure-4
Figure 4. My cat on the moon. Not any cat :)

LinkConclusion

With Ray AIR, fine-tuning a stable diffusion model is super simple. And super scalable: You can just add more machines to your cluster, and Ray can automatically use them. No code changes needed! Just tell Ray how many workers and GPUs you want to use.

Also, with Ray AIR, you worry less about cluster environment or management. Instead, you focus on writing your distributed training code with AIR's expressive and composable APIs.

Finally, you can leverage all the power from the cloud - and put your cat on the moon!

LinkNext Steps

For further exploration of Ray and Ray AIR: