Serving PyTorch models with FastAPI and Ray Serve

By Simon Mo and Chandler Gibbons   

Deploying a machine learning model can be challenging, depending on where the model will be deployed and which tools are used to serve it into production. Every deployment platform has its own pros and cons, as do the serving tools used. So, it is crucial to select the proper tools and the best platform to most effectively serve your model into production and deploy it.

In this article, we will highlight the options available for serving a PyTorch model into production and deploying it with several frameworks, such as TorchServe, Flask, and FastAPI. We will also examine how the integration of FastAPI and Ray Serve can help with scaling our PyTorch model serving API across a Ray cluster. 

LinkPyTorch model serving frameworks

The different frameworks available to serve PyTorch models can be divided into three categories:

  • Customized tools such as TorchServe

  • Cloud-hosted solutions such as Amazon SageMaker

  • Web-based frameworks such as Flask, FastAPI, and Ray Serve

TorchServe was developed by PyTorch as a flexible and easy-to-use tool for serving PyTorch and Torch-scripted models. It can be deployed locally since it comes with a convenient CLI, and it is easy to scale out using Amazon SageMaker or Amazon EKS.

However, TorchServe has several drawbacks. TorchServe is experimental and open source, so there are frequent changes, patches, and updates. Plus, this tool only works with PyTorch and Torch-scripted models, which means it is not framework-agnostic and it is Java-dependent. Moreover, these customized tools are typically hard to develop, deploy, and manage.

Other options for serving machine learning and PyTorch models in particular are cloud-hosted platforms such as Amazon SageMaker, KubeFlow, Google Cloud AI Platform, and Microsoft’s Azure ML SDK. These are powerful serving tools provided by some of the largest tech companies, but they can be very expensive to use. In addition these tools only work with their own ecosystems.

Web-based serving tools such as Flask can be preferable solutions for some of these problems. Flask is a web framework that is efficient, easy to set up, and framework-agnostic. However, as with other web-based serving tools, it can present challenges with scaling.

LinkWhat is Ray Serve?

Ray Serve is a library for serving machine learning models that runs on top of the Ray Distributed Library Ecosystem. It is a simple web server that leverages the complex routing, scaling, and testing logic necessary for production deployments. It is also framework-agnostic and Python-first, so models can be configured and served declaratively in pure Python without YAML or JSON configuration files.

Ray Serve can also be an efficient tool for deploying PyTorch models because it is easy to scale, whether in your data center or in the cloud.

LinkFastAPI and Ray Serve for serving PyTorch models

Ray Serve provides a solid solution to scalability, management, and other previously discussed issues by providing end-to-end control over the request lifecycle, while allowing each model to scale independently. Moreover, the integration of Ray Serve and FastAPI for serving the PyTorch model can improve this whole process. The idea is that you create your FastAPI model and then scale it up with Ray Serve, which helps in serving the model from one CPU to 100+ CPU clusters. This will lead to a huge improvement in the number of requests served per second.

The Ray Serve docs show how to serve a PyTorch model directly. This approach is excellent if you only want to serve a single model. However, we may want to build out a more comprehensive API, and it's quick and easy to do so with FastAPI. 

Fortunately, Ray Serve has built-in FastAPI support, and we can adapt the code from the Ray Serve docs to work in a FastAPI application. Here's everything needed to build a FastAPI app that serves up a pre-trained ResNet model trained on ImageNet data:

