From 41a8ca982e8efc7b8327a7d71a03b4b3fa9aeee4 Mon Sep 17 00:00:00 2001 From: krishdef7 Date: Sat, 4 Apr 2026 01:37:36 +0530 Subject: [PATCH] feat(operator): support multi-slice TPU training via trainer replicas For multi-slice TPU, JobSet models each TPU slice as a ReplicatedJob replica, with parallelism = hosts per slice and replicas = slice count. The operator previously blocked this with two hard constraints: 1. builder.go unconditionally set trainer Replicas = 1, destroying any value from the runtime template. 2. trainingruntime_webhook.go rejected replicas != 1 for all ancestors including trainer. Changes: - builder.go: nil-guard for trainer Replicas, preserving the value from the runtime template instead of unconditional overwrite. - jobset.go: in Build(), compute perSlice = numNodes / replicas for the trainer ancestor so each slice runs the correct number of hosts. - trainingruntime_webhook.go: allow trainer ancestor replicas > 1 to enable multi-slice configurations to pass admission. - trainingruntime_webhook_test.go: update invalid_replicas test to reflect that trainer replicas > 1 is now valid. - trainingruntime_test.go: add test case for 4-slice x 8 hosts (NumNodes=32), verifying Parallelism=8 per slice and MinMember=34. Semantics: numNodes = total hosts across all slices. Per-slice hosts = numNodes / replicas. REF: https://github.com/kubeflow/trainer/issues/3407 Signed-off-by: krishdef7 --- pkg/runtime/core/trainingruntime.go | 6 ++- pkg/runtime/core/trainingruntime_test.go | 50 +++++++++++++++++++ .../framework/plugins/jobset/builder.go | 8 +-- .../framework/plugins/jobset/jobset.go | 23 ++++++++- pkg/webhooks/trainingruntime_webhook.go | 2 +- pkg/webhooks/trainingruntime_webhook_test.go | 2 - 6 files changed, 82 insertions(+), 9 deletions(-) diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 807d9847e6..185ac5fd69 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -166,13 +166,17 @@ func (r *TrainingRuntime) newRuntimeInfo( } for i, rJob := range jobSetSpecApply.ReplicatedJobs { - // TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs. + // TODO: Support multiple replicas for non-trainer replicatedJobs. // REF: https://github.com/kubeflow/trainer/issues/2318 count := ptr.Deref(rJob.Template.Spec.Parallelism, 1) var ancestor *string if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil { if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok { if labelAncestor == constants.AncestorTrainer && mlPolicy != nil { + // For multi-slice TPU, numNodes represents total nodes across all slices. + // Per-slice Parallelism/Completions is computed in Build() by dividing + // the final count (set by EnforceMLPolicy) by trainer replicas. + // REF: https://github.com/kubeflow/trainer/issues/3407 count = ptr.Deref(mlPolicy.NumNodes, 1) // Apply resourcesPerNode from TrainJob to the template spec diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index e933ec844f..0e2c8b1756 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -116,6 +116,56 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), }, }, + "succeeded to build PodGroup and JobSet with multi-slice TPU trainer.": { + trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime"). + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). + WithMLPolicy( + testingutil.MakeMLPolicyWrapper(). + WithNumNodes(32). + Obj(), + ). + PodGroupPolicyCoschedulingSchedulingTimeout(120). + Replicas(4, constants.Node). + Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Obj(), + ).Obj(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("uid"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + NumNodes(32). + Obj(), + ). + Obj(), + wantObjs: []runtime.Object{ + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). + Replicas(4, constants.Node). + Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + Parallelism(8, constants.Node). + Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + Completions(8, constants.Node). + Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + NumNodes(8). + Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). + Obj(), + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). + MinMember(34). // 34 = 8 nodes × 4 slices + 1 DatasetInitializer + 1 ModelInitializer + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("34"), + }). + SchedulingTimeout(120). + Obj(), + }, + }, "succeeded to build JobSet with NumNodes from the Runtime and container from the TrainJob.": { trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index d5d8e1ca86..999848b1a8 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -118,9 +118,11 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build ancestor = jobMetadata.Labels[constants.LabelTrainJobAncestor] } if ancestor == constants.AncestorTrainer { - // TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs. - // REF: https://github.com/kubeflow/trainer/issues/2318 - b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1) + // Preserve replicas for multi-slice TPU support. + // REF: https://github.com/kubeflow/trainer/issues/3407 + if b.Spec.ReplicatedJobs[i].Replicas == nil { + b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1) + } // Update values for the Trainer container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { if *container.Name == constants.Node { diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index a629c6fcc8..ceafd9b80c 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -273,8 +273,27 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine for psIdx, ps := range info.TemplateSpec.PodSets { if ps.Count != nil { - jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count - jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count + rJob := &jobSetSpec.ReplicatedJobs[psIdx] + jobMetadata := rJob.Template.ObjectMetaApplyConfiguration + isTrainer := jobMetadata != nil && jobMetadata.Labels != nil && + jobMetadata.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer + if isTrainer { + // For multi-slice trainer: count = numNodes (total across all slices). + // Parallelism/Completions must be per-slice = numNodes / replicas. + replicas := ptr.Deref(rJob.Replicas, 1) + if replicas <= 0 { + return nil, fmt.Errorf("trainer replicatedJob %d has invalid replicas %d: must be > 0", psIdx, replicas) + } + if *ps.Count%replicas != 0 { + return nil, fmt.Errorf("trainer numNodes %d must be evenly divisible by replicas %d", *ps.Count, replicas) + } + perSlice := *ps.Count / replicas + rJob.Template.Spec.Parallelism = &perSlice + rJob.Template.Spec.Completions = &perSlice + } else { + rJob.Template.Spec.Parallelism = ps.Count + rJob.Template.Spec.Completions = ps.Count + } } apply.UpsertVolumes(&jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...) for containerIdx, container := range ps.Containers { diff --git a/pkg/webhooks/trainingruntime_webhook.go b/pkg/webhooks/trainingruntime_webhook.go index 14f1fc3594..9b6f14160c 100644 --- a/pkg/webhooks/trainingruntime_webhook.go +++ b/pkg/webhooks/trainingruntime_webhook.go @@ -76,7 +76,7 @@ func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorLis } if labelAncestor, ok := rJob.Template.Labels[constants.LabelTrainJobAncestor]; ok && ancestors.Has(labelAncestor) { - if rJob.Replicas != 1 { + if labelAncestor != constants.AncestorTrainer && rJob.Replicas != 1 { allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, rJobReplicasErrorMsg)) } diff --git a/pkg/webhooks/trainingruntime_webhook_test.go b/pkg/webhooks/trainingruntime_webhook_test.go index fb39f0600d..ca0b00d963 100644 --- a/pkg/webhooks/trainingruntime_webhook_test.go +++ b/pkg/webhooks/trainingruntime_webhook_test.go @@ -56,8 +56,6 @@ func TestValidateReplicatedJobs(t *testing.T) { "2", ""), field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"), "2", ""), - field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(3).Child("replicas"), - "2", ""), }, }, "missing required container in replicatedJobs": {