diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 9daf61dc0d..fab56aeffa 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -176,6 +176,9 @@ const ( // OpenMPIEnvDefaultSlots is the OpenMPI default number of slots env key. OpenMPIEnvDefaultSlots string = "OMPI_MCA_orte_set_default_slots" + + // OpenMPIEnvBaseEnvList is the OpenMPI base environment list env key. + OpenMPIEnvBaseEnvList string = "OMPI_MCA_mca_base_env_list" // Distributed envs for torchrun. // Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45 // TorchEnvNumNodes is the env name for the number of training nodes. diff --git a/pkg/runtime/framework/plugins/jobset/builder.go b/pkg/runtime/framework/plugins/jobset/builder.go index d5d8e1ca86..17be413f39 100644 --- a/pkg/runtime/framework/plugins/jobset/builder.go +++ b/pkg/runtime/framework/plugins/jobset/builder.go @@ -139,7 +139,7 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build } } } - if ancestor == constants.AncestorTrainer || b.isRunLauncherAsNode(info) && *rJob.Name == constants.Node { + if ancestor == constants.AncestorTrainer || *rJob.Name == constants.Node { // TODO (andreyvelich): For MPI we should apply container resources to the Node ReplicatedJob also. // Eventually, we should find better way to propagate resources from TrainJob to JobSet. for j, container := range rJob.Template.Spec.Template.Spec.Containers { diff --git a/pkg/runtime/framework/plugins/jobset/builder_test.go b/pkg/runtime/framework/plugins/jobset/builder_test.go new file mode 100644 index 0000000000..659175afaf --- /dev/null +++ b/pkg/runtime/framework/plugins/jobset/builder_test.go @@ -0,0 +1,125 @@ +/* +Copyright 2025 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package jobset + +import ( + "testing" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + batchv1ac "k8s.io/client-go/applyconfigurations/batch/v1" + corev1ac "k8s.io/client-go/applyconfigurations/core/v1" + "k8s.io/utils/ptr" + jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2" + + trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1" + "github.com/kubeflow/trainer/v2/pkg/constants" + "github.com/kubeflow/trainer/v2/pkg/runtime" + utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing" +) + +func TestBuilderTrainerEnvPropagation(t *testing.T) { + testCases := map[string]struct { + trainJobEnv []corev1.EnvVar + initialPodEnv []corev1ac.EnvVarApplyConfiguration + expectedEnv []corev1.EnvVar + }{ + "variables propagated": { + trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, + expectedEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, + }, + "no variables propagated (empty case)": { + trainJobEnv: []corev1.EnvVar{}, + expectedEnv: []corev1.EnvVar{}, + }, + "merge with existing variables": { + trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}}, + initialPodEnv: []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar().WithName("EXISTING_VAR").WithValue("existing_value"), + }, + expectedEnv: []corev1.EnvVar{ + {Name: "EXISTING_VAR", Value: "existing_value"}, + {Name: "CUSTOM_VAR", Value: "custom_value"}, + }, + }, + } + + for name, tc := range testCases { + t.Run(name, func(t *testing.T) { + // Setup TrainJob + trainJob := utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + Env(tc.trainJobEnv...). + Obj()). + Obj() + + // Setup runtime info for MPI (launcher is NOT a node) + info := &runtime.Info{ + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + MPIPolicy(nil, trainer.MPIImplementationOpenMPI, nil, ptr.To(false)). + Obj(), + }, + } + + // Create JobSet spec with initial environment + container := corev1ac.Container().WithName(constants.Node) + for i := range tc.initialPodEnv { + container.WithEnv(&tc.initialPodEnv[i]) + } + + jobSetSpec := jobsetv1alpha2ac.JobSetSpec().WithReplicatedJobs( + jobsetv1alpha2ac.ReplicatedJob(). + WithName(constants.Node). + WithTemplate(batchv1ac.JobTemplateSpec(). + WithSpec(batchv1ac.JobSpec(). + WithTemplate(corev1ac.PodTemplateSpec(). + WithSpec(corev1ac.PodSpec().WithContainers(container)), + ), + ), + ), + ) + + builder := NewBuilder(jobsetv1alpha2ac.JobSet("test-job", metav1.NamespaceDefault).WithSpec(jobSetSpec)) + builder.Trainer(info, trainJob) + + // Verify results + var actualEnv []corev1.EnvVar + for _, rJob := range builder.Spec.ReplicatedJobs { + if *rJob.Name == constants.Node { + for _, c := range rJob.Template.Spec.Template.Spec.Containers { + if *c.Name == constants.Node { + for _, env := range c.Env { + actualEnv = append(actualEnv, corev1.EnvVar{Name: *env.Name, Value: *env.Value}) + } + } + } + } + } + + if len(actualEnv) != len(tc.expectedEnv) { + t.Fatalf("Expected %d environment variables, got %d", len(tc.expectedEnv), len(actualEnv)) + } + + for i, expected := range tc.expectedEnv { + if actualEnv[i].Name != expected.Name || actualEnv[i].Value != expected.Value { + t.Errorf("At index %d: expected %s=%s, got %s=%s", i, expected.Name, expected.Value, actualEnv[i].Name, actualEnv[i].Value) + } + } + }) + } +} diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index ff36bab316..41df8937bf 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -190,6 +190,19 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er ) switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation { case trainer.MPIImplementationOpenMPI: + // Collect all custom environment variable names from the TrainJob to export via SSH. + var envNames []string + if trainJob.Spec.Trainer != nil { + for _, env := range trainJob.Spec.Trainer.Env { + // Only include variables with a static Value. + // Variables with ValueFrom (e.g. FieldRef, ResourceFieldRef) are pod-specific + // and should not be propagated as cluster-wide constants. + if env.ValueFrom == nil { + envNames = append(envNames, env.Name) + } + } + } + apply.UpsertEnvVars( &info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env, *corev1ac.EnvVar(). @@ -205,6 +218,23 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er WithName(constants.OpenMPIEnvKeyRSHArgs). WithValue(constants.OpenMPIEnvDefaultValueRSHArgs), ) + + // Automatically tell OpenMPI to export the custom variables to all nodes. + if len(envNames) > 0 { + envList := "" + for i, name := range envNames { + if i > 0 { + envList += ";" + } + envList += name + } + apply.UpsertEnvVars( + &info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env, + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvBaseEnvList). + WithValue(envList), + ) + } default: return fmt.Errorf("MPI implementation for %v doesn't supported", info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation) } diff --git a/pkg/runtime/framework/plugins/mpi/mpi_test.go b/pkg/runtime/framework/plugins/mpi/mpi_test.go index db5628a5b8..89f7366645 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi_test.go +++ b/pkg/runtime/framework/plugins/mpi/mpi_test.go @@ -829,6 +829,178 @@ trainJob-node-1-0.trainJob slots=1 }, wantBuildError: errorGetSSHAuthSecretFromAPI, }, + "environment variables propagation and filtering for OpenMPI": { + trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "trainJob"). + UID("trainJob"). + Trainer(utiltesting.MakeTrainJobTrainerWrapper(). + Env( + corev1.EnvVar{Name: "STATIC_VAR", Value: "static_value"}, + corev1.EnvVar{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}}, + ). + Obj()). + Obj(), + info: &runtime.Info{ + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil). + Obj(), + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: constants.Launcher, + Count: ptr.To[int32](1), + Endpoints: func(yield func(string) bool) { + yield("trainJob-launcher-0-0.trainJob") + }, + Containers: []runtime.Container{{ + Name: constants.Node, + }}, + }, + { + Name: constants.Node, + Count: ptr.To[int32](1), + Endpoints: func(yield func(string) bool) { + yield("trainJob-node-0-0.trainJob") + }, + Containers: []runtime.Container{{ + Name: constants.Node, + }}, + }, + }, + }, + Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, + }, + wantInfo: &runtime.Info{ + Labels: nil, + Annotations: nil, + RuntimePolicy: runtime.RuntimePolicy{ + MLPolicySource: utiltesting.MakeMLPolicySourceWrapper(). + MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil). + Obj(), + }, + TemplateSpec: runtime.TemplateSpec{ + PodSets: []runtime.PodSet{ + { + Name: constants.Launcher, + Count: ptr.To[int32](1), + Containers: []runtime.Container{{ + Name: constants.Node, + Env: []corev1ac.EnvVarApplyConfiguration{ + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvHostFileLocation). + WithValue(fmt.Sprintf("%s/%s", constants.MPIHostfileDir, constants.MPIHostfileName)), + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvKeepFQDNHostNames). + WithValue("true"), + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvDefaultSlots). + WithValue("1"), + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvKeyRSHArgs). + WithValue(constants.OpenMPIEnvDefaultValueRSHArgs), + *corev1ac.EnvVar(). + WithName(constants.OpenMPIEnvBaseEnvList). + WithValue("STATIC_VAR"), + }, + VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(constants.MPISSHAuthVolumeName). + WithMountPath("/root/.ssh"), + *corev1ac.VolumeMount(). + WithName(constants.MPIHostfileVolumeName). + WithMountPath(constants.MPIHostfileDir), + }, + }}, + Volumes: []corev1ac.VolumeApplyConfiguration{ + *corev1ac.Volume(). + WithName(constants.MPISSHAuthVolumeName). + WithSecret(corev1ac.SecretVolumeSource(). + WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)). + WithItems( + corev1ac.KeyToPath(). + WithKey(corev1.SSHAuthPrivateKey). + WithPath(constants.MPISSHPrivateKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHPublicKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHAuthorizedKeys), + ), + ), + *corev1ac.Volume(). + WithName(constants.MPIHostfileVolumeName). + WithConfigMap(corev1ac.ConfigMapVolumeSource(). + WithName(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix)). + WithItems( + corev1ac.KeyToPath(). + WithKey(constants.MPIHostfileName). + WithPath(constants.MPIHostfileName). + WithMode(0444), + ), + ), + }, + Endpoints: func(yield func(string) bool) { + yield("trainJob-launcher-0-0.trainJob") + }, + }, + { + Name: constants.Node, + Count: ptr.To[int32](1), + Containers: []runtime.Container{{ + Name: constants.Node, + VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{ + *corev1ac.VolumeMount(). + WithName(constants.MPISSHAuthVolumeName). + WithMountPath("/root/.ssh"), + }, + }}, + Volumes: []corev1ac.VolumeApplyConfiguration{ + *corev1ac.Volume(). + WithName(constants.MPISSHAuthVolumeName). + WithSecret(corev1ac.SecretVolumeSource(). + WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)). + WithItems( + corev1ac.KeyToPath(). + WithKey(corev1.SSHAuthPrivateKey). + WithPath(constants.MPISSHPrivateKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHPublicKeyFile), + corev1ac.KeyToPath(). + WithKey(constants.MPISSHPublicKey). + WithPath(constants.MPISSHAuthorizedKeys), + ), + ), + }, + Endpoints: func(yield func(string) bool) { + yield("trainJob-node-0-0.trainJob") + }, + }, + }, + }, + Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)}, + }, + wantObjs: []apiruntime.Object{ + utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithImmutable(true). + WithType(corev1.SecretTypeSSHAuth). + WithData(map[string][]byte{ + constants.MPISSHPublicKey: []byte("EXIST"), + corev1.SSHAuthPrivateKey: []byte("EXIST"), + }). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). + Obj(), + utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithData(map[string]string{ + constants.MPIHostfileName: `trainJob-node-0-0.trainJob slots=1 +`, + }). + ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). + Obj(), + }, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) {