Skip to content

Add trainium, inferentia, and efa parameters to @kubernetes decorator#3086

Draft
emattia wants to merge 7 commits intoNetflix:masterfrom
emattia:trn-k8s
Draft

Add trainium, inferentia, and efa parameters to @kubernetes decorator#3086
emattia wants to merge 7 commits intoNetflix:masterfrom
emattia:trn-k8s

Conversation

@emattia
Copy link
Copy Markdown
Contributor

@emattia emattia commented Apr 8, 2026

PR Type

  • Bug fix
  • New feature
  • Core Runtime change
  • Docs / tooling
  • Refactoring

Summary

Mirror @batch's AWS-accelerator surface on @kubernetes:

  • @kubernetes(trainium=N) requests N AWS Trainium / Inferentia Neuron
    devices (aws.amazon.com/neuron k8s resource).
  • @kubernetes(inferentia=N) is an alias for trainium, mirroring
    @batch(inferentia=N) for API consistency.
  • @kubernetes(efa=N) requests N AWS Elastic Fabric Adapter network
    interfaces (vpc.amazonaws.com/efa k8s resource).

Plumbed through kubernetes_job, kubernetes_jobsets, kubernetes_cli,
and the argo / airflow runtimes consistently with how the existing
gpu parameter is handled.

Issue

No tracking issue. Supersedes the original PR scope of just trainium.
Brings the @kubernetes path to parity with @batch for AWS Neuron
and EFA workloads, unblocking customers who run their own EKS clusters
and want first-class Neuron/EFA support without writing raw pod specs.

Reproduction

Runtime: kubernetes (EKS with AWS Neuron and EFA device plugins
installed; nodes labeled with the relevant accelerator).

Commands to run:

from metaflow import FlowSpec, step, kubernetes, environment

NEURON_IMG = "public.ecr.aws/neuron/pytorch-training-neuronx:2.9.0-neuronx-py312-sdk2.29.1-ubuntu24.04"

class NeuronEfaSmoke(FlowSpec):

    @kubernetes(trainium=1, image=NEURON_IMG)
    @step
    def neuron_only(self):
        import subprocess
        print(subprocess.check_output(["neuron-ls"]).decode())
        self.next(self.gpu_efa)

    # Equivalent — inferentia is an alias for trainium
    @kubernetes(inferentia=1, image=NEURON_IMG)
    @step
    def inferentia_alias(self):
        ...

    @environment(vars={"FI_PROVIDER": "efa"})
    @kubernetes(gpu=8, efa=32, image="<aws-dlc-pytorch-cuda>")
    @step
    def gpu_with_efa(self):
        import torch.distributed as dist
        dist.init_process_group(backend="nccl")
        # NCCL debug log will show "Selected provider is efa"

Where evidence shows up: task pod spec (kubectl describe pod) and
NCCL debug log inside the running container.

Before (master)
TypeError: kubernetes() got an unexpected keyword argument 'trainium'

(also for inferentia, efa)

After (this PR)
$ kubectl describe pod ws-...
...
Limits:
  aws.amazon.com/neuron:    1
  vpc.amazonaws.com/efa:    32
  nvidia.com/gpu:           8

# inside the pod
$ NCCL_DEBUG=INFO python train.py
NCCL INFO NET/OFI Initializing aws-ofi-nccl 1.15.0
NCCL INFO NET/OFI Using transport protocol RDMA (platform set)
NCCL INFO NET/OFI Selected provider is efa, fabric is efa-direct (found 32 nics)

Root Cause

Not a bug fix — net-new feature. The underlying Kubernetes resources
(aws.amazon.com/neuron, vpc.amazonaws.com/efa) are advertised by the
respective AWS device plugins; @kubernetes had no decorator-level
surface to request them. @batch already exposed trainium,
inferentia, and efa. This PR brings @kubernetes to parity.

