Offline Batch Inference: Comparing Ray, Apache Spark, and SageMaker

By Amog Kamsetty, Eric Liang and Jules S. Damji   

LinkTL;DR

As more companies use large scale machine learning (ML) models for training and evaluation, offline batch inference becomes an essential workload. A number of challenges come with it: managing compute infrastructure; optimizing use of all heterogeneous resources; and transferring data from storage to hardware accelerators. Addressing these challenges, Ray performs significantly better as it can coordinate clusters of diverse resources, allowing for better utilization of the specific resource requirements of the workload.

In this blog post, we conduct a comparison of three different solutions for offline batch inference: AWS SageMaker Batch Transform, Apache Spark, and Ray Data. Our experiments demonstrate that Ray Data achieves speeds up to 17x faster than SageMaker Batch Transform and 2x faster than Spark for offline image classification. Ray Data also scales effectively to terabyte-sized datasets. The code for these benchmarks is publicly available and can be found here.

LinkIntroduction

figure_1_batch_offline_inference
Figure 1. General parallel batch processing architecture

Offline batch inference is a critical workload for many AI products, especially with the growing usage of pre-trained foundation models. At its core, it seems like a simple problem: given a trained model and a dataset, get model predictions for each data point.

However, there are many challenges to doing this at scale:

  1. Challenge 1: Managing compute infrastructure and cloud clusters, especially when needing to use heterogeneous clusters, consisting of different instance types to maximize throughput.

  2. Challenge 2: Parallelizing data processing and utilizing all resources (CPUs & GPUs) in the cluster at any given point in time.

  3. Challenge 3: Efficiently transferring data between cloud storage, CPU RAM for preprocessing, and GPU RAM for model inference, especially for unstructured data like images and text. This has ramifications for both performance and developer velocity.

  4. Challenge 4: A user experience that makes it easy to iterate and develop while working at scale.

LinkWhat does the industry recommend for offline batch inference?

The industry hasn’t converged on a standard solution. It seems like every ML practitioner has their own favorite approach.

tweet_image

