Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dlblas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

# this import all kernels dynamically
import dlblas.kernels # noqa

# this register operators to PyTorch (torch.ops.dlblas.xxx)
import dlblas.torch_ops # noqa
from dlblas.utils import get_op

__version__ = "0.0.7"
Expand Down Expand Up @@ -147,3 +150,20 @@ def flash_attention_v2(q, k, v):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids_1d):
op = get_op("apply_rotary_pos_emb", (q, k, cos, sin, position_ids_1d))
return op(q, k, cos, sin, position_ids_1d)


def vector_add(a: Tensor, b: Tensor) -> Tensor:
"""Vector addition: c = a + b

Can be called via:
- dlblas.vector_add(a, b)
- torch.ops.dlblas.vector_add(a, b)

Args:
a: First input tensor (1D)
b: Second input tensor (1D)

Returns:
Output tensor (a + b)
"""
return torch.ops.dlblas.vector_add(a, b)
103 changes: 103 additions & 0 deletions dlblas/kernels/vector_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2025, DeepLink.
"""Triton kernel implementation for vector_add.

Simple vector addition: c = a + b
"""

import torch
import triton
import triton.language as tl

# Define autotune configs as a module-level constant
# This ensures configs are available during torch.compile tracing
AUTOTUNE_CONFIGS = [
triton.Config({"BLOCK_SIZE": 1024}, num_warps=4),
triton.Config({"BLOCK_SIZE": 2048}, num_warps=8),
triton.Config({"BLOCK_SIZE": 512}, num_warps=2),
triton.Config({"BLOCK_SIZE": 256}, num_warps=1),
]


@triton.autotune(configs=AUTOTUNE_CONFIGS, key=["N"])
@triton.jit
def vector_add_kernel(
a_ptr, # Pointer to first input vector
b_ptr, # Pointer to second input vector
c_ptr, # Pointer to output vector
N, # Number of elements
BLOCK_SIZE: tl.constexpr, # Number of elements each program processes
):
"""Triton kernel for vector addition.

Each program instance processes BLOCK_SIZE elements.
"""
# Program ID
pid = tl.program_id(axis=0)

# Compute the range of elements this program will handle
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)

# Create a mask to handle cases where N is not divisible by BLOCK_SIZE
mask = offsets < N

# Load inputs with boundary checking
a = tl.load(a_ptr + offsets, mask=mask, other=0.0)
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)

# Compute addition
c = a + b

# Store output with boundary checking
tl.store(c_ptr + offsets, c, mask=mask)


