HomeBlogBlog Detail

Announcing DP Group Fault Tolerance for vLLM WideEP Deployments with Ray Serve LLM

By Jeffrey Wang, Kourosh Hakhamaneshi, Abrar Sheikh and Seiji Eicher   |   April 2, 2026

Serving large MoE models such as DeepSeek-style models efficiently means combining data-parallel attention with expert parallelism across many GPUs. This WideEP-style layout is now a mainstream serving pattern because it improves memory efficiency, batch size, and throughput for large sparse models. The Ray + vLLM stack has leaned heavily into this path; for a background overview, see the Ray Summit talk on Efficient Multi-Node Orchestration of Sparse MoE Model Serving using Ray Serve and the vLLM writeup on large-scale serving.

Cluster layout showing three nodes each running vLLM DPEP0–DPEP3 groups receiving requests; illustrates distributed serving and grouping across replicas.Figure 1: Data parallel attention - expert parallel MoE deployment pattern commonly used to serve large MoEs. This diagram displays 1 DP group with 4 DP ranks, and each rank hosts a copy of data parallel attention layer and 2 experts.
Cluster layout showing three nodes each running vLLM DPEP0–DPEP3 groups receiving requests; illustrates distributed serving and grouping across replicas.

MoE serving is fundamentally different from regular model serving in one key way. In regular model serving, each replica is an independent, identical copy of the model. In an MoE architecture, expert layers are sharded across an entire group of replicas. That group of replicas, known as a data parallel group (DP group), works together to form a single logical copy of the model. A single query will be routed through many different replicas in the group as different experts are activated. During MoE execution, tokens are dispatched to experts and combined back across the participating ranks. These dispatch/combine paths require every participating rank to be present and healthy. If one rank disappears, tokens can no longer be routed to the missing experts, and the collective execution pattern can no longer proceed correctly.

As WideEP deployments grow larger, their operational blast radius grows with them. A single worker or node failure will disrupt expert routing and invalidate the collectives needed to complete inference. The result is that one localized DP rank failure will take an entire DP group out of service. Even worse, because replicas in a DP group do not fate-share by default, queries will continue to be routed to the live replicas in the same DP group, and all of those queries will fail. In typical WideEP configurations, a DP group spans across 16 to 128 DP ranks. If a single rank fails, the entire group of up to 128 GPUs is effectively non-operational.

Recovering from such a localized rank failure requires the serving system to identify ranks that form the affected DP group and restart them atomically. Traditional LLM serving systems lack the ability to identify associated replicas within a dependent serving group. Without this group-awareness, localized failures cause the entire group to be unavailable, and any request routed to the faulty group has zero availability. The entire system must restart to recover.

This creates a pressing challenge when serving WideEP in production: When one rank in a DP+EP group fails, the remaining ranks cannot cleanly complete MoE dispatch/combine paths, so the practical blast radius is the entire DP group, and the capability to recover from the failure is critical.

Metrics charts showing replicas and QPS over time; timeline highlights worker failure, DP group teardown, and recovery as replicas scale down then back up.Figure 2: Rank failure in data parallel attention - expert parallel MoE deployments.
Metrics charts showing replicas and QPS over time; timeline highlights worker failure, DP group teardown, and recovery as replicas scale down then back up.

In Ray 2.55, we introduced DP group fault tolerance in Ray Serve LLM by leveraging Ray Serve's gang scheduling primitive. This allows users to reduce the blast radius of faults by enforcing fault-tolerance policies at the boundary of DP groups: if one rank fails, the affected group is removed from service and rebuilt as a unit while other healthy groups continue serving.

LinkBasics of DP-EP Deployments

Before diving into the mechanics of how Ray Serve LLM achieves DP group fault tolerance, let’s go through a review of what DP and EP are and why they are standard for serving sparse MoE models.

