Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions core/src/agents/processors/content_processor_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,17 +206,19 @@ function convertForeignEvent(event: Event): Event {
text: `[${event.author}] said: ${part.text}`,
});
} else if (part.functionCall) {
const argsText = safeStringify(part.functionCall.args);
content.parts?.push({
text: `[${event.author}] called tool \`${part.functionCall.name}\` with parameters: ${argsText}`,
text: `[${event.author}] called tool \`${part.functionCall.name}\` with parameters: ${safeStringify(
part.functionCall.args,
)}`,
});
} else if (part.functionResponse) {
const responseText = safeStringify(part.functionResponse.response);
content.parts?.push({
text: `[${event.author}] tool \`${part.functionResponse.name}\` returned result: ${responseText}`,
text: `[${event.author}] tool \`${part.functionResponse.name}\` returned result: ${safeStringify(
part.functionResponse.response,
)}`,
});
} else {
content.parts?.push(part);
content.parts?.push(cloneDeep(part));
}
}

Expand Down Expand Up @@ -254,15 +256,17 @@ function convertForeignEvent(event: Event): Event {
* 2. All non-function_response parts will be appended to the part list of
* the initial function_response event.
*/
function mergeFunctionResponseEvents(events: Event[]): Event {
export function mergeFunctionResponseEvents(events: Event[]): Event {
if (events.length === 0) {
throw new Error('Cannot merge an empty list of events.');
}

const mergedEvent = createEvent(events[0]);
const partsInMergedEvent = mergedEvent.content?.parts || [];

if (partsInMergedEvent.length === 0) {
const mergedEvent = createEvent({
...events[0],
content: events[0].content ? cloneDeep(events[0].content) : undefined,
});
const partsInMergedEvent = mergedEvent.content?.parts;
if (!partsInMergedEvent || partsInMergedEvent.length === 0) {
throw new Error('There should be at least one function_response part.');
}

Expand All @@ -279,17 +283,19 @@ function mergeFunctionResponseEvents(events: Event[]): Event {
throw new Error('There should be at least one function_response part.');
}
for (const part of event.content.parts) {
if (part.functionResponse && part.functionResponse.id) {
const functionCallId = part.functionResponse.id;
const clonedPart = cloneDeep(part);
if (clonedPart.functionResponse && clonedPart.functionResponse.id) {
const functionCallId = clonedPart.functionResponse.id;
if (functionCallId in partIndicesInMergedEvent) {
partsInMergedEvent[partIndicesInMergedEvent[functionCallId]] = part;
partsInMergedEvent[partIndicesInMergedEvent[functionCallId]] =
clonedPart;
} else {
partsInMergedEvent.push(part);
partsInMergedEvent.push(clonedPart);
partIndicesInMergedEvent[functionCallId] =
partsInMergedEvent.length - 1;
}
} else {
partsInMergedEvent.push(part);
partsInMergedEvent.push(clonedPart);
}
}
}
Expand All @@ -313,7 +319,7 @@ function rearrangeEventsForLatestFunctionResponse(events: Event[]): Event[] {
return events;
}

let functionResponsesIds = new Set<string>(
const functionResponsesIds = new Set<string>(
functionResponses
.filter((response): response is {id: string} => !!response.id)
.map((response) => response.id),
Expand All @@ -334,17 +340,18 @@ function rearrangeEventsForLatestFunctionResponse(events: Event[]): Event[] {
}

// Look for corresponding function call event reversely.
let functionCallEventIdx = -1;
let match: {eventIdx: number; responseIds: Set<string>} | undefined;

for (let idx = events.length - 2; idx >= 0; idx--) {
const event = events[idx];
const functionCalls = getFunctionCalls(event);
if (!functionCalls?.length) {
continue;
}

let matchedInEvent = false;
for (const functionCall of functionCalls) {
if (functionCall.id && functionResponsesIds.has(functionCall.id)) {
functionCallEventIdx = idx;
const functionCallIds = new Set<string>(
functionCalls.map((fc) => fc.id).filter((id): id is string => !!id),
);
Expand All @@ -364,16 +371,17 @@ function rearrangeEventsForLatestFunctionResponse(events: Event[]): Event[] {
` ids provided: ${Array.from(functionResponsesIds).join(', ')}`,
);
}
// Expand the function call events to collect all function responses
// from the function call event to the last response event.
// TODO - b/425992518: bad practice, state can mutated multiple times.
functionResponsesIds = functionCallIds;
match = {eventIdx: idx, responseIds: functionCallIds};
matchedInEvent = true;
break;
}
}
if (matchedInEvent) {
break;
}
}

if (functionCallEventIdx === -1) {
if (!match) {
throw new Error(
`No function call event found for function responses ids: ${Array.from(
functionResponsesIds,
Expand All @@ -384,21 +392,22 @@ function rearrangeEventsForLatestFunctionResponse(events: Event[]): Event[] {
// Collect all function response events between the function call event
// and the last function response event
const functionResponseEvents: Event[] = [];
for (let idx = functionCallEventIdx + 1; idx < events.length - 1; idx++) {
const activeResponses = match.responseIds;
for (let idx = match.eventIdx + 1; idx < events.length - 1; idx++) {
const event = events[idx];
const responses = getFunctionResponses(event);
if (
responses &&
responses.some(
(response) => response.id && functionResponsesIds.has(response.id),
(response) => response.id && activeResponses.has(response.id),
)
) {
functionResponseEvents.push(event);
}
}
functionResponseEvents.push(events[events.length - 1]);

const resultEvents = events.slice(0, functionCallEventIdx + 1);
const resultEvents = events.slice(0, match.eventIdx + 1);
resultEvents.push(mergeFunctionResponseEvents(functionResponseEvents));

return resultEvents;
Expand Down Expand Up @@ -468,7 +477,7 @@ function rearrangeEventsForAsyncFunctionResponsesInHistory(
}

if (functionResponseEventsIndices.size === 1) {
const [responseIndex] = [...functionResponseEventsIndices];
const [responseIndex] = Array.from(functionResponseEventsIndices);
resultEvents.push(events[responseIndex]);
} else {
const indicesArray = Array.from(functionResponseEventsIndices).sort(
Expand Down
Loading
Loading