Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 45 additions & 21 deletions python/tokenspeed/runtime/layers/attention/deepseek_v4_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,28 +758,52 @@ def read_deepseek_v4_indexer_fp8_cache(
f"cache_2d must be [pages, >= {min_stride}], got {tuple(cache_2d.shape)}"
)

out = torch.zeros(
slot_mapping.numel(),
index_head_dim,
device=cache_2d.device,
dtype=torch.float32,
)
out_shape = (slot_mapping.numel(), index_head_dim)
# Also bail when cache_2d has zero pages: the gather below uses
# `where(valid, slots, 0)` to keep offsets in-range, but the resulting
# row-0 read still OOBs against an empty `flat_cache`. The reference
# per-token loop tolerated this (it iterates `slot_mapping.tolist()` and
# `continue`s on `slot < 0`), so preserve that behavior with zeros.
if slot_mapping.numel() == 0 or cache_2d.shape[0] == 0:
return torch.zeros(out_shape, device=cache_2d.device, dtype=torch.float32)

flat_cache = cache_2d.reshape(-1)
for token_idx, raw_slot in enumerate(slot_mapping.tolist()):
slot = int(raw_slot)
if slot < 0:
continue
page = slot // block_size
pos = slot % block_size
page_base = page * cache_2d.stride(0)
value_base = page_base + pos * index_head_dim
scale_base = page_base + block_size * index_head_dim + pos * scale_bytes
scale = flat_cache[scale_base : scale_base + scale_bytes].view(torch.float32)[0]
values = flat_cache[value_base : value_base + index_head_dim].view(
torch.float8_e4m3fn
)
out[token_idx].copy_(values.float() * scale)
return out
# Move slot_mapping to cache_2d.device so the gather offsets composed below
# (which mix it with torch.arange(device=cache_2d.device)) don't fail on
# cross-device tensors when the caller passes a CPU slot_mapping.
slots = slot_mapping.to(device=cache_2d.device, dtype=torch.int64)
valid = slots >= 0
safe_slots = torch.where(valid, slots, torch.zeros_like(slots))
pages = torch.div(safe_slots, block_size, rounding_mode="floor")
pos = safe_slots % block_size
page_base = pages * cache_2d.stride(0)
value_base = page_base + pos * index_head_dim
scale_base = page_base + block_size * index_head_dim + pos * scale_bytes

value_offsets = (
value_base[:, None]
+ torch.arange(
index_head_dim,
device=cache_2d.device,
dtype=torch.int64,
)[None, :]
)
values = flat_cache[value_offsets].view(torch.float8_e4m3fn).float()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Skip gathers when all slots are invalid

This vectorized path performs flat_cache[value_offsets] unconditionally, so when slot_mapping is entirely padding (<0) and the cache has zero pages, it still tries to read row 0 from an empty buffer and fails with an out-of-bounds index (on CUDA this can surface as a device-side assert). The previous loop-based implementation skipped invalid slots and returned zeros in this case, so this is a behavioral regression for padded/empty-cache inputs unless you short-circuit when valid.any() is false before computing gathers.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — fixed in 1e29154. Added a shape-only short-circuit that returns zeros when slot_mapping.numel() == 0 or cache_2d.shape[0] == 0. Kept it shape-only (vs valid.any()) so the check stays CUDA-graph-capture-safe — no host sync. Also flipped the early-return from torch.empty to torch.zeros so the contract holds in the "cache has zero pages, all slots are padding" case the reference loop tolerated.

The sibling read_deepseek_v4_indexer_mxfp4_cache has the same latent issue with the same shape — happy to file a follow-up PR for it.


scale_offsets = (
scale_base[:, None]
+ torch.arange(
scale_bytes,
device=cache_2d.device,
dtype=torch.int64,
)[None, :]
)
# scale_bytes is a multiple of 4 (FP32). The reference loop uses only the
# first FP32 per row (`view(torch.float32)[0]`); mirror that here.
scales = flat_cache[scale_offsets].view(torch.float32)[:, 0]

out = values * scales[:, None]
return torch.where(valid[:, None], out, torch.zeros_like(out))


def _compress_v4_state_windows_capturable(
Expand Down