Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pkg/constants/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,9 @@ const (

// OpenMPIEnvDefaultSlots is the OpenMPI default number of slots env key.
OpenMPIEnvDefaultSlots string = "OMPI_MCA_orte_set_default_slots"

// OpenMPIEnvBaseEnvList is the OpenMPI base environment list env key.
OpenMPIEnvBaseEnvList string = "OMPI_MCA_mca_base_env_list"
// Distributed envs for torchrun.
// Ref: https://github.com/pytorch/pytorch/blob/3a0d0885171376ed610c8175a19ba40411fc6f3f/torch/distributed/argparse_util.py#L45
// TorchEnvNumNodes is the env name for the number of training nodes.
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime/framework/plugins/jobset/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (b *Builder) Trainer(info *runtime.Info, trainJob *trainer.TrainJob) *Build
}
}
}
if ancestor == constants.AncestorTrainer || b.isRunLauncherAsNode(info) && *rJob.Name == constants.Node {
if ancestor == constants.AncestorTrainer || *rJob.Name == constants.Node {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is broader than just MPI. The old condition was gated on isRunLauncherAsNode, so it only kicked in for MPI configs. Now any replicated job named node (PyTorch, DeepSpeed, etc.) will get env vars and resourcesPerNode injected unconditionally. Is that the intent here? If so, might be worth noting that in the PR description since it changes behavior for all runtimes, not just MPI.

// TODO (andreyvelich): For MPI we should apply container resources to the Node ReplicatedJob also.
// Eventually, we should find better way to propagate resources from TrainJob to JobSet.
for j, container := range rJob.Template.Spec.Template.Spec.Containers {
Comment thread
AviralKaushal marked this conversation as resolved.
Expand Down
125 changes: 125 additions & 0 deletions pkg/runtime/framework/plugins/jobset/builder_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
Copyright 2025 The Kubeflow Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package jobset

import (
"testing"

corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
batchv1ac "k8s.io/client-go/applyconfigurations/batch/v1"
corev1ac "k8s.io/client-go/applyconfigurations/core/v1"
"k8s.io/utils/ptr"
jobsetv1alpha2ac "sigs.k8s.io/jobset/client-go/applyconfiguration/jobset/v1alpha2"

trainer "github.com/kubeflow/trainer/v2/pkg/apis/trainer/v1alpha1"
"github.com/kubeflow/trainer/v2/pkg/constants"
"github.com/kubeflow/trainer/v2/pkg/runtime"
utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing"
)

func TestBuilderTrainerEnvPropagation(t *testing.T) {
testCases := map[string]struct {
trainJobEnv []corev1.EnvVar
initialPodEnv []corev1ac.EnvVarApplyConfiguration
expectedEnv []corev1.EnvVar
}{
"variables propagated": {
trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
expectedEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
},
"no variables propagated (empty case)": {
trainJobEnv: []corev1.EnvVar{},
expectedEnv: []corev1.EnvVar{},
},
"merge with existing variables": {
trainJobEnv: []corev1.EnvVar{{Name: "CUSTOM_VAR", Value: "custom_value"}},
initialPodEnv: []corev1ac.EnvVarApplyConfiguration{
*corev1ac.EnvVar().WithName("EXISTING_VAR").WithValue("existing_value"),
},
expectedEnv: []corev1.EnvVar{
{Name: "EXISTING_VAR", Value: "existing_value"},
{Name: "CUSTOM_VAR", Value: "custom_value"},
},
},
}

for name, tc := range testCases {
t.Run(name, func(t *testing.T) {
// Setup TrainJob
trainJob := utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
Env(tc.trainJobEnv...).
Obj()).
Obj()

// Setup runtime info for MPI (launcher is NOT a node)
info := &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
MPIPolicy(nil, trainer.MPIImplementationOpenMPI, nil, ptr.To(false)).
Obj(),
},
}

// Create JobSet spec with initial environment
container := corev1ac.Container().WithName(constants.Node)
for i := range tc.initialPodEnv {
container.WithEnv(&tc.initialPodEnv[i])
}

jobSetSpec := jobsetv1alpha2ac.JobSetSpec().WithReplicatedJobs(
jobsetv1alpha2ac.ReplicatedJob().
WithName(constants.Node).
WithTemplate(batchv1ac.JobTemplateSpec().
WithSpec(batchv1ac.JobSpec().
WithTemplate(corev1ac.PodTemplateSpec().
WithSpec(corev1ac.PodSpec().WithContainers(container)),
),
),
),
)

builder := NewBuilder(jobsetv1alpha2ac.JobSet("test-job", metav1.NamespaceDefault).WithSpec(jobSetSpec))
builder.Trainer(info, trainJob)

// Verify results
var actualEnv []corev1.EnvVar
for _, rJob := range builder.Spec.ReplicatedJobs {
if *rJob.Name == constants.Node {
for _, c := range rJob.Template.Spec.Template.Spec.Containers {
if *c.Name == constants.Node {
for _, env := range c.Env {
actualEnv = append(actualEnv, corev1.EnvVar{Name: *env.Name, Value: *env.Value})
}
}
}
}
}

if len(actualEnv) != len(tc.expectedEnv) {
t.Fatalf("Expected %d environment variables, got %d", len(tc.expectedEnv), len(actualEnv))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing tests in jobset_test.go use go-cmp (cmp.Diff) for comparing results. This file does manual index-based comparison with t.Fatalf/t.Errorf. Would be good to keep them consistent so the test patterns stay uniform across the package.

}

for i, expected := range tc.expectedEnv {
if actualEnv[i].Name != expected.Name || actualEnv[i].Value != expected.Value {
t.Errorf("At index %d: expected %s=%s, got %s=%s", i, expected.Name, expected.Value, actualEnv[i].Name, actualEnv[i].Value)
}
}
})
}
}
30 changes: 30 additions & 0 deletions pkg/runtime/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,19 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
)
switch *info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation {
case trainer.MPIImplementationOpenMPI:
// Collect all custom environment variable names from the TrainJob to export via SSH.
var envNames []string
if trainJob.Spec.Trainer != nil {
for _, env := range trainJob.Spec.Trainer.Env {
Comment thread
AviralKaushal marked this conversation as resolved.
// Only include variables with a static Value.
// Variables with ValueFrom (e.g. FieldRef, ResourceFieldRef) are pod-specific
// and should not be propagated as cluster-wide constants.
if env.ValueFrom == nil {
envNames = append(envNames, env.Name)
}
}
}

apply.UpsertEnvVars(
&info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env,
*corev1ac.EnvVar().
Expand All @@ -205,6 +218,23 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er
WithName(constants.OpenMPIEnvKeyRSHArgs).
WithValue(constants.OpenMPIEnvDefaultValueRSHArgs),
)

// Automatically tell OpenMPI to export the custom variables to all nodes.
if len(envNames) > 0 {
envList := ""
for i, name := range envNames {
if i > 0 {
envList += ";"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could use strings.Join(envNames, ";") here instead of the manual loop. Cleaner and does the same thing.

}
envList += name
}
apply.UpsertEnvVars(
&info.TemplateSpec.PodSets[psIdx].Containers[cIdx].Env,
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvBaseEnvList).
WithValue(envList),
)
Comment thread
AviralKaushal marked this conversation as resolved.
}
default:
return fmt.Errorf("MPI implementation for %v doesn't supported", info.RuntimePolicy.MLPolicySource.MPI.MPIImplementation)
}
Expand Down
172 changes: 172 additions & 0 deletions pkg/runtime/framework/plugins/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,178 @@ trainJob-node-1-0.trainJob slots=1
},
wantBuildError: errorGetSSHAuthSecretFromAPI,
},
"environment variables propagation and filtering for OpenMPI": {
trainJob: utiltesting.MakeTrainJobWrapper(metav1.NamespaceDefault, "trainJob").
UID("trainJob").
Trainer(utiltesting.MakeTrainJobTrainerWrapper().
Env(
corev1.EnvVar{Name: "STATIC_VAR", Value: "static_value"},
corev1.EnvVar{Name: "POD_NAME", ValueFrom: &corev1.EnvVarSource{FieldRef: &corev1.ObjectFieldSelector{FieldPath: "metadata.name"}}},
).
Obj()).
Obj(),
info: &runtime.Info{
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
Obj(),
},
TemplateSpec: runtime.TemplateSpec{
PodSets: []runtime.PodSet{
{
Name: constants.Launcher,
Count: ptr.To[int32](1),
Endpoints: func(yield func(string) bool) {
yield("trainJob-launcher-0-0.trainJob")
},
Containers: []runtime.Container{{
Name: constants.Node,
}},
},
{
Name: constants.Node,
Count: ptr.To[int32](1),
Endpoints: func(yield func(string) bool) {
yield("trainJob-node-0-0.trainJob")
},
Containers: []runtime.Container{{
Name: constants.Node,
}},
},
},
},
Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)},
},
wantInfo: &runtime.Info{
Labels: nil,
Annotations: nil,
RuntimePolicy: runtime.RuntimePolicy{
MLPolicySource: utiltesting.MakeMLPolicySourceWrapper().
MPIPolicy(ptr.To[int32](1), trainer.MPIImplementationOpenMPI, ptr.To("/root/.ssh"), nil).
Obj(),
},
TemplateSpec: runtime.TemplateSpec{
PodSets: []runtime.PodSet{
{
Name: constants.Launcher,
Count: ptr.To[int32](1),
Containers: []runtime.Container{{
Name: constants.Node,
Env: []corev1ac.EnvVarApplyConfiguration{
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvHostFileLocation).
WithValue(fmt.Sprintf("%s/%s", constants.MPIHostfileDir, constants.MPIHostfileName)),
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvKeepFQDNHostNames).
WithValue("true"),
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvDefaultSlots).
WithValue("1"),
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvKeyRSHArgs).
WithValue(constants.OpenMPIEnvDefaultValueRSHArgs),
*corev1ac.EnvVar().
WithName(constants.OpenMPIEnvBaseEnvList).
WithValue("STATIC_VAR"),
},
VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{
*corev1ac.VolumeMount().
WithName(constants.MPISSHAuthVolumeName).
WithMountPath("/root/.ssh"),
*corev1ac.VolumeMount().
WithName(constants.MPIHostfileVolumeName).
WithMountPath(constants.MPIHostfileDir),
},
}},
Volumes: []corev1ac.VolumeApplyConfiguration{
*corev1ac.Volume().
WithName(constants.MPISSHAuthVolumeName).
WithSecret(corev1ac.SecretVolumeSource().
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(corev1.SSHAuthPrivateKey).
WithPath(constants.MPISSHPrivateKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHPublicKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHAuthorizedKeys),
),
),
*corev1ac.Volume().
WithName(constants.MPIHostfileVolumeName).
WithConfigMap(corev1ac.ConfigMapVolumeSource().
WithName(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(constants.MPIHostfileName).
WithPath(constants.MPIHostfileName).
WithMode(0444),
),
),
},
Endpoints: func(yield func(string) bool) {
yield("trainJob-launcher-0-0.trainJob")
},
},
{
Name: constants.Node,
Count: ptr.To[int32](1),
Containers: []runtime.Container{{
Name: constants.Node,
VolumeMounts: []corev1ac.VolumeMountApplyConfiguration{
*corev1ac.VolumeMount().
WithName(constants.MPISSHAuthVolumeName).
WithMountPath("/root/.ssh"),
},
}},
Volumes: []corev1ac.VolumeApplyConfiguration{
*corev1ac.Volume().
WithName(constants.MPISSHAuthVolumeName).
WithSecret(corev1ac.SecretVolumeSource().
WithSecretName(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix)).
WithItems(
corev1ac.KeyToPath().
WithKey(corev1.SSHAuthPrivateKey).
WithPath(constants.MPISSHPrivateKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHPublicKeyFile),
corev1ac.KeyToPath().
WithKey(constants.MPISSHPublicKey).
WithPath(constants.MPISSHAuthorizedKeys),
),
),
},
Endpoints: func(yield func(string) bool) {
yield("trainJob-node-0-0.trainJob")
},
},
},
},
Scheduler: &runtime.Scheduler{PodLabels: make(map[string]string)},
},
wantObjs: []apiruntime.Object{
utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault).
WithImmutable(true).
WithType(corev1.SecretTypeSSHAuth).
WithData(map[string][]byte{
constants.MPISSHPublicKey: []byte("EXIST"),
corev1.SSHAuthPrivateKey: []byte("EXIST"),
}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault).
WithData(map[string]string{
constants.MPIHostfileName: `trainJob-node-0-0.trainJob slots=1
`,
}).
ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob").
Obj(),
},
},
}
for name, tc := range cases {
t.Run(name, func(t *testing.T) {
Expand Down
Loading