[train][torchft] Ray Train manages replica group restarts#61475
[train][torchft] Ray Train manages replica group restarts#61475TimothySeah wants to merge 4 commits intoray-project:masterfrom
Conversation
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
There was a problem hiding this comment.
Code Review
This pull request introduces support for replica group restarts, a key feature for fault tolerance with torchft. The changes are well-structured, introducing ExecutionGroup and ExecutionGroupCallback as base classes to share logic between WorkerGroup and the new ReplicaGroup concepts. The controller logic is updated to handle partial restarts of failing replica groups, and the WorkerGroup is enhanced with a replace_replica_group method. The refactoring is clean and the new functionality is supported by a comprehensive set of tests. I have one suggestion regarding the callback handling to make it more robust for custom callbacks.
Signed-off-by: Timothy Seah <tseah@anyscale.com>
| "At least one replacement worker failed to initialize " | ||
| f"in replica group {replica_group_index}." | ||
| ) | ||
| raise WorkerGroupStartupFailedError(error_msg) from actor_error |
There was a problem hiding this comment.
Failed partial replacement causes unrecoverable retry loop
Medium Severity
replace_replica_group shuts down the old replica group (killing workers, setting _workers = None) before attempting to create and initialize replacements. If the subsequent _create_workers or _init_train_context fails, the worker group is left in an inconsistent state — the old replica group is destroyed, but _worker_group_state and _latest_poll_status are not updated. On retry, _execute_resize_decision reads the stale _latest_poll_status, re-enters the partial replacement path (all conditions still hold), and calls replace_replica_group again on the already-shut-down replica group, which immediately fails at get_workers(). This consumes all retry attempts without ever falling back to a full restart that could actually recover.
Additional Locations (1)
| "At least one replacement worker failed to initialize " | ||
| f"in replica group {replica_group_index}." | ||
| ) | ||
| raise WorkerGroupStartupFailedError(error_msg) from actor_error |
There was a problem hiding this comment.
New worker actors orphaned on failed replacement
Low Severity
In replace_replica_group, if _create_workers succeeds but _init_train_context raises RayActorError, the newly created worker actors are never cleaned up. They are not yet tracked in _worker_group_state or _replica_groups, so no subsequent shutdown or full restart will kill them. These orphaned actors hold cluster resources until the controller actor itself dies.


Summary
This PR follows up on #61156 by handling torchft worker group failure recovery.
Here are some of the design decisions:
ReplicaGroup.shutdownis similar toWorkerGroup.shutdown(they both shut down workers and clear state) but doesn't do some other stuff (e.g. callbacks and placementgroup cleanup).WorkerGroup.replace_replica_groupis similar toWorkerGroup._start_implso I refactored their shared functionality accordingly. The main difference is that the former runs fewer callbacks.WorldRankToOngoingPollclass to provide a clean interface for theReplicaGroupto update theWorkerGroup's polling state.I also went through every single
WorkerGroupCallbackmethod and determined whether or not they are relevant forReplicaGroups. In particular:before_init_train_contextis the same for bothReplicaGroupandWorkerGroupacross all callbacks so I just put it into a new sharedExecutionGroupCallback. The only callout is that I need to fixDatasetsCallbackin a future PR (torchft + ray train phase 2 = integration with Ray Data).before_worker_group_shutdownis slightly different so I created a newReplicaGroupCallbackwith this method. Even thoughBackendSetupCallbackis the same,ReportCallbackHandleris different (I will fix reporting in a future PR). Finally we do not want to callStateManagerCallbackon replica groups because replica groups are not reflected in train run state.after_worker_group_startis slightly different. In particular,BackendSetupCallbackis different because we want to call it once per replica group for the worker group. Meanwhile,WorkingDirectorySetupCallbackis the same,DatasetSetupCallbackandReportCallbackHandlerwill be handled in the aforementioned future PR's, andStateManagerandPlacementGroupCleanerCallbackonly apply to worker groups.All other
WorkerGroupCallbackmethods are irrelevant:on_worker_group-start/on_worker_group_shutdownare just for timing the worker group.before_worker_group_startis only used byStateManagerCallbackwhich is irrelevant as explained earlier.after_worker_group_training_startis never used.after_worker_group_shutdownis only used byDatasetsCallbackwill be handled in an aforementioned future PR.after_worker_group_poll_status,before_worker_group_abort, andafter_worker_group_abortare irrelevant because they operate on all the workers in the worker group, while replica groups are just thin wrappers around the workers.Testing
I'm open to more unit test suggestions. I basically tried to unit test different layers of the stack as follows:
test_torch_trainer: e2e test. I also verified that it works as expected. It's still disabled until I add torchft dependencies to the train CI.test_controller: tests that we correctly decide when to do a replica group restart or a full worker group restart. Note that this also tests elastic training.test_worker_group: tests that when wereplace_replica_groupwe correctly update the relevant state (WorkerGroupState, replica groups,WorldRankToOngoingPoll). I addedmark.parametrizeto other unit tests to verify other behavior works as expected with both worker group and replica group restarts e.g. callbacks and worker initialization.test_worker_group_poll_status: simple tests of the newly addedWorldRankToOngoingPollandfailing_replica_group_indices.TODO: will run my prototype driver script in a workspace.