We see that there are three categories of solutions that attempt to address the above challenges. However, only category three addresses all the challenges.

  1. Batch Services: Cloud providers such as AWS, GCP, and Azure provide batch services to manage compute infrastructure for you. Some newer products like Modal Labs provide a user experience layer around cloud providers to abstract away even more complexity. Regardless of which service you choose, the process is the same: you provide your code, and the service runs your code on each node in a cluster. However, while infrastructure management is necessary (Challenge #1), it is not enough. These services have limitations, such as a lack of software libraries to address Challenges #2, #3, and #4 which makes them suitable only for experienced users who can write their own optimized batch inference code.

  2. Online Inference solutions:  Solutions like Bento ML or Ray Serve provide APIs to make it easy to write performant inference code and can abstract away infrastructure complexities. But they are designed for online inference rather than offline batch inference—two different problems with different sets of requirements. These solutions often don’t perform well in the offline case, leading inference service providers like Bento ML to integrating with Apache Spark for offline inference.

  3. Distributed data systems: These solutions are designed to handle large amounts of data and are better suited to handle all three challenges listed above. They break large datasets into reasonable batches for processing and can map a given function (i.e., the model) over the dataset efficiently, effectively handling the challenges of offline batch inference at scale. Examples include Apache Spark and Ray Data.

Ultimately, we have observed that Ray Data is the best practical solution for offline batch inference for modern deep learning applications, both in terms of performance and user experience.

We ran the following comparisons and experiments to corroborate the above observations:

  1. Comparing AWS SageMaker Batch Transform, Apache Spark, and Ray Data performance and UX on an image classification use case

  2. Scaling up Ray Data batch inference to 10 TB, showing 90%+ GPU utilization

LinkImage Classification: SageMaker Batch Transform vs. Apache Spark vs. Ray Data

We ran an image classification task from the MLPerf Inference Benchmark suite in the offline setting. This benchmark uses images from the ImageNet 2012 Dataset and a pre-trained PyTorch ResNet50 model.

We timed a full end-to-end batch inference workload. With images stored in the parquet format in a public S3 bucket, the workload involved the following three steps:

  1. Reading images from S3

  2. Simple CPU preprocessing (resizing, cropping, normalization)

  3. Model inference on GPU

While all Ray experiments were run on Anyscale, Spark experiments were run on Databricks runtime v12.0, with Apache Spark 3.3.1. More details on the full configurations can be found in the repo.

In particular for Spark, we used two different setups:

  • Single-cluster: Both the CPU preprocessing and GPU inference were done on the same cluster. We tried both versions of the PySpark Pandas UDF API: Series-based and Iterator-based.

  • Multi-cluster: CPU preprocessing is done on one cluster, the preprocessed data is saved to a shared file system, and GPU inference is done on a different cluster.

All the source code for the experiments conducted above is available here, including additional details on various other configurations that we attempted.

Link Results of throughput from experiments

When running with a 10 GB dataset on a single g4dn.12xlarge instance (48 CPUs, 4 GPUs), we observe the following throughput (images/sec) results for ResNet-50 batch inference:

figure_5_table

The above throughput differences translate directly into cost savings as the batch inference job with Ray can run much faster—so the faster it finishes the fewer resources are consumed and the cheaper it gets. Additionally, Ray is fully open source.

Figure 1
Figure 1. Throughput difference among three different solutions for 10GB batch inference

LinkSageMaker Batch Transform for batch inference

As we can see, SageMaker Batch Transform is not well suited for batch inference, both for throughput and cost.

While SageMaker handles Challenge #1 to some degree and abstracts away compute infrastructure management, it doesn’t fully address the remaining challenges

The low-throughput results are primarily because while SageMaker Batch Transform claims to be a solution for batch inference on large datasets, under the hood, it uses the same architecture as for online serving systems. It starts an HTTP server, and deploys your model as an endpoint. And then for each image, a separate request is sent to the server.

SageMaker also does not provide support for batching across multiple files, and even if you batched your data beforehand, the max payload per request is 100 MB, nowhere near enough to fully saturate your GPUs.

In terms of developer velocity, when running these benchmarks, we ran into a few additional challenges with SageMaker. In particular, the need to spin up a new cluster every time we ran a script and the difficult to parse stack traces slowed down developer iterations.

While online inference solutions have their place, the additional overhead they have, as well as the inability to fully utilize cluster resources and saturate GPUs by maximizing batch sizes and effectively pipelining data make them highly inefficient for offline batch inference.

LinkComparing Spark and Ray for batch inference

Clearly, as we observed above, shoehorning online inference solutions for offline use cases is not the right approach. Let’s compare the two distributed data system approaches: Spark and Ray Data. 

We see that Ray Data has 2x faster throughput than Spark. Let’s break down how Ray Data and Spark fare for the challenges mentioned earlier:

LinkChallenge #2: Hybrid CPU and GPU workload

Deep learning workloads involve both CPU and GPU computation. For optimal parallelization, all the CPUs and GPUs in the cluster should be fully utilized at any given time. However, this is a limitation of Spark scheduling as all the stages are fused together, regardless of their resource requirements. This means that in the single cluster setup, the total parallelism is limited by the number of GPUs, leading to underutilization of CPUs. This also means that GPUs are completely idle during the reading and preprocessing steps. 

When using a separate CPU and GPU cluster, data needs to be persisted and orchestrated between stages. Overall, for these CPU+GPU workloads, Spark does not fully address challenge #2 and cannot utilize all resources in the cluster concurrently. We attempted to use Spark’s stage level scheduling feature to address this issue, but were unsuccessful.

Figure 2
Figure 2. How Ray data efficiently pipelines and optimizes hardware: CPUs/GPUs

Ray Data, on the other hand, can independently scale the CPU preprocessing and GPU inference steps. It also streams data through these stages, increasing CPUs and GPUs utilization and reducing out-of-memory errors (OOMs), and fine-grained resource allocation for autoscaling and management.

LinkChallenge #3: Large, multi-dimensional tensors

Deep learning workloads often involve dealing with complex data types like images, audio, or video, which are represented by multi-dimensional arrays. Having first-class support for these data types is necessary for better performance and developer experience. Ray Data has Numpy batch format as a first-class citizen, and does not require Pandas, which is suited for deep learning and leads to higher performance and more memory efficiency, as shown in previous benchmarks.

LinkChallenge #4: Ray is Python native

Another fundamental difference is that Ray is Python native whereas Spark is not, which also impacts performance. We ran a microbenchmark that applies a time.sleep(1) UDF. Surprisingly, running the UDF takes 80 seconds; we believe due to JVM<>Pyarrow overhead! On Ray Data, the same microbenchmark takes a little under 5 seconds.

Not to mention, this makes Spark development more difficult. In particular when debugging, the recommended way is to include debug information in the output pandas dataframe rather than using print statements.

Even as we scale to a larger 300 GB dataset using 4 nodes and 16 GPUs, we see that these performance differences still hold:

figure_3_batch_inference
Figure 3. Throughput difference between Spark and Ray solution for 300 GB batch inference

While big data processing systems are the best solutions for offline inference, for deep learning workloads, Ray Data outperforms Spark.

LinkScaling up Ray Data and maximizing performance

Finally, we test the scalability of Ray Data and attempt to maximize performance by using a heterogeneous cluster consisting of a mix of GPU nodes and CPU-only nodes. Deep learning workloads involving complex data types are often memory-bound, so by using additional CPU-only nodes, we add more total memory to our cluster to better maximize GPU utilization. 

Unlike SageMaker or Spark, Ray works natively with heterogeneous clusters consisting of different instance types, fully addressing challenge #1.

We run the same benchmark, except with 10 TB worth of images, using a single cluster of 10 g4dn.12xlarge instances and 10 m5.16xlarge instances.

Throughput: 11,580.958 img/sec

As we can see below, Ray Data achieves over 90%+ GPU utilization throughout the workload run:

batch_inference_figure_4
Figure 4. GPU utilization of all 40 GPUs in the cluster. We are roughly at 90%+ GPU utilization during the entire duration of the workload.


The ability to scale to many GPUs and large data sizes without needing to rewrite code and not sacrificing on throughput is what sets Ray Data apart compared to other solutions for deep learning workloads.

LinkConclusion

In this blog, we identified four technical challenges for doing batch inference at scale and investigated three commonly cited industry solutions for offline batch inferences. Using the cited approaches, we compared Ray, Spark, and Amazon SageMaker Batch Transform.

We concluded that Ray Data best meets all four challenges:

  1. Ray Data abstracts away compute infrastructure management, and can maximize performance for memory-bound workloads as it natively supports heterogeneous clusters.

  2. Ray Data streams data from cloud storage -> CPU -> GPU, ensuring that all resources in the cluster are utilized at any given point in time. This improves throughput and reduces overall cost.

  3. Ray Data has native support for multi-dimensional tensors, which is vital for deep learning workloads.

  4. Ray Data is Python native and programmatic, making it easy to develop and debug.

Overall, Ray Data is the best option for deep learning offline batch inference. In our image classification benchmarks, as shown in the figures above, Ray Data significantly outperforms SageMaker Batch Transform (by 17x) and Spark (by 2x and 3x) while linearly scaling to TB level data sizes. All code for these benchmarks can be found here.

Lastly, as a result of higher throughput and Ray as an open source solution, Ray lowers the overall costs on top of the base EC2 cost.

LinkWhat’s Next

Stay tuned for a follow up post where we dive into Ray Data streaming's design and how to use it for your own ML pipelines. Meanwhile, you can get started with Ray Data for batch inference here.

Finally, we have our Ray Summit 2023 early-bird registration open. Secure your spot, and save some money.

Update 5/10: Thanks to the Databricks Apache Spark developers for pointing out the prefetching flag spark.databricks.execution.pandasUDF.prefetch.maxBatches available in Databricks Runtime that can be used with the Iterator API. It is also possible to implement this manually in Spark open source with background threads. With prefetching set to 4 batches, Spark reaches 159.86 images/sec. The code can be found here.

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.