Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/runtime/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,17 @@ func (r *TrainingRuntime) newRuntimeInfo(
}

for i, rJob := range jobSetSpecApply.ReplicatedJobs {
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
// TODO: Support multiple replicas for non-trainer replicatedJobs.
// REF: https://github.com/kubeflow/trainer/issues/2318
count := ptr.Deref(rJob.Template.Spec.Parallelism, 1)
var ancestor *string
if metadata := rJob.Template.ObjectMetaApplyConfiguration; metadata != nil && metadata.Labels != nil {
if labelAncestor, ok := metadata.Labels[constants.LabelTrainJobAncestor]; ok {
if labelAncestor == constants.AncestorTrainer && mlPolicy != nil {
// For multi-slice TPU, numNodes represents total nodes across all slices.
// Per-slice Parallelism/Completions is computed in Build() by dividing
// the final count (set by EnforceMLPolicy) by trainer replicas.
// REF: https://github.com/kubeflow/trainer/issues/3407
count = ptr.Deref(mlPolicy.NumNodes, 1)

// Apply resourcesPerNode from TrainJob to the template spec
Expand Down
50 changes: 50 additions & 0 deletions pkg/runtime/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,56 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
},
},
"succeeded to build PodGroup and JobSet with multi-slice TPU trainer.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
WithMLPolicy(
testingutil.MakeMLPolicyWrapper().
WithNumNodes(32).
Obj(),
).
PodGroupPolicyCoschedulingSchedulingTimeout(120).
Replicas(4, constants.Node).
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Obj(),
).Obj(),
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
RuntimeRef(trainer.SchemeGroupVersion.WithKind(trainer.TrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper().
NumNodes(32).
Obj(),
).
Obj(),
wantObjs: []runtime.Object{
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
Replicas(4, constants.Node).
Replicas(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
Parallelism(8, constants.Node).
Parallelism(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
Completions(8, constants.Node).
Completions(1, constants.DatasetInitializer, constants.ModelInitializer, constants.Launcher).
NumNodes(8).
Container(constants.DatasetInitializer, constants.DatasetInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Container(constants.ModelInitializer, constants.ModelInitializer, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
Container(constants.Node, constants.Node, "test:runtime", []string{"runtime"}, []string{"runtime"}, resRequests).
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid").
MinMember(34). // 34 = 8 nodes × 4 slices + 1 DatasetInitializer + 1 ModelInitializer
MinResources(corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("34"),
}).
SchedulingTimeout(120).
Obj(),
},
},
"succeeded to build JobSet with NumNodes from the Runtime and container from the TrainJob.": {
trainingRuntime: testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").Spec).
Expand Down
8 changes: 5 additions & 3 deletions pkg/runtime/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,11 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build
ancestor = jobMetadata.Labels[constants.LabelTrainJobAncestor]
}
if ancestor == constants.AncestorTrainer {
// TODO: Support multiple replicas ('.template.spec.replicatedJobs[*].replicas') for replicated Jobs.
// REF: https://github.com/kubeflow/trainer/issues/2318
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
// Preserve replicas for multi-slice TPU support.
// REF: https://github.com/kubeflow/trainer/issues/3407
if b.Spec.ReplicatedJobs[i].Replicas == nil {
b.Spec.ReplicatedJobs[i].Replicas = ptr.To[int32](1)
}
// Update values for the Trainer container.
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
if *container.Name == constants.Node {
Expand Down
23 changes: 21 additions & 2 deletions pkg/runtime/framework/plugins/jobset/jobset.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,27 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine

for psIdx, ps := range info.TemplateSpec.PodSets {
if ps.Count != nil {
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Parallelism = ps.Count
jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Completions = ps.Count
rJob := &jobSetSpec.ReplicatedJobs[psIdx]
jobMetadata := rJob.Template.ObjectMetaApplyConfiguration
isTrainer := jobMetadata != nil && jobMetadata.Labels != nil &&
jobMetadata.Labels[constants.LabelTrainJobAncestor] == constants.AncestorTrainer
if isTrainer {
// For multi-slice trainer: count = numNodes (total across all slices).
// Parallelism/Completions must be per-slice = numNodes / replicas.
replicas := ptr.Deref(rJob.Replicas, 1)
Comment thread
krishdef7 marked this conversation as resolved.
if replicas <= 0 {
return nil, fmt.Errorf("trainer replicatedJob %d has invalid replicas %d: must be > 0", psIdx, replicas)
}
if *ps.Count%replicas != 0 {
return nil, fmt.Errorf("trainer numNodes %d must be evenly divisible by replicas %d", *ps.Count, replicas)
}
perSlice := *ps.Count / replicas
rJob.Template.Spec.Parallelism = &perSlice
rJob.Template.Spec.Completions = &perSlice
} else {
rJob.Template.Spec.Parallelism = ps.Count
rJob.Template.Spec.Completions = ps.Count
}
}
apply.UpsertVolumes(&jobSetSpec.ReplicatedJobs[psIdx].Template.Spec.Template.Spec.Volumes, ps.Volumes...)
for containerIdx, container := range ps.Containers {
Expand Down
2 changes: 1 addition & 1 deletion pkg/webhooks/trainingruntime_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorLis
}

if labelAncestor, ok := rJob.Template.Labels[constants.LabelTrainJobAncestor]; ok && ancestors.Has(labelAncestor) {
if rJob.Replicas != 1 {
if labelAncestor != constants.AncestorTrainer && rJob.Replicas != 1 {
allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, rJobReplicasErrorMsg))
}

Expand Down
2 changes: 0 additions & 2 deletions pkg/webhooks/trainingruntime_webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ func TestValidateReplicatedJobs(t *testing.T) {
"2", ""),
field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"),
"2", ""),
field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(3).Child("replicas"),
"2", ""),
},
},
"missing required container in replicatedJobs": {
Expand Down
Loading