diff --git a/dev/generate-kernel-signatures.py b/dev/generate-kernel-signatures.py index 4be14627f1..f82c0d8624 100644 --- a/dev/generate-kernel-signatures.py +++ b/dev/generate-kernel-signatures.py @@ -16,11 +16,11 @@ "awkward_ListArray_min_range", "awkward_ListArray_validity", "awkward_BitMaskedArray_to_ByteMaskedArray", - "awkward_ListArray_broadcast_tooffsets", - "awkward_ListArray_compact_offsets", + # "awkward_ListArray_broadcast_tooffsets", + # "awkward_ListArray_compact_offsets", "awkward_ListOffsetArray_flatten_offsets", - "awkward_IndexedArray_overlay_mask", - "awkward_ByteMaskedArray_numnull", + # "awkward_IndexedArray_overlay_mask", + # "awkward_ByteMaskedArray_numnull", "awkward_IndexedArray_numnull", "awkward_IndexedArray_numnull_parents", "awkward_IndexedArray_numnull_unique_64", @@ -51,9 +51,9 @@ "awkward_RegularArray_reduce_local_nextparents_64", "awkward_RegularArray_reduce_nonlocal_preparenext_64", # "awkward_missing_repeat", - "awkward_RegularArray_getitem_jagged_expand", - "awkward_ListArray_combinations_length", - "awkward_ListArray_combinations", + # "awkward_RegularArray_getitem_jagged_expand", + # "awkward_ListArray_combinations_length", + # "awkward_ListArray_combinations", "awkward_RegularArray_combinations_64", "awkward_ListArray_getitem_jagged_apply", "awkward_ListArray_getitem_jagged_carrylen", @@ -72,14 +72,14 @@ "awkward_UnionArray_regular_index", "awkward_ListOffsetArray_reduce_nonlocal_nextstarts_64", "awkward_ListArray_getitem_next_range_spreadadvanced", - "awkward_ListArray_localindex", + # "awkward_ListArray_localindex", "awkward_NumpyArray_pad_zero_to_length", "awkward_NumpyArray_reduce_adjust_starts_64", "awkward_NumpyArray_rearrange_shifted", "awkward_NumpyArray_reduce_adjust_starts_shifts_64", "awkward_RegularArray_getitem_next_at", "awkward_BitMaskedArray_to_IndexedOptionArray", - "awkward_ByteMaskedArray_getitem_nextcarry", + # "awkward_ByteMaskedArray_getitem_nextcarry", "awkward_ByteMaskedArray_getitem_nextcarry_outindex", "awkward_ByteMaskedArray_reduce_next_64", "awkward_ByteMaskedArray_reduce_next_nonlocal_nextshifts_64", @@ -97,8 +97,8 @@ "awkward_IndexedArray_local_preparenext_64", "awkward_IndexedArray_ranges_next_64", "awkward_IndexedArray_ranges_carry_next_64", - "awkward_IndexedArray_reduce_next_64", - "awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64", + # "awkward_IndexedArray_reduce_next_64", + # "awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64", "awkward_IndexedArray_reduce_next_nonlocal_nextshifts_fromshifts_64", "awkward_IndexedOptionArray_rpad_and_clip_mask_axis1", "awkward_ListOffsetArray_local_preparenext_64", @@ -118,7 +118,7 @@ "awkward_UnionArray_nestedfill_tags_index", "awkward_UnionArray_regular_index_getsize", "awkward_UnionArray_simplify", - "awkward_UnionArray_simplify_one", + # "awkward_UnionArray_simplify_one", "awkward_RecordArray_reduce_nonlocal_outoffsets_64", # "awkward_reduce_count_64", # "awkward_reduce_max", diff --git a/dev/generate-tests.py b/dev/generate-tests.py index f398e6988e..25f1cff771 100644 --- a/dev/generate-tests.py +++ b/dev/generate-tests.py @@ -1567,19 +1567,40 @@ def gencudaunittests(specdict): count += 1 else: args += ", " + arg.name - f.write(" " * 4 + "funcC(" + args + ")\n") + # Determine if this is a cuda.compute kernel (raises errors eagerly) + # or compiled CUDA kernel (raises errors after `ak_cu.synchronize_cuda()`) + CUDA_COPUTE_KERNELS = { + "awkward_ListArray_compact_offsets", + "awkward_ListArray_broadcast_tooffsets", + } + + raises_error_eagerly = ( + spec.templatized_kernel_name in CUDA_COPUTE_KERNELS + ) + if test["error"]: - f.write( - f""" - error_message = re.escape("{test["message"]} in compiled CUDA code ({spec.templatized_kernel_name})") -""" - ) - f.write( - """ with pytest.raises(ValueError, match=rf"{error_message}"): - ak_cu.synchronize_cuda() -""" - ) + error_message_line = f' error_message = re.escape("{test["message"]} in compiled CUDA code ({spec.templatized_kernel_name})")\n' + if raises_error_eagerly: + # call a kernel directly inside `pytest.raises()` + f.write( + "\n" + + error_message_line + + ' with pytest.raises(ValueError, match=rf"{error_message}"):\n' + + " " * 8 + + "funcC(" + + args + + ")\n" + ) + else: + f.write(" " * 4 + "funcC(" + args + ")\n") + f.write( + "\n" + + error_message_line + + ' with pytest.raises(ValueError, match=rf"{error_message}"):\n' + " ak_cu.synchronize_cuda()\n" + ) else: + f.write(" " * 4 + "funcC(" + args + ")\n") f.write( """ try: diff --git a/src/awkward/_backends/cupy.py b/src/awkward/_backends/cupy.py index 75ecc16674..0149c45cb2 100644 --- a/src/awkward/_backends/cupy.py +++ b/src/awkward/_backends/cupy.py @@ -85,6 +85,18 @@ def _supports_cuda_compute(self, kernel_name: str) -> bool: "awkward_reduce_argmin_complex", "awkward_reduce_count_64", "awkward_reduce_countnonzero", + "awkward_IndexedArray_overlay_mask", + "awkward_IndexedArray_reduce_next_64", + "awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64", + "awkward_ByteMaskedArray_getitem_nextcarry", + "awkward_ByteMaskedArray_numnull", + "awkward_RegularArray_getitem_jagged_expand", + "awkward_UnionArray_simplify_one", + "awkward_ListArray_broadcast_tooffsets", + "awkward_ListArray_localindex", + "awkward_ListArray_compact_offsets", + "awkward_ListArray_combinations_length", + "awkward_ListArray_combinations", "awkward_reduce_countnonzero_complex", # indexing / structure "awkward_missing_repeat", @@ -128,6 +140,18 @@ def _get_cuda_compute_impl(self, kernel_name: str): "awkward_missing_repeat": cuda_compute.awkward_missing_repeat, "awkward_index_rpad_and_clip_axis0": cuda_compute.awkward_index_rpad_and_clip_axis0, "awkward_index_rpad_and_clip_axis1": cuda_compute.awkward_index_rpad_and_clip_axis1, + "awkward_IndexedArray_overlay_mask": cuda_compute.awkward_IndexedArray_overlay_mask, + "awkward_IndexedArray_reduce_next_64": cuda_compute.awkward_IndexedArray_reduce_next_64, + "awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64": cuda_compute.awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64, + "awkward_ByteMaskedArray_getitem_nextcarry": cuda_compute.awkward_ByteMaskedArray_getitem_nextcarry, + "awkward_ByteMaskedArray_numnull": cuda_compute.awkward_ByteMaskedArray_numnull, + "awkward_RegularArray_getitem_jagged_expand": cuda_compute.awkward_RegularArray_getitem_jagged_expand, + "awkward_UnionArray_simplify_one": cuda_compute.awkward_UnionArray_simplify_one, + "awkward_ListArray_broadcast_tooffsets": cuda_compute.awkward_ListArray_broadcast_tooffsets, + "awkward_ListArray_localindex": cuda_compute.awkward_ListArray_localindex, + "awkward_ListArray_compact_offsets": cuda_compute.awkward_ListArray_compact_offsets, + "awkward_ListArray_combinations_length": cuda_compute.awkward_ListArray_combinations_length, + "awkward_ListArray_combinations": cuda_compute.awkward_ListArray_combinations, }.get(kernel_name) def prepare_reducer(self, reducer: ak._reducers.Reducer) -> ak._reducers.Reducer: diff --git a/src/awkward/_connect/cuda/_compute.py b/src/awkward/_connect/cuda/_compute.py index fcdc6459f9..fe30305386 100644 --- a/src/awkward/_connect/cuda/_compute.py +++ b/src/awkward/_connect/cuda/_compute.py @@ -4,7 +4,9 @@ from cuda.compute import ( CountingIterator, + DiscardIterator, OpKind, + inclusive_scan, segmented_reduce, unary_transform, ) @@ -772,6 +774,573 @@ def segment_reduce_countnonzero(segment_id): ) +# Overlays a mask onto an index array: masked positions become -1, unmasked positions keep their original index value. +def awkward_IndexedArray_overlay_mask(toindex, mask, fromindex, length): + def transform(i): + return -1 if mask[i] else fromindex[i] + + indices = CountingIterator(cp.int64(0)) + unary_transform(d_in=indices, d_out=toindex, op=transform, num_items=length) + + +# Skips masked (-1) entries and packs remaining valid entries into nextcarry, tracking where +# each ended up in outindex. Builds nextoffsets[j+1] = cumulative count of valid entries in +# segments 0..j as defined by the offsets array. +# +# Example: +# index = [3, -1, 5, -1, 2, -1, 4] +# offsets = [0, 4, 7] (2 segments: positions 0-3 and 4-6) +# outlength = 2 +# +# nextcarry = [3, 5, 2, 4] (valid index values, compacted) +# nextoffsets = [0, 2, 4] (segment 0 has 2 valid, segment 1 has 2 valid) +# outindex = [0, -1, 1, -1, 2, -1, 3] (position in nextcarry, or -1 if masked) +def awkward_IndexedArray_reduce_next_64( + nextcarry, nextoffsets, outindex, index, offsets, outlength +): + nextoffsets[0] = 0 + if outlength == 0: + return + + index_length = int(offsets[outlength]) + if index_length == 0: + nextoffsets[1 : outlength + 1] = 0 + return + + idx_dtype = index.dtype + valid = (index[:index_length] >= 0).astype(idx_dtype) + scan = cp.empty(index_length, dtype=idx_dtype) + inclusive_scan( + d_in=valid, + d_out=scan, + op=lambda a, b: a + b, + init_value=cp.array([0], dtype=idx_dtype), + num_items=index_length, + ) + + def scatter_and_fill(i): + if index[i] >= 0: + k = scan[i] - 1 + nextcarry[k] = index[i] + return k + return -1 + + unary_transform( + d_in=CountingIterator(idx_dtype.type(0)), + d_out=outindex, + op=scatter_and_fill, + num_items=index_length, + ) + + off_dtype = offsets.dtype.type + + def fill_nextoffsets(j): + stop = offsets[j + 1] + nextoffsets[j + 1] = idx_dtype.type(0) if stop == 0 else scan[stop - 1] + return off_dtype(0) + + unary_transform( + d_in=CountingIterator(off_dtype(0)), + d_out=DiscardIterator(), + op=fill_nextoffsets, + num_items=outlength, + ) + + +# For each valid (non-negative) entry at position i, records the number of null (negative) entries +# that appeared before it. The k-th valid entry gets nextshifts[k] = count of nulls before position i. +# For example, für index = [0, 1, 2, -1, 3, -1, 4] → nextshifts = [0, 0, 0, 1, 2]. +def awkward_IndexedArray_reduce_next_nonlocal_nextshifts_64(nextshifts, index, length): + if length == 0: + return + + index_slice = index[:length] + + # cumsum of (index < 0) gives the running null count at each position. + # this is basically equivalent to calling cuda.compute.inclusive_scan on index_slice < 0 + null_cumsum = cp.cumsum(index_slice < 0) + _ = cp.empty(length, dtype=cp.int64) + + def scatter(i): + null_count = null_cumsum[i] + if index_slice[i] >= 0: + nextshifts[i - null_count] = null_count # output slot = i - null_count + # return a dummy value otherwise + return cp.int64(0) + + indices = CountingIterator(cp.int64(0)) + unary_transform(d_in=indices, d_out=_, op=scatter, num_items=length) + + +# Packs valid entries (where (mask[i] != 0) == validwhen) into tocarry in order. +# Examples: +# mask = [0, 1, 0, 1, 1], validwhen=True → tocarry = [1, 3, 4] +# mask = [0, 1, 0, 1, 1], validwhen=False → tocarry = [0, 2] +# mask = [0, 1, 0, 1, 1, -1, 1], validwhen=True → tocarry = [1, 3, 4, 5, 6] +def awkward_ByteMaskedArray_getitem_nextcarry(tocarry, mask, length, validwhen): + if length == 0: + return + + # valid = ((mask[:length] != 0) == validwhen) + # valid[i] is 1 when the masked element passes the validwhen condition. + + # get the indices of the valid entries using cp.nonzero + valid_indices = cp.nonzero((mask[:length] != 0) == validwhen)[0] + # in case tocarry is not exactly the right size, allocate it in two steps like this + tocarry[: len(valid_indices)] = valid_indices + + +# Counts null (invalid) entries: positions where (mask[i] != 0) != validwhen. +# Examples: +# mask = [0, 1, 0, 1, 1], validwhen=True → numnull = 2 (positions 0 and 2 are null) +# mask = [0, 1, 0, 1, 1], validwhen=False → numnull = 3 (positions 1, 3 and 4 are null) +def awkward_ByteMaskedArray_numnull(numnull, mask, length, validwhen): + numnull[0] = cp.count_nonzero((mask[:length] != 0) != validwhen) + + +# Broadcasts a single jagged offset array across all rows of a regular array +# Example: +# singleoffsets = [0, 2, 5], regularsize = 2, regularlength = 3 +# multistarts = [0, 2, 0, 2, 0, 2] +# multistops = [2, 5, 2, 5, 2, 5] +def awkward_RegularArray_getitem_jagged_expand( + multistarts, multistops, singleoffsets, regularsize, regularlength +): + if regularlength == 0 or regularsize == 0: + return + + # Reshape as (regularlength, regularsize) views (no copy) and broadcast-assign + # singleoffsets[:-1] / singleoffsets[1:] across all rows. + multistarts.reshape(regularlength, regularsize)[:] = singleoffsets[:regularsize] + multistops.reshape(regularlength, regularsize)[:] = singleoffsets[ + 1 : regularsize + 1 + ] + + +# THIS KERNEL IS NOT USED (just for archive) +# Fills a tagged index for one union type: assigns a constant tag and +# sequential index into each segment defined by the starts/counts ranges +# Example input: +# tmpstarts = [0, 3], tag = 1, fromcounts = [3, 2] +# Example output: +# totags = [1, 1, 1, 1, 1] +# toindex = [0, 1, 2, 0, 1] +# also, the tmpstarts get rewritten with stops: tmpstarts = [3, 5] +def awkward_UnionArray_nestedfill_tags_index( + totags, toindex, tmpstarts, tag, fromcounts, length +): + if length == 0: + return + + starts = tmpstarts[:length] + counts = fromcounts[:length] + + # Total span of the output arrays we need to touch: + # the last segment's start + its count gives the furthest written position + total_size = int(starts[length - 1]) + int(counts[length - 1]) + + if total_size == 0: + return + + # +1 at each segment start, -1 just past each segment end. + # cumsum of this will later yield 1 inside any covered range, 0 in gaps. + diff = cp.zeros(total_size + 1, dtype=cp.int8) + + def scatter_and_update(i): + start = starts[i] + count = counts[i] + # Mark this segment's range in the difference array + diff[start] += cp.int8(1) + diff[start + count] -= cp.int8(1) + # update tmpstarts (for the next call of this kernel (for a different union type))? + tmpstarts[i] = start + count + return 0 + + # Scatter segment's ranges and update tmpstarts + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=DiscardIterator(), + op=scatter_and_update, + num_items=length, + ) + + # coverage[j] == 1 if position j falls inside any segment's range, 0 otherwise + coverage = cp.cumsum(diff[:total_size]) + + # scan[j] == local index of element j within its segment + # Since it's a cumsum, the first index starts from 1, 2, 3 ... + # so we'll have to -1 before writing it in toindex + scan = cp.cumsum(coverage, dtype=cp.int64) + + def fill(j): + if coverage[j]: + # Mark this position as belonging to the current tag + totags[j] = tag + toindex[j] = scan[j] - 1 + return 0 + + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=DiscardIterator(), + op=fill, + num_items=total_size, + ) + + +# For each position i where fromtags[i] == fromwhich, sets totags[i] = towhich and +# toindex[i] = fromindex[i] + base. Other positions are left unchanged. +# Example: +# fromtags = [0, 1, 0, 1, 0], fromindex = [0, 0, 1, 1, 2] +# fromwhich=1, towhich=2, base=10 +# totags = [0, 2, 0, 2, 0] +# toindex = [0, 10, 1, 11, 2] +def awkward_UnionArray_simplify_one( + totags, toindex, fromtags, fromindex, towhich, fromwhich, length, base +): + if length == 0: + return + + def transform(i): + if fromtags[i] == fromwhich: + totags[i] = towhich + toindex[i] = fromindex[i] + base + return 0 # discarded + + indices = CountingIterator(cp.int64(0)) + unary_transform( + d_in=indices, d_out=DiscardIterator(), op=transform, num_items=length + ) + + +# producing a carry index that maps each output element back to its position in the original content +# Example input: +# fromoffsets = [0, 3, 5], fromstarts = [10, 20], fromstops = [13, 22], lencontent = 25 +# Example output: +# i=0: range [10, 13) → [10, 11, 12] +# i=1: range [20, 22) → [20, 21] +# tocarry = [10, 11, 12, 20, 21] +def awkward_ListArray_broadcast_tooffsets( + tocarry, fromoffsets, offsetslength, fromstarts, fromstops, lencontent +): + if offsetslength <= 1: + return + + length = offsetslength - 1 + starts = fromstarts[:length] + stops = fromstops[:length] + # counts[i] = how many elements list i should have + counts = fromoffsets[1:offsetslength] - fromoffsets[:length] + + _K = "awkward_ListArray_broadcast_tooffsets" + if cp.any((starts != stops) & (stops > lencontent)): + raise ValueError(f"stops[i] > len(content) in compiled CUDA code ({_K})") + if cp.any(counts < 0): + raise ValueError( + f"broadcast's offsets must be monotonically increasing in compiled CUDA code ({_K})" + ) + if cp.any(stops - starts != counts): + raise ValueError(f"cannot broadcast nested list in compiled CUDA code ({_K})") + + # For each segment i, write the content indices starts[i], starts[i]+1, ..., stops[i]-1 + # into the contiguous output slice tocarry[fromoffsets[i] : fromoffsets[i+1]]. + def fill_list(i): + start = starts[i] + stop = stops[i] + for j in range(start, stop): + tocarry[fromoffsets[i] + j - start] = j + return 0 + + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=DiscardIterator(), + op=fill_list, + num_items=length, + ) + + +# For each segment i, it fills toindex with the local position of each element within that segment — i.e. 0, 1, 2, ... +# Example: +# offsets = [0, 3, 5] +# toindex = [0, 1, 2, 0, 1] +def awkward_ListArray_localindex(toindex, offsets, length): + if length == 0: + return + + starts = offsets[:length] + stops = offsets[1 : length + 1] + + def fill(i): + start = starts[i] + stop = stops[i] + for j in range(start, stop): + toindex[j] = j - start + return 0 + + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=DiscardIterator(), + op=fill, + num_items=length, + ) + + +# Converts a ListArray's (starts, stops) pairs into offsets. +# tooffsets[0] = 0, tooffsets[i+1] = tooffsets[i] + (fromstops[i] - fromstarts[i]) +# Example: +# fromstarts = [10, 20], fromstops = [13, 22], length = 2 +# tooffsets = [0, 3, 5] +def awkward_ListArray_compact_offsets(tooffsets, fromstarts, fromstops, length): + tooffsets[0] = 0 + if length == 0: + return + + starts = fromstarts[:length] + stops = fromstops[:length] + + if cp.any(stops < starts): + raise ValueError( + "stops[i] < starts[i] in compiled CUDA code (awkward_ListArray_compact_offsets)" + ) + + sizes = stops - starts + + # the same as `tooffsets[1 : length + 1] = cp.cumsum(sizes)` + inclusive_scan( + d_in=sizes, + d_out=tooffsets[1 : length + 1], + op=lambda a, b: a + b, + init_value=cp.array([0], dtype=tooffsets.dtype), + num_items=length, + ) + + +# For each list i, counts the number of n-combinations of its elements +# (with or without replacement) and builds an offsets array into tooffsets. +# totallen[0] is set to the total number of combinations across all lists. +# +# Example (n=2, replacement=False): +# starts=[0, 0, 0], stops=[2, 3, 4] +# sizes = [2, 3, 4] +# C(2,2)=1, C(3,2)=3, C(4,2)=6 +# Then the output will be: tooffsets = [0, 1, 4, 10] +# totallen = 10 +def awkward_ListArray_combinations_length( + totallen, tooffsets, n, replacement, starts, stops, length +): + tooffsets[0] = 0 + if length == 0: + totallen[0] = 0 + return + + def combinations_len(i): + size = stops[i] - starts[i] + if replacement: + size = size + (n - 1) + thisn = n + if thisn > size: + return 0 + elif thisn == size: + return 1 + else: + # C(size, n) == C(size, size-n), so use the smaller one + # of the two to minimise the number of loop iterations + if thisn * 2 > size: + thisn = size - thisn + + # Compute C(size, thisn) = size! / (thisn! * (size-thisn)!) incrementally: + # result = size * (size-1) * ... * (size-thisn+1) / thisn! + result = size + for j in range(2, thisn + 1): + result = result * (size - j + 1) + result = result // j + return result + + # Compute the number of combinations for each list + counts = cp.empty(length, dtype=tooffsets.dtype) + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=counts, + op=combinations_len, + num_items=length, + ) + + # Convert counts to offsets: + # tooffsets[i+1] = sum(counts[0..i]) + inclusive_scan( + d_in=counts, + d_out=tooffsets[1 : length + 1], + op=lambda a, b: a + b, + init_value=cp.array([0], dtype=tooffsets.dtype), + num_items=length, + ) + + # Total number of combinations across all lists + totallen[0] = tooffsets[length] + + +# For each list i, enumerates all n-combinations (with or without replacement) +# of its elements and writes the indices into n output carry arrays. +# +# tocarry_ptrs is a CuPy int64 array of length n holding raw device pointers; +# each pointer refers to a pre-allocated int64 array of length totallen. +# +# Example (n=2, replacement=False): +# starts=[0], stops=[3] → elements [0,1,2] +# C(3,2) = 3 combinations in total +# combinations: (0,1),(0,2),(1,2) +# +# Output: +# tocarry_ptrs[0] → [0, 0, 1], tocarry_ptrs[1] → [1, 2, 2] +# toindex: [3, 3] +def awkward_ListArray_combinations( + tocarry_ptrs, toindex, fromindex, n, replacement, starts, stops, length +): + if length == 0: + return + + # Step 1: compute per-list combination counts (same as combinations_length!!) + # TODO: we can just pass combination offsets directly in the future (from src/awkward/contents/listoffsetarray.py:1405) + def combinations_len(i): + size = stops[i] - starts[i] + if replacement: + size = size + (n - 1) + thisn = n + if thisn > size: + return 0 + elif thisn == size: + return 1 + else: + if thisn * 2 > size: + thisn = size - thisn + result = size + for j in range(2, thisn + 1): + result = result * (size - j + 1) + result = result // j + return result + + counts = cp.empty(length, dtype=cp.int64) + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=counts, + op=combinations_len, + num_items=length, + ) + + offsets = cp.empty(length + 1, dtype=cp.int64) + offsets[0] = 0 + inclusive_scan( + d_in=counts, + d_out=offsets[1:], + op=lambda a, b: a + b, + init_value=cp.array([0], dtype=cp.int64), + num_items=length, + ) + + totallen = int(offsets[length]) + if totallen == 0: + return + + # Step 2: wrap raw pointers from tocarry_ptrs into CuPy arrays + # raw int64 pointer values from tocarry_ptrs[k] can't be dereferenced inside a Numba closure, so + # we need this intermediate step + # + # (the pointers themselves are allocated at src/awkward/contents/listoffsetarray.py:1456-1464) + carry_arrays = [] + for k in range(n): + ptr_val = int(tocarry_ptrs[k]) + mem = cp.cuda.UnownedMemory(ptr_val, totallen * 8, None) + memptr = cp.cuda.MemoryPointer(mem, 0) + carry_arrays.append(cp.ndarray(totallen, dtype=cp.int64, memptr=memptr)) # pylint: disable=unexpected-keyword-arg + + # ------------------------------------------------------------------------- + # Step 3: fill carry_arrays[k] for each combination position k in turn. + # + # For each output slot g in [0, totallen): + # + # a) Binary search offsets to find which source list i owns slot g, + # and compute the rank of this combination within that list + # (rank = g - offsets[i], i.e. the 0-based index among all combinations + # of list i in lexicographic order). + # + # b) Unrank: decode the rank back into the actual combination tuple using + # a combinatorial number system. Iterating over positions pos=0..n-1, + # at each position scan forward through candidate values j, counting + # how many combinations start with values < j at this position + # (= C(effective_size-j-1, n-pos-1)). Subtract from remaining rank + # until we find the j where remaining < count — that j is the value + # at position pos. + # + # c) Early exit: once pos==k we have the value for position k and write + # it to carry_k[g], skipping the rest of the unranking. This is why + # we do n separate passes (one per k) rather than one pass writing all + # n positions: each pass only needs to unrank up to position k. + # + # d) Content index: add start (the list's base offset into content) to + # convert from a within-list index to an absolute content index. + # For replacement, subtract pos to undo the stars-and-bars shift. + # ------------------------------------------------------------------------- + def make_pass(k, carry_k): + def fill_pos(g): + # a) Find source list i via binary search on offsets + lo = 0 + hi = length - 1 + while lo < hi: + mid = (lo + hi) >> 1 + if offsets[mid + 1] <= g: + lo = mid + 1 + else: + hi = mid + list_i = lo + start = starts[list_i] + size = stops[list_i] - starts[list_i] + rank = g - offsets[list_i] + # For replacement use stars-and-bars effective size + effective_size = size + n - 1 if replacement else size + + # b) Unrank: decode rank into the combination tuple + lower = 0 # lower bound for j at each position (enforces ordering) + remaining = rank + for pos in range(n): + for j in range(lower, effective_size - (n - pos - 1)): + # Count combinations where position pos has value j: + # = C(effective_size - j - 1, n - pos - 1) + top = effective_size - j - 1 + choose = n - pos - 1 + if choose == 0: + count = 1 + else: + if choose * 2 > top: # use smaller equivalent + choose = top - choose + c = top + for q in range(2, choose + 1): + c = c * (top - q + 1) + c = c // q + count = c + if remaining < count: + # c) j is the value at position pos + if pos == k: + # d) write absolute content index and exit early + carry_k[g] = (j - pos if replacement else j) + start + return 0 + lower = j + 1 # next position must be >= j+1 (no repeat) + break + remaining -= count + return 0 + + return fill_pos + + # One parallel pass per combination position k + for k in range(n): + unary_transform( + d_in=CountingIterator(cp.int64(0)), + d_out=DiscardIterator(), + op=make_pass(k, carry_arrays[k]), + num_items=totallen, + ) + + toindex[:n] = totallen + + def awkward_index_rpad_and_clip_axis0(toindex, target, length): """ Fill ``toindex[0..target)`` with the identity mapping ``[0..shorter)``