Sparse MoE models like DeepSeek-V3 typically use data parallelism (DP) rather than tensor parallelism (TP). The reason is that they use Multi-Head Latent Attention (MLA) rather than the traditional Multi-Head Attention (MHA). In MHA, the presence of multiple KV heads makes it natural to shard along the KV-head dimension to achieve tensor parallelism. However, this approach does not translate to MLA, because MLA compresses the KV cache into a low-dimensional latent vector. Sharding this latent would require an all-gather across ranks to reconstruct it, introducing communication overhead and negating the memory efficiency that MLA is designed to deliver. Therefore, DP for MLA, where the attention layer is replicated across all participating ranks, is the commonly adopted pattern.

Cluster layout showing three nodes each running vLLM DPEP0–DPEP3 groups receiving requests; illustrates distributed serving and grouping across replicas.Figure 3: MLA layers are replicated across ranks while MoE layers are sharded among ranks.
Cluster layout showing three nodes each running vLLM DPEP0–DPEP3 groups receiving requests; illustrates distributed serving and grouping across replicas.

In traditional transformers, attention layers are followed by a single feed-forward linear network. In MoE models, that feed-forward linear network is replaced by a collection of smaller, parallel feed-forward networks, each representing an “expert”, preceded by a learned gating mechanism that decides which experts to activate for each token at runtime.

System diagram with Ray workers running DP0–DP3 attention layers and shared experts E0–E7, forming a data-parallel attention group coordinated by DeepEP kernels.Figure 4: Token routing / expert selection mechanism in an MoE layer.
System diagram with Ray workers running DP0–DP3 attention layers and shared experts E0–E7, forming a data-parallel attention group coordinated by DeepEP kernels.

In DeepSeek-V3, each MoE layer has 256 experts, of which only 8 are activated per token. To serve MoE models efficiently, experts are distributed across participating ranks (e.g. a typical configuration hosts 8 experts per rank across 32 ranks). To further boost serving throughput, experts can be distributed across a larger number of GPUs, increasing the effective KV cache size because each rank now holds weights from fewer experts, which in turn enables larger batch sizes. This pattern of distributing experts across many GPUs is commonly known as Wide Expert Parallelism (WideEP).

Dispatch and combine are two communication primitives that enable distributed MoE. Because experts are spread across ranks, a token processed on rank 0 may need to be routed to an expert sitting on rank 1. Dispatch is the all-to-all communication step that sends each token’s hidden state to whichever ranks hold its selected experts. Each rank then runs its local experts on the tokens it received. Combine is the reverse all-to-all: the expert outputs are sent back to the originating ranks, where they are weighted and summed according to the gate scores to produce the final MoE layer output.

Together, the replicated DP attention layers and the sharded expert layers form a DP group, and all ranks among a DP group must operate collectively. EP inherently subsumes DP: deploying experts across EP ranks requires the attention layers to run in a data-parallel fashion over the same set of ranks.

LinkWhy the DP Group Is the Right Unit

In WideEP deployments, partial groups are not valid serving units. That means the control plane should never expose half-alive groups to traffic. Instead, it should preserve one invariant everywhere:

  • Schedule a whole group all at once

  • Health check the whole group

  • Scale in increments of the whole group

  • Recovery policy atomically applies to the entire group

That is the core idea behind DP group fault tolerance. It is not a new engine algorithm; it is the control-plane behavior required to make WideEP safe in production. With this capability, an important operational tradeoff becomes visible: group width versus number of groups.

The vLLM large-scale serving post highlights an important design opportunity. In the decode results, throughput per GPU remains in a similar range across EP sizes 32, 72, and 96, indicating that efficiency doesn’t require maximizing DPEP width. The EP group width should typically be tuned to the smallest value that maximizes throughput. The reason is that smaller groups lead to a smaller blast radius.

Decode throughput bar chart comparing output tokens/sec per GPU across expert parallel sizes 32, 72, and 96, showing slight throughput decrease with larger EP size.Figure 5: Decode throughput from the vLLM large-scale serving post. This is the key figure supporting the redundancy argument because throughput per GPU stays in a similar range across EP sizes 32, 72, and 96.
Decode throughput bar chart comparing output tokens/sec per GPU across expert parallel sizes 32, 72, and 96, showing slight throughput decrease with larger EP size.

LinkState of the Art

