diff --git a/NOTICE b/NOTICE index 975dbc1..f839ad9 100644 --- a/NOTICE +++ b/NOTICE @@ -52,6 +52,9 @@ Portions of the following files are derived from PyTorch AO: - comfy_kitchen/backends/eager/quantization.py Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py +- comfy_kitchen/tensor/fp8.py + Source: https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py + - comfy_kitchen/float_utils.py (_f32_to_floatx_unpacked, _floatx_unpacked_to_f32) Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py diff --git a/README.md b/README.md index dc5d448..297b6ae 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,16 @@ Fast kernel library for Diffusion inference with multiple compute backends. | `apply_rope` | ✓ | ✓ | ✓ | | `apply_rope1` | ✓ | ✓ | ✓ | +## Distributed Capabilities Matrix + +| Layout | DTensor|FSDP pre/post all-gather| +|------------------------|--------|------------------------| +| `TensorCoreFP8Layout` | ✓ | ✓ | +| `TensorCoreNVFP4Layout`| ✓ | ✓ | +| `TensorCoreMXFP8Layout`| | | + +This is for custom nodes that might implement distributed operations. Such as [Raylight](https://github.com/komikndr/raylight) + ## Quantized Tensors diff --git a/comfy_kitchen/tensor/base.py b/comfy_kitchen/tensor/base.py index bd7ef7a..03b0ad5 100644 --- a/comfy_kitchen/tensor/base.py +++ b/comfy_kitchen/tensor/base.py @@ -137,6 +137,41 @@ def get_requirements(cls) -> dict[str, Any]: "fast_matmul_supported": cls.supports_fast_matmul(), } + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + """Prepare data for FSDP all_gather. + + Returns: + Tuple of (all_gather_outputs, metadata): + - all_gather_outputs: data to be gathered (as tuple of tensors) + - metadata: additional info needed after gathering (as tuple) + """ + raise NotImplementedError(f"pre_all_gather not implemented for {cls.__name__}") + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + """Reconstruct QuantizedTensor after FSDP all_gather. + + Args: + qtensor: Original quantized tensor (used for layout context) + all_gather_outputs: Gathered tensors + metadata: Metadata from pre_all_gather + param_dtype: Expected parameter dtype + out: Optional output tensor to update in-place + + Returns: + Either updates out in-place (returns None) or returns new (QuantizedTensor, metadata) tuple + """ + raise NotImplementedError(f"post_all_gather not implemented for {cls.__name__}") + class QuantizedTensor(torch.Tensor): """ @@ -327,6 +362,25 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): params = ctx["params_class"](**params_kwargs) return QuantizedTensor(inner_tensors["_qdata"], ctx["layout_cls"], params) + # ==================== FSDP Hooks ==================== + + def fsdp_pre_all_gather(self, mesh): + """FSDP pre_all_gather hook - delegates to layout class.""" + return self.layout_cls.pre_all_gather(self, mesh) + + def fsdp_post_all_gather( + self, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + """FSDP post_all_gather hook - delegates to layout class.""" + return self.layout_cls.post_all_gather( + self, all_gather_outputs, metadata, param_dtype, out=out + ) + # ==================== Torch Dispatch ==================== @classmethod @@ -348,7 +402,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return op_handlers[parent_cls](qt, args, kwargs) # Step 3: Fallback to dequantization - logger.debug(f"Unhandled op {func} for {layout_cls.__name__ if layout_cls else 'unknown'}, dequantizing") return cls._dequant_and_fallback(func, args, kwargs) @classmethod @@ -438,13 +491,15 @@ def _handle_is_contiguous(qt, args, kwargs): def _handle_copy_(qt, args, kwargs): dst, src = args[0], args[1] - if not isinstance(src, QuantizedTensor): - raise TypeError(f"Cannot copy {type(src).__name__} to QuantizedTensor") + non_blocking = kwargs.get("non_blocking", len(args) >= 3) + + if not isinstance(dst, QuantizedTensor): + dst = QuantizedTensor(src._qdata, src._layout_cls, src._params) + if dst._layout_cls != src._layout_cls: raise TypeError(f"Layout mismatch: {dst._layout_cls} vs {src._layout_cls}") dst_orig_dtype = dst._params.orig_dtype - non_blocking = kwargs.get("non_blocking", len(args) >= 3) dst._qdata.copy_(src._qdata, non_blocking=non_blocking) dst._params.copy_from(src._params, non_blocking=non_blocking) diff --git a/comfy_kitchen/tensor/fp8.py b/comfy_kitchen/tensor/fp8.py index 2c20bc1..0380998 100644 --- a/comfy_kitchen/tensor/fp8.py +++ b/comfy_kitchen/tensor/fp8.py @@ -2,13 +2,13 @@ from __future__ import annotations import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Any, Optional, Tuple +from .base import QuantizedTensor import torch import comfy_kitchen as ck -from comfy_kitchen.scaled_mm_v2 import scaled_mm_v2 from .base import BaseLayoutParams, QuantizedLayout, dequantize_args, register_layout_op @@ -80,6 +80,49 @@ def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, to "_scale": params.scale, } + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + qdata = qtensor._qdata + if not qdata.is_contiguous(): + qdata = qdata.contiguous() + + scale = qtensor._params.scale + if isinstance(scale, torch.Tensor): + scale = scale.to(device=qdata.device) + + return (qdata,), (scale,) + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + (data,) = all_gather_outputs + (scale,) = metadata + + if out is not None: + if not isinstance(out, QuantizedTensor): + raise TypeError(f"Expected QuantizedTensor out, got {type(out)}") + out._qdata = data + out._params = cls.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return + + params = cls.Params( + scale=scale, + orig_dtype=param_dtype, + orig_shape=tuple(data.shape), + ) + return QuantizedTensor(data, qtensor._layout_cls, params), (data,) + # ==================== Helper Utilities ==================== @@ -91,14 +134,17 @@ def _fp8_scaled_mm( bias: torch.Tensor | None = None, out_dtype: torch.dtype | None = None, ) -> torch.Tensor: - return scaled_mm_v2( + """Core FP8 scaled matrix multiplication using torch._scaled_mm.""" + output = torch._scaled_mm( input_qdata.contiguous(), weight_qdata, + bias=bias, scale_a=scale_a, scale_b=scale_b, - bias=bias, out_dtype=out_dtype, ) + # Handle tuple return for older PyTorch versions + return output[0] if isinstance(output, tuple) else output def _make_fp8_shape_handler(aten_op): @@ -211,6 +257,231 @@ def _handle_fp8_addmm(qt, args, kwargs): return torch.addmm(*dequantize_args(args)) +# ==================== Distributed Operations ==================== +# Required c10d ops : c10d allgather, c10d wait, c10d broadcast (this is for broadcast_rank0=True) +# Required aten ops : slice, split, new_zeros, as_strided, cat, alias + +@register_layout_op(torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreFP8Layout) +def _handle_all_gather(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = None + input_idx = None + for i, arg in enumerate(args): + if isinstance(arg, QuantizedTensor): + input_tensor = arg + input_idx = i + + assert input_tensor is not None + + qdata = input_tensor._qdata + layout_cls = input_tensor._layout_cls + params = input_tensor._params + + qdata_bytes = qdata.contiguous().view(torch.uint8) + + new_args = list(args) + new_args[input_idx] = qdata_bytes + + gathered_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default(*new_args, **kwargs) + + gathered_qdata = gathered_bytes.view(qdata.dtype) + gathered_params = replace(params, orig_shape=tuple(gathered_qdata.shape)) + return QuantizedTensor(gathered_qdata, layout_cls, gathered_params) + + +@register_layout_op(torch.ops._c10d_functional.wait_tensor.default, TensorCoreFP8Layout) +def _handle_wait_tensor(qt, args, kwargs): + from .base import QuantizedTensor + + qtensor = args[0] + + waited_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._qdata.view(torch.uint8), + *args[1:], + **kwargs, + ) + + waited_qdata = waited_bytes.view(qtensor._qdata.dtype) + waited_params = replace(qtensor._params, orig_shape=tuple(waited_qdata.shape)) + return QuantizedTensor(waited_qdata, qtensor._layout_cls, waited_params) + + +# Should not be use, since tensor.copy_ inside fsdp, calls native dtype that can't recieve correct Quantized class +# This is mainly for state_dict broadcast from rank 0. +@register_layout_op(torch.ops.c10d.broadcast_.default, TensorCoreFP8Layout) +def _handle_broadcast(qt, args, kwargs): + from .base import QuantizedTensor + import torch + + tensor_list = args[0] + + input_tensor = None + input_idx = None + for idx, t in enumerate(tensor_list): + if isinstance(t, QuantizedTensor): + input_tensor = t + input_idx = idx + break + + if input_tensor is None: + return torch.ops.c10d.broadcast_.default(*args, **kwargs) + + qdata = input_tensor._qdata.contiguous() + qdata_bytes = qdata.view(torch.uint8) + + new_tensor_list = list(tensor_list) + new_tensor_list[input_idx] = qdata_bytes + + new_args = list(args) + new_args[0] = new_tensor_list + + broadcasted = torch.ops.c10d.broadcast_.default( + *new_args, + **kwargs, + ) + + if isinstance(broadcasted, tuple): + tensor_list_out, work = broadcasted + else: + tensor_list_out = broadcasted + work = None + + broadcasted_qdata = tensor_list_out[input_idx].view(qdata.dtype) + + new_out_list = list(tensor_list_out) + new_out_list[input_idx] = _wrap_fp8_tensor( + input_tensor, + broadcasted_qdata, + ) + + if work is not None: + return new_out_list, work + else: + return new_out_list + + +def _wrap_fp8_tensor(qtensor, qdata): + from .base import QuantizedTensor + + new_params = TensorCoreFP8Layout.Params( + scale=qtensor._params.scale, + orig_dtype=qtensor._params.orig_dtype, + orig_shape=tuple(qdata.shape), + ) + return QuantizedTensor(qdata, qtensor._layout_cls, new_params) + + +@register_layout_op(torch.ops.c10d.scatter_.default, TensorCoreFP8Layout) +def _handle_scatter(qt, args, kwargs): + from .base import QuantizedTensor + + output_tensors = args[0] + input_tensors = args[1] + + quantized_outputs: list[tuple[int, QuantizedTensor]] = [] + new_output_tensors = list(output_tensors) + for idx, tensor in enumerate(output_tensors): + if isinstance(tensor, QuantizedTensor): + quantized_outputs.append((idx, tensor)) + new_output_tensors[idx] = tensor._qdata.contiguous().view(torch.uint8) + + has_quantized_input = False + new_input_tensors: list[list[torch.Tensor]] = [] + + def process_input_list(entry): + nonlocal has_quantized_input + processed: list[torch.Tensor] = [] + for t in entry: + if isinstance(t, QuantizedTensor): + has_quantized_input = True + processed.append(t._qdata.contiguous().view(torch.uint8)) + else: + processed.append(t) + return processed + + for entry in input_tensors: + if isinstance(entry, (list, tuple)): + new_input_tensors.append(process_input_list(entry)) + else: + new_input_tensors.append(entry) # type: ignore[arg-type] + + if not quantized_outputs and not has_quantized_input: + return torch.ops.c10d.scatter_.default(*args, **kwargs) + + new_args = [new_output_tensors, new_input_tensors, *args[2:]] + result = torch.ops.c10d.scatter_.default(*new_args, **kwargs) + + if isinstance(result, tuple): + output_list, work = result + else: + output_list = result + work = None + + output_list = list(output_list) + for idx, original in quantized_outputs: + qdata = output_list[idx].view(original._qdata.dtype) + output_list[idx] = _wrap_fp8_tensor(original, qdata) + + if work is not None: + return output_list, work + return output_list + + +@register_layout_op(torch.ops.aten.slice.Tensor, TensorCoreFP8Layout) +def _handle_fp8_slice_tensor(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.slice.Tensor(*args, **kwargs) + + sliced_qdata = torch.ops.aten.slice.Tensor(input_tensor._qdata, *args[1:], **kwargs) + return _wrap_fp8_tensor(input_tensor, sliced_qdata) + + +@register_layout_op(torch.ops.aten.split.Tensor, TensorCoreFP8Layout) +def _handle_fp8_split(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.split.Tensor(*args, **kwargs) + + qdata_chunks = torch.ops.aten.split.Tensor(input_tensor._qdata, *args[1:], **kwargs) + wrapped_chunks = tuple(_wrap_fp8_tensor(input_tensor, chunk) for chunk in qdata_chunks) + return wrapped_chunks + + +@register_layout_op(torch.ops.aten.cat.default, TensorCoreFP8Layout) +def _handle_fp8_cat(qt, args, kwargs): + from .base import QuantizedTensor + + tensors = args[0] + if not isinstance(tensors, (list, tuple)) or not tensors: + return torch.ops.aten.cat.default(*args, **kwargs) + + qdata_list = [] + first_qtensor = None + for item in tensors: + if not isinstance(item, QuantizedTensor): + return torch.ops.aten.cat.default(*args, **kwargs) + qdata_list.append(item._qdata) + if first_qtensor is None: + first_qtensor = item + + assert first_qtensor is not None + concatenated_qdata = torch.ops.aten.cat.default(qdata_list, *args[1:], **kwargs) + return _wrap_fp8_tensor(first_qtensor, concatenated_qdata) + + +@register_layout_op(torch.ops.aten.new_zeros.default, TensorCoreFP8Layout) +def _handle_new_zeros(qt, args, kwargs): + input_tensor = args[0] + new_zero_qdata = torch.ops.aten.new_zeros.default(input_tensor._qdata, *args[1:], **kwargs) + return _wrap_fp8_tensor(input_tensor, new_zero_qdata) + + # ==================== FP8 Shape Operations ==================== # These preserve quantization since FP8 is not packed (1:1 element mapping) @@ -218,5 +489,7 @@ def _handle_fp8_addmm(qt, args, kwargs): torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten.t.default, + torch.ops.aten.as_strided.default, + torch.ops.aten.alias.default ): register_layout_op(_aten_op, TensorCoreFP8Layout)(_make_fp8_shape_handler(_aten_op)) diff --git a/comfy_kitchen/tensor/nvfp4.py b/comfy_kitchen/tensor/nvfp4.py index a0e8409..7f68595 100644 --- a/comfy_kitchen/tensor/nvfp4.py +++ b/comfy_kitchen/tensor/nvfp4.py @@ -1,4 +1,5 @@ """NVFP4 (E2M1) block quantization layout for tensor cores.""" + from __future__ import annotations import logging @@ -8,7 +9,7 @@ import torch import comfy_kitchen as ck -from comfy_kitchen.float_utils import F4_E2M1_MAX, F8_E4M3_MAX, roundup +from comfy_kitchen.float_utils import F4_E2M1_MAX, F8_E4M3_MAX, from_blocked, roundup, to_blocked from .base import BaseLayoutParams, QuantizedLayout, dequantize_args, register_layout_op @@ -24,8 +25,8 @@ class TensorCoreNVFP4Layout(QuantizedLayout): Note: Requires SM >= 10.0 (Blackwell) for hardware-accelerated matmul. - Shape operations (view, reshape, transpose) are not supported due to - packed format and block scales - they fall back to dequantization. + View-like operations remain limited because NVFP4 uses packed values plus + blocked scales, but FSDP row-wise alias/slice/split are supported. """ MIN_SM_VERSION = (10, 0) @@ -37,6 +38,7 @@ class Params(BaseLayoutParams): Inherits scale, orig_dtype, orig_shape from BaseLayoutParams. Adds block_scale for per-block scaling factors. """ + block_scale: torch.Tensor transposed: bool = False @@ -46,7 +48,9 @@ def _tensor_fields(self) -> list[str]: def _validate_tensor_fields(self): if isinstance(self.scale, torch.Tensor): - object.__setattr__(self, "scale", self.scale.to(dtype=torch.float32, non_blocking=True)) + object.__setattr__( + self, "scale", self.scale.to(dtype=torch.float32, non_blocking=True) + ) @classmethod def quantize( @@ -93,7 +97,6 @@ def get_plain_tensors( @classmethod def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, torch.Tensor]: - """Return key suffix → tensor mapping for serialization.""" return { "": qdata, "_scale": params.block_scale, @@ -117,14 +120,466 @@ def get_logical_shape_from_storage(cls, storage_shape: tuple[int, ...]) -> tuple """Compute logical (padded) shape from storage shape by reversing packing.""" return (storage_shape[0], storage_shape[1] * 2) + @classmethod + def pre_all_gather(cls, qtensor: QuantizedTensor, mesh): + qdata = qtensor._qdata.contiguous() + block_scale = qtensor._params.block_scale.contiguous() + metadata = { + "scale": qtensor._params.scale, + "orig_dtype": qtensor._params.orig_dtype, + "orig_shape": qtensor._params.orig_shape, + "transposed": qtensor._params.transposed, + "qdata_shape": tuple(qdata.shape), + } + return (qdata, block_scale), metadata + + @classmethod + def post_all_gather( + cls, + qtensor: QuantizedTensor, + all_gather_outputs: tuple[torch.Tensor, ...], + metadata, + param_dtype: torch.dtype, + *, + out: QuantizedTensor | None = None, + ): + from .base import QuantizedTensor + + gathered_qdata, gathered_block_scale = all_gather_outputs + orig_shape = _scaled_rowwise_orig_shape(qtensor, gathered_qdata, metadata.get("orig_shape")) + params = cls.Params( + scale=metadata["scale"], + orig_dtype=metadata.get("orig_dtype", param_dtype), + orig_shape=orig_shape, + block_scale=gathered_block_scale, + transposed=metadata.get("transposed", False), + ) + + if out is not None: + out._qdata = gathered_qdata + out._params = params + return out, (gathered_qdata, gathered_block_scale) + return QuantizedTensor(gathered_qdata, qtensor._layout_cls, params), ( + gathered_qdata, + gathered_block_scale, + ) + + +class _CompositeWork: + def __init__(self, *works): + self._works = [work for work in works if work is not None] + + def wait(self): + result = None + for work in self._works: + result = work.wait() + return result + + def is_completed(self): + return all(getattr(work, "is_completed", lambda: True)() for work in self._works) + + +def _extract_collective_result(result): + if isinstance(result, tuple): + return result[0], result[1] + return result, None + + +def _block_scale_unblocked_shape(qtensor) -> tuple[int, int]: + storage_shape = tuple(qtensor._qdata.shape) + logical_cols = TensorCoreNVFP4Layout.get_logical_shape_from_storage(storage_shape)[1] + return storage_shape[0], logical_cols // 16 + + +def _unblock_block_scale(qtensor, block_scale: torch.Tensor | None = None) -> torch.Tensor: + block_scale = qtensor._params.block_scale if block_scale is None else block_scale + num_rows, num_cols = _block_scale_unblocked_shape(qtensor) + return from_blocked(block_scale, num_rows=num_rows, num_cols=num_cols) + + +def _reblock_block_scale(scale_rows: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return to_blocked(scale_rows.to(dtype=dtype), flatten=False) + + +def _scaled_rowwise_orig_shape( + qtensor, new_qdata: torch.Tensor, orig_shape=None +) -> tuple[int, ...]: + if orig_shape is not None: + return tuple(orig_shape) + orig_shape = qtensor._params.orig_shape + if getattr(qtensor._params, "transposed", False): + return tuple(orig_shape) + if len(orig_shape) != 2 or qtensor._qdata.dim() != 2 or new_qdata.dim() != 2: + return tuple(orig_shape) + + old_logical_rows = int(orig_shape[0]) + old_storage_rows = int(qtensor._qdata.shape[0]) + new_storage_rows = int(new_qdata.shape[0]) + if old_storage_rows == 0: + new_logical_rows = new_storage_rows + else: + new_logical_rows = ( + old_logical_rows * new_storage_rows + old_storage_rows - 1 + ) // old_storage_rows + return (new_logical_rows, int(orig_shape[1])) + + +def _wrap_nvfp4_tensor( + qtensor, + qdata: torch.Tensor, + *, + block_scale: torch.Tensor | None = None, + orig_shape: tuple[int, ...] | None = None, + transposed: bool | None = None, +): + from .base import QuantizedTensor + + new_params = TensorCoreNVFP4Layout.Params( + scale=qtensor._params.scale, + orig_dtype=qtensor._params.orig_dtype, + orig_shape=_scaled_rowwise_orig_shape(qtensor, qdata, orig_shape), + block_scale=qtensor._params.block_scale if block_scale is None else block_scale, + transposed=qtensor._params.transposed if transposed is None else transposed, + ) + return QuantizedTensor(qdata, qtensor._layout_cls, new_params) + + +def _normalize_slice_args(size: int, start, end, step) -> tuple[int, int, int]: + step = 1 if step is None else step + if step != 1: + raise NotImplementedError("NVFP4 only supports slice step=1") + start = 0 if start is None else start + end = size if end is None else end + if start < 0: + start += size + if end < 0: + end += size + start = max(0, min(start, size)) + end = max(start, min(end, size)) + return start, end, step + + +def _logical_rows_to_storage_rows(qtensor, start: int, end: int) -> tuple[int, int]: + logical_rows = int(qtensor._params.orig_shape[0]) + storage_rows = int(qtensor._qdata.shape[0]) + if logical_rows <= 0: + return 0, 0 + start_storage = (start * storage_rows) // logical_rows + end_storage = (end * storage_rows + logical_rows - 1) // logical_rows + return start_storage, end_storage + + +def _slice_rows_nvfp4(input_tensor, start, end): + start_storage, end_storage = _logical_rows_to_storage_rows(input_tensor, start, end) + sliced_qdata = torch.ops.aten.slice.Tensor( + input_tensor._qdata, 0, start_storage, end_storage, 1 + ) + block_scale_rows = _unblock_block_scale(input_tensor) + sliced_block_scale = block_scale_rows[start_storage:end_storage] + reblocked_scale = _reblock_block_scale( + sliced_block_scale, input_tensor._params.block_scale.dtype + ) + new_orig_shape = (end - start, int(input_tensor._params.orig_shape[1])) + return _wrap_nvfp4_tensor( + input_tensor, sliced_qdata, block_scale=reblocked_scale, orig_shape=new_orig_shape + ) + + +# ==================== Distributed Operations ==================== + + +@register_layout_op( + torch.ops._c10d_functional.all_gather_into_tensor.default, TensorCoreNVFP4Layout +) +def _handle_all_gather(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = None + input_idx = None + for i, arg in enumerate(args): + if isinstance(arg, QuantizedTensor): + input_tensor = arg + input_idx = i + break + + assert input_tensor is not None + assert input_idx is not None + + qdata_bytes = input_tensor._qdata.contiguous().view(torch.uint8) + block_scale_bytes = input_tensor._params.block_scale.contiguous().view(torch.uint8) + + q_args = list(args) + q_args[input_idx] = qdata_bytes + gathered_qdata_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( + *q_args, **kwargs + ) + + b_args = list(args) + b_args[input_idx] = block_scale_bytes + gathered_block_scale_bytes = torch.ops._c10d_functional.all_gather_into_tensor.default( + *b_args, **kwargs + ) + + gathered_qdata = gathered_qdata_bytes.view(input_tensor._qdata.dtype) + gathered_block_scale = gathered_block_scale_bytes.view(input_tensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(input_tensor, gathered_qdata, block_scale=gathered_block_scale) + + +@register_layout_op(torch.ops._c10d_functional.wait_tensor.default, TensorCoreNVFP4Layout) +def _handle_wait_tensor(qt, args, kwargs): + qtensor = args[0] + + waited_qdata_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._qdata.view(torch.uint8), + *args[1:], + **kwargs, + ) + waited_block_scale_bytes = torch.ops._c10d_functional.wait_tensor.default( + qtensor._params.block_scale.contiguous().view(torch.uint8), + *args[1:], + **kwargs, + ) + + waited_qdata = waited_qdata_bytes.view(qtensor._qdata.dtype) + waited_block_scale = waited_block_scale_bytes.view(qtensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(qtensor, waited_qdata, block_scale=waited_block_scale) + + +@register_layout_op(torch.ops.c10d.broadcast_.default, TensorCoreNVFP4Layout) +def _handle_broadcast(qt, args, kwargs): + from .base import QuantizedTensor + + tensor_list = args[0] + quantized_entries = [ + (idx, tensor) + for idx, tensor in enumerate(tensor_list) + if isinstance(tensor, QuantizedTensor) + ] + if not quantized_entries: + return torch.ops.c10d.broadcast_.default(*args, **kwargs) + + q_tensor_list = list(tensor_list) + b_tensor_list = list(tensor_list) + for idx, tensor in quantized_entries: + q_tensor_list[idx] = tensor._qdata.contiguous().view(torch.uint8) + b_tensor_list[idx] = tensor._params.block_scale.contiguous().view(torch.uint8) + + q_result = torch.ops.c10d.broadcast_.default(q_tensor_list, *args[1:], **kwargs) + b_result = torch.ops.c10d.broadcast_.default(b_tensor_list, *args[1:], **kwargs) + q_list, q_work = _extract_collective_result(q_result) + b_list, b_work = _extract_collective_result(b_result) + + output_list = list(q_list) + for idx, original in quantized_entries: + output_list[idx] = _wrap_nvfp4_tensor( + original, + q_list[idx].view(original._qdata.dtype), + block_scale=b_list[idx].view(original._params.block_scale.dtype), + orig_shape=original._params.orig_shape, + ) + + if q_work is not None or b_work is not None: + return output_list, _CompositeWork(q_work, b_work) + return output_list + + +@register_layout_op(torch.ops.c10d.scatter_.default, TensorCoreNVFP4Layout) +def _handle_scatter(qt, args, kwargs): + from .base import QuantizedTensor + + output_tensors = args[0] + input_tensors = args[1] + + quantized_outputs = [] + new_q_outputs = list(output_tensors) + new_b_outputs = list(output_tensors) + for idx, tensor in enumerate(output_tensors): + if isinstance(tensor, QuantizedTensor): + quantized_outputs.append((idx, tensor)) + new_q_outputs[idx] = tensor._qdata.contiguous().view(torch.uint8) + new_b_outputs[idx] = tensor._params.block_scale.contiguous().view(torch.uint8) + + has_quantized_input = False + q_inputs = [] + b_inputs = [] + for entry in input_tensors: + if isinstance(entry, (list, tuple)): + q_entry = [] + b_entry = [] + for tensor in entry: + if isinstance(tensor, QuantizedTensor): + has_quantized_input = True + q_entry.append(tensor._qdata.contiguous().view(torch.uint8)) + b_entry.append(tensor._params.block_scale.contiguous().view(torch.uint8)) + else: + q_entry.append(tensor) + b_entry.append(tensor) + q_inputs.append(q_entry) + b_inputs.append(b_entry) + else: + q_inputs.append(entry) + b_inputs.append(entry) + + if not quantized_outputs and not has_quantized_input: + return torch.ops.c10d.scatter_.default(*args, **kwargs) + + q_result = torch.ops.c10d.scatter_.default(new_q_outputs, q_inputs, *args[2:], **kwargs) + b_result = torch.ops.c10d.scatter_.default(new_b_outputs, b_inputs, *args[2:], **kwargs) + q_list, q_work = _extract_collective_result(q_result) + b_list, b_work = _extract_collective_result(b_result) + + output_list = list(q_list) + for idx, original in quantized_outputs: + output_list[idx] = _wrap_nvfp4_tensor( + original, + q_list[idx].view(original._qdata.dtype), + block_scale=b_list[idx].view(original._params.block_scale.dtype), + orig_shape=original._params.orig_shape, + ) + + if q_work is not None or b_work is not None: + return output_list, _CompositeWork(q_work, b_work) + return output_list + + +# ==================== NVFP4 Shape Operations ==================== + + +@register_layout_op(torch.ops.aten.alias.default, TensorCoreNVFP4Layout) +def _handle_nvfp4_alias(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.alias.default(*args, **kwargs) + + aliased_qdata = torch.ops.aten.alias.default(input_tensor._qdata) + aliased_block_scale = torch.ops.aten.alias.default(input_tensor._params.block_scale) + return _wrap_nvfp4_tensor( + input_tensor, + aliased_qdata, + block_scale=aliased_block_scale, + orig_shape=input_tensor._params.orig_shape, + transposed=input_tensor._params.transposed, + ) + + +@register_layout_op(torch.ops.aten.slice.Tensor, TensorCoreNVFP4Layout) +def _handle_nvfp4_slice(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.slice.Tensor(*args, **kwargs) + if getattr(input_tensor._params, "transposed", False): + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + dim = args[1] if len(args) > 1 else 0 + start = args[2] if len(args) > 2 else None + end = args[3] if len(args) > 3 else None + step = args[4] if len(args) > 4 else None + dim = dim if dim >= 0 else dim + len(input_tensor._params.orig_shape) + + if dim == 0: + start, end, _ = _normalize_slice_args( + int(input_tensor._params.orig_shape[0]), start, end, step + ) + return _slice_rows_nvfp4(input_tensor, start, end) + + if dim == 1: + start, end, step = _normalize_slice_args( + int(input_tensor._params.orig_shape[1]), start, end, step + ) + if start == 0 and end == int(input_tensor._params.orig_shape[1]) and step == 1: + return _handle_nvfp4_alias(qt, (input_tensor,), {}) + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + return torch.ops.aten.slice.Tensor(*dequantize_args(args), **kwargs) + + +@register_layout_op(torch.ops.aten.split.Tensor, TensorCoreNVFP4Layout) +def _handle_nvfp4_split(qt, args, kwargs): + from .base import QuantizedTensor + + input_tensor = args[0] + if not isinstance(input_tensor, QuantizedTensor): + return torch.ops.aten.split.Tensor(*args, **kwargs) + if getattr(input_tensor._params, "transposed", False): + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + split_size = args[1] + dim = kwargs.get("dim", args[2] if len(args) > 2 else 0) + dim = dim if dim >= 0 else dim + len(input_tensor._params.orig_shape) + if dim != 0: + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + logical_rows = int(input_tensor._params.orig_shape[0]) + if isinstance(split_size, int): + chunks = [] + for start in range(0, logical_rows, split_size): + end = min(start + split_size, logical_rows) + chunks.append(_slice_rows_nvfp4(input_tensor, start, end)) + return tuple(chunks) + return torch.ops.aten.split.Tensor(*dequantize_args(args), **kwargs) + + +@register_layout_op(torch.ops.aten.cat.default, TensorCoreNVFP4Layout) +def _handle_nvfp4_cat(qt, args, kwargs): + from .base import QuantizedTensor + + tensors = args[0] + dim = kwargs.get("dim", args[1] if len(args) > 1 else 0) + if dim != 0 or not isinstance(tensors, (list, tuple)) or not tensors: + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + if not all(isinstance(tensor, QuantizedTensor) for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + + first = tensors[0] + if any(getattr(tensor._params, "transposed", False) for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + if any(tensor._params.orig_shape[1] != first._params.orig_shape[1] for tensor in tensors): + return torch.ops.aten.cat.default(*dequantize_args(args), **kwargs) + + qdata = torch.ops.aten.cat.default([tensor._qdata for tensor in tensors], 0) + block_scale_rows = [_unblock_block_scale(tensor) for tensor in tensors] + block_scale = _reblock_block_scale( + torch.ops.aten.cat.default(block_scale_rows, 0), first._params.block_scale.dtype + ) + orig_rows = sum(int(tensor._params.orig_shape[0]) for tensor in tensors) + return _wrap_nvfp4_tensor( + first, + qdata, + block_scale=block_scale, + orig_shape=(orig_rows, int(first._params.orig_shape[1])), + ) + + +@register_layout_op(torch.ops.aten.new_zeros.default, TensorCoreNVFP4Layout) +def _handle_new_zeros(qt, args, kwargs): + input_tensor = args[0] + size = tuple(args[1]) if len(args) > 1 else tuple(input_tensor._params.orig_shape) + if len(size) != 2: + return torch.ops.aten.new_zeros.default(*dequantize_args(args), **kwargs) + + device = kwargs.get("device", input_tensor._qdata.device) + storage_shape = TensorCoreNVFP4Layout.get_storage_shape(size) + qdata = torch.zeros(storage_shape, device=device, dtype=input_tensor._qdata.dtype) + block_cols = TensorCoreNVFP4Layout.get_padded_shape(size)[1] // 16 + block_scale_rows = torch.zeros( + (storage_shape[0], block_cols), + device=device, + dtype=input_tensor._params.block_scale.dtype, + ) + block_scale = _reblock_block_scale(block_scale_rows, input_tensor._params.block_scale.dtype) + return _wrap_nvfp4_tensor(input_tensor, qdata, block_scale=block_scale, orig_shape=size) + # ==================== NVFP4 Transpose Operation ==================== # Transpose is a no-op that tracks logical transposition via a flag. @register_layout_op(torch.ops.aten.t.default, TensorCoreNVFP4Layout) def _handle_nvfp4_transpose(qt, args, kwargs): - """Handle transpose as a logical no-op for NVFP4. - """ + """Handle transpose as a logical no-op for NVFP4.""" from .base import QuantizedTensor input_tensor = args[0] @@ -157,6 +612,13 @@ def _slice_to_original_shape( return result +def _linear_dequantize_fallback(input_tensor, weight, bias): + input_dense, weight_dense, bias_dense = dequantize_args((input_tensor, weight, bias)) + assert isinstance(input_dense, torch.Tensor) + assert isinstance(weight_dense, torch.Tensor) + return torch.nn.functional.linear(input_dense, weight_dense, bias_dense) + + @register_layout_op(torch.ops.aten.mm.default, TensorCoreNVFP4Layout) def _handle_nvfp4_mm(qt, args, kwargs): """NVFP4 matrix multiplication: output = a @ b @@ -165,17 +627,15 @@ def _handle_nvfp4_mm(qt, args, kwargs): with scaled_mm_nvfp4 since that kernel computes a @ b_phys.T, which equals a @ b_logical when b_logical = b_phys.T. - This handles the common torch.compile decomposition: linear(x, w) → mm(x, w.t()) + This handles the common torch.compile decomposition: linear(x, w) -> mm(x, w.t()) """ from .base import QuantizedTensor a, b = args[0], args[1] - # Fast path: both operands are NVFP4 QuantizedTensors if not (isinstance(a, QuantizedTensor) and isinstance(b, QuantizedTensor)): return torch.mm(*dequantize_args(args)) - # NVFP4 only supports 2D tensors if a._qdata.dim() != 2: return torch.mm(*dequantize_args(args)) @@ -183,7 +643,6 @@ def _handle_nvfp4_mm(qt, args, kwargs): b_transposed = getattr(b._params, "transposed", False) if a_transposed or not b_transposed: - # Can't handle these cases with current kernel, fallback logger.debug("NVFP4 mm: unsupported transpose configuration, falling back to dequantize") return torch.mm(*dequantize_args(args)) @@ -223,26 +682,25 @@ def _handle_nvfp4_linear(qt, args, kwargs): input_tensor, weight = args[0], args[1] bias = args[2] if len(args) > 2 else None - # Fast path: both operands are NVFP4 QuantizedTensors if not (isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor)): - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias) - # NVFP4 only supports 2D tensors if input_tensor._qdata.dim() != 2: - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias) input_transposed = getattr(input_tensor._params, "transposed", False) weight_transposed = getattr(weight._params, "transposed", False) if input_transposed or weight_transposed: - logger.debug("NVFP4 linear: unsupported transpose configuration, falling back to dequantize") - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + logger.debug( + "NVFP4 linear: unsupported transpose configuration, falling back to dequantize" + ) + return _linear_dequantize_fallback(input_tensor, weight, bias) input_qdata, scale_a, block_scale_a = TensorCoreNVFP4Layout.get_plain_tensors(input_tensor) weight_qdata, scale_b, block_scale_b = TensorCoreNVFP4Layout.get_plain_tensors(weight) out_dtype = kwargs.get("out_dtype", input_tensor._params.orig_dtype) try: - # scaled_mm_nvfp4 computes (a @ b.T) * scale, which is linear semantics result = ck.scaled_mm_nvfp4( input_qdata, weight_qdata, @@ -254,11 +712,10 @@ def _handle_nvfp4_linear(qt, args, kwargs): out_dtype=out_dtype, ) - # Slice output to original (non-padded) shape orig_m = input_tensor._params.orig_shape[0] - orig_n = weight._params.orig_shape[0] # weight is (out_features, in_features) + orig_n = weight._params.orig_shape[0] return _slice_to_original_shape(result, orig_m, orig_n) except (RuntimeError, TypeError) as e: logger.warning(f"NVFP4 scaled_mm failed: {e}, falling back to dequantization") - return torch.nn.functional.linear(*dequantize_args((input_tensor, weight, bias))) + return _linear_dequantize_fallback(input_tensor, weight, bias)