diff --git a/docs/model_server_rest_api_chat.md b/docs/model_server_rest_api_chat.md index 6ed3c48adb..8c4e1e5fdf 100644 --- a/docs/model_server_rest_api_chat.md +++ b/docs/model_server_rest_api_chat.md @@ -9,8 +9,11 @@ The endpoint is exposed via a path: http://server_name:port/v3/chat/completions -### Example request +::::{tab-set} +:::{tab-item} Unary +:sync: unary +**Request:** ``` curl http://localhost/v3/chat/completions \ -H "Content-Type: application/json" \ @@ -26,12 +29,11 @@ curl http://localhost/v3/chat/completions \ "content": "hello" } ], - stream: false + "stream": false }' ``` -### Example response - +**Response:** ```json { "choices": [ @@ -55,6 +57,52 @@ curl http://localhost/v3/chat/completions \ } } ``` +::: + +:::{tab-item} Stream +:sync: stream + +**Request:** +``` +curl http://localhost/v3/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "llama3", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "hello" + } + ], + "stream": true + }' +``` + +**Response:** + +- handshake +- reasoning +- actual content +- end of stream + +``` +data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null},"finish_reason":null}],"created":1772634283,"model":"llama3","object":"chat.completion.chunk"} + +data: {"choices":[{"index":0,"logprobs":null,"delta":{"reasoning_content":"Reasoning..."},"finish_reason":null}],"created":1772634283,"model":"llama3","object":"chat.completion.chunk"} + +data: {"choices":[{"index":0,"logprobs":null,"delta":{"content":"Hello!"},"finish_reason":null}],"created":1772634283,"model":"llama3","object":"chat.completion.chunk"} + +data: [DONE] +``` + +**Note**: First chunk contains role and content=`null` indicating first token has been generated. It is good indication for Time to First Token metric. Last chunk contains content with full message and `data: [DONE]` indicating end of generation. +::: +:::: + In case of VLM models, the request can include the images in three different formats: 1) Base64 encoding: @@ -242,7 +290,7 @@ If any of those parameters is not specified and request is made to Prompt Lookup | choices.index | ✅ | ✅ | integer | The index of the choice in the list of choices. | | choices.message | ✅ | ✅ | object | A chat completion message generated by the model. **When streaming, the field name is `delta` instead of `message`.** | | choices.message.role | ⚠️ | ✅ | string | The role of the author of this message. **_Currently hardcoded as `assistant`_** | -| choices.message.content | ✅ | ✅ | string | The contents of the message. | +| choices.message.content | ✅ | ✅ | string or null | The contents of the message | | choices.message.reasoning_content | ✅ | ❌ | string | If model supports reasoning and is deployed with appropriate response parser, the reasoning part of the output is stored in the field. | | choices.message.tool_calls | ✅ | ✅ | array | The tool calls generated by the model, such as function calls. | | choices.finish_reason | ✅ | ✅ | string or null | The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `tool_calls` if stopped due to a tool call, or `null` when generation continues (streaming). | diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 0402017564..6898b51604 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -1373,4 +1373,52 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingUsageChunk() { writer.EndObject(); // } return buffer.GetString(); } + +std::string OpenAIChatCompletionsHandler::serializeStreamingHandshakeChunk() { + OVMS_PROFILE_FUNCTION(); + Document doc; + doc.SetObject(); + Document::AllocatorType& allocator = doc.GetAllocator(); + + Value choices(kArrayType); + Value choice(kObjectType); + + // choices: array of size N, where N is related to n request parameter + choices.SetArray(); + choice.SetObject(); + + choice.AddMember("index", 0, allocator); + if (endpoint == Endpoint::CHAT_COMPLETIONS) { + Value delta(kObjectType); + delta.SetObject(); + delta.AddMember("role", Value("assistant", allocator), allocator); + delta.AddMember("content", Value(rapidjson::kNullType), allocator); + choice.AddMember("delta", delta, allocator); + } else if (endpoint == Endpoint::COMPLETIONS) { + choice.AddMember("text", Value(rapidjson::kNullType), allocator); + } + + choice.AddMember("finish_reason", Value(rapidjson::kNullType), allocator); + choices.PushBack(choice, allocator); + + doc.AddMember("choices", choices, allocator); + + // created: integer; Unix timestamp (in seconds) when the MP graph was created. + doc.AddMember("created", std::chrono::duration_cast(created.time_since_epoch()).count(), allocator); + + // model: string; copied from the request + doc.AddMember("model", Value(request.model.c_str(), allocator), allocator); + + // object: string; defined that the type streamed chunk rather than complete response + if (endpoint == Endpoint::CHAT_COMPLETIONS) { + doc.AddMember("object", Value("chat.completion.chunk", allocator), allocator); + } else if (endpoint == Endpoint::COMPLETIONS) { + doc.AddMember("object", Value("text_completion.chunk", allocator), allocator); + } + + StringBuffer buffer; + Writer writer(buffer); + doc.Accept(writer); + return buffer.GetString(); +} } // namespace ovms diff --git a/src/llm/apis/openai_completions.hpp b/src/llm/apis/openai_completions.hpp index 0b513fd528..516133f03a 100644 --- a/src/llm/apis/openai_completions.hpp +++ b/src/llm/apis/openai_completions.hpp @@ -127,5 +127,6 @@ class OpenAIChatCompletionsHandler { std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results); std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason); std::string serializeStreamingUsageChunk(); + std::string serializeStreamingHandshakeChunk(); }; } // namespace ovms diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 75480efe37..6d9810ae5f 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -256,6 +256,12 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptrlastStreamerCallbackOutput = ""; std::string lastTextChunk = ss.str(); + + bool isFirstToken = GenerationPhase::INPUT_TOKEN_PROCESSING == executionContext->generationPhase; + if (isFirstToken) { + executionContext->generationPhase = GenerationPhase::OUTPUT_TOKEN_PROCESSING; + } + ov::genai::GenerationFinishReason finishReason = generationOutput.finish_reason; if (finishReason == ov::genai::GenerationFinishReason::NONE) { // continue if (lastTextChunk.size() > 0) { @@ -264,6 +270,9 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptrresponse = wrapTextInServerSideEventMessage(serializedChunk); SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Generated subsequent streaming response: {}", executionContext->response); } + } else if (isFirstToken) { + std::string serializedChunk = executionContext->apiHandler->serializeStreamingHandshakeChunk(); + executionContext->response = wrapTextInServerSideEventMessage(serializedChunk); } executionContext->sendLoopbackSignal = true; } else { // finish generation diff --git a/src/llm/servable.hpp b/src/llm/servable.hpp index 83fa4eee5d..e4a5dd5ee2 100644 --- a/src/llm/servable.hpp +++ b/src/llm/servable.hpp @@ -59,6 +59,11 @@ Instance of this class is created for each request and is passed through multipl Note that GenAiServableExecutionContext pointer is the only parameter most of the GenAiServable methods take. */ +enum class GenerationPhase { + INPUT_TOKEN_PROCESSING, + OUTPUT_TOKEN_PROCESSING, +}; + struct GenAiServableExecutionContext { // Common API related members ovms::HttpPayload payload; @@ -74,6 +79,7 @@ struct GenAiServableExecutionContext { std::shared_ptr textStreamer; bool sendLoopbackSignal = false; std::string lastStreamerCallbackOutput; + GenerationPhase generationPhase = GenerationPhase::INPUT_TOKEN_PROCESSING; }; struct ExtraGenerationInfo { diff --git a/src/test/llm/llmnode_test.cpp b/src/test/llm/llmnode_test.cpp index 2e52e4fa59..48ec80d112 100644 --- a/src/test/llm/llmnode_test.cpp +++ b/src/test/llm/llmnode_test.cpp @@ -61,6 +61,7 @@ struct TestParameters { bool checkLogprobs; bool checkFinishReason; bool testSpeculativeDecoding; + bool checkHandshakeChunk; }; class LLMFlowHttpTest : public ::testing::Test { @@ -193,6 +194,32 @@ TEST(OpenAiApiHandlerTest, writeLogprobs) { } */ +// Reusable helper: asserts that a streaming chat completion chunk is the initial +// initial empty message with role:assistant and content:null. +inline void assertInitialStreamChatCompletionChunk(const std::string& response, const std::string& expectedModel) { + const std::string dataPrefix = "data:"; + ASSERT_GE(response.size(), dataPrefix.size()); + ASSERT_EQ(response.substr(0, dataPrefix.size()), dataPrefix); + size_t pos = response.find("\n"); + ASSERT_NE(pos, std::string::npos); + rapidjson::Document d; + rapidjson::ParseResult ok = d.Parse(response.substr(dataPrefix.size(), pos - dataPrefix.size()).c_str()); + ASSERT_EQ(ok.Code(), 0); + ASSERT_TRUE(d.HasMember("choices")); + ASSERT_TRUE(d["choices"].IsArray()); + ASSERT_EQ(d["choices"].Size(), 1); + const auto& choice = d["choices"][0]; + ASSERT_EQ(choice["index"].GetInt(), 0); + ASSERT_TRUE(choice["finish_reason"].IsNull()); + ASSERT_TRUE(choice["delta"].IsObject()); + EXPECT_STREQ(choice["delta"]["role"].GetString(), "assistant"); + ASSERT_TRUE(choice["delta"]["content"].IsNull()); + ASSERT_TRUE(d.HasMember("created")); + ASSERT_TRUE(d["created"].IsInt()); + EXPECT_STREQ(d["model"].GetString(), expectedModel.c_str()); + EXPECT_STREQ(d["object"].GetString(), "chat.completion.chunk"); +} + class LLMFlowHttpTestParameterized : public LLMFlowHttpTest, public ::testing::WithParamInterface {}; TEST_P(LLMFlowHttpTestParameterized, unaryCompletionsJson) { @@ -1676,7 +1703,13 @@ TEST_P(LLMFlowHttpTestParameterized, inferChatCompletionsStream) { ] } )"; - ON_CALL(*writer, PartialReply).WillByDefault([this, ¶ms](std::string response) { + int replyCounter = 0; + ON_CALL(*writer, PartialReply).WillByDefault([this, ¶ms, &replyCounter](std::string response) { + if (replyCounter == 0 && params.checkHandshakeChunk) { + replyCounter++; + assertInitialStreamChatCompletionChunk(response, params.modelName); + return; + } rapidjson::Document d; std::string dataPrefix = "data:"; ASSERT_STREQ(response.substr(0, dataPrefix.size()).c_str(), dataPrefix.c_str()); @@ -1829,8 +1862,16 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) { ovms::StatusCode::PARTIAL_END); SPDLOG_TRACE("After dispatch"); - // Check if there is at least one response - ASSERT_GT(responses.size(), 0); + if (params.checkHandshakeChunk) { + // Check if there is more than 1 partial response - initial and at least one real response with stop string + ASSERT_GT(responses.size(), 1); + + // Assert initial message with empty content + assertInitialStreamChatCompletionChunk(responses[0], params.modelName); + } else { + // For legacy there is no initial empty message + ASSERT_GT(responses.size(), 0); + } if (params.checkFinishReason) { ASSERT_TRUE(responses.back().find("\"finish_reason\":\"stop\"") != std::string::npos); @@ -1845,7 +1886,7 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) { // or simply any token (or group of tokens) that has dot in a middle. // Check for no existence of a dot: - for (size_t i = 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) { + for (size_t i = params.checkHandshakeChunk ? 1 : 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) { // Assert there is no dot '.' in the response // Cut "data: " prefix @@ -2554,11 +2595,11 @@ INSTANTIATE_TEST_SUITE_P( LLMFlowHttpTestInstances, LLMFlowHttpTestParameterized, ::testing::Values( - // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding - TestParameters{"lm_cb_regular", true, true, true, false}, - TestParameters{"lm_legacy_regular", false, false, false, false}, - TestParameters{"vlm_cb_regular", false, true, true, false}, - TestParameters{"vlm_legacy_regular", false, false, false, false})); + // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty stop string + TestParameters{"lm_cb_regular", true, true, true, false, true}, + TestParameters{"lm_legacy_regular", false, false, false, false, false}, + TestParameters{"vlm_cb_regular", false, true, true, false, true}, + TestParameters{"vlm_legacy_regular", false, false, false, false, false})); const std::string validRequestBodyWithParameter(const std::string& modelName, const std::string& parameter, const std::string& value) { std::string requestBody = R"( @@ -3367,11 +3408,11 @@ INSTANTIATE_TEST_SUITE_P( LLMHttpParametersValidationTestInstances, LLMHttpParametersValidationTest, ::testing::Values( - // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding - TestParameters{"lm_cb_regular", true, true, true, false}, - TestParameters{"lm_legacy_regular", false, false, false, false}, - TestParameters{"vlm_cb_regular", false, true, true, false}, - TestParameters{"vlm_legacy_regular", false, false, false, false})); + // params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty control msg + TestParameters{"lm_cb_regular", true, true, true, false, true}, + TestParameters{"lm_legacy_regular", false, false, false, false, false}, + TestParameters{"vlm_cb_regular", false, true, true, false, true}, + TestParameters{"vlm_legacy_regular", false, false, false, false, false})); // Common tests for all pipeline types (testing logic executed prior pipeline type selection) class LLMConfigHttpTest : public ::testing::Test {};