Why This Fix Is Correct

  • Mirrors @batch's API surface exactly. inferentia collapses into
    trainium at step_init and is popped before any runtime translation
    — same shape as batch_decorator.py:175-211, only with trainium as
    canonical (since on K8s the underlying resource name is
    aws.amazon.com/neuron and we surface what users running on Trainium
    hardware naturally type first).
  • Doesn't disturb the existing GPU path. gpu and trainium are
    enforced as mutually exclusive (matching @batch's convention).
  • Argo/Airflow runtimes already had the trainium plumbing pattern from
    earlier in this branch; efa follows the same pattern.

Failure Modes Considered

  1. Backward compat: flows using only gpu / gpu_vendor are
    unaffected — new attributes default to None and resource-limit
    emission is gated on non-None values.
  2. Mutual exclusion: specifying both inferentia and trainium
    raises a clear error in step_init (mirrors @batch). Specifying
    both gpu and trainium was already enforced.
  3. Wire format consistency: inferentia is popped from
    self.attributes after collapsing into trainium, so the runtime
    CLI / argo / airflow translation only ever sees the canonical key.
  4. Cross-runtime: changes propagate through kubernetes_job,
    kubernetes_jobsets, argo, and airflow consistently with how
    trainium was already plumbed.
  5. Validation: efa value validated as positive integer (mirrors
    trainium and tmpfs_size validation patterns in the same file).

Tests

  • Unit tests added/updated — not yet; planned as a follow-up
    mirroring existing kube_utils tests with parametrize cases for
    mutual-exclusion + resource-limit emission. Happy to land tests
    either in this PR (push another commit) or a follow-up — let me
    know reviewer preference.
  • Manual reproduction provided above
  • Smoke-tested end-to-end on a real EKS cluster with Neuron and
    EFA device plugins. Pod spec contains the right resource limits;
    NCCL via aws-ofi-nccl selects EFA as the network backend.
  • CI passes — TBD (CI doesn't have AWS Trainium/EFA hardware to
    truly exercise the runtime path, only static / unit checks).

Non-Goals

  • Not touching @batch (already has these parameters).
  • Not adding a --inferentia CLI flag — inferentia is purely a
    decorator-time convenience that resolves to trainium before any
    CLI invocation, mirroring @batch's CLI which only exposes the
    canonical name (--inferentia for batch since inferentia is
    canonical there; --trainium for k8s since trainium is canonical
    here).
  • Not adding NCCL/libfabric environment-variable defaults
    (FI_PROVIDER, FI_EFA_USE_DEVICE_RDMA). Users set those via
    @environment for now; auto-injection is a separate ergonomics PR.
  • Not opining on which Trainium/Inferentia instance type a user
    should target — that's a cluster-side concern (instance allowlist
    • AMI selection on the EKS managed nodegroup side).

AI Tool Usage

  • AI tools were used (Anthropic Claude — research on AWS DLC tag
    selection, Karpenter EFA NIC layout prior art, and drafting this
    PR description). All generated code reviewed, understood, and
    tested end-to-end on a live cluster.

@emattia emattia marked this pull request as draft April 8, 2026 00:32
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 8, 2026

Greptile Summary

