Skip to content

[train][torchft] Ray Train manages replica group restarts#61475

Open
TimothySeah wants to merge 4 commits intoray-project:masterfrom
TimothySeah:tseah/torchft-phase-1b
Open

[train][torchft] Ray Train manages replica group restarts#61475
TimothySeah wants to merge 4 commits intoray-project:masterfrom
TimothySeah:tseah/torchft-phase-1b

Conversation

@TimothySeah
Copy link
Contributor

@TimothySeah TimothySeah commented Mar 4, 2026

Summary

This PR follows up on #61156 by handling torchft worker group failure recovery.

Here are some of the design decisions:

  • ReplicaGroup.shutdown is similar to WorkerGroup.shutdown (they both shut down workers and clear state) but doesn't do some other stuff (e.g. callbacks and placementgroup cleanup).
  • WorkerGroup.replace_replica_group is similar to WorkerGroup._start_impl so I refactored their shared functionality accordingly. The main difference is that the former runs fewer callbacks.
  • I added a new WorldRankToOngoingPoll class to provide a clean interface for the ReplicaGroup to update the WorkerGroup's polling state.

I also went through every single WorkerGroupCallback method and determined whether or not they are relevant for ReplicaGroups. In particular:

  • before_init_train_context is the same for both ReplicaGroup and WorkerGroup across all callbacks so I just put it into a new shared ExecutionGroupCallback. The only callout is that I need to fix DatasetsCallback in a future PR (torchft + ray train phase 2 = integration with Ray Data).
  • before_worker_group_shutdown is slightly different so I created a new ReplicaGroupCallback with this method. Even though BackendSetupCallback is the same, ReportCallbackHandler is different (I will fix reporting in a future PR). Finally we do not want to call StateManagerCallback on replica groups because replica groups are not reflected in train run state.
  • after_worker_group_start is slightly different. In particular, BackendSetupCallback is different because we want to call it once per replica group for the worker group. Meanwhile, WorkingDirectorySetupCallback is the same, DatasetSetupCallback and ReportCallbackHandler will be handled in the aforementioned future PR's, and StateManager and PlacementGroupCleanerCallback only apply to worker groups.

All other WorkerGroupCallback methods are irrelevant:

  • on_worker_group-start/on_worker_group_shutdown are just for timing the worker group.
  • before_worker_group_start is only used by StateManagerCallback which is irrelevant as explained earlier.
  • after_worker_group_training_start is never used.
  • after_worker_group_shutdown is only used by DatasetsCallback will be handled in an aforementioned future PR.
  • after_worker_group_poll_status, before_worker_group_abort, and after_worker_group_abort are irrelevant because they operate on all the workers in the worker group, while replica groups are just thin wrappers around the workers.

Testing

I'm open to more unit test suggestions. I basically tried to unit test different layers of the stack as follows:

  • test_torch_trainer: e2e test. I also verified that it works as expected. It's still disabled until I add torchft dependencies to the train CI.
  • test_controller: tests that we correctly decide when to do a replica group restart or a full worker group restart. Note that this also tests elastic training.
  • test_worker_group: tests that when we replace_replica_group we correctly update the relevant state (WorkerGroupState, replica groups, WorldRankToOngoingPoll). I added mark.parametrize to other unit tests to verify other behavior works as expected with both worker group and replica group restarts e.g. callbacks and worker initialization.
  • test_worker_group_poll_status: simple tests of the newly added WorldRankToOngoingPoll and failing_replica_group_indices.

TODO: will run my prototype driver script in a workspace.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for replica group restarts, a key feature for fault tolerance with torchft. The changes are well-structured, introducing ExecutionGroup and ExecutionGroupCallback as base classes to share logic between WorkerGroup and the new ReplicaGroup concepts. The controller logic is updated to handle partial restarts of failing replica groups, and the WorkerGroup is enhanced with a replace_replica_group method. The refactoring is clean and the new functionality is supported by a comprehensive set of tests. I have one suggestion regarding the callback handling to make it more robust for custom callbacks.

Signed-off-by: Timothy Seah <tseah@anyscale.com>
@TimothySeah TimothySeah marked this pull request as ready for review March 4, 2026 04:11
@TimothySeah TimothySeah requested a review from a team as a code owner March 4, 2026 04:11
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

"At least one replacement worker failed to initialize "
f"in replica group {replica_group_index}."
)
raise WorkerGroupStartupFailedError(error_msg) from actor_error
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Failed partial replacement causes unrecoverable retry loop

Medium Severity

replace_replica_group shuts down the old replica group (killing workers, setting _workers = None) before attempting to create and initialize replacements. If the subsequent _create_workers or _init_train_context fails, the worker group is left in an inconsistent state — the old replica group is destroyed, but _worker_group_state and _latest_poll_status are not updated. On retry, _execute_resize_decision reads the stale _latest_poll_status, re-enters the partial replacement path (all conditions still hold), and calls replace_replica_group again on the already-shut-down replica group, which immediately fails at get_workers(). This consumes all retry attempts without ever falling back to a full restart that could actually recover.

Additional Locations (1)

Fix in Cursor Fix in Web

"At least one replacement worker failed to initialize "
f"in replica group {replica_group_index}."
)
raise WorkerGroupStartupFailedError(error_msg) from actor_error
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New worker actors orphaned on failed replacement

Low Severity

In replace_replica_group, if _create_workers succeeds but _init_train_context raises RayActorError, the newly created worker actors are never cleaned up. They are not yet tracked in _worker_group_state or _replica_groups, so no subsequent shutdown or full restart will kill them. These orphaned actors hold cluster resources until the controller actor itself dies.

Fix in Cursor Fix in Web

@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Mar 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

train Ray Train Related Issue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant