diff --git a/pkg/runtime/core/trainingruntime.go b/pkg/runtime/core/trainingruntime.go index 807d9847e6..185ac5fd69 100644 --- a/pkg/runtime/core/trainingruntime.go +++ b/pkg/runtime/core/trainingruntime.go @@ -166,13 +166,17 @@ func (r *TrainingRuntime) newRuntimeInfo( } for i, rJob := range jobSetSpecApply.ReplicatedJobs { - // TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs. + // TODO: Support multiple replicas for non-trainer replicatedJobs. // REF: https://github.com/kubeflow/trainer/issues/2318 count := ptr.Deref(rJob.Template.Spec.Parallelism, 1) var ancestor *string if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil { if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok { if labelAncestor == constants.AncestorTrainer && mlPolicy != nil { + // For multi-slice TPU, numNodes represents total nodes across all slices. + // Per-slice Parallelism/Completions is computed in Build() by dividing + // the final count (set by EnforceMLPolicy) by trainer replicas. + // REF: https://github.com/kubeflow/trainer/issues/3407 count = ptr.Deref(mlPolicy.NumNodes, 1) // Apply resourcesPerNode from TrainJob to the template spec diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index e933ec844f..0e2c8b1756 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -116,6 +116,56 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), }, }, + "succeeded to build PodGroup and JobSet with multi-slice TPU trainer.": { + trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime"). + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). + WithMLPolicy( + testingutil.MakeMLPolicyWrapper(). + WithNumNodes(32). + Obj(), + ). + PodGroupPolicyCoschedulingSchedulingTimeout(120). + Replicas(4, constants.Node). + Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Obj(), + ).Obj(), + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + UID("uid"). + RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime"). + Trainer( + testingutil.MakeTrainJobTrainerWrapper(). + NumNodes(32). + Obj(), + ). + Obj(), + wantObjs: []runtime.Object{ + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). + Replicas(4, constants.Node). + Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + Parallelism(8, constants.Node). + Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + Completions(8, constants.Node). + Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher). + NumNodes(8). + Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests). + PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). + Obj(), + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). + MinMember(34). // 34 = 8 nodes × 4 slices + 1 DatasetInitializer + 1 ModelInitializer + MinResources(corev1.ResourceList{ + corev1.ResourceCPU: resource.MustParse("34"), + }). + SchedulingTimeout(120). + Obj(), + }, + }, "succeeded to build JobSet with NumNodes from the Runtime and container from the TrainJob.": { trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec( testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec). diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index d5d8e1ca86..999848b1a8 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -118,9 +118,11 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build ancestor = jobMetadata.Labels[constants.LabelTrainJobAncestor] } if ancestor == constants.AncestorTrainer { - // TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs. - // REF: https://github.com/kubeflow/trainer/issues/2318 - b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1) + // Preserve replicas for multi-slice TPU support. + // REF: https://github.com/kubeflow/trainer/issues/3407 + if b.Spec.ReplicatedJobs[i].Replicas == nil { + b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1) + } // Update values for the Trainer container. for j, container := range rJob.Template.Spec.Template.Spec.Containers { if *container.Name == constants.Node { diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index a629c6fcc8..ceafd9b80c 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -273,8 +273,27 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine for psIdx, ps := range info.TemplateSpec.PodSets { if ps.Count != nil { - jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count - jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count + rJob := &jobSetSpec.ReplicatedJobs[psIdx] + jobMetadata := rJob.Template.ObjectMetaApplyConfiguration + isTrainer := jobMetadata != nil && jobMetadata.Labels != nil && + jobMetadata.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer + if isTrainer { + // For multi-slice trainer: count = numNodes (total across all slices). + // Parallelism/Completions must be per-slice = numNodes / replicas. + replicas := ptr.Deref(rJob.Replicas, 1) + if replicas <= 0 { + return nil, fmt.Errorf("trainer replicatedJob %d has invalid replicas %d: must be > 0", psIdx, replicas) + } + if *ps.Count%replicas != 0 { + return nil, fmt.Errorf("trainer numNodes %d must be evenly divisible by replicas %d", *ps.Count, replicas) + } + perSlice := *ps.Count / replicas + rJob.Template.Spec.Parallelism = &perSlice + rJob.Template.Spec.Completions = &perSlice + } else { + rJob.Template.Spec.Parallelism = ps.Count + rJob.Template.Spec.Completions = ps.Count + } } apply.UpsertVolumes(&jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...) for containerIdx, container := range ps.Containers { diff --git a/pkg/webhooks/trainingruntime_webhook.go b/pkg/webhooks/trainingruntime_webhook.go index 14f1fc3594..9b6f14160c 100644 --- a/pkg/webhooks/trainingruntime_webhook.go +++ b/pkg/webhooks/trainingruntime_webhook.go @@ -76,7 +76,7 @@ func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorLis } if labelAncestor, ok := rJob.Template.Labels[constants.LabelTrainJobAncestor]; ok && ancestors.Has(labelAncestor) { - if rJob.Replicas != 1 { + if labelAncestor != constants.AncestorTrainer && rJob.Replicas != 1 { allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, rJobReplicasErrorMsg)) } diff --git a/pkg/webhooks/trainingruntime_webhook_test.go b/pkg/webhooks/trainingruntime_webhook_test.go index fb39f0600d..ca0b00d963 100644 --- a/pkg/webhooks/trainingruntime_webhook_test.go +++ b/pkg/webhooks/trainingruntime_webhook_test.go @@ -56,8 +56,6 @@ func TestValidateReplicatedJobs(t *testing.T) { "2", ""), field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"), "2", ""), - field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(3).Child("replicas"), - "2", ""), }, }, "missing required container in replicatedJobs": {