Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1373,4 +1373,52 @@ std::string OpenAIChatCompletionsHandler::serializeStreamingUsageChunk() {
writer.EndObject(); // }
return buffer.GetString();
}

std::string OpenAIChatCompletionsHandler::serializeStreamingFirstTokenControlChunk() {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the context of FirstTokenControlChunk
what is the "control" aspect here?

OVMS_PROFILE_FUNCTION();
Document doc;
doc.SetObject();
Document::AllocatorType& allocator = doc.GetAllocator();

Value choices(kArrayType);
Value choice(kObjectType);

// choices: array of size N, where N is related to n request parameter
choices.SetArray();
choice.SetObject();

choice.AddMember("index", 0, allocator);
if (endpoint == Endpoint::CHAT_COMPLETIONS) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could document this behavior maybe in the API reference, so it's clear that we send that empty response (and only for CB pipelines right?)

Value delta(kObjectType);
delta.SetObject();
delta.AddMember("role", Value("assistant", allocator), allocator);
delta.AddMember("content", Value(rapidjson::kNullType), allocator);
choice.AddMember("delta", delta, allocator);
} else if (endpoint == Endpoint::COMPLETIONS) {
choice.AddMember("text", Value(rapidjson::kNullType), allocator);
}

choice.AddMember("finish_reason", Value(rapidjson::kNullType), allocator);
choices.PushBack(choice, allocator);

doc.AddMember("choices", choices, allocator);

// created: integer; Unix timestamp (in seconds) when the MP graph was created.
doc.AddMember("created", std::chrono::duration_cast<std::chrono::seconds>(created.time_since_epoch()).count(), allocator);

// model: string; copied from the request
doc.AddMember("model", Value(request.model.c_str(), allocator), allocator);

// object: string; defined that the type streamed chunk rather than complete response
if (endpoint == Endpoint::CHAT_COMPLETIONS) {
doc.AddMember("object", Value("chat.completion.chunk", allocator), allocator);
} else if (endpoint == Endpoint::COMPLETIONS) {
doc.AddMember("object", Value("text_completion.chunk", allocator), allocator);
}

StringBuffer buffer;
Writer<StringBuffer> writer(buffer);
doc.Accept(writer);
return buffer.GetString();
}
} // namespace ovms
1 change: 1 addition & 0 deletions src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,5 +127,6 @@ class OpenAIChatCompletionsHandler {
std::string serializeUnaryResponse(ov::genai::VLMDecodedResults& results);
std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason);
std::string serializeStreamingUsageChunk();
std::string serializeStreamingFirstTokenControlChunk();
};
} // namespace ovms
5 changes: 5 additions & 0 deletions src/llm/servable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptr<GenAiServable
executionContext->lastStreamerCallbackOutput = "";

std::string lastTextChunk = ss.str();
executionContext->loopIteration++;