1import ray
2from ray import serve
3from fastapi import FastAPI, UploadFile, File
4
5import torch
6from torchvision import transforms
7from torchvision.models import resnet18
8from PIL import Image
9
10from io import BytesIO
11
12app = FastAPI()
13ray.init(address="auto")
14serve.start(detached=True)
15
16@serve.deployment
17@serve.ingress(app)
18class ModelServer:
19  def __init__(self):
20    self.count = 0
21    self.model = resnet18(pretrained=True).eval()
22    self.preprocessor = transforms.Compose([
23        transforms.Resize(224),
24        transforms.CenterCrop(224),
25        transforms.ToTensor(),
26        transforms.Lambda(lambda t: t[:3, ...]),  # remove the alpha channel
27        transforms.Normalize(
28            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29    ])
30
31  def classify(self, image_payload_bytes):
32    pil_image = Image.open(BytesIO(image_payload_bytes))
33
34    pil_images = [pil_image]  #batch size is one
35    input_tensor = torch.cat(
36        [self.preprocessor(i).unsqueeze(0) for i in pil_images])
37
38    with torch.no_grad():
39        output_tensor = self.model(input_tensor)
40    return {"class_index": int(torch.argmax(output_tensor[0]))}
41
42  @app.get("/")
43  def get(self):
44    return "Welcome to the PyTorch model server."
45
46  @app.post("/classify_image")
47  async def classify_image(self, file: UploadFile = File(...)):
48    image_bytes = await file.read()
49    return self.classify(image_bytes)
50
51ModelServer.deploy()

We start by creating a FastAPI application, connecting to a Ray cluster, and starting Ray Serve. If you want to connect to a Ray cluster running on remote machines, you'll need to specify the address of one of the cluster nodes when you call ray.init().

Then, we create a class-based FastAPI application that loads the model in the __init__ method and classifies images in the classify method.

When an API user sends a POST request to the asynchronous classify_image endpoint and sends an image in the request body, FastAPI automatically includes the image in the file parameter, from which we read the raw bytes of the image and send them on to the classify function.

We could run this FastAPI app on its own, without Ray Serve. But then we'd be on the hook for manually scaling it out. Fortunately, Ray Serve makes it easy to automatically turn our FastAPI app into a distributed application that scales automatically. All we had to do to make this possible was add two decorators to our FastAPI class: @serve.deployment and @serve.ingress(app). Then, we add a call to ModelServer.deploy() to the end of our code and that's it. Our FastAPI application is ready to run in a Ray cluster. 

Once the app is up and running, we can test it by using Postman to send a request to http://127.0.0.1:8000/ModelServer/classify_image:

blog-serving-pytorch-models-1

In this case, we sent it an image of a cat, and it returned: 

{

  "class_index": 285

Now, this isn't very descriptive. We could add a friendly description from a list of ImageNet class identifiers. Whether we want to depends on whom we expect to consume our API. It may very well be more useful to return the ImageNet class index and let the API consumer decide what to do with it. For example, the consumer may already have a list of class descriptions custom-translated into 50 different languages, so our API returning an English label wouldn't be very helpful.

In this case, class index 285 maps to “Egyptian cat,” which is a decent enough description of the image we uploaded.

If you'd like to try the app for yourself, you can find a copy of it in this GitHub repository.

LinkTesting

To measure the performance of the serving process, we will use our laptop as a head node. We will only use two cores by setting our number of replicas to 2 in the code above, and we will run the server by running the above sample code exactly as before.

Now, if we saturate our back end by sending many requests, we will notice that the number of queries per second increases from 0.67 to 0.81. In order to increase that number even more, we can utilize more cores. For argument’s sake, let’s add 4. This can improve the number of queries per second to 0.84, which leads to a reduction in latency, making our model able to reply to many requests in a shorter period of time.  

LinkTesting on a cluster

For the purposes of parallel computation and testing our model on a Ray cluster, we can use the batching features provided by Ray Serve. To do this, we should test on a cluster (a group of nodes) instead of only one node. We can initiate this using Anyscale, in which a cloud-based Ray cluster is created.

Consequently, this will increase the number of queries per second and enable Ray Serve to simplify the use of these features and scale our models. We can get an even greater speed increase using more cores and replicas.

LinkConclusion

Deploying and serving a scalable machine model into production can be a challenging task. To serve a model, you can choose from a variety of available machine learning tools to best meet your needs. In this article, we’ve seen that the integration of FastAPI with Ray Serve can be one of the best solutions for deploying a PyTorch model, and that using Ray Serve can be one of the greatest solutions for scaling it up.

For more on Ray and PyTorch, check out our blog post on how Hutom.io uses Ray and PyTorch to scale surgical video analysis and review, or catch our upcoming Meetup, where we’ll cover Ray Train, PyTorch, TorchX, and distributed deep learning.

Interested in learning more about Ray Serve? Register for our upcoming Meetup, where we'll discuss productionizing ML at scale with Ray Serve.

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.