Kubernetes provides gang scheduling with scheduling-time guarantees, but lacks runtime fate-sharing guarantees for the replicas in a gang. Grove, a Kubernetes-based scheduler designed for Dynamo deployments, extends this with both scheduling-time and runtime gang scheduling guarantees, though it has not yet been applied to achieve WideEP fault tolerance. At the engine level, SGLang implements elastic expert parallelism fault tolerance mechanisms.

LinkDP Group Fault Tolerance in Ray Serve LLM for vLLM

Ray Serve LLM implements this with gang scheduling. Each DP rank is hosted on a Ray Serve replica, and the replicas that form one logical model instance are managed as a gang.

Diagram of MoE pipeline: requests fan out to DP0–DP3 attention replicas, dispatched to shared expert layers E0–E7 via DeepEP, then combined back per DPEP rank.Figure 6: In Ray Serve LLM, each data parallel attention rank is hosted on a Ray Serve replica, and all data parallel attention ranks in a model replica form a data parallel attention group, mapping to a Ray Serve gang.
Diagram of MoE pipeline: requests fan out to DP0–DP3 attention replicas, dispatched to shared expert layers E0–E7 via DeepEP, then combined back per DPEP rank.

This gives Ray Serve the right orchestration semantics for WideEP:

  • All replicas in a DP group are scheduled together.

  • A failure in one member invalidates the group.

  • The failed group is torn down and recreated atomically.

  • The router can continue sending traffic to other healthy groups.

Router selects top-k experts for an input token, dispatches to chosen experts, then combines outputs via weighted sum; inactive experts are skipped.Figure 7: DP group fault tolerance.
Router selects top-k experts for an input token, dispatches to chosen experts, then combines outputs via weighted sum; inactive experts are skipped.

DP group fault tolerance is enabled by default in Ray 2.55+. If you are already serving DP deployments on Ray Serve LLM, no code changes are required.

1from ray.serve.llm import (
2    build_dp_deployment,
3    LLMConfig,
4    ModelLoadingConfig,
5)
6
7llm_config = LLMConfig(
8    model_loading_config=ModelLoadingConfig(
9        model_id="microsoft/Phi-tiny-MoE-instruct",
10        model_source="microsoft/Phi-tiny-MoE-instruct",
11    ),
12    deployment_config=dict(
13        num_replicas=2, # <--- Number of DP groups
14    ),
15    engine_kwargs=dict(
16        tensor_parallel_size=1,
17        pipeline_parallel_size=1,
18        data_parallel_size=2, # <--- DP group size
19    ),
20)
21
22app = build_dp_deployment(llm_config)
23deployment_handle = serve.run(app, blocking=False)

For a deeper walkthrough of the feature, see the Ray LLM Office Hours recording on DP group fault tolerance.

LinkHow Recovery Works

At runtime, the recovery loop is straightforward:

  • A rank in a DP group becomes unhealthy.

  • The entire DP group is marked unhealthy.

  • Ray Serve stops routing traffic to that group.

  • The failed group is torn down.

  • A healthy replacement group is created and rejoins the deployment.

Autoscaling metrics: replicas, QPS, P90 and P99 latency across time; shows scaling up/down of DP groups and latency spikes during startup and scaling events.Figure 8: DP group fault tolerance lifecycle: Rank failure → entire DP group (gang) is marked as unhealthy → Ray Serve controller detects unhealthy gang and stops routing traffic → gang recovery → back to healthy.
Autoscaling metrics: replicas, QPS, P90 and P99 latency across time; shows scaling up/down of DP groups and latency spikes during startup and scaling events.

This preserves the invariant that only complete groups ever receive traffic. It also means a rank failure does not have to become a deployment-wide outage as long as there are other healthy groups available. The Grafana dashboard below shows this operation in action. We simulate a failure by killing one of the vLLM DPEP workers and the incoming traffic keeps getting served.

Recovery flow: failed GPU enters recovery, DP groups are rebalanced, experts reassigned, and system resumes normal dispatch and combine operations.Figure 9: In Ray Serve LLM, with DP group fault tolerance, there is no availability drop, and data parallel attention group atomicity is preserved throughout the deployment.
Recovery flow: failed GPU enters recovery, DP groups are rebalanced, experts reassigned, and system resumes normal dispatch and combine operations.

