Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 108 additions & 3 deletions ds4_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ static cudaStream_t g_model_prefetch_stream;
static cudaStream_t g_model_upload_stream;
static cublasHandle_t g_cublas;
static int g_cublas_ready;
static int g_cuda_sm_major;
static int g_quality_mode;

struct cuda_model_range {
Expand Down Expand Up @@ -504,6 +505,16 @@ static int cuda_q8_use_dp4a(void) {
return getenv("DS4_CUDA_NO_Q8_DP4A") == NULL;
}

static int cuda_skip_ordered_f16_matmul(void) {
if (getenv("DS4_CUDA_FORCE_ORDERED_F16_MATMUL") != NULL) return 0;
if (getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") != NULL) return 1;
/* Blackwell-class GPUs measured so far (Thor sm_110 and GB10 sm_121) run
* the regular 256-thread reduction faster than the ordered 32-thread decode
* path. Keep older architectures on the existing default unless explicitly
* overridden. */
return g_cuda_sm_major >= 11;
}

static int cuda_q8_f16_preload_allowed(const char *label, uint64_t in_dim, uint64_t out_dim) {
if (cuda_q8_label_is_attention_output(label) &&
getenv("DS4_CUDA_ATTENTION_OUTPUT_PRELOAD") == NULL &&
Expand Down Expand Up @@ -1207,6 +1218,7 @@ extern "C" int ds4_gpu_init(void) {
if (!cuda_ok(cudaSetDevice(dev), "set device")) return 0;
cudaDeviceProp prop;
if (cudaGetDeviceProperties(&prop, dev) == cudaSuccess) {
g_cuda_sm_major = prop.major;
fprintf(stderr, "ds4: CUDA backend initialized on %s (sm_%d%d)\n",
prop.name, prop.major, prop.minor);
}
Expand Down Expand Up @@ -5986,7 +5998,7 @@ extern "C" int ds4_gpu_matmul_f16_tensor(ds4_gpu_tensor *out, const void *model_
!serial_f16 &&
!serial_router &&
n_tok == 1u &&
getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") == NULL;
!cuda_skip_ordered_f16_matmul();
if (!serial_f16 && g_cublas_ready && n_tok > 1) {
const uint64_t xh_count = n_tok * in_dim;
__half *xh = (__half *)cuda_tmp_alloc(xh_count * sizeof(__half), "f16 gemm activations");
Expand Down Expand Up @@ -6047,7 +6059,7 @@ extern "C" int ds4_gpu_matmul_f16_pair_tensor(
getenv("DS4_CUDA_NO_F16_PAIR_MATMUL") != NULL ||
getenv("DS4_CUDA_SERIAL_F16_MATMUL") != NULL ||
getenv("DS4_CUDA_SERIAL_ROUTER") != NULL ||
getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") != NULL) {
cuda_skip_ordered_f16_matmul()) {
return ds4_gpu_matmul_f16_tensor(out0, model_map, model_size, weight0_offset,
in_dim, out_dim, x, n_tok) &&
ds4_gpu_matmul_f16_tensor(out1, model_map, model_size, weight1_offset,
Expand Down Expand Up @@ -9455,6 +9467,74 @@ __global__ static void moe_down_expert_tile16_row2048_kernel(
}
}

template <uint32_t ROW_SPAN>
__launch_bounds__(256, 2)
__global__ static void moe_down_expert_tile8_rowspan_kernel(
float *down_out,
const char *down_base,
const cuda_block_q8_K *midq,
const uint32_t *sorted_pairs,
const uint32_t *offsets,
const uint32_t *counts,
const uint32_t *tile_total,
const uint32_t *tile_experts,
const uint32_t *tile_starts,
uint64_t down_expert_bytes,
uint64_t down_row_bytes,
uint32_t midq_blocks,
uint32_t out_dim,
uint32_t n_expert,
uint32_t atomic_out) {
uint32_t tile = blockIdx.y;
if (tile >= *tile_total) return;
uint32_t lane = threadIdx.x & 7u;
uint32_t row_lane = threadIdx.x >> 3u;
uint32_t expert = tile_experts[tile];
uint32_t local_start = tile_starts[tile];
__shared__ cuda_block_q8_K sxq[8][8];
uint32_t pair[8] = {0};
const cuda_block_q8_K *xqb[8] = {NULL};
uint32_t np = 0;
for (; np < 8u; np++) {
uint32_t local_pair = local_start + np;
if (local_pair >= counts[expert]) break;
pair[np] = sorted_pairs[offsets[expert] + local_pair];
xqb[np] = midq + (uint64_t)pair[np] * midq_blocks;
}
if (midq_blocks <= 8u) {
for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) {
uint32_t p = i / midq_blocks;
uint32_t b = i - p * midq_blocks;
sxq[p][b] = xqb[p][b];
}
__syncthreads();
for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p];
}
for (uint32_t rr = 0; rr < ROW_SPAN / 32u; rr++) {
uint32_t row = blockIdx.x * ROW_SPAN + row_lane + rr * 32u;
if (row >= out_dim) continue;
const cuda_block_q2_K *wr = (const cuda_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes);
float acc[8] = {0.0f};
for (uint32_t b = lane; b < midq_blocks; b += 8u) {
dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL,
xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL,
xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL,
xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, acc);
}
for (uint32_t p = 0; p < np; p++) {
acc[p] = quarter_warp_sum_f32(acc[p], lane);
if (lane == 0) {
if (atomic_out) {
uint32_t tok = pair[p] / n_expert;
atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]);
} else {
down_out[(uint64_t)pair[p] * out_dim + row] = acc[p];
}
}
}
}
}

template <uint32_t ROW_SPAN>
__global__ static void moe_down_expert_tile16_rowspan_kernel(
float *down_out,
Expand Down Expand Up @@ -9826,6 +9906,8 @@ static int routed_moe_launch(
const uint32_t down_row_span =
getenv("DS4_CUDA_MOE_DOWN_ROW512") != NULL ? 512u :
getenv("DS4_CUDA_MOE_DOWN_ROW1024") != NULL ? 1024u : 2048u;
const uint32_t use_down_tile8_rowspan =
use_atomic_down && expert_tile_m == 8u && getenv("DS4_CUDA_MOE_DOWN_TILE8_ROWSPAN") != NULL;
const uint32_t use_down_row2048 = use_atomic_down && expert_tile_m == 8u &&
(getenv("DS4_CUDA_MOE_DOWN_ROW2048") != NULL ||
getenv("DS4_CUDA_MOE_DOWN_ROW256") != NULL ||
Expand Down Expand Up @@ -10125,7 +10207,30 @@ static int routed_moe_launch(
/* The direct decode kernel writes the final token row. */
} else if (sorted_pairs && use_expert_tiles && sorted_offsets && sorted_counts &&
down_tile_total && down_tile_experts && down_tile_starts) {
if (use_down_row2048) {
if (use_down_tile8_rowspan) {
if (down_row_span == 512u) {
dim3 tgrid((out_dim + 511u) / 512u, tile_capacity, 1);
moe_down_expert_tile8_rowspan_kernel<512><<<tgrid, 256>>>(
use_atomic_down ? (float *)out->ptr : (float *)down->ptr,
down_w, midq, sorted_pairs, sorted_offsets, sorted_counts,
tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes,
midq_blocks, out_dim, n_expert, use_atomic_down);
} else if (down_row_span == 1024u) {
dim3 tgrid((out_dim + 1023u) / 1024u, tile_capacity, 1);
moe_down_expert_tile8_rowspan_kernel<1024><<<tgrid, 256>>>(
use_atomic_down ? (float *)out->ptr : (float *)down->ptr,
down_w, midq, sorted_pairs, sorted_offsets, sorted_counts,
tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes,
midq_blocks, out_dim, n_expert, use_atomic_down);
} else {
dim3 tgrid((out_dim + 2047u) / 2048u, tile_capacity, 1);
moe_down_expert_tile8_rowspan_kernel<2048><<<tgrid, 256>>>(
use_atomic_down ? (float *)out->ptr : (float *)down->ptr,
down_w, midq, sorted_pairs, sorted_offsets, sorted_counts,
tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes,
midq_blocks, out_dim, n_expert, use_atomic_down);
}
} else if (use_down_row2048) {
if (down_row_span == 512u) {
dim3 tgrid((out_dim + 511u) / 512u, down_tile_capacity, 1);
moe_down_expert_tile16_rowspan_kernel<512><<<tgrid, 256>>>(
Expand Down