Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/cpp/disaggServerBenchmark.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -543,7 +543,7 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
std::nullopt); // cacheSalt
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
return request;
}
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -838,7 +838,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
std::nullopt); // cacheSalt
}

void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/tensorrt_llm/batch_manager/blockKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ using VecTokens = std::vector<TokenIdType>;
using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
using MmKey = tensorrt_llm::executor::MmKey;

//! \brief Generate the multimodal extra keys for a single KV cache block.
Expand All @@ -49,7 +48,8 @@ struct BlockKey
// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;
// Cache salt string. Used as part of the block key so blocks from different salts do not match.
std::optional<std::string> cacheSalt = std::nullopt;

BlockKey() = default;

Expand All @@ -64,12 +64,12 @@ struct BlockKey
}

explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
std::vector<MmKey> extraKeys = {}, std::optional<std::string> cacheSalt = std::nullopt)
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
, cacheSaltID{cacheSaltID}
, cacheSalt{std::move(cacheSalt)}
{
}

Expand All @@ -86,15 +86,15 @@ struct BlockKey
}

//! \brief Count the number of leading tokens that match between this key and \p other.
//! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSaltID differ, because those fields must
//! \details Returns 0 immediately when loraTaskId, extraKeys, or cacheSalt differ, because those fields must
//! match exactly before token content is considered.
//! \param other The key to compare against.
//! \return Number of leading uniqueTokens that are identical in both keys.
int numMatchingTokens(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && extraKeys == other.extraKeys
&& cacheSaltID == other.cacheSaltID)
&& cacheSalt == other.cacheSalt)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
Expand Down
1 change: 0 additions & 1 deletion cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;
using MmKey = tensorrt_llm::executor::MmKey;
using WindowSizeType = SizeType32;

Expand Down
34 changes: 18 additions & 16 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ class GenericLlmRequest
using MillisecondsType = std::chrono::milliseconds;
using TimePoint = std::chrono::time_point<std::chrono::steady_clock>;
using Duration = std::chrono::time_point<std::chrono::steady_clock>::duration;
using CacheSaltIDType = runtime::CacheSaltIDType;

GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
Expand Down Expand Up @@ -147,11 +146,12 @@ class GenericLlmRequest
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt,
std::optional<TimePoint> arrivalTime = std::nullopt,
std::optional<std::vector<std::tuple<std::string, int>>> agent_hierarchy = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalItemRunCuOffsets = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalRunPositions = std::nullopt,
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalRunLengths = std::nullopt)
std::optional<std::shared_ptr<std::vector<SizeType32>>> multimodalRunLengths = std::nullopt,
std::optional<std::string> cacheSalt = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -213,7 +213,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
, mCacheSalt(std::move(cacheSalt))
, mAgentHierarchy(std::move(agent_hierarchy))
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
Expand Down Expand Up @@ -242,7 +242,7 @@ class GenericLlmRequest
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
std::optional<std::string> cacheSalt = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens.size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -283,7 +283,7 @@ class GenericLlmRequest
, mContextPhaseParams(contextPhaseParams)
, mNumReturnSequences(numReturnSequences)
, mLanguageAdapterUid(languageAdapterUid)
, mCacheSaltID(cacheSaltID)
, mCacheSalt(std::move(cacheSalt))
{
if (mEncoderTokens.has_value())
{
Expand Down Expand Up @@ -323,7 +323,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(req.getGuidedDecodingParams())
, mLanguageAdapterUid(req.getLanguageAdapterUid())
, mAllottedTimeMs(req.getAllottedTimeMs())
, mCacheSaltID(req.getCacheSaltID())
, mCacheSalt(req.getCacheSalt())
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
Expand Down Expand Up @@ -1897,9 +1897,9 @@ class GenericLlmRequest
return mLanguageAdapterUid;
}

[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
[[nodiscard]] std::optional<std::string> getCacheSalt() const
{
return mCacheSaltID;
return mCacheSalt;
}

std::vector<SizeType32> getLanguageAdapterRouting(
Expand Down Expand Up @@ -2196,8 +2196,8 @@ class GenericLlmRequest

bool mUseDraftModel{false};

// Cache salt id for each request.
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};
// Cache salt string. Used in BlockKey hashing/matching and surfaced in KV cache events.
std::optional<std::string> mCacheSalt{std::nullopt};

std::optional<std::vector<std::tuple<std::string, int>>> mAgentHierarchy{std::nullopt};

Expand Down Expand Up @@ -2394,11 +2394,12 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt, std::optional<TimePoint> arrivalTime = std::nullopt,
std::optional<TimePoint> arrivalTime = std::nullopt,
std::optional<std::vector<std::tuple<std::string, int>>> agent_hierarchy = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalItemRunCuOffsets = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalRunPositions = std::nullopt,
std::optional<std::vector<SizeType32>> multimodalRunLengths = std::nullopt)
std::optional<std::vector<SizeType32>> multimodalRunLengths = std::nullopt,
std::optional<std::string> cacheSalt = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
Expand Down Expand Up @@ -2431,8 +2432,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID,
arrivalTime, std::move(agent_hierarchy),
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, arrivalTime,
std::move(agent_hierarchy),
multimodalItemRunCuOffsets.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalItemRunCuOffsets.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
Expand All @@ -2441,7 +2442,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
multimodalRunLengths.has_value()
? std::make_shared<std::vector<SizeType32>>(std::move(multimodalRunLengths.value()))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt))
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
std::move(cacheSalt))
{
}

