How Hutom.io uses Ray and PyTorch to Scale Surgical Video Analysis and Review

By Jihun Yoon   

Jihun Yoon is a machine learning research engineer at hutom working on developing deep learning models for a surgical data analysis platform to help hospitals and surgeons provide more optimized patient care.

Hutom Ray

LinkBenefits of Minimally Invasive Surgeries

In the last 30 years, the number of open surgeries has decreased significantly in favor of minimally invasive surgeries. These surgeries use robots to increase consistency in surgical procedure quality across different medical specialists, leading to better patient outcomes and fewer complications. These robot-assisted surgeries are often done through small incisions and are only visible to the surgeon through a camera feed. Because of this, more surgeries than ever before are being recorded in the form of videos that document all the operations taking place inside the patient. To analyze and evaluate these surgical videos in an automated manner, localization and motion estimation of the surgical tools are essential. To do this, computer vision is needed. The problem is that existing tools fall short in this regard.

LinkHutom: Reshaping Surgery through Machine Learning and Computer Vision

reshapingSurgeryDigitalInnovation
The core of the new paradigm of surgery is effective management of big data.

Hutom is leading the new paradigm of surgery with its big data platform which helps hospitals and surgeons provide more optimized patient care. Its automated machine learning and computer vision system helps patients and doctors through:

  • Better personalized surgical planning (e.g., CT/MRI 2D images, virtual anatomy modeling)

  • Real-time surgeon assistance (e.g., AI-empowered imaging recognition, synchronized camera movement)

  • Surgery analysis and review for learning and archival purposes (e.g., surgical video data analysis)

This technology not only leads to improved surgical performance, but also gives accountability for patients and insurance companies by keeping a record of what surgical events transpired over the course of a surgery. All this labeling and archival is done with machine learning and computer vision. 

In order to train deep learning models on medical data, there were a number of technical challenges Hutom has had to overcome which include: 

  • Medical data being very difficult to acquire due to legal or administrative requirements

  • High cost of annotation due to medical experts being highly paid professionals

  • Small non diverse datasets where the distribution of the test dataset is often different from the training set causing models to fail to generalize sufficiently

To mitigate these challenges, we utilized a synthetic data generation technique and implemented various domain randomization and semi-supervised methods. This provided performance gains, but we found that often better gains can come from rapidly scaling up computing resources to optimize models via hyper-parameter searches. To do this, we utilize Ray and its ecosystem to scale and deploy our PyTorch models.

LinkHyperparameter search with Ray Tune and Ray Train

Most computer-vision recognition algorithms have been developed with general datasets (e.g. COCO, Imagenet, etc). For example, COCO (Common Objects In Context) doesn’t have much in common with surgical video. This means that the best performing hyperparameters from these models don't necessarily generalize well to surgical video datasets. For this reason, we need to develop better performing models through hyperparameter search. The problem is that while extensive hyperparameter search leads to performance gains, implementing search algorithms and distributed training comes with technical challenges. To solve these, we use Ray Tune which is a hyperparameter tuning library built on Ray. Some of the main features we use include:

The Ray ecosystem has also been invaluable in implementing the data ingest portion of training. To solve the distributed training challenge, we use Ray Train (formerly Ray SGD) which is a lightweight library for the distributed training of deep learning models. What we like about Ray Train (formerly Ray SGD) is that:

  • It only takes a couple lines of code to easily scale single process training code to a cluster

  • The integrations with PyTorch and Ray Tune make it easy to tune our distributed model

LinkTraining an Instance Segmentation Model on Surgical Videos with Ray and PyTorch

BBoxAveragePrecision
Using Ray Tune led to a performance improvement and reduced training time by 51%.

This section describes a simplified example of how we trained a Mask R-CNN Resnet50 FPN model from the official torchvision models on gastrectomy surgical videos obtained from our Vi-hub product, an AI-empowered video hub which records video during surgery. Using 10 trials, 2 workers per trial, and a single NVIDIA DGX A100, we started with the baseline model’s configurations from the original paper and setup a search space over the number of epochs, learning rate, and learning rate decay steps. Training this baseline model on our system takes around 20 hours and is not the most performant. Ray Tune’s implementation of Asynchronous Successive Halving Algorithm (ASHA) is more performant and takes around 9 hours 50 minutes - a 51% reduction in training time. ASHA is a simple but powerful scheduler which aggressively terminates low-performing trials (known as “dynamic resource allocation”) as the image below shows.

Ray Tune Loss
Ray Tune logs training results on TensorBoard automatically. This particular graph shows training model loss for 10 trials of hyperparameters with ASHA. The longest trial is the best performing model.

Perhaps most importantly, Ray Tune makes it easy to use both common search spaces and custom/conditional search spaces with any scheduler, including ASHA. This is useful when you want to encode condition dependencies like search learning rates based on epochs and step decay milestones based on epochs. Below is a code snippet of this in action:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import ray
from ray import train
from ray.train import Trainer
from ray import tune
from ray.tune.integration.mlflow import MLflowLoggerCallback
from ray.tune.schedulers import ASHAScheduler
from ray.tune import CLIReporter
# ... other imports here

# specify our hyperparameter ranges to search over
# we can encode conditional dependencies with “tune.sample_from” and Python lambdas
config = {
        "epochs":
            tune.choice([13, 20]),
        "lr":
            tune.sample_from(lambda spec: tune.uniform(lower=0.01, upper=0.02)
                         if spec.config.batch_size == 8 else
                         tune.uniform(lower=0.02, upper=0.03)
                         if spec.config.batch_size == 16 else 0.02),
        "lr_steps":
            tune.sample_from(lambda spec: [8, 11] if spec.config.epochs == 13 else
                         [12, 16] if spec.config.epochs == 20 else [8, 11]),
        ….
    }

