Skip to content
3 changes: 3 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Portions of the following files are derived from PyTorch AO:
- comfy_kitchen/backends/eager/quantization.py
Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/nvfp4_tensor.py

- comfy_kitchen/tensor/fp8.py
Source: https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py

- comfy_kitchen/float_utils.py (_f32_to_floatx_unpacked, _floatx_unpacked_to_f32)
Source: https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats/custom_cast.py

Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@ Fast kernel library for Diffusion inference with multiple compute backends.
| `apply_rope` | ✓ | ✓ | ✓ |
| `apply_rope1` | ✓ | ✓ | ✓ |

## Distributed Capabilities Matrix

| Layout | DTensor|FSDP pre/post all-gather|
|------------------------|--------|------------------------|
| `TensorCoreFP8Layout` | ✓ | ✓ |
| `TensorCoreNVFP4Layout`| ✓ | ✓ |
| `TensorCoreMXFP8Layout`| | |

This is for custom nodes that might implement distributed operations. Such as [Raylight](https://github.com/komikndr/raylight)


## Quantized Tensors

Expand Down
63 changes: 59 additions & 4 deletions comfy_kitchen/tensor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,41 @@ def get_requirements(cls) -> dict[str, Any]:
"fast_matmul_supported": cls.supports_fast_matmul(),
}

@classmethod
def pre_all_gather(cls, qtensor: QuantizedTensor, mesh):
"""Prepare data for FSDP all_gather.

Returns:
Tuple of (all_gather_outputs, metadata):
- all_gather_outputs: data to be gathered (as tuple of tensors)
- metadata: additional info needed after gathering (as tuple)
"""
raise NotImplementedError(f"pre_all_gather not implemented for {cls.__name__}")

@classmethod
def post_all_gather(
cls,
qtensor: QuantizedTensor,
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: QuantizedTensor | None = None,
):
"""Reconstruct QuantizedTensor after FSDP all_gather.

Args:
qtensor: Original quantized tensor (used for layout context)
all_gather_outputs: Gathered tensors
metadata: Metadata from pre_all_gather
param_dtype: Expected parameter dtype
out: Optional output tensor to update in-place

Returns:
Either updates out in-place (returns None) or returns new (QuantizedTensor, metadata) tuple
"""
raise NotImplementedError(f"post_all_gather not implemented for {cls.__name__}")


class QuantizedTensor(torch.Tensor):
"""
Expand Down Expand Up @@ -327,6 +362,25 @@ def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride):
params = ctx["params_class"](**params_kwargs)
return QuantizedTensor(inner_tensors["_qdata"], ctx["layout_cls"], params)

# ==================== FSDP Hooks ====================

def fsdp_pre_all_gather(self, mesh):
"""FSDP pre_all_gather hook - delegates to layout class."""
return self.layout_cls.pre_all_gather(self, mesh)

def fsdp_post_all_gather(
self,
all_gather_outputs: tuple[torch.Tensor, ...],
metadata: Any,
param_dtype: torch.dtype,
*,
out: QuantizedTensor | None = None,
):
"""FSDP post_all_gather hook - delegates to layout class."""
return self.layout_cls.post_all_gather(
self, all_gather_outputs, metadata, param_dtype, out=out
)

# ==================== Torch Dispatch ====================

@classmethod
Expand All @@ -348,7 +402,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
return op_handlers[parent_cls](qt, args, kwargs)

# Step 3: Fallback to dequantization
logger.debug(f"Unhandled op {func} for {layout_cls.__name__ if layout_cls else 'unknown'}, dequantizing")
return cls._dequant_and_fallback(func, args, kwargs)

@classmethod
Expand Down Expand Up @@ -438,13 +491,15 @@ def _handle_is_contiguous(qt, args, kwargs):

def _handle_copy_(qt, args, kwargs):
dst, src = args[0], args[1]
if not isinstance(src, QuantizedTensor):
raise TypeError(f"Cannot copy {type(src).__name__} to QuantizedTensor")
non_blocking = kwargs.get("non_blocking", len(args) >= 3)

if not isinstance(dst, QuantizedTensor):
dst = QuantizedTensor(src._qdata, src._layout_cls, src._params)

if dst._layout_cls != src._layout_cls:
raise TypeError(f"Layout mismatch: {dst._layout_cls} vs {src._layout_cls}")

dst_orig_dtype = dst._params.orig_dtype
non_blocking = kwargs.get("non_blocking", len(args) >= 3)

dst._qdata.copy_(src._qdata, non_blocking=non_blocking)
dst._params.copy_from(src._params, non_blocking=non_blocking)
Expand Down
Loading