Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/compute/attention_cublas.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ static cublasHandle_t get_attn_cublas_handle() {
return handle;
}

void attention_cublas_prewarm() {
// Force lazy-init of the static cuBLAS handle. The first
// cublasGemmBatchedEx call after this is guaranteed not to trigger
// cublasCreate's internal cudaMalloc (illegal under stream capture).
(void)get_attn_cublas_handle();
}

// ---------------------------------------------------------------------------
// Fused causal softmax FP32 → FP16: reads FP32 S matrix, writes FP16 probs
// to a separate output buffer. Replaces causal_softmax_fp32_inplace_kernel
Expand Down
6 changes: 6 additions & 0 deletions src/compute/attention_cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,10 @@ void attention_cublas_prefill(const Tensor& Q, const Tensor& K, const Tensor& V,
int n_heads, int n_kv_heads, int head_dim, float scale, bool causal,
float softcap = 0.0f, int q_offset = 0, cudaStream_t stream = nullptr);

// Force-create the static cuBLAS handle. Safe to call multiple times.
// Engine init calls this so the first attention_cublas_prefill invocation
// inside a captured stream can reuse the handle without cublasCreate
// (which does internal cudaMalloc for workspace and is illegal under capture).
void attention_cublas_prewarm();

} // namespace imp
29 changes: 28 additions & 1 deletion src/runtime/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "model/gguf_loader.h"
#include "model/chat_template.h"
#include "compute/gemm.h"
#include "compute/attention_cublas.h"
#include "compute/gemm_grouped.h"
#include "compute/sampling.h"
#include "compute/attention.h"
Expand Down Expand Up @@ -963,6 +964,7 @@ bool Engine::init(std::shared_ptr<Model> model, const EngineConfig& config) {
return false;
}
gemm_init();
attention_cublas_prewarm();
scheduler_ = std::make_unique<Scheduler>(config_.max_batch_size);
(void)stream_.create(cudaStreamNonBlocking);

Expand Down Expand Up @@ -2233,7 +2235,32 @@ void Engine::step_prefill_one(std::shared_ptr<Request>& req, int effective_chunk
executor_->use_workspace(0);
}
Tensor logits_out;
executor_->forward_logits(state, logits_out, pf_stream);

// Prefill graph capture (opt-in, Phase 4 of MoE-prefill-graphs work).
// Conditions: env-gated, pool path (stable device buffers), and
// chunk shape stable (in practice all non-last chunks share chunk_len
// = prefill_chunk_size). H2D upload happened above on pf_stream
// *before* this wrapper — captured region is forward_logits only,
// analogous to the decode graph pattern.
static const bool prefill_graph_enabled = (std::getenv("IMP_PREFILL_GRAPH") != nullptr);
const bool can_capture = prefill_graph_enabled && pf_pool_used && config_.use_cuda_graphs;
if (can_capture) {
const int block_count = static_cast<int>(block_table.size());
if (chunk_len != last_prefill_chunk_len_ || block_count != last_prefill_block_count_) {
prefill_graph_runner_.invalidate_for_update();
last_prefill_chunk_len_ = chunk_len;
last_prefill_block_count_ = block_count;
}
prefill_graph_runner_.set_decode_fn([this, &state, &logits_out](cudaStream_t s) {
executor_->forward_logits(state, logits_out, s);
});
prefill_graph_runner_.execute(pf_stream);
if (logits_out.data == nullptr) {
logits_out = executor_->get_logits_view(/*n_sequences=*/1);
}
} else {
executor_->forward_logits(state, logits_out, pf_stream);
}

if (!pf_pool_used) {
free_prefill_buffers(d_token_ids, d_positions, d_block_tables, d_context_lens, pf_stream);
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ class Engine {
static constexpr int kMaxGraphPoolSize = 32;
CudaGraphRunner decode_graph_pool_[kMaxGraphPoolSize]; // index = n_sequences - 1
int last_decode_max_blocks_per_graph_[kMaxGraphPoolSize] = {};

// Prefill graph runner — captures forward_logits for non-last chunks of
// chunked prefill. Single runner: in practice chunk_len == prefill_chunk_size
// for all non-last chunks, so per-shape variability collapses to one shape.
// Opt-in via IMP_PREFILL_GRAPH=1 (Phase 4 of MoE-prefill-graphs work).
CudaGraphRunner prefill_graph_runner_;
int last_prefill_chunk_len_ = -1;
int last_prefill_block_count_ = -1;
int32_t* h_sample_pinned_ = nullptr;
// Async conditional graph loop
CudaGraphConditionalRunner async_graph_runner_;
Expand Down
Loading