Expand Down
15 changes: 9 additions & 6 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,9 @@ class Request
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
/// @param disaggRequestId Disaggregated request ID.
/// @param cacheSalt Optional cache salt string. If provided, KV cache blocks are tagged so reuse is limited to
/// requests with the same salt. The string is also surfaced in KV cache events. Defaults to std::nullopt.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
Expand Down Expand Up @@ -743,8 +744,7 @@ class Request
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt,
std::optional<IdType> disaggRequestId = std::nullopt);
std::optional<IdType> disaggRequestId = std::nullopt, std::optional<std::string> cacheSalt = std::nullopt);

/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
Expand Down Expand Up @@ -792,7 +792,7 @@ class Request
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::string> getCacheSalt() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;
[[nodiscard]] std::optional<IdType> getDisaggRequestId() const;

Expand Down Expand Up @@ -829,7 +829,7 @@ class Request
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);
void setCacheSalt(std::optional<std::string> cacheSalt);
void setDisaggRequestId(IdType disaggRequestId);

private:
Expand Down Expand Up @@ -1729,13 +1729,14 @@ struct KVCacheStoredBlockData

KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens tokens,
std::optional<tensorrt_llm::runtime::LoraTaskIdType> loraId, SizeType32 cacheLevel, SizeType32 priority,
std::vector<MmKey> mmKeys = {})
std::vector<MmKey> mmKeys = {}, std::optional<std::string> cacheSalt = std::nullopt)
: blockHash{blockHash}
, tokens{std::move(tokens)}
, loraId{loraId}
, cacheLevel{cacheLevel}
, priority{priority}
, mmKeys{std::move(mmKeys)}
, cacheSalt{std::move(cacheSalt)}
{
}

Expand All @@ -1751,6 +1752,8 @@ struct KVCacheStoredBlockData
SizeType32 priority;
/// @brief The multimodal keys of the block
std::vector<MmKey> mmKeys;
/// @brief The original cache salt string of the block, if any
std::optional<std::string> cacheSalt;
};

struct KVCacheStoredData
Expand Down
1 change: 0 additions & 1 deletion cpp/include/tensorrt_llm/executor/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = std::uint64_t;
using LogitsPostProcessor
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
Expand Down
1 change: 0 additions & 1 deletion cpp/include/tensorrt_llm/runtime/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ using TokenIdType = std::int32_t;
using LoraTaskIdType = std::uint64_t;
using TokenExtraIdType = std::uint64_t;
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
using CacheSaltIDType = std::uint64_t;

struct UniqueToken
{
Expand Down
10 changes: 5 additions & 5 deletions cpp/tensorrt_llm/batch_manager/blockKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,15 @@ std::vector<BlockKey> buildBlockKeys(
currentTokenIdx += uniqueTokens.size();

blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSalt());
}
return blockKeys;
}

bool BlockKey::operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens
&& extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
&& extraKeys == other.extraKeys && cacheSalt == other.cacheSalt);
}

BlockKey BlockKey::shorten(int newNumTokens) const
Expand All @@ -364,10 +364,10 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);

if (parentHash == 0 && blockKey.cacheSaltID)
if (parentHash == 0 && blockKey.cacheSalt)
{
// Only hashing the cache salt ID for the first block in the sequence
uint64_t c = blockKey.cacheSaltID.value();
// Only mix the cache salt into the hash for the first block in the sequence.
uint64_t c = static_cast<uint64_t>(std::hash<std::string>{}(blockKey.cacheSalt.value()));
seed = hash64Mix(c, seed);
}

Expand Down
8 changes: 4 additions & 4 deletions cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void CacheTransceiver::respondAndSendAsync(std::shared_ptr<LlmRequest> llmReques
return;
}
setContextState(llmRequest.get());
auto future = mCacheSender->sendAsync(*llmRequest);
auto future = mCacheSender->sendAsync(llmRequest);
mSenderFutures.emplace_back(std::move(llmRequest), std::move(future));
}

Expand All @@ -410,7 +410,7 @@ void CacheTransceiver::respondAndSendLayerWise(

llmRequest->setState(LlmRequestState::kDISAGG_CONTEXT_INIT_AND_TRANS);
setContextState(llmRequest.get());
auto future = mCacheSender->sendAsync(*llmRequest);
auto future = mCacheSender->sendAsync(llmRequest);
mSenderFutures.emplace_back(llmRequest, std::move(future));
}
}
Expand All @@ -419,7 +419,7 @@ void CacheTransceiver::requestAndReceiveSync(std::shared_ptr<LlmRequest> llmRequ
{
TLLM_CHECK(llmRequest && llmRequest->isGenerationOnlyRequest());
{
auto future = mCacheReceiver->receiveAsync(*llmRequest);
auto future = mCacheReceiver->receiveAsync(llmRequest);
future.get();
}
llmRequest->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_COMPLETE);
Expand All @@ -438,7 +438,7 @@ void CacheTransceiver::requestAndReceiveAsync(std::shared_ptr<LlmRequest> llmReq
return;
}

auto future = mCacheReceiver->receiveAsync(*llmRequest);
auto future = mCacheReceiver->receiveAsync(llmRequest);
auto* requestPtr = llmRequest.get();
mRequesterFutures.emplace_back(std::move(llmRequest), std::move(future));
requestPtr->setState(LlmRequestState::kDISAGG_GENERATION_TRANS_IN_PROGRESS);
Expand Down
Loading
Loading