Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions pkg/runtime/framework/plugins/flux/flux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,41 @@ func TestFlux(t *testing.T) {
wantInitContainers []string
wantCommand []string
wantTTY bool
wantInfo *runtime.Info
wantMLPolicyError error
wantBuildError error
}{
"no action when flux policy is nil": {
info: &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{},
},
trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(),
wantInfo: &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{},
},
wantMLPolicyError: nil,
wantBuildError: nil,
},
"flux mutations are applied correctly": {

wantInfo: &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: &trainer.MLPolicySource{
Flux: &trainer.FluxMLPolicySource{
NumProcPerNode: &procs,
},
},
},
TemplateSpec: runtime.TemplateSpec{
PodSets: []runtime.PodSet{
{
Name: constants.Node,
Ancestor: ptr.To(constants.AncestorTrainer),
Count: ptr.To[int32](1),
},
},
},
},
Comment on lines 76 to +95
Copy link

Copilot AI Apr 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second test case is missing explicit wantMLPolicyError and wantBuildError field assignments for consistency with the first test case. Add wantMLPolicyError: nil and wantBuildError: nil to clarify the expected behavior.

Copilot generated this review using guidance from repository custom instructions.
info: &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: &trainer.MLPolicySource{
Expand All @@ -84,6 +111,8 @@ func TestFlux(t *testing.T) {
},
},
},
wantMLPolicyError: nil,
wantBuildError: nil,
trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("test-uid").
Trainer(utiltesting.MakeTrainJobTrainerWrapper().NumNodes(2).Obj()).
Expand All @@ -109,8 +138,15 @@ func TestFlux(t *testing.T) {
p, _ := New(ctx, cli, nil, nil)

err := p.(framework.EnforceMLPolicyPlugin).EnforceMLPolicy(tc.info, tc.trainJob)
if err != nil {
t.Fatalf("EnforceMLPolicy failed: %v", err)
if diff := gocmp.Diff(tc.wantMLPolicyError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from EnforceMLPolicy (-want, +got): %s", diff)
}
if diff := gocmp.Diff(tc.wantInfo, tc.info,
cmpopts.SortSlices(func(a, b string) bool { return a < b }),
cmpopts.SortMaps(func(a, b int) bool { return a < b }),
utiltesting.PodSetEndpointsCmpOpts,
); len(diff) != 0 {
t.Errorf("Unexpected info from EnforceMLPolicy (-want, +got): %s", diff)
}

if tc.info.RuntimePolicy.MLPolicySource != nil && tc.info.RuntimePolicy.MLPolicySource.Flux != nil && tc.info.TemplateSpec.ObjApply != nil {
Expand Down Expand Up @@ -140,8 +176,8 @@ func TestFlux(t *testing.T) {
}

objs, err := p.(framework.ComponentBuilderPlugin).Build(ctx, tc.info, tc.trainJob)
if err != nil {
t.Fatalf("Build failed: %v", err)
if diff := gocmp.Diff(tc.wantBuildError, err, cmpopts.EquateErrors()); len(diff) != 0 {
t.Errorf("Unexpected error from Build (-want, +got): %s", diff)
}

typedObjs, _ := utiltesting.ToObject(cli.Scheme(), objs...)
Expand Down
Loading