-
Notifications
You must be signed in to change notification settings - Fork 948
fix(runtimes): propagate trainer environment variables to worker processes #3454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| /* | ||
| Copyright 2025 The Kubeflow Authors. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. | ||
| */ | ||
|
|
||
| package jobset | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| corev1 "k8s.io/api/core/v1" | ||
| metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" | ||
| batchv1ac "k8s.io/client-go/applyconfigurations/batch/v1" | ||
| corev1ac "k8s.io/client-go/applyconfigurations/core/v1" | ||
| "k8s.io/utils/ptr" | ||
| jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" | ||
|
|
||
| trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" | ||
| "github.com/kubeflow/trainer/v2/pkg/constants" | ||
| "github.com/kubeflow/trainer/v2/pkg/runtime" | ||
| utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing" | ||
| ) | ||
|
|
||
| func TestBuilderTrainerEnvPropagation(t *testing.T) { | ||
| testCases := map[string]struct { | ||
| trainJobEnv []corev1.EnvVar | ||
| initialPodEnv []corev1ac.EnvVarApplyConfiguration | ||
| expectedEnv []corev1.EnvVar | ||
| }{ | ||
| "variables propagated": { | ||
| trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, | ||
| expectedEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, | ||
| }, | ||
| "no variables propagated (empty case)": { | ||
| trainJobEnv: []corev1.EnvVar{}, | ||
| expectedEnv: []corev1.EnvVar{}, | ||
| }, | ||
| "merge with existing variables": { | ||
| trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, | ||
| initialPodEnv: []corev1ac.EnvVarApplyConfiguration{ | ||
| *corev1ac.EnvVar().WithName("EXISTING_VAR").WithValue("existing_value"), | ||
| }, | ||
| expectedEnv: []corev1.EnvVar{ | ||
| {Name: "EXISTING_VAR", Value: "existing_value"}, | ||
| {Name: "CUSTOM_VAR", Value: "custom_value"}, | ||
| }, | ||
| }, | ||
| } | ||
|
|
||
| for name, tc := range testCases { | ||
| t.Run(name, func(t *testing.T) { | ||
| // Setup TrainJob | ||
| trainJob := utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). | ||
| Trainer(utiltesting.MakeTrainJobTrainerWrapper(). | ||
| Env(tc.trainJobEnv...). | ||
| Obj()). | ||
| Obj() | ||
|
|
||
| // Setup runtime info for MPI (launcher is NOT a node) | ||
| info := &runtime.Info{ | ||
| RuntimePolicy: runtime.RuntimePolicy{ | ||
| MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). | ||
| MPIPolicy(nil, trainer.MPIImplementationOpenMPI, nil, ptr.To(false)). | ||
| Obj(), | ||
| }, | ||
| } | ||
|
|
||
| // Create JobSet spec with initial environment | ||
| container := corev1ac.Container().WithName(constants.Node) | ||
| for i := range tc.initialPodEnv { | ||
| container.WithEnv(&tc.initialPodEnv[i]) | ||
| } | ||
|
|
||
| jobSetSpec := jobsetv1alpha2ac.JobSetSpec().WithReplicatedJobs( | ||
| jobsetv1alpha2ac.ReplicatedJob(). | ||
| WithName(constants.Node). | ||
| WithTemplate(batchv1ac.JobTemplateSpec(). | ||
| WithSpec(batchv1ac.JobSpec(). | ||
| WithTemplate(corev1ac.PodTemplateSpec(). | ||
| WithSpec(corev1ac.PodSpec().WithContainers(container)), | ||
| ), | ||
| ), | ||
| ), | ||
| ) | ||
|
|
||
| builder := NewBuilder(jobsetv1alpha2ac.JobSet("test-job", metav1.NamespaceDefault).WithSpec(jobSetSpec)) | ||
| builder.Trainer(info, trainJob) | ||
|
|
||
| // Verify results | ||
| var actualEnv []corev1.EnvVar | ||
| for _, rJob := range builder.Spec.ReplicatedJobs { | ||
| if *rJob.Name == constants.Node { | ||
| for _, c := range rJob.Template.Spec.Template.Spec.Containers { | ||
| if *c.Name == constants.Node { | ||
| for _, env := range c.Env { | ||
| actualEnv = append(actualEnv, corev1.EnvVar{Name: *env.Name, Value: *env.Value}) | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| if len(actualEnv) != len(tc.expectedEnv) { | ||
| t.Fatalf("Expected %d environment variables, got %d", len(tc.expectedEnv), len(actualEnv)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing tests in |
||
| } | ||
|
|
||
| for i, expected := range tc.expectedEnv { | ||
| if actualEnv[i].Name != expected.Name || actualEnv[i].Value != expected.Value { | ||
| t.Errorf("At index %d: expected %s=%s, got %s=%s", i, expected.Name, expected.Value, actualEnv[i].Name, actualEnv[i].Value) | ||
| } | ||
| } | ||
| }) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -190,6 +190,19 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er | |
| ) | ||
| switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { | ||
| case trainer.MPIImplementationOpenMPI: | ||
| // Collect all custom environment variable names from the TrainJob to export via SSH. | ||
| var envNames []string | ||
| if trainJob.Spec.Trainer != nil { | ||
| for _, env := range trainJob.Spec.Trainer.Env { | ||
|
AviralKaushal marked this conversation as resolved.
|
||
| // Only include variables with a static Value. | ||
| // Variables with ValueFrom (e.g. FieldRef, ResourceFieldRef) are pod-specific | ||
| // and should not be propagated as cluster-wide constants. | ||
| if env.ValueFrom == nil { | ||
| envNames = append(envNames, env.Name) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| apply.UpsertEnvVars( | ||
| &info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env, | ||
| *corev1ac.EnvVar(). | ||
|
|
@@ -205,6 +218,23 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er | |
| WithName(constants.OpenMPIEnvKeyRSHArgs). | ||
| WithValue(constants.OpenMPIEnvDefaultValueRSHArgs), | ||
| ) | ||
|
|
||
| // Automatically tell OpenMPI to export the custom variables to all nodes. | ||
| if len(envNames) > 0 { | ||
| envList := "" | ||
| for i, name := range envNames { | ||
| if i > 0 { | ||
| envList += ";" | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: you could use |
||
| } | ||
| envList += name | ||
| } | ||
| apply.UpsertEnvVars( | ||
| &info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env, | ||
| *corev1ac.EnvVar(). | ||
| WithName(constants.OpenMPIEnvBaseEnvList). | ||
| WithValue(envList), | ||
| ) | ||
|
AviralKaushal marked this conversation as resolved.
|
||
| } | ||
| default: | ||
| return fmt.Errorf("MPI implementation for %v doesn't supported", info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is broader than just MPI. The old condition was gated on
isRunLauncherAsNode, so it only kicked in for MPI configs. Now any replicated job namednode(PyTorch, DeepSpeed, etc.) will get env vars andresourcesPerNodeinjected unconditionally. Is that the intent here? If so, might be worth noting that in the PR description since it changes behavior for all runtimes, not just MPI.