LinkDP Group Autoscaling

The same principle applies to autoscaling. In standard stateless services, scaling one replica at a time is natural. In WideEP deployments, this pattern will not work. Adding or removing an individual rank would create a partial group, which is not a valid serving topology.

Autoscaling must also respect DP group boundaries. Scale-up and scale-down need to happen in group-sized increments, not individual-replica increments.

Ray Serve LLM supports this through gang-aware autoscaling. Replica counts remain aligned with the DP group size, preserving the same atomicity guarantees during scaling that we rely on during failure recovery.

Failure impact diagram: DP1 rank fails on GPU1, blocking dispatch; affected experts E2–E3 highlighted, illustrating partial pipeline disruption.Figure 10: In Ray Serve LLM, with gang-aware autoscaling, DP group atomicity is preserved throughout the scale-up and scale-down process. The DP deployment shown in this figure has data parallel size of 2, and therefore the number of Ray Serve replicas is always a multiple of 2 to align with the DP group.
Failure impact diagram: DP1 rank fails on GPU1, blocking dispatch; affected experts E2–E3 highlighted, illustrating partial pipeline disruption.

The configuration is similarly simple:

1from ray.serve.llm import (
2    build_dp_deployment,
3    LLMConfig,
4    ModelLoadingConfig,
5)
6
7llm_config = LLMConfig(
8    model_loading_config=ModelLoadingConfig(
9        model_id="microsoft/Phi-tiny-MoE-instruct",
10        model_source="microsoft/Phi-tiny-MoE-instruct",
11    ),
12    deployment_config=dict(
13        num_replicas="auto",
14        autoscaling_config=dict(
15            min_replicas=1, # <-- Min. number of DP groups
16            max_replicas=4, # <-- Max. number of DP groups
17        )
18    ),
19    engine_kwargs=dict(
20        tensor_parallel_size=1,
21        pipeline_parallel_size=1,
22        data_parallel_size=2, # <--- DP group size
23    ),
24)
25
26app = build_dp_deployment(llm_config)
27deployment_handle = serve.run(app, blocking=False)

LinkWhere vLLM’s Elastic EP Fits

DP group fault tolerance is one layer of a larger resilience story for large MoE serving.

At a high level, there are two complementary layers:

  • Orchestration-level resilience, which manages failure domains across groups.

  • Engine-level elasticity, which manages recovery and continuation within a group.

DP-group fault tolerance solves the orchestration layer of the problem. vLLM Elastic EP addresses a complementary engine layer: how the runtime can dynamically evolve its DP+EP topology and eventually handle richer in-engine elasticity and recovery. The right references here are the community’s vLLM Elastic Expert Parallelism RFC.

Failure scenario: one DP rank/GPU fails causing DeepEP dispatch hang; diagram shows affected experts and ranks with redistribution and recovery steps.Figure 11: Elastic expert parallelism.
Failure scenario: one DP rank/GPU fails causing DeepEP dispatch hang; diagram shows affected experts and ranks with redistribution and recovery steps.

These two layers are complementary, not competing:

  • Ray Serve LLM manages which groups exist, which groups are healthy, and where traffic goes.

  • vLLM Elastic EP advances what is possible inside a running DP+EP engine.

Check out Ray Serve LLM’s roadmap for the integration with vLLM’s Elastic Expert Parallelism.

LinkConclusion

Data parallel attention - expert parallel MoE deployment pattern is commonly used to serve large MoE LLMs, and Ray Serve LLM offers fault tolerance and autoscaling through data parallel attention group fault tolerance and Ray Serve gang autoscaling capabilities, minimizing availability drop and adapting to traffic throughout the serving lifecycle. All of the code for reproducing the above results is available here. Stay tuned for more updates on the integration between vLLM elastic expert parallelism and Ray Serve LLM!

LinkJoin the Community!

  • Join the Ray Slack #llm channel

  • Ray LLM office hours; sign up here for the Calendar invite

  • Ray LLM office hours past recordings

Ready to try Anyscale?

Access Anyscale today to see how companies using Anyscale and Ray benefit from rapid time-to-market and faster iterations across the entire AI lifecycle.