diff --git a/pkg/runtime/framework/plugins/flux/flux_test.go b/pkg/runtime/framework/plugins/flux/flux_test.go index 02aec04d30..668476e634 100644 --- a/pkg/runtime/framework/plugins/flux/flux_test.go +++ b/pkg/runtime/framework/plugins/flux/flux_test.go @@ -58,14 +58,41 @@ func TestFlux(t *testing.T) { wantInitContainers []string wantCommand []string wantTTY bool + wantInfo *runtime.Info + wantMLPolicyError error + wantBuildError error }{ "no action when flux policy is nil": { info: &runtime.Info{ RuntimePolicy: runtime.RuntimePolicy{}, }, trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), + wantInfo: &runtime.Info{ + RuntimePolicy: runtime.RuntimePolicy{}, + }, + wantMLPolicyError: nil, + wantBuildError: nil, }, "flux mutations are applied correctly": { + + wantInfo: &runtime.Info{ + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: &trainer.MLPolicySource{ + Flux: &trainer.FluxMLPolicySource{ + NumProcPerNode: &procs, + }, + }, + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: constants.Node, + Ancestor: ptr.To(constants.AncestorTrainer), + Count: ptr.To[int32](1), + }, + }, + }, + }, info: &runtime.Info{ RuntimePolicy: runtime.RuntimePolicy{ MLPolicySource: &trainer.MLPolicySource{ @@ -84,6 +111,8 @@ func TestFlux(t *testing.T) { }, }, }, + wantMLPolicyError: nil, + wantBuildError: nil, trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("test-uid"). Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(2).Obj()). @@ -109,8 +138,15 @@ func TestFlux(t *testing.T) { p, _ := New(ctx, cli, nil, nil) err := p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob) - if err != nil { - t.Fatalf("EnforceMLPolicy failed: %v", err) + if diff := gocmp.Diff(tc.wantMLPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from EnforceMLPolicy (-want, +got): %s", diff) + } + if diff := gocmp.Diff(tc.wantInfo, tc.info, + cmpopts.SortSlices(func(a, b string) bool { return a < b }), + cmpopts.SortMaps(func(a, b int) bool { return a < b }), + utiltesting.PodSetEndpointsCmpOpts, + ); len(diff) != 0 { + t.Errorf("Unexpected info from EnforceMLPolicy (-want, +got): %s", diff) } if tc.info.RuntimePolicy.MLPolicySource != nil && tc.info.RuntimePolicy.MLPolicySource.Flux != nil && tc.info.TemplateSpec.ObjApply != nil { @@ -140,8 +176,8 @@ func TestFlux(t *testing.T) { } objs, err := p.(framework.ComponentBuilderPlugin).Build(ctx, tc.info, tc.trainJob) - if err != nil { - t.Fatalf("Build failed: %v", err) + if diff := gocmp.Diff(tc.wantBuildError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from Build (-want, +got): %s", diff) } typedObjs, _ := utiltesting.ToObject(cli.Scheme(), objs...)