Skip to content

[FEA] cuda.local.array should accept objects with a dtype attribute #846

@a-ellison

Description

@a-ellison

Describe the bug

The dtype parameter of cuda.local.array inside a kernel does not directly accept objects that have a dtype attribute, even though numpy's dtype protocol supports this and cuda.device_array on the host already follows it.

Steps/Code to reproduce bug

import numpy as np
from numba import cuda, types
from numba.extending import typeof_impl, register_model, overload_attribute, models

class Foo:
    def __init__(self):
        self.dtype = np.dtype('uint8')

class FooType(types.Type):
    def __init__(self, val):
        self.dtype = val.dtype
        super().__init__(name='Foo')

@typeof_impl.register(Foo)
def typeof_Foo(val, c):
    return FooType(val)

# empty struct
@register_model(FooType)
class FooModel(models.StructModel):
    def __init__(self, dmm, fe_type):
        super().__init__(dmm, fe_type, members=[])

@overload_attribute(FooType, "dtype")
def ol_Foo_dtype(foo):
    dtype = foo.dtype
    return lambda foo: dtype

foo = Foo()

# Works on host side (numpy dtype protocol: foo.dtype is used):
states = cuda.device_array(10, dtype=foo)
print(f"cuda.device_array dtype: {states.dtype}")  # uint8

# The same doesn't work on the device side
@cuda.jit
def kernel(out):
    # as a workaround, using foo.dtype works
    state = cuda.local.array(10, dtype=foo)
    out[0] = state[0]

kernel[1, 1](cuda.device_array(1, dtype=np.uint8))

Error:

/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py:716: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
  warn(errors.NumbaPerformanceWarning(msg))
cuda.device_array dtype: uint8
Traceback (most recent call last):
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/main.py", line 50, in <module>
    kernel[1, 1](cuda.device_array(1, dtype=np.uint8))
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py", line 719, in __call__
    return self.dispatcher.call(
           ~~~~~~~~~~~~~~~~~~~~^
        args, self.griddim, self.blockdim, self.stream, self.sharedmem
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py", line 1642, in call
    kernel = _dispatcher.Dispatcher._cuda_call(self, *args)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py", line 1650, in _compile_for_args
    return self.compile(tuple(argtypes))
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_lock.py", line 74, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py", line 1908, in compile
    kernel = _Kernel(self.py_func, argtypes, **self.targetoptions)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_lock.py", line 74, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/dispatcher.py", line 155, in __init__
    cres = compile_cuda(
        self.py_func,
    ...<9 lines>...
        lto=lto,
    )
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_lock.py", line 74, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/compiler.py", line 789, in compile_cuda
    cres = compile_extra(
        typingctx=typingctx,
    ...<8 lines>...
        abi_info=abi_info,
    )
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/compiler.py", line 599, in compile_extra
    return pipeline.compile_extra(func)
           ~~~~~~~~~~~~~~~~~~~~~~^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py", line 146, in compile_extra
    return self._compile_bytecode()
           ~~~~~~~~~~~~~~~~~~~~~~^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py", line 214, in _compile_bytecode
    return self._compile_core()
           ~~~~~~~~~~~~~~~~~~^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py", line 190, in _compile_core
    raise e
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler.py", line 182, in _compile_core
    pm.run(self.state)
    ~~~~~~^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_machinery.py", line 392, in run
    raise e
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_machinery.py", line 385, in run
    self._runPass(idx, pass_inst, state)
    ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_lock.py", line 74, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_machinery.py", line 337, in _runPass
    mutated |= check(pss.run_pass, internal_state)
               ~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/compiler_machinery.py", line 291, in check
    mangled = func(compiler_state)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typed_passes.py", line 141, in run_pass
    typemap, return_type, calltypes, errs = type_inference_stage(
                                            ~~~~~~~~~~~~~~~~~~~~^
        state.typingctx,
        ^^^^^^^^^^^^^^^^
    ...<5 lines>...
        raise_errors=self._raise_errors,
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typed_passes.py", line 120, in type_inference_stage
    errs = infer.propagate(raise_errors=raise_errors)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typeinfer.py", line 1147, in propagate
    errors = self.constraints.propagate(self)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typeinfer.py", line 192, in propagate
    constraint(typeinfer)
    ~~~~~~~~~~^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typeinfer.py", line 620, in __call__
    self.resolve(typeinfer, typevars, fnty)
    ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typeinfer.py", line 643, in resolve
    sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/core/typeinfer.py", line 1695, in resolve_call
    return self.context.resolve_function_type(fnty, pos_args, kw_args)
           ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/typing/context.py", line 198, in resolve_function_type
    res = self._resolve_user_function_type(func, args, kws)
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba_cuda/numba/cuda/typing/context.py", line 250, in _resolve_user_function_type
    return func.get_call_type(self, args, kws)
           ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba/core/types/functions.py", line 332, in get_call_type
    failures.raise_error()
    ~~~~~~~~~~~~~~~~~~~~^^
  File "/home/scratch.aellison_ent/sandbox/numbda-cuda-local-array-dtype-bug/.venv/lib/python3.13/site-packages/numba/core/types/functions.py", line 230, in raise_error
    raise errors.TypingError(self.format())
numba.core.errors.TypingError: No implementation of function Function(<function local.array at 0x7878d215ba60>) found for signature:
 
 >>> array(Literal[int](10), dtype=Foo)
 
There are 2 candidate implementations:
 - Of which 2 did not match due to:
 Overload of function 'array': File: numba/cuda/cudadecl.py: Line 28.
   With argument(s): '(int64, dtype=Foo)':
  No match.


Expected behavior

cuda.local.array should accept objects with a dtype attribute, consistent with numpy's dtype protocol and the existing behavior of cuda.device_array. The workaround of explicitly writing cuda.local.array(n, dtype=foo.dtype) works, but is inconsistent.

Environment details:

  • Environment location: Bare-metal
  • Method of numba-cuda install: uv
    • numba: 0.65.0
    • numba-cuda: 0.29.0

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions