diff --git a/pkg/runtime/framework/plugins/jobset/jobset.go b/pkg/runtime/framework/plugins/jobset/jobset.go index a629c6fcc8..b433dbb731 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset.go +++ b/pkg/runtime/framework/plugins/jobset/jobset.go @@ -322,6 +322,29 @@ func (j *JobSet) Build(ctx context.Context, info *runtime.Info, trainJob *traine WithController(true). WithBlockOwnerDeletion(true)) + // Container images in spec.replicatedJobs are immutable in JobSet. + // When resuming a suspended JobSet, omit images from the apply config so SSA + // does not attempt to update them. The job resumes with the image it was created with. + if oldJobSet != nil && + ptr.Deref(oldJobSet.Spec.Suspend, false) && + !ptr.Deref(trainJob.Spec.Suspend, false) { + for i := range jobSet.Spec.ReplicatedJobs { + if jobSet.Spec.ReplicatedJobs[i].Template == nil || + jobSet.Spec.ReplicatedJobs[i].Template.Spec == nil || + jobSet.Spec.ReplicatedJobs[i].Template.Spec.Template == nil || + jobSet.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec == nil { + continue + } + podSpec := jobSet.Spec.ReplicatedJobs[i].Template.Spec.Template.Spec + for j := range podSpec.Containers { + podSpec.Containers[j].Image = nil + } + for j := range podSpec.InitContainers { + podSpec.InitContainers[j].Image = nil + } + } + } + return []apiruntime.ApplyConfiguration{jobSet}, nil } diff --git a/pkg/runtime/framework/plugins/jobset/jobset_test.go b/pkg/runtime/framework/plugins/jobset/jobset_test.go index 2e8e5d3968..91336d6915 100644 --- a/pkg/runtime/framework/plugins/jobset/jobset_test.go +++ b/pkg/runtime/framework/plugins/jobset/jobset_test.go @@ -1586,3 +1586,153 @@ func TestValidate(t *testing.T) { }) } } + +// TestBuild_ResumeOmitsContainerImages verifies that when a suspended JobSet is +// resumed, container images are omitted from the SSA apply config so that the +// immutable spec.replicatedJobs field is not updated. The job resumes with the +// image it was originally created with. +func TestBuild_ResumeOmitsContainerImages(t *testing.T) { + const upgradedImage = "registry.example.com/trainer:v2" + + cases := map[string]struct { + existingJobSet *jobsetv1alpha2.JobSet + trainJobSuspend bool + wantNilImages bool + wantNilResult bool + }{ + "suspended JobSet being resumed: images must be nil in apply config": { + existingJobSet: &jobsetv1alpha2.JobSet{ + ObjectMeta: metav1.ObjectMeta{Name: "trainjob", Namespace: metav1.NamespaceDefault}, + Spec: jobsetv1alpha2.JobSetSpec{Suspend: ptr.To(true)}, + }, + trainJobSuspend: false, + wantNilImages: true, + }, + "no existing JobSet: images present in apply config (creation path)": { + existingJobSet: nil, + trainJobSuspend: false, + wantNilImages: false, + }, + "both running: existing guard fires, nil result": { + existingJobSet: &jobsetv1alpha2.JobSet{ + ObjectMeta: metav1.ObjectMeta{Name: "trainjob", Namespace: metav1.NamespaceDefault}, + Spec: jobsetv1alpha2.JobSetSpec{Suspend: ptr.To(false)}, + }, + trainJobSuspend: false, + wantNilResult: true, + }, + "TrainJob suspended: images present (not a resume)": { + existingJobSet: &jobsetv1alpha2.JobSet{ + ObjectMeta: metav1.ObjectMeta{Name: "trainjob", Namespace: metav1.NamespaceDefault}, + Spec: jobsetv1alpha2.JobSetSpec{Suspend: ptr.To(true)}, + }, + trainJobSuspend: true, + wantNilImages: false, + }, + } + + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + _, ctx := ktesting.NewTestContext(t) + clientBuilder := utiltesting.NewClientBuilder() + if tc.existingJobSet != nil { + clientBuilder = clientBuilder.WithObjects(tc.existingJobSet) + } + cli := clientBuilder.Build() + + p, err := New(ctx, cli, nil, nil) + if err != nil { + t.Fatalf("Failed to initialize JobSet plugin: %v", err) + } + + trainJob := utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "trainjob"). + Suspend(tc.trainJobSuspend). + RuntimeRef(trainer.GroupVersion.WithKind(trainer.ClusterTrainingRuntimeKind), "runtime"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + Container(upgradedImage, nil, nil, corev1.ResourceList{}). + Obj()). + Obj() + + info := &runtime.Info{ + Labels: map[string]string{}, + Annotations: map[string]string{}, + Scheduler: &runtime.Scheduler{}, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: constants.Node, + Containers: []runtime.Container{{Name: constants.Node, Image: upgradedImage}}, + }, + }, + ObjApply: jobsetv1alpha2ac.JobSetSpec(). + WithReplicatedJobs( + jobsetv1alpha2ac.ReplicatedJob(). + WithName(constants.Node). + WithTemplate(batchv1ac.JobTemplateSpec(). + WithSpec(batchv1ac.JobSpec(). + WithTemplate(corev1ac.PodTemplateSpec(). + WithSpec(corev1ac.PodSpec(). + WithContainers( + corev1ac.Container(). + WithName(constants.Node). + WithImage(upgradedImage), + ), + ), + ), + ), + ), + ), + }, + } + + objs, err := p.(framework.ComponentBuilderPlugin).Build(ctx, info, trainJob) + if err != nil { + t.Fatalf("Unexpected error from Build: %v", err) + } + + if tc.wantNilResult { + if objs != nil { + t.Errorf("Expected nil result (running guard), got %v", objs) + } + return + } + if objs == nil || len(objs) == 0 { + t.Fatalf("Expected non-nil apply config, got nil") + } + + jobSetApply, ok := objs[0].(*jobsetv1alpha2ac.JobSetApplyConfiguration) + if !ok { + t.Fatalf("Expected *JobSetApplyConfiguration, got %T", objs[0]) + } + + for _, rJob := range jobSetApply.Spec.ReplicatedJobs { + if rJob.Template == nil || rJob.Template.Spec == nil || + rJob.Template.Spec.Template == nil || rJob.Template.Spec.Template.Spec == nil { + continue + } + for _, c := range rJob.Template.Spec.Template.Spec.Containers { + if tc.wantNilImages { + if c.Image != nil { + t.Errorf("Expected nil image for container %s on resume, got %q", *c.Name, *c.Image) + } + } else { + if c.Image == nil { + t.Errorf("Expected non-nil image for container %s, got nil", *c.Name) + } + } + } + for _, c := range rJob.Template.Spec.Template.Spec.InitContainers { + if tc.wantNilImages { + if c.Image != nil { + t.Errorf("Expected nil image for initContainer %s on resume, got %q", *c.Name, *c.Image) + } + } else { + if c.Image == nil { + t.Errorf("Expected non-nil image for initContainer %s, got nil", *c.Name) + } + } + } + } + }) + } +}