-
Notifications
You must be signed in to change notification settings - Fork 239
Send empty message right after first token generation (continuous batching) #4020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1373,4 +1373,52 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingUsageChunk() { | |
| writer.EndObject(); // } | ||
| return buffer.GetString(); | ||
| } | ||
|
|
||
| std::string OpenAIChatCompletionsHandler::serializeStreamingFirstTokenControlChunk() { | ||
| OVMS_PROFILE_FUNCTION(); | ||
| Document doc; | ||
| doc.SetObject(); | ||
| Document::AllocatorType& allocator = doc.GetAllocator(); | ||
|
|
||
| Value choices(kArrayType); | ||
| Value choice(kObjectType); | ||
|
|
||
| // choices: array of size N, where N is related to n request parameter | ||
| choices.SetArray(); | ||
| choice.SetObject(); | ||
|
|
||
| choice.AddMember("index", 0, allocator); | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if (endpoint == Endpoint::CHAT_COMPLETIONS) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could document this behavior maybe in the API reference, so it's clear that we send that empty response (and only for CB pipelines right?) |
||
| Value delta(kObjectType); | ||
| delta.SetObject(); | ||
| delta.AddMember("role", Value("assistant", allocator), allocator); | ||
| delta.AddMember("content", Value(rapidjson::kNullType), allocator); | ||
| choice.AddMember("delta", delta, allocator); | ||
| } else if (endpoint == Endpoint::COMPLETIONS) { | ||
| choice.AddMember("text", Value(rapidjson::kNullType), allocator); | ||
| } | ||
|
|
||
| choice.AddMember("finish_reason", Value(rapidjson::kNullType), allocator); | ||
| choices.PushBack(choice, allocator); | ||
|
|
||
| doc.AddMember("choices", choices, allocator); | ||
|
|
||
| // created: integer; Unix timestamp (in seconds) when the MP graph was created. | ||
| doc.AddMember("created", std::chrono::duration_cast<std::chrono::seconds>(created.time_since_epoch()).count(), allocator); | ||
|
|
||
| // model: string; copied from the request | ||
| doc.AddMember("model", Value(request.model.c_str(), allocator), allocator); | ||
|
|
||
| // object: string; defined that the type streamed chunk rather than complete response | ||
| if (endpoint == Endpoint::CHAT_COMPLETIONS) { | ||
| doc.AddMember("object", Value("chat.completion.chunk", allocator), allocator); | ||
| } else if (endpoint == Endpoint::COMPLETIONS) { | ||
| doc.AddMember("object", Value("text_completion.chunk", allocator), allocator); | ||
| } | ||
|
|
||
| StringBuffer buffer; | ||
| Writer<StringBuffer> writer(buffer); | ||
| doc.Accept(writer); | ||
| return buffer.GetString(); | ||
| } | ||
dkalinowski marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } // namespace ovms | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -74,6 +74,7 @@ struct GenAiServableExecutionContext { | |
| std::shared_ptr<ov::genai::TextStreamer> textStreamer; | ||
| bool sendLoopbackSignal = false; | ||
| std::string lastStreamerCallbackOutput; | ||
| size_t loopIteration = 0; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This name does not explain the purpose to me. Also, couldn't this be a bool like |
||
| }; | ||
|
|
||
| struct ExtraGenerationInfo { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,7 @@ struct TestParameters { | |
| bool checkLogprobs; | ||
| bool checkFinishReason; | ||
| bool testSpeculativeDecoding; | ||
| bool supportsEmptyControlMsg; | ||
| }; | ||
|
|
||
| class LLMFlowHttpTest : public ::testing::Test { | ||
|
|
@@ -193,6 +194,32 @@ TEST(OpenAiApiHandlerTest, writeLogprobs) { | |
| } | ||
| */ | ||
|
|
||
| // Reusable helper: asserts that a streaming chat completion chunk is the initial | ||
| // initial empty message with role:assistant and content:null. | ||
| inline void assertInitialStreamChatCompletionChunk(const std::string& response, const std::string& expectedModel) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about test for completion endpoint? |
||
| const std::string dataPrefix = "data:"; | ||
| ASSERT_GE(response.size(), dataPrefix.size()); | ||
| ASSERT_EQ(response.substr(0, dataPrefix.size()), dataPrefix); | ||
| size_t pos = response.find("\n"); | ||
| ASSERT_NE(pos, std::string::npos); | ||
| rapidjson::Document d; | ||
| rapidjson::ParseResult ok = d.Parse(response.substr(dataPrefix.size(), pos - dataPrefix.size()).c_str()); | ||
| ASSERT_EQ(ok.Code(), 0); | ||
| ASSERT_TRUE(d.HasMember("choices")); | ||
| ASSERT_TRUE(d["choices"].IsArray()); | ||
| ASSERT_EQ(d["choices"].Size(), 1); | ||
| const auto& choice = d["choices"][0]; | ||
| ASSERT_EQ(choice["index"].GetInt(), 0); | ||
| ASSERT_TRUE(choice["finish_reason"].IsNull()); | ||
| ASSERT_TRUE(choice["delta"].IsObject()); | ||
| EXPECT_STREQ(choice["delta"]["role"].GetString(), "assistant"); | ||
| ASSERT_TRUE(choice["delta"]["content"].IsNull()); | ||
| ASSERT_TRUE(d.HasMember("created")); | ||
| ASSERT_TRUE(d["created"].IsInt()); | ||
| EXPECT_STREQ(d["model"].GetString(), expectedModel.c_str()); | ||
| EXPECT_STREQ(d["object"].GetString(), "chat.completion.chunk"); | ||
| } | ||
|
|
||
| class LLMFlowHttpTestParameterized : public LLMFlowHttpTest, public ::testing::WithParamInterface<TestParameters> {}; | ||
|
|
||
| TEST_P(LLMFlowHttpTestParameterized, unaryCompletionsJson) { | ||
|
|
@@ -1676,7 +1703,14 @@ TEST_P(LLMFlowHttpTestParameterized, inferChatCompletionsStream) { | |
| ] | ||
| } | ||
| )"; | ||
| ON_CALL(*writer, PartialReply).WillByDefault([this, ¶ms](std::string response) { | ||
| int replyCounter = 0; | ||
| ON_CALL(*writer, PartialReply).WillByDefault([this, ¶ms, &replyCounter](std::string response) { | ||
| if (replyCounter == 0 && params.supportsEmptyControlMsg) { | ||
| replyCounter++; | ||
| assertInitialStreamChatCompletionChunk(response, params.modelName); | ||
| return; | ||
| } | ||
| replyCounter++; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not needed? |
||
| rapidjson::Document d; | ||
| std::string dataPrefix = "data:"; | ||
| ASSERT_STREQ(response.substr(0, dataPrefix.size()).c_str(), dataPrefix.c_str()); | ||
|
|
@@ -1829,8 +1863,16 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) { | |
| ovms::StatusCode::PARTIAL_END); | ||
| SPDLOG_TRACE("After dispatch"); | ||
|
|
||
| // Check if there is at least one response | ||
| ASSERT_GT(responses.size(), 0); | ||
| if (params.supportsEmptyControlMsg) { | ||
| // Check if there is more than 1 partial response - initial and at least one real response with stop string | ||
| ASSERT_GT(responses.size(), 1); | ||
|
|
||
| // Assert initial message with empty content | ||
| assertInitialStreamChatCompletionChunk(responses[0], params.modelName); | ||
| } else { | ||
| // For legacy there is no initial empty message | ||
| ASSERT_GT(responses.size(), 0); | ||
| } | ||
|
|
||
| if (params.checkFinishReason) { | ||
| ASSERT_TRUE(responses.back().find("\"finish_reason\":\"stop\"") != std::string::npos); | ||
|
|
@@ -1845,7 +1887,7 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) { | |
| // or simply any token (or group of tokens) that has dot in a middle. | ||
|
|
||
| // Check for no existence of a dot: | ||
| for (size_t i = 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) { | ||
| for (size_t i = params.supportsEmptyControlMsg ? 1 : 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) { | ||
| // Assert there is no dot '.' in the response | ||
|
|
||
| // Cut "data: " prefix | ||
|
|
@@ -2554,11 +2596,11 @@ INSTANTIATE_TEST_SUITE_P( | |
| LLMFlowHttpTestInstances, | ||
| LLMFlowHttpTestParameterized, | ||
| ::testing::Values( | ||
| // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding | ||
| TestParameters{"lm_cb_regular", true, true, true, false}, | ||
| TestParameters{"lm_legacy_regular", false, false, false, false}, | ||
| TestParameters{"vlm_cb_regular", false, true, true, false}, | ||
| TestParameters{"vlm_legacy_regular", false, false, false, false})); | ||
| // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty stop string | ||
| TestParameters{"lm_cb_regular", true, true, true, false, true}, | ||
| TestParameters{"lm_legacy_regular", false, false, false, false, false}, | ||
| TestParameters{"vlm_cb_regular", false, true, true, false, true}, | ||
| TestParameters{"vlm_legacy_regular", false, false, false, false, false})); | ||
|
|
||
| const std::string validRequestBodyWithParameter(const std::string& modelName, const std::string& parameter, const std::string& value) { | ||
| std::string requestBody = R"( | ||
|
|
@@ -3367,11 +3409,11 @@ INSTANTIATE_TEST_SUITE_P( | |
| LLMHttpParametersValidationTestInstances, | ||
| LLMHttpParametersValidationTest, | ||
| ::testing::Values( | ||
| // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding | ||
| TestParameters{"lm_cb_regular", true, true, true, false}, | ||
| TestParameters{"lm_legacy_regular", false, false, false, false}, | ||
| TestParameters{"vlm_cb_regular", false, true, true, false}, | ||
| TestParameters{"vlm_legacy_regular", false, false, false, false})); | ||
| // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty control msg | ||
| TestParameters{"lm_cb_regular", true, true, true, false, true}, | ||
| TestParameters{"lm_legacy_regular", false, false, false, false, false}, | ||
| TestParameters{"vlm_cb_regular", false, true, true, false, true}, | ||
| TestParameters{"vlm_legacy_regular", false, false, false, false, false})); | ||
|
|
||
| // Common tests for all pipeline types (testing logic executed prior pipeline type selection) | ||
| class LLMConfigHttpTest : public ::testing::Test {}; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the context of
FirstTokenControlChunkwhat is the "control" aspect here?