From aede1ce43a5e85b0ffe31bd7d50f544894ee7017 Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Tue, 7 Apr 2026 16:27:36 -0700 Subject: [PATCH 1/7] Add trainium parameter to @kubernetes decorator --- metaflow/plugins/airflow/airflow.py | 5 +++++ metaflow/plugins/argo/argo_workflows.py | 8 +++++++ metaflow/plugins/kubernetes/kubernetes.py | 4 ++++ metaflow/plugins/kubernetes/kubernetes_cli.py | 3 +++ .../kubernetes/kubernetes_decorator.py | 22 +++++++++++++++++++ metaflow/plugins/kubernetes/kubernetes_job.py | 20 ++++++++++++++++- .../plugins/kubernetes/kubernetes_jobsets.py | 20 ++++++++++++++++- 7 files changed, 80 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 43ef65efacb..2c36d6d3ca6 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -449,6 +449,11 @@ def _to_job(self, node): # Don't set GPU limits if gpu isn't specified. if k8s_deco.attributes["gpu"] is not None }, + **{ + "aws.amazon.com/neuron": str(k8s_deco.attributes["trainium"]) + for k in [0] + if k8s_deco.attributes.get("trainium") is not None + }, }, ) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index b49c0f252f6..666b8f015a9 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2636,6 +2636,7 @@ def _container_templates(self): disk=str(resources["disk"]), gpu=resources["gpu"], gpu_vendor=str(resources["gpu_vendor"]), + trainium=resources.get("trainium"), tolerations=resources["tolerations"], use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, @@ -2874,6 +2875,13 @@ def _container_templates(self): for k in [0] if resources["gpu"] is not None }, + **{ + "aws.amazon.com/neuron": str( + resources["trainium"] + ) + for k in [0] + if resources.get("trainium") is not None + }, }, ), # Configure secrets diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index c19b3efe3b9..6e077386953 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -181,6 +181,7 @@ def create_jobset( cpu=None, gpu=None, gpu_vendor=None, + trainium=None, disk=None, memory=None, use_tmpfs=None, @@ -215,6 +216,7 @@ def create_jobset( disk=disk, gpu=gpu, gpu_vendor=gpu_vendor, + trainium=trainium, timeout_in_seconds=run_time_limit, # Retries are handled by Metaflow runtime retries=0, @@ -482,6 +484,7 @@ def create_job_object( cpu=None, gpu=None, gpu_vendor=None, + trainium=None, disk=None, memory=None, use_tmpfs=None, @@ -528,6 +531,7 @@ def create_job_object( disk=disk, gpu=gpu, gpu_vendor=gpu_vendor, + trainium=trainium, timeout_in_seconds=run_time_limit, # Retries are handled by Metaflow runtime retries=0, diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index e15f7b06cb9..8b321ed4f31 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -89,6 +89,7 @@ def kubernetes(): @click.option("--memory", help="Memory requirement for Kubernetes pod.") @click.option("--gpu", help="GPU requirement for Kubernetes pod.") @click.option("--gpu-vendor", help="GPU vendor requirement for Kubernetes pod.") +@click.option("--trainium", help="AWS Trainium/Inferentia Neuron device requirement for Kubernetes pod.") @click.option("--run-id", help="Passed to the top-level 'step'.") @click.option("--task-id", help="Passed to the top-level 'step'.") @click.option("--input-paths", help="Passed to the top-level 'step'.") @@ -178,6 +179,7 @@ def step( memory=None, gpu=None, gpu_vendor=None, + trainium=None, use_tmpfs=None, tmpfs_tempdir=None, tmpfs_size=None, @@ -323,6 +325,7 @@ def _sync_metadata(): memory=memory, gpu=gpu, gpu_vendor=gpu_vendor, + trainium=trainium, use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index bd3ae7e12c4..425170dccd3 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -152,6 +152,7 @@ class KubernetesDecorator(StepDecorator): "namespace": None, "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, + "trainium": None, # number of AWS Trainium/Inferentia Neuron devices "tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"}, # {"key": "foo", "operator": "Equal", "value": "bar"}] "labels": None, # e.g. {"test-label": "value", "another-label":"value2"} @@ -382,6 +383,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge max(float(my_val or 0), float(v or 0)) ) + # Validate mutually exclusive: gpu and trainium cannot both be set. + if ( + self.attributes["trainium"] is not None + and self.attributes["gpu"] is not None + ): + raise KubernetesException( + "Cannot specify both 'gpu' and 'trainium' for step *{step}*.".format( + step=step + ) + ) + # Check GPU vendor. if self.attributes["gpu_vendor"].lower() not in ("amd", "nvidia"): raise KubernetesException( @@ -412,6 +424,16 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge ) ) + if self.attributes["trainium"] is not None and not ( + isinstance(self.attributes["trainium"], (int, unicode, basestring)) + and float(self.attributes["trainium"]).is_integer() + ): + raise KubernetesException( + "Invalid trainium value *{}* for step *{step}*; it should be an integer".format( + self.attributes["trainium"], step=step + ) + ) + if self.attributes["tmpfs_size"]: if not ( isinstance(self.attributes["tmpfs_size"], (int, unicode, basestring)) diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index b81777bcc7b..e288c6ebd09 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -182,6 +182,13 @@ def create_job_spec(self): # Don't set GPU limits if gpu isn't specified. if self._kwargs["gpu"] is not None }, + **{ + "aws.amazon.com/neuron": str( + self._kwargs["trainium"] + ) + for k in [0] + if self._kwargs.get("trainium") is not None + }, }, ), volume_mounts=( @@ -236,7 +243,18 @@ def create_job_spec(self): tolerations=[ client.V1Toleration(**toleration) for toleration in self._kwargs.get("tolerations") or [] - ], + ] + + ( + [ + client.V1Toleration( + key="aws.amazon.com/neuron", + operator="Exists", + effect="NoSchedule", + ) + ] + if self._kwargs.get("trainium") is not None + else [] + ), volumes=( [ client.V1Volume( diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index da0f0fc3130..912743f83b1 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -679,6 +679,13 @@ def dump(self): # Don't set GPU limits if gpu isn't specified. if self._kwargs["gpu"] is not None }, + **{ + "aws.amazon.com/neuron": str( + self._kwargs["trainium"] + ) + for k in [0] + if self._kwargs.get("trainium") is not None + }, }, ), volume_mounts=( @@ -740,7 +747,18 @@ def dump(self): client.V1Toleration(**toleration) for toleration in self._kwargs.get("tolerations") or [] - ], + ] + + ( + [ + client.V1Toleration( + key="aws.amazon.com/neuron", + operator="Exists", + effect="NoSchedule", + ) + ] + if self._kwargs.get("trainium") is not None + else [] + ), volumes=( [ client.V1Volume( From 4f1abfc2e1a13338c3439ff0a8ada835aa10f5b3 Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Tue, 7 Apr 2026 18:01:36 -0700 Subject: [PATCH 2/7] Fix neuron toleration gaps and trainium=0 validation --- metaflow/plugins/airflow/airflow.py | 7 +++++++ metaflow/plugins/argo/argo_workflows.py | 9 ++++++++- metaflow/plugins/kubernetes/kubernetes_decorator.py | 3 ++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 2c36d6d3ca6..bee60760dbb 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -506,6 +506,13 @@ def _to_job(self, node): retry_exponential_backoff=False, # todo : should this be a arg we allow on CLI. not right now - there is an open ticket for this - maybe at some point we will. reattach_on_restart=False, secrets=[], + tolerations=( + [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}] + if k8s_deco.attributes.get("trainium") is not None + else [] + ) + ( + k8s_deco.attributes.get("tolerations") or [] + ), ) k8s_operator_args["in_cluster"] = True if AIRFLOW_KUBERNETES_CONN_ID is not None: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 666b8f015a9..7256a9dc0a9 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2801,7 +2801,14 @@ def _container_templates(self): # Set node selectors .node_selectors(resources.get("node_selector")) # Set tolerations - .tolerations(resources.get("tolerations")) + .tolerations( + (resources.get("tolerations") or []) + + ( + [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}] + if resources.get("trainium") is not None + else [] + ) + ) # Set image pull secrets if present. We need to use pod_spec_patch due to Argo not supporting this on a template level. .pod_spec_patch( { diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 425170dccd3..a089500e5da 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -427,9 +427,10 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge if self.attributes["trainium"] is not None and not ( isinstance(self.attributes["trainium"], (int, unicode, basestring)) and float(self.attributes["trainium"]).is_integer() + and int(float(self.attributes["trainium"])) > 0 ): raise KubernetesException( - "Invalid trainium value *{}* for step *{step}*; it should be an integer".format( + "Invalid trainium value *{}* for step *{step}*; it should be a positive integer".format( self.attributes["trainium"], step=step ) ) From 827420bd633d4ba68d03510b27a1c5ad96a7d2a2 Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Tue, 7 Apr 2026 18:04:51 -0700 Subject: [PATCH 3/7] format for pre-commit --- metaflow/plugins/airflow/airflow.py | 13 +++++++++---- metaflow/plugins/argo/argo_workflows.py | 8 +++++++- metaflow/plugins/kubernetes/kubernetes_cli.py | 5 ++++- metaflow/plugins/kubernetes/kubernetes_jobsets.py | 3 ++- 4 files changed, 22 insertions(+), 7 deletions(-) diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index bee60760dbb..662082ad3c9 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -507,12 +507,17 @@ def _to_job(self, node): reattach_on_restart=False, secrets=[], tolerations=( - [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}] + [ + { + "key": "aws.amazon.com/neuron", + "operator": "Exists", + "effect": "NoSchedule", + } + ] if k8s_deco.attributes.get("trainium") is not None else [] - ) + ( - k8s_deco.attributes.get("tolerations") or [] - ), + ) + + (k8s_deco.attributes.get("tolerations") or []), ) k8s_operator_args["in_cluster"] = True if AIRFLOW_KUBERNETES_CONN_ID is not None: diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 7256a9dc0a9..6bf2a01438a 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2804,7 +2804,13 @@ def _container_templates(self): .tolerations( (resources.get("tolerations") or []) + ( - [{"key": "aws.amazon.com/neuron", "operator": "Exists", "effect": "NoSchedule"}] + [ + { + "key": "aws.amazon.com/neuron", + "operator": "Exists", + "effect": "NoSchedule", + } + ] if resources.get("trainium") is not None else [] ) diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 8b321ed4f31..490dca4e261 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -89,7 +89,10 @@ def kubernetes(): @click.option("--memory", help="Memory requirement for Kubernetes pod.") @click.option("--gpu", help="GPU requirement for Kubernetes pod.") @click.option("--gpu-vendor", help="GPU vendor requirement for Kubernetes pod.") -@click.option("--trainium", help="AWS Trainium/Inferentia Neuron device requirement for Kubernetes pod.") +@click.option( + "--trainium", + help="AWS Trainium/Inferentia Neuron device requirement for Kubernetes pod.", +) @click.option("--run-id", help="Passed to the top-level 'step'.") @click.option("--task-id", help="Passed to the top-level 'step'.") @click.option("--input-paths", help="Passed to the top-level 'step'.") diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index 912743f83b1..918b61ceacb 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -684,7 +684,8 @@ def dump(self): self._kwargs["trainium"] ) for k in [0] - if self._kwargs.get("trainium") is not None + if self._kwargs.get("trainium") + is not None }, }, ), From 6022f18e7dd7d9f6795879dbd1400822b3febe8e Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Mon, 4 May 2026 10:26:45 -0700 Subject: [PATCH 4/7] Add efa parameter to @kubernetes decorator; Mirror the existing @batch(efa=N) parameter on @kubernetes. When efa is set, the pod requests N vpc.amazonaws.com/efa resources, advertised by the AWS EFA k8s device plugin on EFA-enabled nodes. Plumbed through to argo and airflow runtimes consistently with how trainium= is. --- metaflow/plugins/airflow/airflow.py | 5 +++++ metaflow/plugins/argo/argo_workflows.py | 8 ++++++++ metaflow/plugins/kubernetes/kubernetes.py | 4 ++++ metaflow/plugins/kubernetes/kubernetes_cli.py | 6 ++++++ metaflow/plugins/kubernetes/kubernetes_decorator.py | 12 ++++++++++++ metaflow/plugins/kubernetes/kubernetes_job.py | 7 +++++++ metaflow/plugins/kubernetes/kubernetes_jobsets.py | 8 ++++++++ 7 files changed, 50 insertions(+) diff --git a/metaflow/plugins/airflow/airflow.py b/metaflow/plugins/airflow/airflow.py index 662082ad3c9..09ac2646d74 100644 --- a/metaflow/plugins/airflow/airflow.py +++ b/metaflow/plugins/airflow/airflow.py @@ -454,6 +454,11 @@ def _to_job(self, node): for k in [0] if k8s_deco.attributes.get("trainium") is not None }, + **{ + "vpc.amazonaws.com/efa": str(k8s_deco.attributes["efa"]) + for k in [0] + if k8s_deco.attributes.get("efa") is not None + }, }, ) diff --git a/metaflow/plugins/argo/argo_workflows.py b/metaflow/plugins/argo/argo_workflows.py index 6bf2a01438a..075ff49b5f5 100644 --- a/metaflow/plugins/argo/argo_workflows.py +++ b/metaflow/plugins/argo/argo_workflows.py @@ -2637,6 +2637,7 @@ def _container_templates(self): gpu=resources["gpu"], gpu_vendor=str(resources["gpu_vendor"]), trainium=resources.get("trainium"), + efa=resources.get("efa"), tolerations=resources["tolerations"], use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, @@ -2895,6 +2896,13 @@ def _container_templates(self): for k in [0] if resources.get("trainium") is not None }, + **{ + "vpc.amazonaws.com/efa": str( + resources["efa"] + ) + for k in [0] + if resources.get("efa") is not None + }, }, ), # Configure secrets diff --git a/metaflow/plugins/kubernetes/kubernetes.py b/metaflow/plugins/kubernetes/kubernetes.py index 6e077386953..83fa17767c6 100644 --- a/metaflow/plugins/kubernetes/kubernetes.py +++ b/metaflow/plugins/kubernetes/kubernetes.py @@ -182,6 +182,7 @@ def create_jobset( gpu=None, gpu_vendor=None, trainium=None, + efa=None, disk=None, memory=None, use_tmpfs=None, @@ -217,6 +218,7 @@ def create_jobset( gpu=gpu, gpu_vendor=gpu_vendor, trainium=trainium, + efa=efa, timeout_in_seconds=run_time_limit, # Retries are handled by Metaflow runtime retries=0, @@ -485,6 +487,7 @@ def create_job_object( gpu=None, gpu_vendor=None, trainium=None, + efa=None, disk=None, memory=None, use_tmpfs=None, @@ -532,6 +535,7 @@ def create_job_object( gpu=gpu, gpu_vendor=gpu_vendor, trainium=trainium, + efa=efa, timeout_in_seconds=run_time_limit, # Retries are handled by Metaflow runtime retries=0, diff --git a/metaflow/plugins/kubernetes/kubernetes_cli.py b/metaflow/plugins/kubernetes/kubernetes_cli.py index 490dca4e261..3237fd6d4e3 100644 --- a/metaflow/plugins/kubernetes/kubernetes_cli.py +++ b/metaflow/plugins/kubernetes/kubernetes_cli.py @@ -93,6 +93,10 @@ def kubernetes(): "--trainium", help="AWS Trainium/Inferentia Neuron device requirement for Kubernetes pod.", ) +@click.option( + "--efa", + help="Number of Elastic Fabric Adapter network interfaces for Kubernetes pod.", +) @click.option("--run-id", help="Passed to the top-level 'step'.") @click.option("--task-id", help="Passed to the top-level 'step'.") @click.option("--input-paths", help="Passed to the top-level 'step'.") @@ -183,6 +187,7 @@ def step( gpu=None, gpu_vendor=None, trainium=None, + efa=None, use_tmpfs=None, tmpfs_tempdir=None, tmpfs_size=None, @@ -329,6 +334,7 @@ def _sync_metadata(): gpu=gpu, gpu_vendor=gpu_vendor, trainium=trainium, + efa=efa, use_tmpfs=use_tmpfs, tmpfs_tempdir=tmpfs_tempdir, tmpfs_size=tmpfs_size, diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index a089500e5da..78ee6284907 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -153,6 +153,7 @@ class KubernetesDecorator(StepDecorator): "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, "trainium": None, # number of AWS Trainium/Inferentia Neuron devices + "efa": None, # number of Elastic Fabric Adapter network interfaces "tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"}, # {"key": "foo", "operator": "Equal", "value": "bar"}] "labels": None, # e.g. {"test-label": "value", "another-label":"value2"} @@ -435,6 +436,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge ) ) + if self.attributes["efa"] is not None and not ( + isinstance(self.attributes["efa"], (int, unicode, basestring)) + and float(self.attributes["efa"]).is_integer() + and int(float(self.attributes["efa"])) > 0 + ): + raise KubernetesException( + "Invalid efa value *{}* for step *{step}*; it should be a positive integer".format( + self.attributes["efa"], step=step + ) + ) + if self.attributes["tmpfs_size"]: if not ( isinstance(self.attributes["tmpfs_size"], (int, unicode, basestring)) diff --git a/metaflow/plugins/kubernetes/kubernetes_job.py b/metaflow/plugins/kubernetes/kubernetes_job.py index e288c6ebd09..eae406a3985 100644 --- a/metaflow/plugins/kubernetes/kubernetes_job.py +++ b/metaflow/plugins/kubernetes/kubernetes_job.py @@ -189,6 +189,13 @@ def create_job_spec(self): for k in [0] if self._kwargs.get("trainium") is not None }, + **{ + "vpc.amazonaws.com/efa": str( + self._kwargs["efa"] + ) + for k in [0] + if self._kwargs.get("efa") is not None + }, }, ), volume_mounts=( diff --git a/metaflow/plugins/kubernetes/kubernetes_jobsets.py b/metaflow/plugins/kubernetes/kubernetes_jobsets.py index 918b61ceacb..9b814de422b 100644 --- a/metaflow/plugins/kubernetes/kubernetes_jobsets.py +++ b/metaflow/plugins/kubernetes/kubernetes_jobsets.py @@ -687,6 +687,14 @@ def dump(self): if self._kwargs.get("trainium") is not None }, + **{ + "vpc.amazonaws.com/efa": str( + self._kwargs["efa"] + ) + for k in [0] + if self._kwargs.get("efa") + is not None + }, }, ), volume_mounts=( From 798d1cb4e1e51fd9df430393bfb30323ac25315b Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Mon, 4 May 2026 10:44:40 -0700 Subject: [PATCH 5/7] document trainium/efa in docstring --- metaflow/plugins/kubernetes/kubernetes_decorator.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index 78ee6284907..ab6c879740f 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -96,6 +96,17 @@ class KubernetesDecorator(StepDecorator): the scheduled node should not have GPUs. gpu_vendor : str, default KUBERNETES_GPU_VENDOR The vendor of the GPUs to be used for this step. + trainium : int, optional, default None + Number of AWS Trainium / Inferentia Neuron devices required for this + step. Maps to the `aws.amazon.com/neuron` Kubernetes resource managed + by the AWS Neuron device plugin -- same resource regardless of whether + the underlying chip is Trainium or Inferentia, since they share the + device-plugin / AMI / runtime stack. + efa : int, optional, default None + Number of AWS Elastic Fabric Adapter network interfaces required for + this step. Maps to the `vpc.amazonaws.com/efa` Kubernetes resource + managed by the AWS EFA device plugin. Only valid on EFA-capable + instance types where the pool was provisioned with EFA NICs. tolerations : List[Dict[str,str]], default [] The default is extracted from METAFLOW_KUBERNETES_TOLERATIONS. Kubernetes tolerations to use when launching pod in Kubernetes. From 4415bdb8cf19fd597ba1833dbfbd6f3c70735b93 Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Mon, 4 May 2026 11:57:22 -0700 Subject: [PATCH 6/7] alias inferentia to trainium like @batch does --- .../kubernetes/kubernetes_decorator.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index ab6c879740f..b6e2961bccb 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -102,6 +102,9 @@ class KubernetesDecorator(StepDecorator): by the AWS Neuron device plugin -- same resource regardless of whether the underlying chip is Trainium or Inferentia, since they share the device-plugin / AMI / runtime stack. + inferentia : int, optional, default None + Alias for `trainium`. Use only one of the two. Provided for API + consistency with `@batch(inferentia=...)`. efa : int, optional, default None Number of AWS Elastic Fabric Adapter network interfaces required for this step. Maps to the `vpc.amazonaws.com/efa` Kubernetes resource @@ -164,6 +167,7 @@ class KubernetesDecorator(StepDecorator): "gpu": None, # value of 0 implies that the scheduled node should not have GPUs "gpu_vendor": None, "trainium": None, # number of AWS Trainium/Inferentia Neuron devices + "inferentia": None, # alias for trainium; both map to aws.amazon.com/neuron "efa": None, # number of Elastic Fabric Adapter network interfaces "tolerations": None, # e.g., [{"key": "arch", "operator": "Equal", "value": "amd"}, # {"key": "foo", "operator": "Equal", "value": "bar"}] @@ -395,6 +399,23 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge max(float(my_val or 0), float(v or 0)) ) + # Alias inferentia to trainium and check that both are not in use. + # `trainium` is canonical on @kubernetes (the underlying Neuron device + # plugin advertises a single `aws.amazon.com/neuron` resource for both + # chip families). `inferentia` is provided for API consistency with + # `@batch(inferentia=...)` -- it collapses into `trainium` and is + # popped from the wire format before any runtime translation. + if ( + self.attributes["inferentia"] is not None + and self.attributes["trainium"] is not None + ): + raise KubernetesException( + "only specify a value for 'inferentia' or 'trainium', not both." + ) + if self.attributes["inferentia"] is not None: + self.attributes["trainium"] = self.attributes["inferentia"] + self.attributes.pop("inferentia", None) + # Validate mutually exclusive: gpu and trainium cannot both be set. if ( self.attributes["trainium"] is not None From 9b816f1943f7e6446952f0517045ffd5a84a0c75 Mon Sep 17 00:00:00 2001 From: Eddie Mattia Date: Mon, 4 May 2026 13:28:33 -0700 Subject: [PATCH 7/7] Make @kubernetes inferentia-trainium alias resolution idempotent --- metaflow/plugins/kubernetes/kubernetes_decorator.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/metaflow/plugins/kubernetes/kubernetes_decorator.py b/metaflow/plugins/kubernetes/kubernetes_decorator.py index b6e2961bccb..c33feeae8f7 100644 --- a/metaflow/plugins/kubernetes/kubernetes_decorator.py +++ b/metaflow/plugins/kubernetes/kubernetes_decorator.py @@ -403,18 +403,17 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge # `trainium` is canonical on @kubernetes (the underlying Neuron device # plugin advertises a single `aws.amazon.com/neuron` resource for both # chip families). `inferentia` is provided for API consistency with - # `@batch(inferentia=...)` -- it collapses into `trainium` and is - # popped from the wire format before any runtime translation. + # `@batch(inferentia=...)` -- it collapses into `trainium` here. if ( - self.attributes["inferentia"] is not None - and self.attributes["trainium"] is not None + self.attributes.get("inferentia") is not None + and self.attributes.get("trainium") is not None ): raise KubernetesException( "only specify a value for 'inferentia' or 'trainium', not both." ) - if self.attributes["inferentia"] is not None: + if self.attributes.get("inferentia") is not None: self.attributes["trainium"] = self.attributes["inferentia"] - self.attributes.pop("inferentia", None) + self.attributes["inferentia"] = None # Validate mutually exclusive: gpu and trainium cannot both be set. if (