This PR adds a trainium parameter to @kubernetes, mirroring the existing @batch UX, so users can request AWS Trainium/Inferentia Neuron devices via @kubernetes(trainium=X). The core implementation in kubernetes_job.py and kubernetes_jobsets.py is correct — both the aws.amazon.com/neuron resource limit and the automatic aws.amazon.com/neuron:NoSchedule toleration are injected. However, the same toleration is missing from two execution paths:

  • Argo Workflows (non-parallel path): argo_workflows.py adds the Neuron resource limit but calls .tolerations(resources.get(\"tolerations\")) without appending the Neuron toleration, so pods will stay Pending on tainted Neuron nodes.
  • Airflow: airflow.py adds the Neuron resource limit to the resources dict but never passes a matching toleration to KubernetesPodOperator, causing the same scheduling failure.

Confidence Score: 4/5

Two execution paths (Argo Workflows non-parallel and Airflow) will silently fail to schedule on Neuron nodes because the required toleration is not injected; the core Kubernetes job/jobset paths work correctly.

Two confirmed P1 bugs mean pods will not schedule on Trainium nodes when using Argo Workflows non-parallel path or Airflow. The core kubernetes_job and kubernetes_jobsets paths are correct. Score is 4 rather than lower because the feature works for the primary direct-Kubernetes execution path.

metaflow/plugins/argo/argo_workflows.py and metaflow/plugins/airflow/airflow.py need the automatic Neuron toleration added to their pod specs.

Vulnerabilities

No security concerns identified. The change adds a resource-limit annotation and a toleration to Kubernetes pod specs; there is no user-controlled input reaching a sensitive API without validation.

Important Files Changed

Filename Overview
metaflow/plugins/kubernetes/kubernetes_decorator.py Adds trainium as a new decorator attribute with mutual-exclusion check against gpu, integer validation, and CLI forwarding — correct, though the validator allows trainium=0 which would spuriously add a Neuron toleration.
metaflow/plugins/kubernetes/kubernetes_job.py Correctly adds aws.amazon.com/neuron resource limit and automatically injects the aws.amazon.com/neuron:NoSchedule toleration when trainium is set.
metaflow/plugins/kubernetes/kubernetes_jobsets.py Correctly adds aws.amazon.com/neuron resource limit and auto-injects the Neuron toleration for the parallel JobSet path, consistent with kubernetes_job.py.
metaflow/plugins/argo/argo_workflows.py Adds Neuron resource limit to the non-parallel pod spec and threads trainium through to the JobSet path, but the non-parallel path omits the required aws.amazon.com/neuron:NoSchedule toleration — pods will fail to schedule on Neuron nodes.
metaflow/plugins/airflow/airflow.py Adds Neuron resource limit to the resources dict but never adds a matching aws.amazon.com/neuron:NoSchedule toleration to the Airflow operator args, so pods will remain pending on tainted Neuron nodes.
metaflow/plugins/kubernetes/kubernetes_cli.py Adds --trainium CLI option and threads it through to the step command correctly.
metaflow/plugins/kubernetes/kubernetes.py Adds trainium parameter to both create_job and create_jobset methods and forwards it to the job/jobset constructors correctly.

Comments Outside Diff (1)

  1. metaflow/plugins/argo/argo_workflows.py, line 2796 (link)

    P1 Missing automatic Neuron toleration in non-parallel Argo Workflows path

    The non-parallel (non-JobSet) Argo Workflows pod spec adds the aws.amazon.com/neuron resource limit (line ~2871) but does not inject the corresponding aws.amazon.com/neuron:NoSchedule toleration. Trainium/Inferentia nodes carry that taint by default, so any pod that reaches this code path with trainium=N will remain in Pending state — it will never be scheduled.

    The JobSet path correctly auto-injects the toleration (via kubernetes_jobsets.py), and kubernetes_job.py does the same. The fix is to extend the toleration list here analogously:

    .tolerations(
        (resources.get("tolerations") or [])
        + (
            [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}]
            if resources.get("trainium") is not None
            else []
        )
    )

Reviews (1): Last reviewed commit: "Add trainium parameter to @kubernetes de..." | Re-trigger Greptile

Comment thread metaflow/plugins/airflow/airflow.py
Comment thread metaflow/plugins/kubernetes/kubernetes_decorator.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 8, 2026

Welcome to Codecov 🎉

Once you merge this PR into your default branch, you're all set! Codecov will compare coverage reports and display results in all future pull requests.

Thanks for integrating Codecov - We've got you covered ☂️

@emattia emattia changed the title Add trainium parameter to @kubernetes decorator Add trainium, inferentia, and efa parameters to @kubernetes decorator May 4, 2026
emattia added 7 commits May 5, 2026 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant