diff --git a/ds4_cuda.cu b/ds4_cuda.cu index ce18d55c..90c00bd1 100644 --- a/ds4_cuda.cu +++ b/ds4_cuda.cu @@ -85,6 +85,7 @@ static cudaStream_t g_model_prefetch_stream; static cudaStream_t g_model_upload_stream; static cublasHandle_t g_cublas; static int g_cublas_ready; +static int g_cuda_sm_major; static int g_quality_mode; struct cuda_model_range { @@ -504,6 +505,16 @@ static int cuda_q8_use_dp4a(void) { return getenv("DS4_CUDA_NO_Q8_DP4A") == NULL; } +static int cuda_skip_ordered_f16_matmul(void) { + if (getenv("DS4_CUDA_FORCE_ORDERED_F16_MATMUL") != NULL) return 0; + if (getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") != NULL) return 1; + /* Blackwell-class GPUs measured so far (Thor sm_110 and GB10 sm_121) run + * the regular 256-thread reduction faster than the ordered 32-thread decode + * path. Keep older architectures on the existing default unless explicitly + * overridden. */ + return g_cuda_sm_major >= 11; +} + static int cuda_q8_f16_preload_allowed(const char *label, uint64_t in_dim, uint64_t out_dim) { if (cuda_q8_label_is_attention_output(label) && getenv("DS4_CUDA_ATTENTION_OUTPUT_PRELOAD") == NULL && @@ -1207,6 +1218,7 @@ extern "C" int ds4_gpu_init(void) { if (!cuda_ok(cudaSetDevice(dev), "set device")) return 0; cudaDeviceProp prop; if (cudaGetDeviceProperties(&prop, dev) == cudaSuccess) { + g_cuda_sm_major = prop.major; fprintf(stderr, "ds4: CUDA backend initialized on %s (sm_%d%d)\n", prop.name, prop.major, prop.minor); } @@ -5986,7 +5998,7 @@ extern "C" int ds4_gpu_matmul_f16_tensor(ds4_gpu_tensor *out, const void *model_ !serial_f16 && !serial_router && n_tok == 1u && - getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") == NULL; + !cuda_skip_ordered_f16_matmul(); if (!serial_f16 && g_cublas_ready && n_tok > 1) { const uint64_t xh_count = n_tok * in_dim; __half *xh = (__half *)cuda_tmp_alloc(xh_count * sizeof(__half), "f16 gemm activations"); @@ -6047,7 +6059,7 @@ extern "C" int ds4_gpu_matmul_f16_pair_tensor( getenv("DS4_CUDA_NO_F16_PAIR_MATMUL") != NULL || getenv("DS4_CUDA_SERIAL_F16_MATMUL") != NULL || getenv("DS4_CUDA_SERIAL_ROUTER") != NULL || - getenv("DS4_CUDA_NO_ORDERED_F16_MATMUL") != NULL) { + cuda_skip_ordered_f16_matmul()) { return ds4_gpu_matmul_f16_tensor(out0, model_map, model_size, weight0_offset, in_dim, out_dim, x, n_tok) && ds4_gpu_matmul_f16_tensor(out1, model_map, model_size, weight1_offset, @@ -9455,6 +9467,74 @@ __global__ static void moe_down_expert_tile16_row2048_kernel( } } +template +__launch_bounds__(256, 2) +__global__ static void moe_down_expert_tile8_rowspan_kernel( + float *down_out, + const char *down_base, + const cuda_block_q8_K *midq, + const uint32_t *sorted_pairs, + const uint32_t *offsets, + const uint32_t *counts, + const uint32_t *tile_total, + const uint32_t *tile_experts, + const uint32_t *tile_starts, + uint64_t down_expert_bytes, + uint64_t down_row_bytes, + uint32_t midq_blocks, + uint32_t out_dim, + uint32_t n_expert, + uint32_t atomic_out) { + uint32_t tile = blockIdx.y; + if (tile >= *tile_total) return; + uint32_t lane = threadIdx.x & 7u; + uint32_t row_lane = threadIdx.x >> 3u; + uint32_t expert = tile_experts[tile]; + uint32_t local_start = tile_starts[tile]; + __shared__ cuda_block_q8_K sxq[8][8]; + uint32_t pair[8] = {0}; + const cuda_block_q8_K *xqb[8] = {NULL}; + uint32_t np = 0; + for (; np < 8u; np++) { + uint32_t local_pair = local_start + np; + if (local_pair >= counts[expert]) break; + pair[np] = sorted_pairs[offsets[expert] + local_pair]; + xqb[np] = midq + (uint64_t)pair[np] * midq_blocks; + } + if (midq_blocks <= 8u) { + for (uint32_t i = threadIdx.x; i < np * midq_blocks; i += blockDim.x) { + uint32_t p = i / midq_blocks; + uint32_t b = i - p * midq_blocks; + sxq[p][b] = xqb[p][b]; + } + __syncthreads(); + for (uint32_t p = 0; p < np; p++) xqb[p] = sxq[p]; + } + for (uint32_t rr = 0; rr < ROW_SPAN / 32u; rr++) { + uint32_t row = blockIdx.x * ROW_SPAN + row_lane + rr * 32u; + if (row >= out_dim) continue; + const cuda_block_q2_K *wr = (const cuda_block_q2_K *)(down_base + (uint64_t)expert * down_expert_bytes + (uint64_t)row * down_row_bytes); + float acc[8] = {0.0f}; + for (uint32_t b = lane; b < midq_blocks; b += 8u) { + dev_dot_q2_K_q8_K_block8(wr + b, xqb[0] ? xqb[0] + b : NULL, xqb[1] ? xqb[1] + b : NULL, + xqb[2] ? xqb[2] + b : NULL, xqb[3] ? xqb[3] + b : NULL, + xqb[4] ? xqb[4] + b : NULL, xqb[5] ? xqb[5] + b : NULL, + xqb[6] ? xqb[6] + b : NULL, xqb[7] ? xqb[7] + b : NULL, np, acc); + } + for (uint32_t p = 0; p < np; p++) { + acc[p] = quarter_warp_sum_f32(acc[p], lane); + if (lane == 0) { + if (atomic_out) { + uint32_t tok = pair[p] / n_expert; + atomicAdd(down_out + (uint64_t)tok * out_dim + row, acc[p]); + } else { + down_out[(uint64_t)pair[p] * out_dim + row] = acc[p]; + } + } + } + } +} + template __global__ static void moe_down_expert_tile16_rowspan_kernel( float *down_out, @@ -9826,6 +9906,8 @@ static int routed_moe_launch( const uint32_t down_row_span = getenv("DS4_CUDA_MOE_DOWN_ROW512") != NULL ? 512u : getenv("DS4_CUDA_MOE_DOWN_ROW1024") != NULL ? 1024u : 2048u; + const uint32_t use_down_tile8_rowspan = + use_atomic_down && expert_tile_m == 8u && getenv("DS4_CUDA_MOE_DOWN_TILE8_ROWSPAN") != NULL; const uint32_t use_down_row2048 = use_atomic_down && expert_tile_m == 8u && (getenv("DS4_CUDA_MOE_DOWN_ROW2048") != NULL || getenv("DS4_CUDA_MOE_DOWN_ROW256") != NULL || @@ -10125,7 +10207,30 @@ static int routed_moe_launch( /* The direct decode kernel writes the final token row. */ } else if (sorted_pairs && use_expert_tiles && sorted_offsets && sorted_counts && down_tile_total && down_tile_experts && down_tile_starts) { - if (use_down_row2048) { + if (use_down_tile8_rowspan) { + if (down_row_span == 512u) { + dim3 tgrid((out_dim + 511u) / 512u, tile_capacity, 1); + moe_down_expert_tile8_rowspan_kernel<512><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else if (down_row_span == 1024u) { + dim3 tgrid((out_dim + 1023u) / 1024u, tile_capacity, 1); + moe_down_expert_tile8_rowspan_kernel<1024><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } else { + dim3 tgrid((out_dim + 2047u) / 2048u, tile_capacity, 1); + moe_down_expert_tile8_rowspan_kernel<2048><<>>( + use_atomic_down ? (float *)out->ptr : (float *)down->ptr, + down_w, midq, sorted_pairs, sorted_offsets, sorted_counts, + tile_total, tile_experts, tile_starts, down_expert_bytes, down_row_bytes, + midq_blocks, out_dim, n_expert, use_atomic_down); + } + } else if (use_down_row2048) { if (down_row_span == 512u) { dim3 tgrid((out_dim + 511u) / 512u, down_tile_capacity, 1); moe_down_expert_tile16_rowspan_kernel<512><<>>(