# create our scheduler, specifying the metric to optimize for and other run settings
scheduler = ASHAScheduler(
        time_attr=”training_iteration”,
        metric="val_avg_mAP",
        mode="max",
        max_t=100,
        grace_period=1,
        reduction_factor=2)

# use Ray Train to enable distributed training on GPUs
trainer = Trainer(backend="torch",
                      num_workers=config["ray_workers"],
                      use_gpu=config["use_gpu"])
Trainable = trainer.to_tune_trainable(train_func)

# connect Ray Train to Tune and run the hyperparameter search
analysis = tune.run(Trainable,
                    num_samples=config["num_samples"],
                    config=config,
                    scheduler=scheduler,
                    verbose=2,
                    progress_reporter=reporter,
                    callbacks=[
                            # use an MLFlow callback to save results and metadata 
                            MLflowLoggerCallback(
                                tracking_uri=config["tracking_uri"],
                                experiment_name=config["experiment_name"],
                                save_artifact=True)
                        ])

The code used in this example is available on github.

When the code runs, it automatically logs training results on Tensorboard and the MLflow tracking server which we find useful to manage our training results.

MLFlow

LinkScalable Model Deployment with Ray Serve

Surg-Gram
Surgical video analysis on Surg-Gram with online/batch serving.

After we get the best performing model from hyperparameter search, we need to deploy our model to analyze surgical videos via localization of surgical instruments and organs. Surg-Gram is our intelligent surgical platform which utilizes Ray Serve to help scale our surgical video analysis, review, and deployment. The image above shows how our platform Surg-Gram utilizes Ray Serve. A serve_client is allocated to each surgery analysis task. Each client sends images asynchronously to the backend server and each model replica computes prediction. Based on the number of GPUs (or fraction of GPUs) for each model and the number of replicas, Ray Serve spawns models to our GPUs automatically. Ray Serve also provides a batching feature which helps our model use vectorization to perform computation in parallel. We have found that Ray Serve makes it very easy to utilize these features and scale our models which the pseudocode below shows.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from ray import serve
import asyncio
import httpx
import requests
# ... other imports here

config = {“num_gpus”: 0.5, “num_replicas”:16, “max_concurrent_queries”:1000}

@serve.deployment(route_prefix="/detection", ray_actor_options=config["num_gpus"], num_replicas=config["num_replicas"], max_concurrent_queries=config["max_concurrent_queries"])
class DetectionModelEndpoint:
    def __init__(self):
        self.model = torchvision.models.detection.__dict__[
            "maskrcnn_resnet50_fpn"](pretrained=True).eval().cuda()
        self.preprocessor = T.ToTensor()
    def __del__(self):
        # release GPU memory
        del self.model

    @serve.batch(max_batch_size=max_batch_size)
    async def __call__(self, starlette_requests):
        batch_size = len(starlette_requests)
        pil_images = []
        for request in starlette_requests:
            image_payload_bytes = await request.body()
        ….
        # make inference
        ….
        return results

DetectionModelEndpoint.deploy()

async def inference_client(test_image_bytes):        
    async with httpx.AsyncClient() as client:
        resp = await client.post("http://localhost:8000/detection", data=test_image_bytes)
        return resp

To measure the serving performance, we did two experiments on a NVIDIA DGX A100 with different numbers of replicas and clients and different batch sizes. The first benchmark shows how many seconds it took for Mask R-CNN Resnet50 FPN model inference when scaling out the number of replicas and clients. Each client requested to predict one image and it took 0.27 seconds to handle 32 clients' requests using 16 replicas which is a 87% reduction in latency compared to handling the same number of requests using 1 replica.

ScalingOutBenchmark
Using Ray Serve led to a 87% reduced inference time by scaling out model replicas. The 50th percentile latency was estimated with 30 trials.

The second benchmark shows how increasing batch size is effective and how there is a 29% reduction in latency by using a max_batch_size of 8 requests per batch instead of 1. Scaling out replicas and batch size is very effective for decreasing the latency.  However, one should not blindly increase these numbers.  If there are more replicas than clients, latency can actually increase, and a bigger batch size also doesn’t guarantee performance improvement.

Batching Benchmark
Using Ray Serve led to a 29% reduced inference time by increasing batch size. The 50th percentile latency was estimated with 30 trials.

LinkConclusion

In this post, we showed how Ray is helping to scale our use-case of training, tuning, and serving our instance segmentation model for surgical video analysis. Ray made distributed hyperparameter search and scaling serving models very easy. We are currently exploring using Ray Datasets for flexible distributed data loading and Ray Workflows for fast durable application flows. Lastly, we are hiring! Join us and help reshape surgery through machine learning and computer vision.

Next steps

Anyscale's Platform in your Cloud

Get started today with Anyscale's self-service AI/ML platform:


  • Powerful, unified platform for all your AI jobs from training to inference and fine-tuning
  • Powered by Ray. Built by the Ray creators. Ray is the high-performance technology behind many of the most sophisticated AI projects in the world (OpenAI, Uber, Netflix, Spotify)
  • AI App building and experimentation without the Infra and Ops headaches
  • Multi-cloud and on-prem hybrid support