def vector_add_impl(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Main implementation function called by PyTorch op.

Args:
a: First input tensor (1D vector)
b: Second input tensor (1D vector)

Returns:
c: Output tensor (a + b)

This is the actual computation function that will be registered
to torch.ops.dlblas.vector_add
"""
# Parameter validation
# Support both CUDA and NPU devices
device_type = a.device.type
if device_type not in ["cuda", "npu"]:
raise RuntimeError(
f"vector_add only supports CUDA or NPU tensors, got {device_type}"
)

if a.dim() != 1 or b.dim() != 1:
raise RuntimeError(
f"vector_add expects 1D tensors, got {a.dim()}D and {b.dim()}D"
)

if a.shape[0] != b.shape[0]:
raise RuntimeError(f"vector length mismatch: {a.shape[0]} vs {b.shape[0]}")

if a.dtype != b.dtype:
raise RuntimeError(f"dtype mismatch: {a.dtype} vs {b.dtype}")

N = a.shape[0]

# Allocate output
c = torch.empty_like(a)

# Grid configuration: number of program instances needed
grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),)

# Launch kernel
vector_add_kernel[grid](
a,
b,
c,
N,
)

return c
37 changes: 37 additions & 0 deletions dlblas/torch_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2025, DeepLink.
"""PyTorch operator registration module.

This module registers dlBLAS operators as PyTorch native operators,
enabling calls via torch.ops.dlblas.xxx()

All operators in this module follow PyTorch's torch.library registration mechanism,
making them compatible with:
- torch.compile
- torch.jit.trace / torch.jit.script
- PyTorch autograd (when implemented)
"""

import torch

# Check PyTorch version and library availability
try:
from torch.library import Library

HAS_TORCH_LIBRARY = True
except ImportError:
HAS_TORCH_LIBRARY = False
print(
"Warning: torch.library not available (requires PyTorch 2.0+). "
"PyTorch ops registration disabled."
)

if HAS_TORCH_LIBRARY:
# Create a library handle for dlblas operators
# "FRAGMENT" allows incremental registration across modules
_lib = Library("dlblas", "FRAGMENT")

# Import and register all operators
# Each operator module will use the shared _lib handle
from . import vector_add # noqa: F401

__all__ = ["vector_add"]
149 changes: 149 additions & 0 deletions dlblas/torch_ops/vector_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2025, DeepLink.
"""Register vector_add as a PyTorch operator.

This module registers the vector_add Triton kernel as a native PyTorch operator,
enabling calls via:
- torch.ops.dlblas.vector_add(a, b)
- Support for torch.compile
- Support for torch.jit tracing
"""

import torch
from torch.library import Library
from typing import Tuple

# Import the Triton kernel implementation
from dlblas.kernels.vector_add import vector_add_impl

# Get the shared library handle from parent module
from dlblas.torch_ops import _lib

# ===== Step 1: Define Meta Function =====


def vector_add_meta(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Meta function for vector_add - computes output shape/dtype without computation.

This is used by PyTorch for:
- Shape inference (for torch.compile, JIT tracing)
- DType inference
- Device inference

Args:
a: First input tensor
b: Second input tensor

Returns:
Empty tensor with correct shape/dtype/device (no actual computation)

Raises:
RuntimeError: If inputs are not on CUDA or have mismatched shapes
"""
# Basic validation
if a.dim() != 1 or b.dim() != 1:
raise RuntimeError(
f"vector_add expects 1D tensors, got {a.dim()}D and {b.dim()}D"
)

if a.shape[0] != b.shape[0]:
raise RuntimeError(f"vector length mismatch: {a.shape[0]} vs {b.shape[0]}")

if a.dtype != b.dtype:
raise RuntimeError(f"dtype mismatch: {a.dtype} vs {b.dtype}")

# Return empty tensor with the correct metadata
return torch.empty_like(a)


# ===== Step 2: Register Operator Schema =====

# Define the operator signature (schema)
_lib.define(
"vector_add(Tensor a, Tensor b) -> Tensor",
)


# ===== Step 3: Register Implementations =====

# Register implementation for PrivateUse1 device (NPU on Ascend)
# Ascend NPU uses "PrivateUse1" as device key in PyTorch
_lib.impl("vector_add", vector_add_impl, "PrivateUse1")

# Also register for CUDA if available
_lib.impl("vector_add", vector_add_impl, "CUDA")

# Register Meta implementation (for shape inference)
_lib.impl("vector_add", vector_add_meta, "Meta")


# ===== Step 4: Optional - Register CPU Fallback =====


def vector_add_cpu(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""CPU fallback implementation using PyTorch native operations.

This allows the operator to work on CPU tensors as well.
"""
return a + b


_lib.impl("vector_add", vector_add_cpu, "CPU")


# ===== Step 5: Optional - Autograd Support =====


def vector_add_backward(
a: torch.Tensor,
b: torch.Tensor,
grad_output: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Backward pass for vector_add.

For c = a + b, the gradients are:
- grad_a = grad_output
- grad_b = grad_output

Args:
a: Original input a (for shape reference)
b: Original input b (for shape reference)
grad_output: Gradient of loss with respect to output c

Returns:
Tuple of (grad_a, grad_b)
"""
return grad_output, grad_output


class VectorAddFunction(torch.autograd.Function):
"""Custom autograd function for vector_add with gradient support."""

@staticmethod
def forward(ctx, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
# Save inputs for backward pass
ctx.save_for_backward(a, b)
return vector_add_impl(a, b)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
a, b = ctx.saved_tensors
return vector_add_backward(a, b, grad_output)


def vector_add_with_autograd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Vector add with autograd support.

Use this when you need gradient computation.
Note: For CUDA tensors, this uses the Triton kernel in forward pass.
"""
return VectorAddFunction.apply(a, b)


# ===== Convenience exports =====

__all__ = [
"vector_add_impl",
"vector_add_meta",
"vector_add_cpu",
"vector_add_with_autograd",
]
Loading
Loading