Over the past few releases, the Ray Train team has focused on improving the developer experience and runtime stability for distributed training workloads.
Ray Train V2 builds on this foundation — providing better usability, stronger reliability guarantees, and a cleaner API surface that enables faster feature development across the stack.
To that end, all new functionality introduced this quarter is built on Ray Train V2, and we strongly encourage users to begin migrating. See the Train Tune API Revamp REP and the Migration Guide for full details.
Ray Train has an existing set of functionality around managing model checkpoints. By default, checkpoints are uploaded synchronously to the target directory when the train.report API is called. However, this can cause GPU utilization to be low if checkpoints are large or are frequently saved.
To address this, Ray Train V2 now offers built-in asynchronous checkpoint uploading. If enabled, checkpoints can be uploaded to cloud storage in a separate CPU thread, so GPU training jobs don’t pause during I/O operations. This approach keeps GPU utilization high when checkpoints are large or frequent. Read more about asynchronous checkpointing in the documentation.
1def train_fn(config):
2 ...
3 metrics = {...}
4 tmpdir = tempfile.mkdtemp()
5 ... # Save checkpoint to tmpdir
6 checkpoint = Checkpoint.from_directory(tmpdir)
7 train.report(
8 metrics,
9 checkpoint=checkpoint,
10 checkpoint_upload_mode=train.CheckpointUploadMode.ASYNC,
11 )It is common to validate the model periodically during training to measure model performance and prevent overfitting. The standard way to do this is to periodically switch between training and validation within the training loop.
In Ray 2.51, Ray Train now allows you to asynchronously validate the model in a separate Ray task. This means you can
Run validation in parallel without blocking the training loop, leading to higher training goodput
Run validation on different (cheaper, less efficient) hardware than training
Easily scale out the validation using Ray Data or Ray Train
Leverage autoscaling to launch machines only for the duration of the validation
Read more about asynchronous checkpoint validation in the documentation.
1def validate_fn(checkpoint, config):
2 ...
3 return {"score": ...}
4
5def train_fn(config):
6 ...
7 ray.train.report(
8 metrics,
9 checkpoint=checkpoint,
10 checkpoint_upload_mode=ray.train.CheckpointUploadMode.ASYNC,
11 validate_fn=validate_fn,
12 validate_config={"dataset": ...},
13 )Ray Train now offers a JaxTrainer API, which enables seamless scaling of distributed JAX training to TPUs.
Distributed JAX has historically required users to run in a multi-controller configuration, spawning parallel processes across each device. However, this style of execution can be hard to recover from failures and can be hard to develop against iteratively.
With this new API, users can now scale JAX workloads natively on TPUs with the same single-controller orchestration and fault tolerance as other commonly used APIs like Ray Train’s TorchTrainer and TensorFlowTrainer. Read more about JAX on Ray Train in the documentation.
1from ray.train.v2.jax import JaxTrainer
2from ray.train import ScalingConfig
3
4def train_func():
5 # Your JAX training code here.
6
7scaling_config = ScalingConfig(num_workers=4, use_tpu=True, topology="4x4", accelerator_type="TPU-V6E")
8trainer = JaxTrainer(train_func, scaling_config=scaling_config)
9result = trainer.fit()Ray Train v2 now offers a “local” mode of execution. In `local mode`, instead of the default execution strategy of running training code across multiple Ray actors in parallel, local mode executes your training function directly in the current process, without needing to change your code. This provides a simplified debugging environment where you can iterate quickly on your training logic.
Local mode supports two execution modes:
Single-process mode: Runs your training function in a single process, ideal for rapid iteration and debugging.
Multi-process mode with torchrun: Launches multiple processes for multi-GPU training, useful for debugging distributed training logic with familiar tools.
Read more about local mode on the Ray Train documentation.
Ray Train V2 lays the groundwork for the next generation of distributed training — one that’s more modular, performant, and developer-friendly. Upcoming releases will extend this foundation with enhanced fault tolerance, deeper framework integrations, and unified experiment management across Train and Tune.