diff --git a/cli/internal/config/nodebootstrap/flags.go b/cli/internal/config/nodebootstrap/flags.go new file mode 100644 index 0000000..7a6adad --- /dev/null +++ b/cli/internal/config/nodebootstrap/flags.go @@ -0,0 +1,62 @@ +package nodebootstrap + +import ( + "fmt" + "strings" + + corev1 "k8s.io/api/core/v1" +) + +// parseNodeLabels parses a slice of "key=value" strings into a map. An empty +// slice returns a nil map. Duplicate keys cause an error. +func parseNodeLabels(in []string) (map[string]string, error) { + if len(in) == 0 { + return nil, nil + } + out := make(map[string]string, len(in)) + for _, s := range in { + k, v, ok := strings.Cut(s, "=") + if !ok || k == "" { + return nil, fmt.Errorf("invalid --node-label %q: expected key=value", s) + } + if _, dup := out[k]; dup { + return nil, fmt.Errorf("duplicate --node-label key %q", k) + } + out[k] = v + } + return out, nil +} + +// parseTaints parses a slice of "key=value:Effect" or "key:Effect" strings into +// []corev1.Taint. Effect must be one of NoSchedule, PreferNoSchedule, NoExecute. +func parseTaints(in []string) ([]corev1.Taint, error) { + if len(in) == 0 { + return nil, nil + } + out := make([]corev1.Taint, 0, len(in)) + for _, s := range in { + // Effect is everything after the LAST ':'. This allows ':' in the value + // portion when the input is in key=value:Effect form. + idx := strings.LastIndex(s, ":") + if idx < 0 { + return nil, fmt.Errorf("invalid --taint %q: expected key[=value]:Effect", s) + } + head, effect := s[:idx], corev1.TaintEffect(s[idx+1:]) + switch effect { + case corev1.TaintEffectNoSchedule, corev1.TaintEffectPreferNoSchedule, corev1.TaintEffectNoExecute: + default: + return nil, fmt.Errorf("invalid --taint %q: effect must be NoSchedule, PreferNoSchedule, or NoExecute", s) + } + var key, value string + if i := strings.Index(head, "="); i >= 0 { + key, value = head[:i], head[i+1:] + } else { + key = head + } + if key == "" { + return nil, fmt.Errorf("invalid --taint %q: empty key", s) + } + out = append(out, corev1.Taint{Key: key, Value: value, Effect: effect}) + } + return out, nil +} diff --git a/cli/internal/config/nodebootstrap/flags_test.go b/cli/internal/config/nodebootstrap/flags_test.go new file mode 100644 index 0000000..4edd644 --- /dev/null +++ b/cli/internal/config/nodebootstrap/flags_test.go @@ -0,0 +1,95 @@ +package nodebootstrap + +import ( + "testing" + + corev1 "k8s.io/api/core/v1" +) + +func Test_parseNodeLabels(t *testing.T) { + tests := []struct { + name string + in []string + want map[string]string + wantErr bool + }{ + {name: "empty", in: nil, want: nil}, + {name: "simple", in: []string{"a=1", "b=2"}, want: map[string]string{"a": "1", "b": "2"}}, + {name: "empty value", in: []string{"a="}, want: map[string]string{"a": ""}}, + {name: "domain key", in: []string{"nvidia.com/gpu.product=H200"}, want: map[string]string{"nvidia.com/gpu.product": "H200"}}, + {name: "missing equals", in: []string{"justkey"}, wantErr: true}, + {name: "empty key", in: []string{"=value"}, wantErr: true}, + {name: "duplicate", in: []string{"a=1", "a=2"}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseNodeLabels(tt.in) + if (err != nil) != tt.wantErr { + t.Fatalf("err = %v, wantErr = %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + if len(got) != len(tt.want) { + t.Fatalf("got %v, want %v", got, tt.want) + } + for k, v := range tt.want { + if got[k] != v { + t.Errorf("key %q: got %q, want %q", k, got[k], v) + } + } + }) + } +} + +func Test_parseTaints(t *testing.T) { + tests := []struct { + name string + in []string + want []corev1.Taint + wantErr bool + }{ + {name: "empty", in: nil, want: nil}, + { + name: "key=value:NoSchedule", + in: []string{"nvidia.com/gpu=present:NoSchedule"}, + want: []corev1.Taint{{Key: "nvidia.com/gpu", Value: "present", Effect: corev1.TaintEffectNoSchedule}}, + }, + { + name: "key:NoSchedule (no value)", + in: []string{"dedicated:NoSchedule"}, + want: []corev1.Taint{{Key: "dedicated", Effect: corev1.TaintEffectNoSchedule}}, + }, + { + name: "all effects", + in: []string{"a:NoSchedule", "b:PreferNoSchedule", "c:NoExecute"}, + want: []corev1.Taint{ + {Key: "a", Effect: corev1.TaintEffectNoSchedule}, + {Key: "b", Effect: corev1.TaintEffectPreferNoSchedule}, + {Key: "c", Effect: corev1.TaintEffectNoExecute}, + }, + }, + {name: "missing effect", in: []string{"key=value"}, wantErr: true}, + {name: "invalid effect", in: []string{"key=value:Bogus"}, wantErr: true}, + {name: "empty key", in: []string{":NoSchedule"}, wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseTaints(tt.in) + if (err != nil) != tt.wantErr { + t.Fatalf("err = %v, wantErr = %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + if len(got) != len(tt.want) { + t.Fatalf("got %v, want %v", got, tt.want) + } + for i := range tt.want { + if got[i] != tt.want[i] { + t.Errorf("[%d]: got %+v, want %+v", i, got[i], tt.want[i]) + } + } + }) + } +} diff --git a/cli/internal/config/nodebootstrap/nodebootstrap.go b/cli/internal/config/nodebootstrap/nodebootstrap.go index e6c6d84..541f983 100644 --- a/cli/internal/config/nodebootstrap/nodebootstrap.go +++ b/cli/internal/config/nodebootstrap/nodebootstrap.go @@ -7,6 +7,7 @@ import ( "github.com/spf13/cobra" + "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/api/features/kubeadm" "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/userdata/flex" "github.com/Azure/aks-flex/plugin/pkg/services/agentpools/userdata/ubuntu" "github.com/Azure/aks-flex/plugin/pkg/util/cloudinit" @@ -25,6 +26,8 @@ var flagEnableNvidiaGPURuntime bool var flagVariant string var flagArch string var flagKubeVersion string +var flagNodeLabels []string +var flagTaints []string func init() { r.Handle("ubuntu", writeUbuntuUserData) @@ -37,6 +40,10 @@ func init() { "Kubernetes version for the downloaded binaries.") Command.Flags().StringVar(&flagVariant, "variant", variantCloudInit, fmt.Sprintf("Output variant: %q produces cloud-init YAML user data, %q produces an equivalent standalone bash script.", variantCloudInit, variantScript)) + Command.Flags().StringSliceVar(&flagNodeLabels, "node-label", nil, + "Extra node label to register the node with, as key=value. Repeat for multiple labels. Merged with the labels derived from the AKS cluster (cluster name, managed=false, stretch-managed=true).") + Command.Flags().StringSliceVar(&flagTaints, "taint", nil, + "Taint to register the node with, as key[=value]:Effect (e.g. nvidia.com/gpu=present:NoSchedule). Repeat for multiple taints.") } // marshalUserData marshals the cloud-init UserData according to the selected @@ -62,11 +69,15 @@ func marshalUserData(ud *cloudinit.UserData, w io.Writer) error { } func writeFlexUserData(ctx context.Context, w io.Writer) error { + kc, err := kubeadmConfigFromFlags(ctx) + if err != nil { + return err + } ud, err := flex.UserData( flex.WithEnableNvidiaGPURuntime(flagEnableNvidiaGPURuntime), flex.WithArch(flagArch), flex.WithKubeVersion(flagKubeVersion), - flex.WithKubeadmConfig(configcmd.DefaultKubeadmConfig(ctx)), + flex.WithKubeadmConfig(kc), ) if err != nil { return fmt.Errorf("generating flex userdata: %w", err) @@ -75,9 +86,38 @@ func writeFlexUserData(ctx context.Context, w io.Writer) error { } func writeUbuntuUserData(ctx context.Context, w io.Writer) error { - ud, err := ubuntu.UserData(configcmd.DefaultKubeadmConfig(ctx)) + kc, err := kubeadmConfigFromFlags(ctx) + if err != nil { + return err + } + ud, err := ubuntu.UserData(kc) if err != nil { return fmt.Errorf("generating ubuntu userdata: %w", err) } return marshalUserData(ud, w) } + +// kubeadmConfigFromFlags returns the default kubeadm config (derived from the +// live AKS cluster when reachable) with extra --node-label and --taint flag +// values merged in. +func kubeadmConfigFromFlags(ctx context.Context) (*kubeadm.Config, error) { + kc := configcmd.DefaultKubeadmConfig(ctx) + + extraLabels, err := parseNodeLabels(flagNodeLabels) + if err != nil { + return nil, err + } + if len(extraLabels) > 0 { + kc.AddNodeLabels(extraLabels) + } + + taints, err := parseTaints(flagTaints) + if err != nil { + return nil, err + } + if len(taints) > 0 { + kc.AddK8SRegisterTaints(taints...) + } + + return kc, nil +}