ov::genai::GenerationFinishReason finishReason = generationOutput.finish_reason;
if (finishReason == ov::genai::GenerationFinishReason::NONE) { // continue
if (lastTextChunk.size() > 0) {
Expand All @@ -264,6 +266,9 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptr<GenAiServable
executionContext->response = wrapTextInServerSideEventMessage(serializedChunk);
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Generated subsequent streaming response: {}", executionContext->response);
}
} else if (executionContext->loopIteration <= 1) {
std::string serializedChunk = executionContext->apiHandler->serializeStreamingFirstTokenControlChunk();
executionContext->response = wrapTextInServerSideEventMessage(serializedChunk);
}
executionContext->sendLoopbackSignal = true;
} else { // finish generation
Expand Down
1 change: 1 addition & 0 deletions src/llm/servable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct GenAiServableExecutionContext {
std::shared_ptr<ov::genai::TextStreamer> textStreamer;
bool sendLoopbackSignal = false;
std::string lastStreamerCallbackOutput;
size_t loopIteration = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name does not explain the purpose to me. Also, couldn't this be a bool like decodingPhase? Or even an enum like RequestProcessingPhase.prefill / RequestProcessingPhase.decode - starting with prefill and switching to decode after first read finishes.

};

struct ExtraGenerationInfo {
Expand Down
70 changes: 56 additions & 14 deletions src/test/llm/llmnode_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct TestParameters {
bool checkLogprobs;
bool checkFinishReason;
bool testSpeculativeDecoding;
bool supportsEmptyControlMsg;
};

class LLMFlowHttpTest : public ::testing::Test {
Expand Down Expand Up @@ -193,6 +194,32 @@ TEST(OpenAiApiHandlerTest, writeLogprobs) {
}
*/

// Reusable helper: asserts that a streaming chat completion chunk is the initial
// initial empty message with role:assistant and content:null.
inline void assertInitialStreamChatCompletionChunk(const std::string& response, const std::string& expectedModel) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about test for completion endpoint?

const std::string dataPrefix = "data:";
ASSERT_GE(response.size(), dataPrefix.size());
ASSERT_EQ(response.substr(0, dataPrefix.size()), dataPrefix);
size_t pos = response.find("\n");
ASSERT_NE(pos, std::string::npos);
rapidjson::Document d;
rapidjson::ParseResult ok = d.Parse(response.substr(dataPrefix.size(), pos - dataPrefix.size()).c_str());
ASSERT_EQ(ok.Code(), 0);
ASSERT_TRUE(d.HasMember("choices"));
ASSERT_TRUE(d["choices"].IsArray());
ASSERT_EQ(d["choices"].Size(), 1);
const auto& choice = d["choices"][0];
ASSERT_EQ(choice["index"].GetInt(), 0);
ASSERT_TRUE(choice["finish_reason"].IsNull());
ASSERT_TRUE(choice["delta"].IsObject());
EXPECT_STREQ(choice["delta"]["role"].GetString(), "assistant");
ASSERT_TRUE(choice["delta"]["content"].IsNull());
ASSERT_TRUE(d.HasMember("created"));
ASSERT_TRUE(d["created"].IsInt());
EXPECT_STREQ(d["model"].GetString(), expectedModel.c_str());
EXPECT_STREQ(d["object"].GetString(), "chat.completion.chunk");
}

class LLMFlowHttpTestParameterized : public LLMFlowHttpTest, public ::testing::WithParamInterface<TestParameters> {};

TEST_P(LLMFlowHttpTestParameterized, unaryCompletionsJson) {
Expand Down Expand Up @@ -1676,7 +1703,14 @@ TEST_P(LLMFlowHttpTestParameterized, inferChatCompletionsStream) {
]
}
)";
ON_CALL(*writer, PartialReply).WillByDefault([this, &params](std::string response) {
int replyCounter = 0;
ON_CALL(*writer, PartialReply).WillByDefault([this, &params, &replyCounter](std::string response) {
if (replyCounter == 0 && params.supportsEmptyControlMsg) {
replyCounter++;
assertInitialStreamChatCompletionChunk(response, params.modelName);
return;
}
replyCounter++;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed?

rapidjson::Document d;
std::string dataPrefix = "data:";
ASSERT_STREQ(response.substr(0, dataPrefix.size()).c_str(), dataPrefix.c_str());
Expand Down Expand Up @@ -1829,8 +1863,16 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) {
ovms::StatusCode::PARTIAL_END);
SPDLOG_TRACE("After dispatch");

// Check if there is at least one response
ASSERT_GT(responses.size(), 0);
if (params.supportsEmptyControlMsg) {
// Check if there is more than 1 partial response - initial and at least one real response with stop string
ASSERT_GT(responses.size(), 1);

// Assert initial message with empty content
assertInitialStreamChatCompletionChunk(responses[0], params.modelName);
} else {
// For legacy there is no initial empty message
ASSERT_GT(responses.size(), 0);
}

if (params.checkFinishReason) {
ASSERT_TRUE(responses.back().find("\"finish_reason\":\"stop\"") != std::string::npos);
Expand All @@ -1845,7 +1887,7 @@ TEST_P(LLMFlowHttpTestParameterized, streamChatCompletionsSingleStopString) {
// or simply any token (or group of tokens) that has dot in a middle.

// Check for no existence of a dot:
for (size_t i = 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) {
for (size_t i = params.supportsEmptyControlMsg ? 1 : 0; i < responses.size() - numberOfLastResponsesToCheckForStopString; ++i) {
// Assert there is no dot '.' in the response

// Cut "data: " prefix
Expand Down Expand Up @@ -2554,11 +2596,11 @@ INSTANTIATE_TEST_SUITE_P(
LLMFlowHttpTestInstances,
LLMFlowHttpTestParameterized,
::testing::Values(
// params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding
TestParameters{"lm_cb_regular", true, true, true, false},
TestParameters{"lm_legacy_regular", false, false, false, false},
TestParameters{"vlm_cb_regular", false, true, true, false},
TestParameters{"vlm_legacy_regular", false, false, false, false}));
// params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty stop string
TestParameters{"lm_cb_regular", true, true, true, false, true},
TestParameters{"lm_legacy_regular", false, false, false, false, false},
TestParameters{"vlm_cb_regular", false, true, true, false, true},
TestParameters{"vlm_legacy_regular", false, false, false, false, false}));

const std::string validRequestBodyWithParameter(const std::string& modelName, const std::string& parameter, const std::string& value) {
std::string requestBody = R"(
Expand Down Expand Up @@ -3367,11 +3409,11 @@ INSTANTIATE_TEST_SUITE_P(
LLMHttpParametersValidationTestInstances,
LLMHttpParametersValidationTest,
::testing::Values(
// params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding
TestParameters{"lm_cb_regular", true, true, true, false},
TestParameters{"lm_legacy_regular", false, false, false, false},
TestParameters{"vlm_cb_regular", false, true, true, false},
TestParameters{"vlm_legacy_regular", false, false, false, false}));
// params: model name, generate expected output, check logprobs, check finish reason, test speculative decoding, supports empty control msg
TestParameters{"lm_cb_regular", true, true, true, false, true},
TestParameters{"lm_legacy_regular", false, false, false, false, false},
TestParameters{"vlm_cb_regular", false, true, true, false, true},
TestParameters{"vlm_legacy_regular", false, false, false, false, false}));

// Common tests for all pipeline types (testing logic executed prior pipeline type selection)
class LLMConfigHttpTest : public ::testing::Test {};
Expand Down