From e0321f1bf53c69114556659983eae5f4e882e8a7 Mon Sep 17 00:00:00 2001 From: komikndr Date: Wed, 28 Jan 2026 17:45:41 +0700 Subject: [PATCH 01/10] c10d implementation for fp8 --- comfy_kitchen/tensor/fp8.py | 58 +++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 3 deletions(-) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 2c20bc1..5f9b998 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -8,7 +8,6 @@ import torch import comfy_kitchen as ck -from comfy_kitchen.scaled_mm_v2 import scaled_mm_v2 from .base import BaseLayoutParams, QuantizedLayout, dequantize_args, register_layout_op @@ -91,14 +90,17 @@ def _fp8_scaled_mm( bias: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, ) -> torch.Tensor: - return scaled_mm_v2( + """Core FP8 scaled matrix multiplication using torch._scaled_mm.""" + output = torch._scaled_mm( input_qdata.contiguous(), weight_qdata, + bias=bias, scale_a=scale_a, scale_b=scale_b, - bias=bias, out_dtype=out_dtype, ) + # Handle tuple return for older PyTorch versions + return output[0] if isinstance(output, tuple) else output def _make_fp8_shape_handler(aten_op): @@ -211,6 +213,56 @@ def _handle_fp8_addmm(qt, args, kwargs): return torch.addmm(*dequantize_args(args)) +# ==================== Distributed Operations ==================== + +@register_layout_op(torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) +@register_layout_op(torch.ops.c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) +def _handle_all_gather(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = None + input_idx = None + for i, arg in enumerate(args): + if isinstance(arg, QuantizedTensor): + input_tensor = arg + input_idx = i + + assert input_tensor is not None + + qdata = input_tensor._qdata + layout_cls = input_tensor._layout_cls + params = input_tensor._params + + qdata_bytes = qdata.contiguous().view(torch.uint8) + + new_args = list(args) + new_args[input_idx] = qdata_bytes + + gathered_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( + *new_args, **kwargs + ) + + gathered_qdata = gathered_bytes.view(qdata.dtype) + return QuantizedTensor(gathered_qdata, layout_cls, params) + + +@register_layout_op(torch.ops._c10d_functional.wait_tensor.default, TensorCoreFP8Layout) +@register_layout_op(torch.ops.c10d_functional.wait_tensor.default, TensorCoreFP8Layout) +def _handle_wait_tensor(qt, args, kwargs): + from .base import QuantizedTensor + + qtensor = args[0] + + waited_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._qdata.view(torch.uint8), + *args[1:], + **kwargs, + ) + + waited_qdata = waited_bytes.view(qtensor._qdata.dtype) + return QuantizedTensor(waited_qdata, qtensor._layout_cls, qtensor._params) + + # ==================== FP8 Shape Operations ==================== # These preserve quantization since FP8 is not packed (1:1 element mapping) From f6b8fb70f536b390886f2bb4320da2aba569f1fb Mon Sep 17 00:00:00 2001 From: komikndr Date: Sat, 31 Jan 2026 17:31:28 +0700 Subject: [PATCH 02/10] add additional ops require for sharding --- comfy_kitchen/tensor/fp8.py | 53 +++++++++++++++++++++++++++++++++---- 1 file changed, 48 insertions(+), 5 deletions(-) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 5f9b998..23d8980 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from dataclasses import dataclass +from dataclasses import dataclass, replace from typing import TYPE_CHECKING import torch @@ -214,9 +214,10 @@ def _handle_fp8_addmm(qt, args, kwargs): # ==================== Distributed Operations ==================== +# Required c10d ops : c10d allgather, c10d wait +# Required aten ops : slice, split, new_zeros, as_strided @register_layout_op(torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) -@register_layout_op(torch.ops.c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) def _handle_all_gather(qt, args, kwargs): from .base import QuantizedTensor @@ -243,11 +244,11 @@ def _handle_all_gather(qt, args, kwargs): ) gathered_qdata = gathered_bytes.view(qdata.dtype) - return QuantizedTensor(gathered_qdata, layout_cls, params) + gathered_params = replace(params, orig_shape=tuple(gathered_qdata.shape)) + return QuantizedTensor(gathered_qdata, layout_cls, gathered_params) @register_layout_op(torch.ops._c10d_functional.wait_tensor.default, TensorCoreFP8Layout) -@register_layout_op(torch.ops.c10d_functional.wait_tensor.default, TensorCoreFP8Layout) def _handle_wait_tensor(qt, args, kwargs): from .base import QuantizedTensor @@ -260,7 +261,47 @@ def _handle_wait_tensor(qt, args, kwargs): ) waited_qdata = waited_bytes.view(qtensor._qdata.dtype) - return QuantizedTensor(waited_qdata, qtensor._layout_cls, qtensor._params) + waited_params = replace(qtensor._params, orig_shape=tuple(waited_qdata.shape)) + return QuantizedTensor(waited_qdata, qtensor._layout_cls, waited_params) + + +def _wrap_fp8_tensor(qtensor, qdata): + from .base import QuantizedTensor + + new_params = TensorCoreFP8Layout.Params( + scale=qtensor._params.scale, + orig_dtype=qtensor._params.orig_dtype, + orig_shape=tuple(qdata.shape), + ) + return QuantizedTensor(qdata, qtensor._layout_cls, new_params) + + +@register_layout_op(torch.ops.aten.slice.Tensor, TensorCoreFP8Layout) +def _handle_fp8_slice_tensor(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.slice.Tensor(*args, **kwargs) + + sliced_qdata = torch.ops.aten.slice.Tensor(input_tensor._qdata, *args[1:], **kwargs) + return _wrap_fp8_tensor(input_tensor, sliced_qdata) + + +@register_layout_op(torch.ops.aten.split.Tensor, TensorCoreFP8Layout) +def _handle_fp8_split(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.split.Tensor(*args, **kwargs) + + qdata_chunks = torch.ops.aten.split.Tensor(input_tensor._qdata, *args[1:], **kwargs) + wrapped_chunks = tuple( + _wrap_fp8_tensor(input_tensor, chunk) + for chunk in qdata_chunks + ) + return wrapped_chunks # ==================== FP8 Shape Operations ==================== @@ -270,5 +311,7 @@ def _handle_wait_tensor(qt, args, kwargs): torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten.t.default, + torch.ops.aten.as_strided.default, + torch.ops.aten.new_zeros.default, ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) From bc5d7ddc6625c560dcb367c471de1f4597d06c16 Mon Sep 17 00:00:00 2001 From: komikndr Date: Sun, 1 Feb 2026 01:15:31 +0700 Subject: [PATCH 03/10] aten ops alias, and broadcast --- comfy_kitchen/tensor/fp8.py | 79 ++++++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 2 deletions(-) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 23d8980..18a7784 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -214,8 +214,8 @@ def _handle_fp8_addmm(qt, args, kwargs): # ==================== Distributed Operations ==================== -# Required c10d ops : c10d allgather, c10d wait -# Required aten ops : slice, split, new_zeros, as_strided +# Required c10d ops : c10d allgather, c10d wait, c10d broadcast (this is for broadcast_rank0=True) +# Required aten ops : slice, split, new_zeros, as_strided, cat, alias @register_layout_op(torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) def _handle_all_gather(qt, args, kwargs): @@ -265,6 +265,58 @@ def _handle_wait_tensor(qt, args, kwargs): return QuantizedTensor(waited_qdata, qtensor._layout_cls, waited_params) +@register_layout_op(torch.ops.c10d.broadcast_.default, TensorCoreFP8Layout) +def _handle_broadcast(qt, args, kwargs): + from .base import QuantizedTensor + import torch + + tensor_list = args[0] + + input_tensor = None + input_idx = None + for idx, t in enumerate(tensor_list): + if isinstance(t, QuantizedTensor): + input_tensor = t + input_idx = idx + break + + if input_tensor is None: + return torch.ops.c10d.broadcast_.default(*args, **kwargs) + + qdata = input_tensor._qdata.contiguous() + qdata_bytes = qdata.view(torch.uint8) + + new_tensor_list = list(tensor_list) + new_tensor_list[input_idx] = qdata_bytes + + new_args = list(args) + new_args[0] = new_tensor_list + + broadcasted = torch.ops.c10d.broadcast_.default( + *new_args, + **kwargs, + ) + + if isinstance(broadcasted, tuple): + tensor_list_out, work = broadcasted + else: + tensor_list_out = broadcasted + work = None + + broadcasted_qdata = tensor_list_out[input_idx].view(qdata.dtype) + + new_out_list = list(tensor_list_out) + new_out_list[input_idx] = _wrap_fp8_tensor( + input_tensor, + broadcasted_qdata, + ) + + if work is not None: + return new_out_list, work + else: + return new_out_list + + def _wrap_fp8_tensor(qtensor, qdata): from .base import QuantizedTensor @@ -304,6 +356,28 @@ def _handle_fp8_split(qt, args, kwargs): return wrapped_chunks +@register_layout_op(torch.ops.aten.cat.default, TensorCoreFP8Layout) +def _handle_fp8_cat(qt, args, kwargs): + from .base import QuantizedTensor + + tensors = args[0] + if not isinstance(tensors, (list, tuple)) or not tensors: + return torch.ops.aten.cat.default(*args, **kwargs) + + qdata_list = [] + first_qtensor = None + for item in tensors: + if not isinstance(item, QuantizedTensor): + return torch.ops.aten.cat.default(*args, **kwargs) + qdata_list.append(item._qdata) + if first_qtensor is None: + first_qtensor = item + + assert first_qtensor is not None + concatenated_qdata = torch.ops.aten.cat.default(qdata_list, *args[1:], **kwargs) + return _wrap_fp8_tensor(first_qtensor, concatenated_qdata) + + # ==================== FP8 Shape Operations ==================== # These preserve quantization since FP8 is not packed (1:1 element mapping) @@ -313,5 +387,6 @@ def _handle_fp8_split(qt, args, kwargs): torch.ops.aten.t.default, torch.ops.aten.as_strided.default, torch.ops.aten.new_zeros.default, + torch.ops.aten.alias.default ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) From d5538c62ab696480a940ce1b1407345379962fd6 Mon Sep 17 00:00:00 2001 From: komikndr Date: Tue, 3 Feb 2026 00:16:19 +0700 Subject: [PATCH 04/10] Add scater ops, and rework copy to quantized dst to src format --- comfy_kitchen/tensor/base.py | 8 ++++-- comfy_kitchen/tensor/fp8.py | 56 ++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/comfy_kitchen/tensor/base.py b/comfy_kitchen/tensor/base.py index bd7ef7a..8dc0dbc 100644 --- a/comfy_kitchen/tensor/base.py +++ b/comfy_kitchen/tensor/base.py @@ -438,13 +438,15 @@ def _handle_is_contiguous(qt, args, kwargs): def _handle_copy_(qt, args, kwargs): dst, src = args[0], args[1] - if not isinstance(src, QuantizedTensor): - raise TypeError(f"Cannot copy {type(src).__name__} to QuantizedTensor") + non_blocking = kwargs.get("non_blocking", len(args) >= 3) + + if not isinstance(dst, QuantizedTensor): + dst = QuantizedTensor(src._qdata, src._layout_cls, src._params) + if dst._layout_cls != src._layout_cls: raise TypeError(f"Layout mismatch: {dst._layout_cls} vs {src._layout_cls}") dst_orig_dtype = dst._params.orig_dtype - non_blocking = kwargs.get("non_blocking", len(args) >= 3) dst._qdata.copy_(src._qdata, non_blocking=non_blocking) dst._params.copy_from(src._params, non_blocking=non_blocking) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 18a7784..40efd4a 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -328,6 +328,62 @@ def _wrap_fp8_tensor(qtensor, qdata): return QuantizedTensor(qdata, qtensor._layout_cls, new_params) +@register_layout_op(torch.ops.c10d.scatter_.default, TensorCoreFP8Layout) +def _handle_scatter(qt, args, kwargs): + from .base import QuantizedTensor + + output_tensors = args[0] + input_tensors = args[1] + + quantized_outputs: list[tuple[int, QuantizedTensor]] = [] + new_output_tensors = list(output_tensors) + for idx, tensor in enumerate(output_tensors): + if isinstance(tensor, QuantizedTensor): + quantized_outputs.append((idx, tensor)) + new_output_tensors[idx] = tensor._qdata.contiguous().view(torch.uint8) + + has_quantized_input = False + new_input_tensors: list[list[torch.Tensor]] = [] + + def process_input_list(entry): + nonlocal has_quantized_input + processed: list[torch.Tensor] = [] + for t in entry: + if isinstance(t, QuantizedTensor): + has_quantized_input = True + processed.append(t._qdata.contiguous().view(torch.uint8)) + else: + processed.append(t) + return processed + + for entry in input_tensors: + if isinstance(entry, (list, tuple)): + new_input_tensors.append(process_input_list(entry)) + else: + new_input_tensors.append(entry) # type: ignore[arg-type] + + if not quantized_outputs and not has_quantized_input: + return torch.ops.c10d.scatter_.default(*args, **kwargs) + + new_args = [new_output_tensors, new_input_tensors, *args[2:]] + result = torch.ops.c10d.scatter_.default(*new_args, **kwargs) + + if isinstance(result, tuple): + output_list, work = result + else: + output_list = result + work = None + + output_list = list(output_list) + for idx, original in quantized_outputs: + qdata = output_list[idx].view(original._qdata.dtype) + output_list[idx] = _wrap_fp8_tensor(original, qdata) + + if work is not None: + return output_list, work + return output_list + + @register_layout_op(torch.ops.aten.slice.Tensor, TensorCoreFP8Layout) def _handle_fp8_slice_tensor(qt, args, kwargs): from .base import QuantizedTensor From 1db1945296e08f583ef4a2d8f575640ee3d56773 Mon Sep 17 00:00:00 2001 From: komikndr Date: Sun, 15 Feb 2026 17:39:08 +0700 Subject: [PATCH 05/10] Add new_zeros op --- comfy_kitchen/tensor/fp8.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 40efd4a..11b3b23 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -265,6 +265,8 @@ def _handle_wait_tensor(qt, args, kwargs): return QuantizedTensor(waited_qdata, qtensor._layout_cls, waited_params) +# Should not be use, since tensor.copy_ inside fsdp, calls native dtype that can't recieve correct Quantized class +# This is mainly for state_dict broadcast from rank 0. @register_layout_op(torch.ops.c10d.broadcast_.default, TensorCoreFP8Layout) def _handle_broadcast(qt, args, kwargs): from .base import QuantizedTensor @@ -434,6 +436,13 @@ def _handle_fp8_cat(qt, args, kwargs): return _wrap_fp8_tensor(first_qtensor, concatenated_qdata) +@register_layout_op(torch.ops.aten.new_zeros.default, TensorCoreFP8Layout) +def _handle_new_zeros(qt, args, kwargs): + input_tensor = args[0] + new_zero_qdata = torch.ops.aten.new_zeros.default(input_tensor._qdata, *args[1:], **kwargs) + return _wrap_fp8_tensor(input_tensor, new_zero_qdata) + + # ==================== FP8 Shape Operations ==================== # These preserve quantization since FP8 is not packed (1:1 element mapping) @@ -442,7 +451,6 @@ def _handle_fp8_cat(qt, args, kwargs): torch.ops.aten.reshape.default, torch.ops.aten.t.default, torch.ops.aten.as_strided.default, - torch.ops.aten.new_zeros.default, torch.ops.aten.alias.default ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) From da71067d17eae0b820fe545dae079b7f5623df40 Mon Sep 17 00:00:00 2001 From: komikndr Date: Sun, 15 Feb 2026 19:30:51 +0700 Subject: [PATCH 06/10] pre and post gather fsdp_hook --- comfy_kitchen/tensor/fp8.py | 65 ++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 11b3b23..12563c1 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -3,7 +3,8 @@ import logging from dataclasses import dataclass, replace -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Optional, Tuple +from .base import QuantizedTensor import torch @@ -454,3 +455,65 @@ def _handle_new_zeros(qt, args, kwargs): torch.ops.aten.alias.default ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) + + +# ==================== FSDP All-Gather Hooks ==================== + +def _fsdp_pre_all_gather_fp8(self, mesh): + if self._layout_cls != "TensorCoreFP8Layout": + raise NotImplementedError(f"FSDP all_gather not supported for {self._layout_cls}") + + qdata = self._qdata + if not qdata.is_contiguous(): + qdata = qdata.contiguous() + + scale = self._params.scale + if isinstance(scale, torch.Tensor): + scale = scale.to(device=qdata.device) + + return (qdata,), (scale,) + + +def _fsdp_post_all_gather_fp8( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[torch.Tensor] = None, +): + if self._layout_cls != "TensorCoreFP8Layout": + raise NotImplementedError(f"FSDP all_gather not supported for {self._layout_cls}") + + (data,) = all_gather_outputs + (scale,) = metadata + + if out is not None: + from .base import QuantizedTensor + + if not isinstance(out, QuantizedTensor): + raise TypeError(f"Expected QuantizedTensor out, got {type(out)}") + out._qdata = data + out._params = TensorCoreFP8Layout.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return + + from .base import QuantizedTensor + + params = TensorCoreFP8Layout.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return QuantizedTensor(data, "TensorCoreFP8Layout", params), (data,) + + +# Monkey patch for now +if not hasattr(QuantizedTensor, "fsdp_pre_all_gather"): + QuantizedTensor.fsdp_pre_all_gather = _fsdp_pre_all_gather_fp8 # type: ignore[attr-defined] + +if not hasattr(QuantizedTensor, "fsdp_post_all_gather"): + QuantizedTensor.fsdp_post_all_gather = _fsdp_post_all_gather_fp8 # type: ignore[attr-defined] From c738713425b247727dc23e667feec3100a80c69b Mon Sep 17 00:00:00 2001 From: komikndr Date: Sun, 15 Feb 2026 19:47:47 +0700 Subject: [PATCH 07/10] Change README and NOTICE --- NOTICE | 3 +++ README.md | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/NOTICE b/NOTICE index 975dbc1..f839ad9 100644 --- a/NOTICE +++ b/NOTICE @@ -52,6 +52,9 @@ Portions of the following files are derived from PyTorch AO: - comfy_kitchen/backends/eager/quantization.py Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py +- comfy_kitchen/tensor/fp8.py + Source: https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py + - comfy_kitchen/float_utils.py (_f32_to_floatx_unpacked, _floatx_unpacked_to_f32) Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py diff --git a/README.md b/README.md index dc5d448..e693823 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,13 @@ Fast kernel library for Diffusion inference with multiple compute backends. | `apply_rope` | ✓ | ✓ | ✓ | | `apply_rope1` | ✓ | ✓ | ✓ | +## DTensor Capabilities Matrix +| Layout | Status | +|------------------------|--------| +| `TensorCoreFP8Layout` | ✓ | +| `TensorCoreNVFP4Layout`| | +| `TensorCoreMXFP8Layout`| | + ## Quantized Tensors From bc2f6468e637d0d5816e06b54822e3f3e8f7121d Mon Sep 17 00:00:00 2001 From: komikndr Date: Thu, 26 Feb 2026 03:45:06 +0700 Subject: [PATCH 08/10] Remove monkey patching for fsdp pre/post_all_gather Now it implement in classmethod, more cleaner, and NVFP4 can implement it down the line --- comfy_kitchen/tensor/base.py | 55 ++++++++++++++++- comfy_kitchen/tensor/fp8.py | 114 ++++++++++++++--------------------- 2 files changed, 99 insertions(+), 70 deletions(-) diff --git a/comfy_kitchen/tensor/base.py b/comfy_kitchen/tensor/base.py index 8dc0dbc..03b0ad5 100644 --- a/comfy_kitchen/tensor/base.py +++ b/comfy_kitchen/tensor/base.py @@ -137,6 +137,41 @@ def get_requirements(cls) -> dict[str, Any]: "fast_matmul_supported": cls.supports_fast_matmul(), } + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + """Prepare data for FSDP all_gather. + + Returns: + Tuple of (all_gather_outputs, metadata): + - all_gather_outputs: data to be gathered (as tuple of tensors) + - metadata: additional info needed after gathering (as tuple) + """ + raise NotImplementedError(f"pre_all_gather not implemented for {cls.__name__}") + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + """Reconstruct QuantizedTensor after FSDP all_gather. + + Args: + qtensor: Original quantized tensor (used for layout context) + all_gather_outputs: Gathered tensors + metadata: Metadata from pre_all_gather + param_dtype: Expected parameter dtype + out: Optional output tensor to update in-place + + Returns: + Either updates out in-place (returns None) or returns new (QuantizedTensor, metadata) tuple + """ + raise NotImplementedError(f"post_all_gather not implemented for {cls.__name__}") + class QuantizedTensor(torch.Tensor): """ @@ -327,6 +362,25 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): params = ctx["params_class"](**params_kwargs) return QuantizedTensor(inner_tensors["_qdata"], ctx["layout_cls"], params) + # ==================== FSDP Hooks ==================== + + def fsdp_pre_all_gather(self, mesh): + """FSDP pre_all_gather hook - delegates to layout class.""" + return self.layout_cls.pre_all_gather(self, mesh) + + def fsdp_post_all_gather( + self, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + """FSDP post_all_gather hook - delegates to layout class.""" + return self.layout_cls.post_all_gather( + self, all_gather_outputs, metadata, param_dtype, out=out + ) + # ==================== Torch Dispatch ==================== @classmethod @@ -348,7 +402,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return op_handlers[parent_cls](qt, args, kwargs) # Step 3: Fallback to dequantization - logger.debug(f"Unhandled op {func} for {layout_cls.__name__ if layout_cls else 'unknown'}, dequantizing") return cls._dequant_and_fallback(func, args, kwargs) @classmethod diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 12563c1..0380998 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -80,6 +80,49 @@ def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, to "_scale": params.scale, } + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + qdata = qtensor._qdata + if not qdata.is_contiguous(): + qdata = qdata.contiguous() + + scale = qtensor._params.scale + if isinstance(scale, torch.Tensor): + scale = scale.to(device=qdata.device) + + return (qdata,), (scale,) + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + + if out is not None: + if not isinstance(out, QuantizedTensor): + raise TypeError(f"Expected QuantizedTensor out, got {type(out)}") + out._qdata = data + out._params = cls.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return + + params = cls.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return QuantizedTensor(data, qtensor._layout_cls, params), (data,) + # ==================== Helper Utilities ==================== @@ -240,9 +283,7 @@ def _handle_all_gather(qt, args, kwargs): new_args = list(args) new_args[input_idx] = qdata_bytes - gathered_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( - *new_args, **kwargs - ) + gathered_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default(*new_args, **kwargs) gathered_qdata = gathered_bytes.view(qdata.dtype) gathered_params = replace(params, orig_shape=tuple(gathered_qdata.shape)) @@ -408,10 +449,7 @@ def _handle_fp8_split(qt, args, kwargs): return torch.ops.aten.split.Tensor(*args, **kwargs) qdata_chunks = torch.ops.aten.split.Tensor(input_tensor._qdata, *args[1:], **kwargs) - wrapped_chunks = tuple( - _wrap_fp8_tensor(input_tensor, chunk) - for chunk in qdata_chunks - ) + wrapped_chunks = tuple(_wrap_fp8_tensor(input_tensor, chunk) for chunk in qdata_chunks) return wrapped_chunks @@ -455,65 +493,3 @@ def _handle_new_zeros(qt, args, kwargs): torch.ops.aten.alias.default ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) - - -# ==================== FSDP All-Gather Hooks ==================== - -def _fsdp_pre_all_gather_fp8(self, mesh): - if self._layout_cls != "TensorCoreFP8Layout": - raise NotImplementedError(f"FSDP all_gather not supported for {self._layout_cls}") - - qdata = self._qdata - if not qdata.is_contiguous(): - qdata = qdata.contiguous() - - scale = self._params.scale - if isinstance(scale, torch.Tensor): - scale = scale.to(device=qdata.device) - - return (qdata,), (scale,) - - -def _fsdp_post_all_gather_fp8( - self, - all_gather_outputs: Tuple[torch.Tensor, ...], - metadata: Any, - param_dtype: torch.dtype, - *, - out: Optional[torch.Tensor] = None, -): - if self._layout_cls != "TensorCoreFP8Layout": - raise NotImplementedError(f"FSDP all_gather not supported for {self._layout_cls}") - - (data,) = all_gather_outputs - (scale,) = metadata - - if out is not None: - from .base import QuantizedTensor - - if not isinstance(out, QuantizedTensor): - raise TypeError(f"Expected QuantizedTensor out, got {type(out)}") - out._qdata = data - out._params = TensorCoreFP8Layout.Params( - scale=scale, - orig_dtype=param_dtype, - orig_shape=tuple(data.shape), - ) - return - - from .base import QuantizedTensor - - params = TensorCoreFP8Layout.Params( - scale=scale, - orig_dtype=param_dtype, - orig_shape=tuple(data.shape), - ) - return QuantizedTensor(data, "TensorCoreFP8Layout", params), (data,) - - -# Monkey patch for now -if not hasattr(QuantizedTensor, "fsdp_pre_all_gather"): - QuantizedTensor.fsdp_pre_all_gather = _fsdp_pre_all_gather_fp8 # type: ignore[attr-defined] - -if not hasattr(QuantizedTensor, "fsdp_post_all_gather"): - QuantizedTensor.fsdp_post_all_gather = _fsdp_post_all_gather_fp8 # type: ignore[attr-defined] From 79ebc96301d5a1b1022271a0ee172becdf7633cd Mon Sep 17 00:00:00 2001 From: komikndr Date: Sat, 7 Mar 2026 15:52:43 +0700 Subject: [PATCH 09/10] NVFP4 Enable --- README.md | 2 +- comfy_kitchen/tensor/nvfp4.py | 499 ++++++++++++++++++++++++++++++++-- 2 files changed, 479 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index e693823..9d5b90a 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Fast kernel library for Diffusion inference with multiple compute backends. | Layout | Status | |------------------------|--------| | `TensorCoreFP8Layout` | ✓ | -| `TensorCoreNVFP4Layout`| | +| `TensorCoreNVFP4Layout`| ✓ | | `TensorCoreMXFP8Layout`| | diff --git a/comfy_kitchen/tensor/nvfp4.py b/comfy_kitchen/tensor/nvfp4.py index a0e8409..7f68595 100644 --- a/comfy_kitchen/tensor/nvfp4.py +++ b/comfy_kitchen/tensor/nvfp4.py @@ -1,4 +1,5 @@ """NVFP4 (E2M1) block quantization layout for tensor cores.""" + from __future__ import annotations import logging @@ -8,7 +9,7 @@ import torch import comfy_kitchen as ck -from comfy_kitchen.float_utils import F4_E2M1_MAX, F8_E4M3_MAX, roundup +from comfy_kitchen.float_utils import F4_E2M1_MAX, F8_E4M3_MAX, from_blocked, roundup, to_blocked from .base import BaseLayoutParams, QuantizedLayout, dequantize_args, register_layout_op @@ -24,8 +25,8 @@ class TensorCoreNVFP4Layout(QuantizedLayout): Note: Requires SM >= 10.0 (Blackwell) for hardware-accelerated matmul. - Shape operations (view, reshape, transpose) are not supported due to - packed format and block scales - they fall back to dequantization. + View-like operations remain limited because NVFP4 uses packed values plus + blocked scales, but FSDP row-wise alias/slice/split are supported. """ MIN_SM_VERSION = (10, 0) @@ -37,6 +38,7 @@ class Params(BaseLayoutParams): Inherits scale, orig_dtype, orig_shape from BaseLayoutParams. Adds block_scale for per-block scaling factors. """ + block_scale: torch.Tensor transposed: bool = False @@ -46,7 +48,9 @@ def _tensor_fields(self) -> list[str]: def _validate_tensor_fields(self): if isinstance(self.scale, torch.Tensor): - object.__setattr__(self, "scale", self.scale.to(dtype=torch.float32, non_blocking=True)) + object.__setattr__( + self, "scale", self.scale.to(dtype=torch.float32, non_blocking=True) + ) @classmethod def quantize( @@ -93,7 +97,6 @@ def get_plain_tensors( @classmethod def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, torch.Tensor]: - """Return key suffix → tensor mapping for serialization.""" return { "": qdata, "_scale": params.block_scale, @@ -117,14 +120,466 @@ def get_logical_shape_from_storage(cls, storage_shape: tuple[int, ...]) -> tuple """Compute logical (padded) shape from storage shape by reversing packing.""" return (storage_shape[0], storage_shape[1] * 2) + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + qdata = qtensor._qdata.contiguous() + block_scale = qtensor._params.block_scale.contiguous() + metadata = { + "scale": qtensor._params.scale, + "orig_dtype": qtensor._params.orig_dtype, + "orig_shape": qtensor._params.orig_shape, + "transposed": qtensor._params.transposed, + "qdata_shape": tuple(qdata.shape), + } + return (qdata, block_scale), metadata + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + from .base import QuantizedTensor + + gathered_qdata, gathered_block_scale = all_gather_outputs + orig_shape = _scaled_rowwise_orig_shape(qtensor, gathered_qdata, metadata.get("orig_shape")) + params = cls.Params( + scale=metadata["scale"], + orig_dtype=metadata.get("orig_dtype", param_dtype), + orig_shape=orig_shape, + block_scale=gathered_block_scale, + transposed=metadata.get("transposed", False), + ) + + if out is not None: + out._qdata = gathered_qdata + out._params = params + return out, (gathered_qdata, gathered_block_scale) + return QuantizedTensor(gathered_qdata, qtensor._layout_cls, params), ( + gathered_qdata, + gathered_block_scale, + ) + + +class _CompositeWork: + def __init__(self, *works): + self._works = [work for work in works if work is not None] + + def wait(self): + result = None + for work in self._works: + result = work.wait() + return result + + def is_completed(self): + return all(getattr(work, "is_completed", lambda: True)() for work in self._works) + + +def _extract_collective_result(result): + if isinstance(result, tuple): + return result[0], result[1] + return result, None + + +def _block_scale_unblocked_shape(qtensor) -> tuple[int, int]: + storage_shape = tuple(qtensor._qdata.shape) + logical_cols = TensorCoreNVFP4Layout.get_logical_shape_from_storage(storage_shape)[1] + return storage_shape[0], logical_cols // 16 + + +def _unblock_block_scale(qtensor, block_scale: torch.Tensor | None = None) -> torch.Tensor: + block_scale = qtensor._params.block_scale if block_scale is None else block_scale + num_rows, num_cols = _block_scale_unblocked_shape(qtensor) + return from_blocked(block_scale, num_rows=num_rows, num_cols=num_cols) + + +def _reblock_block_scale(scale_rows: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return to_blocked(scale_rows.to(dtype=dtype), flatten=False) + + +def _scaled_rowwise_orig_shape( + qtensor, new_qdata: torch.Tensor, orig_shape=None +) -> tuple[int, ...]: + if orig_shape is not None: + return tuple(orig_shape) + orig_shape = qtensor._params.orig_shape + if getattr(qtensor._params, "transposed", False): + return tuple(orig_shape) + if len(orig_shape) != 2 or qtensor._qdata.dim() != 2 or new_qdata.dim() != 2: + return tuple(orig_shape) + + old_logical_rows = int(orig_shape[0]) + old_storage_rows = int(qtensor._qdata.shape[0]) + new_storage_rows = int(new_qdata.shape[0]) + if old_storage_rows == 0: + new_logical_rows = new_storage_rows + else: + new_logical_rows = ( + old_logical_rows * new_storage_rows + old_storage_rows - 1 + ) // old_storage_rows + return (new_logical_rows, int(orig_shape[1])) + + +def _wrap_nvfp4_tensor( + qtensor, + qdata: torch.Tensor, + *, + block_scale: torch.Tensor | None = None, + orig_shape: tuple[int, ...] | None = None, + transposed: bool | None = None, +): + from .base import QuantizedTensor + + new_params = TensorCoreNVFP4Layout.Params( + scale=qtensor._params.scale, + orig_dtype=qtensor._params.orig_dtype, + orig_shape=_scaled_rowwise_orig_shape(qtensor, qdata, orig_shape), + block_scale=qtensor._params.block_scale if block_scale is None else block_scale, + transposed=qtensor._params.transposed if transposed is None else transposed, + ) + return QuantizedTensor(qdata, qtensor._layout_cls, new_params) + + +def _normalize_slice_args(size: int, start, end, step) -> tuple[int, int, int]: + step = 1 if step is None else step + if step != 1: + raise NotImplementedError("NVFP4 only supports slice step=1") + start = 0 if start is None else start + end = size if end is None else end + if start < 0: + start += size + if end < 0: + end += size + start = max(0, min(start, size)) + end = max(start, min(end, size)) + return start, end, step + + +def _logical_rows_to_storage_rows(qtensor, start: int, end: int) -> tuple[int, int]: + logical_rows = int(qtensor._params.orig_shape[0]) + storage_rows = int(qtensor._qdata.shape[0]) + if logical_rows <= 0: + return 0, 0 + start_storage = (start * storage_rows) // logical_rows + end_storage = (end * storage_rows + logical_rows - 1) // logical_rows + return start_storage, end_storage + + +def _slice_rows_nvfp4(input_tensor, start, end): + start_storage, end_storage = _logical_rows_to_storage_rows(input_tensor, start, end) + sliced_qdata = torch.ops.aten.slice.Tensor( + input_tensor._qdata, 0, start_storage, end_storage, 1 + ) + block_scale_rows = _unblock_block_scale(input_tensor) + sliced_block_scale = block_scale_rows[start_storage:end_storage] + reblocked_scale = _reblock_block_scale( + sliced_block_scale, input_tensor._params.block_scale.dtype + ) + new_orig_shape = (end - start, int(input_tensor._params.orig_shape[1])) + return _wrap_nvfp4_tensor( + input_tensor, sliced_qdata, block_scale=reblocked_scale, orig_shape=new_orig_shape + ) + + +# ==================== Distributed Operations ==================== + + +@register_layout_op( + torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreNVFP4Layout +) +def _handle_all_gather(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = None + input_idx = None + for i, arg in enumerate(args): + if isinstance(arg, QuantizedTensor): + input_tensor = arg + input_idx = i + break + + assert input_tensor is not None + assert input_idx is not None + + qdata_bytes = input_tensor._qdata.contiguous().view(torch.uint8) + block_scale_bytes = input_tensor._params.block_scale.contiguous().view(torch.uint8) + + q_args = list(args) + q_args[input_idx] = qdata_bytes + gathered_qdata_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( + *q_args, **kwargs + ) + + b_args = list(args) + b_args[input_idx] = block_scale_bytes + gathered_block_scale_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( + *b_args, **kwargs + ) + + gathered_qdata = gathered_qdata_bytes.view(input_tensor._qdata.dtype) + gathered_block_scale = gathered_block_scale_bytes.view(input_tensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(input_tensor, gathered_qdata, block_scale=gathered_block_scale) + + +@register_layout_op(torch.ops._c10d_functional.wait_tensor.default, TensorCoreNVFP4Layout) +def _handle_wait_tensor(qt, args, kwargs): + qtensor = args[0] + + waited_qdata_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._qdata.view(torch.uint8), + *args[1:], + **kwargs, + ) + waited_block_scale_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._params.block_scale.contiguous().view(torch.uint8), + *args[1:], + **kwargs, + ) + + waited_qdata = waited_qdata_bytes.view(qtensor._qdata.dtype) + waited_block_scale = waited_block_scale_bytes.view(qtensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(qtensor, waited_qdata, block_scale=waited_block_scale) + + +@register_layout_op(torch.ops.c10d.broadcast_.default, TensorCoreNVFP4Layout) +def _handle_broadcast(qt, args, kwargs): + from .base import QuantizedTensor + + tensor_list = args[0] + quantized_entries = [ + (idx, tensor) + for idx, tensor in enumerate(tensor_list) + if isinstance(tensor, QuantizedTensor) + ] + if not quantized_entries: + return torch.ops.c10d.broadcast_.default(*args, **kwargs) + + q_tensor_list = list(tensor_list) + b_tensor_list = list(tensor_list) + for idx, tensor in quantized_entries: + q_tensor_list[idx] = tensor._qdata.contiguous().view(torch.uint8) + b_tensor_list[idx] = tensor._params.block_scale.contiguous().view(torch.uint8) + + q_result = torch.ops.c10d.broadcast_.default(q_tensor_list, *args[1:], **kwargs) + b_result = torch.ops.c10d.broadcast_.default(b_tensor_list, *args[1:], **kwargs) + q_list, q_work = _extract_collective_result(q_result) + b_list, b_work = _extract_collective_result(b_result) + + output_list = list(q_list) + for idx, original in quantized_entries: + output_list[idx] = _wrap_nvfp4_tensor( + original, + q_list[idx].view(original._qdata.dtype), + block_scale=b_list[idx].view(original._params.block_scale.dtype), + orig_shape=original._params.orig_shape, + ) + + if q_work is not None or b_work is not None: + return output_list, _CompositeWork(q_work, b_work) + return output_list + + +@register_layout_op(torch.ops.c10d.scatter_.default, TensorCoreNVFP4Layout) +def _handle_scatter(qt, args, kwargs): + from .base import QuantizedTensor + + output_tensors = args[0] + input_tensors = args[1] + + quantized_outputs = [] + new_q_outputs = list(output_tensors) + new_b_outputs = list(output_tensors) + for idx, tensor in enumerate(output_tensors): + if isinstance(tensor, QuantizedTensor): + quantized_outputs.append((idx, tensor)) + new_q_outputs[idx] = tensor._qdata.contiguous().view(torch.uint8) + new_b_outputs[idx] = tensor._params.block_scale.contiguous().view(torch.uint8) + + has_quantized_input = False + q_inputs = [] + b_inputs = [] + for entry in input_tensors: + if isinstance(entry, (list, tuple)): + q_entry = [] + b_entry = [] + for tensor in entry: + if isinstance(tensor, QuantizedTensor): + has_quantized_input = True + q_entry.append(tensor._qdata.contiguous().view(torch.uint8)) + b_entry.append(tensor._params.block_scale.contiguous().view(torch.uint8)) + else: + q_entry.append(tensor) + b_entry.append(tensor) + q_inputs.append(q_entry) + b_inputs.append(b_entry) + else: + q_inputs.append(entry) + b_inputs.append(entry) + + if not quantized_outputs and not has_quantized_input: + return torch.ops.c10d.scatter_.default(*args, **kwargs) + + q_result = torch.ops.c10d.scatter_.default(new_q_outputs, q_inputs, *args[2:], **kwargs) + b_result = torch.ops.c10d.scatter_.default(new_b_outputs, b_inputs, *args[2:], **kwargs) + q_list, q_work = _extract_collective_result(q_result) + b_list, b_work = _extract_collective_result(b_result) + + output_list = list(q_list) + for idx, original in quantized_outputs: + output_list[idx] = _wrap_nvfp4_tensor( + original, + q_list[idx].view(original._qdata.dtype), + block_scale=b_list[idx].view(original._params.block_scale.dtype), + orig_shape=original._params.orig_shape, + ) + + if q_work is not None or b_work is not None: + return output_list, _CompositeWork(q_work, b_work) + return output_list + + +# ==================== NVFP4 Shape Operations ==================== + + +@register_layout_op(torch.ops.aten.alias.default, TensorCoreNVFP4Layout) +def _handle_nvfp4_alias(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.alias.default(*args, **kwargs) + + aliased_qdata = torch.ops.aten.alias.default(input_tensor._qdata) + aliased_block_scale = torch.ops.aten.alias.default(input_tensor._params.block_scale) + return _wrap_nvfp4_tensor( + input_tensor, + aliased_qdata, + block_scale=aliased_block_scale, + orig_shape=input_tensor._params.orig_shape, + transposed=input_tensor._params.transposed, + ) + + +@register_layout_op(torch.ops.aten.slice.Tensor, TensorCoreNVFP4Layout) +def _handle_nvfp4_slice(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.slice.Tensor(*args, **kwargs) + if getattr(input_tensor._params, "transposed", False): + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + dim = args[1] if len(args) > 1 else 0 + start = args[2] if len(args) > 2 else None + end = args[3] if len(args) > 3 else None + step = args[4] if len(args) > 4 else None + dim = dim if dim >= 0 else dim + len(input_tensor._params.orig_shape) + + if dim == 0: + start, end, _ = _normalize_slice_args( + int(input_tensor._params.orig_shape[0]), start, end, step + ) + return _slice_rows_nvfp4(input_tensor, start, end) + + if dim == 1: + start, end, step = _normalize_slice_args( + int(input_tensor._params.orig_shape[1]), start, end, step + ) + if start == 0 and end == int(input_tensor._params.orig_shape[1]) and step == 1: + return _handle_nvfp4_alias(qt, (input_tensor,), {}) + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + +@register_layout_op(torch.ops.aten.split.Tensor, TensorCoreNVFP4Layout) +def _handle_nvfp4_split(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.split.Tensor(*args, **kwargs) + if getattr(input_tensor._params, "transposed", False): + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + split_size = args[1] + dim = kwargs.get("dim", args[2] if len(args) > 2 else 0) + dim = dim if dim >= 0 else dim + len(input_tensor._params.orig_shape) + if dim != 0: + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + logical_rows = int(input_tensor._params.orig_shape[0]) + if isinstance(split_size, int): + chunks = [] + for start in range(0, logical_rows, split_size): + end = min(start + split_size, logical_rows) + chunks.append(_slice_rows_nvfp4(input_tensor, start, end)) + return tuple(chunks) + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + +@register_layout_op(torch.ops.aten.cat.default, TensorCoreNVFP4Layout) +def _handle_nvfp4_cat(qt, args, kwargs): + from .base import QuantizedTensor + + tensors = args[0] + dim = kwargs.get("dim", args[1] if len(args) > 1 else 0) + if dim != 0 or not isinstance(tensors, (list, tuple)) or not tensors: + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + if not all(isinstance(tensor, QuantizedTensor) for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + + first = tensors[0] + if any(getattr(tensor._params, "transposed", False) for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + if any(tensor._params.orig_shape[1] != first._params.orig_shape[1] for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + + qdata = torch.ops.aten.cat.default([tensor._qdata for tensor in tensors], 0) + block_scale_rows = [_unblock_block_scale(tensor) for tensor in tensors] + block_scale = _reblock_block_scale( + torch.ops.aten.cat.default(block_scale_rows, 0), first._params.block_scale.dtype + ) + orig_rows = sum(int(tensor._params.orig_shape[0]) for tensor in tensors) + return _wrap_nvfp4_tensor( + first, + qdata, + block_scale=block_scale, + orig_shape=(orig_rows, int(first._params.orig_shape[1])), + ) + + +@register_layout_op(torch.ops.aten.new_zeros.default, TensorCoreNVFP4Layout) +def _handle_new_zeros(qt, args, kwargs): + input_tensor = args[0] + size = tuple(args[1]) if len(args) > 1 else tuple(input_tensor._params.orig_shape) + if len(size) != 2: + return torch.ops.aten.new_zeros.default(*dequantize_args(args), **kwargs) + + device = kwargs.get("device", input_tensor._qdata.device) + storage_shape = TensorCoreNVFP4Layout.get_storage_shape(size) + qdata = torch.zeros(storage_shape, device=device, dtype=input_tensor._qdata.dtype) + block_cols = TensorCoreNVFP4Layout.get_padded_shape(size)[1] // 16 + block_scale_rows = torch.zeros( + (storage_shape[0], block_cols), + device=device, + dtype=input_tensor._params.block_scale.dtype, + ) + block_scale = _reblock_block_scale(block_scale_rows, input_tensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(input_tensor, qdata, block_scale=block_scale, orig_shape=size) + # ==================== NVFP4 Transpose Operation ==================== # Transpose is a no-op that tracks logical transposition via a flag. @register_layout_op(torch.ops.aten.t.default, TensorCoreNVFP4Layout) def _handle_nvfp4_transpose(qt, args, kwargs): - """Handle transpose as a logical no-op for NVFP4. - """ + """Handle transpose as a logical no-op for NVFP4.""" from .base import QuantizedTensor input_tensor = args[0] @@ -157,6 +612,13 @@ def _slice_to_original_shape( return result +def _linear_dequantize_fallback(input_tensor, weight, bias): + input_dense, weight_dense, bias_dense = dequantize_args((input_tensor, weight, bias)) + assert isinstance(input_dense, torch.Tensor) + assert isinstance(weight_dense, torch.Tensor) + return torch.nn.functional.linear(input_dense, weight_dense, bias_dense) + + @register_layout_op(torch.ops.aten.mm.default, TensorCoreNVFP4Layout) def _handle_nvfp4_mm(qt, args, kwargs): """NVFP4 matrix multiplication: output = a @ b @@ -165,17 +627,15 @@ def _handle_nvfp4_mm(qt, args, kwargs): with scaled_mm_nvfp4 since that kernel computes a @ b_phys.T, which equals a @ b_logical when b_logical = b_phys.T. - This handles the common torch.compile decomposition: linear(x, w) → mm(x, w.t()) + This handles the common torch.compile decomposition: linear(x, w) -> mm(x, w.t()) """ from .base import QuantizedTensor a, b = args[0], args[1] - # Fast path: both operands are NVFP4 QuantizedTensors if not (isinstance(a, QuantizedTensor) and isinstance(b, QuantizedTensor)): return torch.mm(*dequantize_args(args)) - # NVFP4 only supports 2D tensors if a._qdata.dim() != 2: return torch.mm(*dequantize_args(args)) @@ -183,7 +643,6 @@ def _handle_nvfp4_mm(qt, args, kwargs): b_transposed = getattr(b._params, "transposed", False) if a_transposed or not b_transposed: - # Can't handle these cases with current kernel, fallback logger.debug("NVFP4 mm: unsupported transpose configuration, falling back to dequantize") return torch.mm(*dequantize_args(args)) @@ -223,26 +682,25 @@ def _handle_nvfp4_linear(qt, args, kwargs): input_tensor, weight = args[0], args[1] bias = args[2] if len(args) > 2 else None - # Fast path: both operands are NVFP4 QuantizedTensors if not (isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor)): - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias) - # NVFP4 only supports 2D tensors if input_tensor._qdata.dim() != 2: - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias) input_transposed = getattr(input_tensor._params, "transposed", False) weight_transposed = getattr(weight._params, "transposed", False) if input_transposed or weight_transposed: - logger.debug("NVFP4 linear: unsupported transpose configuration, falling back to dequantize") - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + logger.debug( + "NVFP4 linear: unsupported transpose configuration, falling back to dequantize" + ) + return _linear_dequantize_fallback(input_tensor, weight, bias) input_qdata, scale_a, block_scale_a = TensorCoreNVFP4Layout.get_plain_tensors(input_tensor) weight_qdata, scale_b, block_scale_b = TensorCoreNVFP4Layout.get_plain_tensors(weight) out_dtype = kwargs.get("out_dtype", input_tensor._params.orig_dtype) try: - # scaled_mm_nvfp4 computes (a @ b.T) * scale, which is linear semantics result = ck.scaled_mm_nvfp4( input_qdata, weight_qdata, @@ -254,11 +712,10 @@ def _handle_nvfp4_linear(qt, args, kwargs): out_dtype=out_dtype, ) - # Slice output to original (non-padded) shape orig_m = input_tensor._params.orig_shape[0] - orig_n = weight._params.orig_shape[0] # weight is (out_features, in_features) + orig_n = weight._params.orig_shape[0] return _slice_to_original_shape(result, orig_m, orig_n) except (RuntimeError, TypeError) as e: logger.warning(f"NVFP4 scaled_mm failed: {e}, falling back to dequantization") - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias) From 6bfe97b160621455e727fa4783e7071fd9687a32 Mon Sep 17 00:00:00 2001 From: komikndr Date: Sat, 7 Mar 2026 15:54:00 +0700 Subject: [PATCH 10/10] Update README for NVFP4 --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 9d5b90a..297b6ae 100644 --- a/README.md +++ b/README.md @@ -17,12 +17,15 @@ Fast kernel library for Diffusion inference with multiple compute backends. | `apply_rope` | ✓ | ✓ | ✓ | | `apply_rope1` | ✓ | ✓ | ✓ | -## DTensor Capabilities Matrix -| Layout | Status | -|------------------------|--------| -| `TensorCoreFP8Layout` | ✓ | -| `TensorCoreNVFP4Layout`| ✓ | -| `TensorCoreMXFP8Layout`| | +## Distributed Capabilities Matrix + +| Layout | DTensor|FSDP pre/post all-gather| +|------------------------|--------|------------------------| +| `TensorCoreFP8Layout` | ✓ | ✓ | +| `TensorCoreNVFP4Layout`| ✓ | ✓ | +| `TensorCoreMXFP8Layout`| | | + +This is for custom nodes that might implement distributed operations. Such as [Raylight](https://github.com/komikndr/raylight) ## Quantized Tensors