diff --git a/pkg/constants/constants.go b/pkg/constants/constants.go index 533360664b..433069faf0 100644 --- a/pkg/constants/constants.go +++ b/pkg/constants/constants.go @@ -43,6 +43,9 @@ const ( // LabelSupport indicates support status for a runtime, e.g. "deprecated". LabelSupport string = "trainer.kubeflow.org/support" + // LabelJobName is the label to identify job-owned ConfigMap and Secret resources. + LabelJobName string = "trainer.kubeflow.org/trainjob-name" + // SupportDeprecated indicates the runtime is deprecated when used with LabelSupport. SupportDeprecated string = "deprecated" diff --git a/pkg/runtime/core/trainingruntime_test.go b/pkg/runtime/core/trainingruntime_test.go index e933ec844f..94bc5a6c3d 100644 --- a/pkg/runtime/core/trainingruntime_test.go +++ b/pkg/runtime/core/trainingruntime_test.go @@ -1950,6 +1950,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), wantObjs: []runtime.Object{ testingutil.MakeConfigMapWrapper(fmt.Sprintf("test-job%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "test-job"}). ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). WithData(map[string]string{ constants.MPIHostfileName: `test-job-node-0-0.test-job slots=8 @@ -1958,6 +1959,7 @@ test-job-node-0-1.test-job slots=8 }). Obj(), testingutil.MakeSecretWrapper(fmt.Sprintf("test-job%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "test-job"}). ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). WithImmutable(true). WithData(map[string][]byte{ diff --git a/pkg/runtime/framework/core/framework_test.go b/pkg/runtime/framework/core/framework_test.go index b4c5a17fab..fc40ea10c6 100644 --- a/pkg/runtime/framework/core/framework_test.go +++ b/pkg/runtime/framework/core/framework_test.go @@ -1297,6 +1297,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }). Obj(), testingutil.MakeSecretWrapper(fmt.Sprintf("test-job%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "test-job"}). WithImmutable(true). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ @@ -1306,6 +1307,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "test-job", "uid"). Obj(), testingutil.MakeConfigMapWrapper(fmt.Sprintf("test-job%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "test-job"}). WithData(map[string]string{ constants.MPIHostfileName: `test-job-launcher-0-0.test-job slots=1 test-job-node-0-0.test-job slots=1 diff --git a/pkg/runtime/framework/plugins/mpi/mpi.go b/pkg/runtime/framework/plugins/mpi/mpi.go index ff36bab316..7d4d6758fa 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi.go +++ b/pkg/runtime/framework/plugins/mpi/mpi.go @@ -29,6 +29,7 @@ import ( "golang.org/x/crypto/ssh" corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" apiruntime "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/validation/field" corev1ac "k8s.io/client-go/applyconfigurations/core/v1" @@ -38,6 +39,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/cache" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/handler" + "sigs.k8s.io/controller-runtime/pkg/predicate" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" configapi "github.com/kubeflow/trainer/v2/pkg/apis/config/v1alpha1" @@ -218,19 +220,27 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info, trainJob *trainer.TrainJob) er func (m *MPI) ReconcilerBuilders() []runtime.ReconcilerBuilder { return []runtime.ReconcilerBuilder{ func(b *builder.Builder, cl client.Client, cache cache.Cache) *builder.Builder { - return b.Watches( + return b.WatchesMetadata( &corev1.ConfigMap{}, handler.EnqueueRequestForOwner( m.client.Scheme(), m.client.RESTMapper(), &trainer.TrainJob{}, handler.OnlyControllerOwner(), ), + builder.WithPredicates(predicate.NewPredicateFuncs(func(obj client.Object) bool { + _, ok := obj.GetLabels()[constants.LabelJobName] + return ok + })), ) }, func(b *builder.Builder, cl client.Client, cache cache.Cache) *builder.Builder { - return b.Watches( + return b.WatchesMetadata( &corev1.Secret{}, handler.EnqueueRequestForOwner( m.client.Scheme(), m.client.RESTMapper(), &trainer.TrainJob{}, handler.OnlyControllerOwner(), ), + builder.WithPredicates(predicate.NewPredicateFuncs(func(obj client.Object) bool { + _, ok := obj.GetLabels()[constants.LabelJobName] + return ok + })), ) }, } @@ -244,7 +254,9 @@ func (m *MPI) Build(ctx context.Context, info *runtime.Info, trainJob *trainer.T var objects []apiruntime.ApplyConfiguration // SSHAuthSecret is immutable. - if err := m.client.Get(ctx, client.ObjectKey{Name: sshAuthSecretName(trainJob.Name), Namespace: trainJob.Namespace}, &corev1.Secret{}); err != nil { + partialSecret := &metav1.PartialObjectMetadata{} + partialSecret.SetGroupVersionKind(corev1.SchemeGroupVersion.WithKind("Secret")) + if err := m.client.Get(ctx, client.ObjectKey{Name: sshAuthSecretName(trainJob.Name), Namespace: trainJob.Namespace}, partialSecret); err != nil { if client.IgnoreNotFound(err) != nil { return nil, err } @@ -275,6 +287,9 @@ func (m *MPI) buildSSHAuthSecret(trainJob *trainer.TrainJob) (*corev1ac.SecretAp return nil, err } return corev1ac.Secret(sshAuthSecretName(trainJob.Name), trainJob.Namespace). + WithLabels(map[string]string{ + constants.LabelJobName: trainJob.Name, + }). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ corev1.SSHAuthPrivateKey: privatePEM, @@ -310,6 +325,9 @@ func (m *MPI) buildHostFileConfigMap(info *runtime.Info, trainJob *trainer.Train } } return corev1ac.ConfigMap(fmt.Sprintf("%s%s", trainJob.Name, constants.MPIHostfileConfigMapSuffix), trainJob.Namespace). + WithLabels(map[string]string{ + constants.LabelJobName: trainJob.Name, + }). WithData(map[string]string{ constants.MPIHostfileName: hostFile.String(), }). diff --git a/pkg/runtime/framework/plugins/mpi/mpi_test.go b/pkg/runtime/framework/plugins/mpi/mpi_test.go index db5628a5b8..32c0df3414 100644 --- a/pkg/runtime/framework/plugins/mpi/mpi_test.go +++ b/pkg/runtime/framework/plugins/mpi/mpi_test.go @@ -203,6 +203,7 @@ func TestMPI(t *testing.T) { }, wantObjs: []apiruntime.Object{ utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithImmutable(true). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ @@ -212,6 +213,7 @@ func TestMPI(t *testing.T) { ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). Obj(), utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithData(map[string]string{ constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=1 trainJob-node-1-1.trainJob slots=1 @@ -339,6 +341,7 @@ trainJob-node-1-1.trainJob slots=1 }, wantObjs: []apiruntime.Object{ utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithImmutable(true). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ @@ -348,6 +351,7 @@ trainJob-node-1-1.trainJob slots=1 ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). Obj(), utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithData(map[string]string{ constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=2 `, @@ -476,6 +480,7 @@ trainJob-node-1-1.trainJob slots=1 }, wantObjs: []apiruntime.Object{ utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithImmutable(true). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ @@ -485,6 +490,7 @@ trainJob-node-1-1.trainJob slots=1 ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). Obj(), utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithData(map[string]string{ constants.MPIHostfileName: `trainJob-node-1-0.trainJob slots=5 `, @@ -647,6 +653,7 @@ trainJob-node-1-1.trainJob slots=1 }, wantObjs: []apiruntime.Object{ utiltesting.MakeSecretWrapper(fmt.Sprintf("trainJob%s", constants.MPISSHAuthSecretSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithImmutable(true). WithType(corev1.SecretTypeSSHAuth). WithData(map[string][]byte{ @@ -656,6 +663,7 @@ trainJob-node-1-1.trainJob slots=1 ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). Obj(), utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithData(map[string]string{ constants.MPIHostfileName: `trainJob-launcher-0-0.trainJob slots=1 trainJob-node-1-0.trainJob slots=1 @@ -668,6 +676,7 @@ trainJob-node-1-0.trainJob slots=1 "sshAuth secret already has existed in the cluster": { objs: []client.Object{ utiltesting.MakeSecretWrapper(sshAuthSecretName("trainJob"), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). ControllerReference(trainer.SchemeGroupVersion.WithKind(trainer.TrainJobKind), "trainJob", "trainJob"). WithImmutable(true). Obj(), @@ -746,6 +755,7 @@ trainJob-node-1-0.trainJob slots=1 }, wantObjs: []apiruntime.Object{ utiltesting.MakeConfigMapWrapper(fmt.Sprintf("trainJob%s", constants.MPIHostfileConfigMapSuffix), metav1.NamespaceDefault). + WithLabels(map[string]string{constants.LabelJobName: "trainJob"}). WithData(map[string]string{ constants.MPIHostfileName: `trainJob-launcher-0-0.trainJob slots=1 `, @@ -839,7 +849,7 @@ trainJob-node-1-0.trainJob slots=1 b := utiltesting.NewClientBuilder().WithObjects(tc.objs...) b.WithInterceptorFuncs(interceptor.Funcs{ Get: func(ctx context.Context, client client.WithWatch, key client.ObjectKey, obj client.Object, opts ...client.GetOption) error { - if _, ok := obj.(*corev1.Secret); ok && errors.Is(tc.wantBuildError, errorGetSSHAuthSecretFromAPI) { + if _, ok := obj.(*metav1.PartialObjectMetadata); ok && errors.Is(tc.wantBuildError, errorGetSSHAuthSecretFromAPI) { return errorGetSSHAuthSecretFromAPI } return client.Get(ctx, key, obj, opts...) diff --git a/pkg/util/testing/wrapper.go b/pkg/util/testing/wrapper.go index 9a1006221e..f0c40f36c1 100644 --- a/pkg/util/testing/wrapper.go +++ b/pkg/util/testing/wrapper.go @@ -1425,6 +1425,16 @@ func (c *ConfigMapWrapper) WithData(data map[string]string) *ConfigMapWrapper { return c } +func (c *ConfigMapWrapper) WithLabels(labels map[string]string) *ConfigMapWrapper { + if c.Labels == nil { + c.Labels = make(map[string]string, len(labels)) + } + for k, v := range labels { + c.Labels[k] = v + } + return c +} + func (c *ConfigMapWrapper) ControllerReference(gvk schema.GroupVersionKind, name, uid string) *ConfigMapWrapper { c.OwnerReferences = append(c.OwnerReferences, metav1.OwnerReference{ APIVersion: gvk.GroupVersion().String(), @@ -1475,6 +1485,17 @@ func (s *SecretWrapper) WithData(data map[string][]byte) *SecretWrapper { return s } + +func (s *SecretWrapper) WithLabels(labels map[string]string) *SecretWrapper { + if s.Labels == nil { + s.Labels = make(map[string]string, len(labels)) + } + for k, v := range labels { + s.Labels[k] = v + } + return s +} + func (s *SecretWrapper) WithImmutable(immutable bool) *SecretWrapper { s.Immutable = &immutable return s