Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions core/src/plugins/base_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ export enum ContextCompactionTrigger {
Auto = 'Auto',
}

/**
* Optional control flag a plugin callback may merge into its return value to
* decide whether to short-circuit the callback chain. Omitting the flag, or
* setting it to `true`, keeps the default behavior: the returned value
* short-circuits the remaining plugins and is propagated up. Setting it to
* `false` lets the chain continue, forwarding the rest of the payload to the
* next plugin.
*/
type ShortCircuitControl = {
shortCircuit?: boolean;
};

/**
* Base class for creating plugins.
*
Expand Down Expand Up @@ -135,7 +147,7 @@ export abstract class BasePlugin {
async onUserMessageCallback(params: {
invocationContext: InvocationContext;
userMessage: Content;
}): Promise<Content | undefined> {
}): Promise<(Content & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -154,7 +166,7 @@ export abstract class BasePlugin {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
async beforeRunCallback(params: {
invocationContext: InvocationContext;
}): Promise<Content | undefined> {
}): Promise<(Content & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -174,7 +186,7 @@ export abstract class BasePlugin {
async onEventCallback(params: {
invocationContext: InvocationContext;
event: Event;
}): Promise<Event | undefined> {
}): Promise<(Event & ShortCircuitControl) | undefined> {
return;
}

Expand Down Expand Up @@ -210,7 +222,7 @@ export abstract class BasePlugin {
async beforeAgentCallback(params: {
agent: BaseAgent;
callbackContext: Context;
}): Promise<Content | undefined> {
}): Promise<(Content & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -230,7 +242,7 @@ export abstract class BasePlugin {
async afterAgentCallback(params: {
agent: BaseAgent;
callbackContext: Context;
}): Promise<Content | undefined> {
}): Promise<(Content & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -251,7 +263,7 @@ export abstract class BasePlugin {
async beforeModelCallback(params: {
callbackContext: Context;
llmRequest: LlmRequest;
}): Promise<LlmResponse | undefined> {
}): Promise<(LlmResponse & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -271,7 +283,7 @@ export abstract class BasePlugin {
async afterModelCallback(params: {
callbackContext: Context;
llmResponse: LlmResponse;
}): Promise<LlmResponse | undefined> {
}): Promise<(LlmResponse & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -294,7 +306,7 @@ export abstract class BasePlugin {
callbackContext: Context;
llmRequest: LlmRequest;
error: Error;
}): Promise<LlmResponse | undefined> {
}): Promise<(LlmResponse & ShortCircuitControl) | undefined> {
return;
}

Expand All @@ -315,7 +327,7 @@ export abstract class BasePlugin {
async beforeToolSelection(params: {
callbackContext: Context;
tools: Readonly<Record<string, BaseTool>>;
}): Promise<Readonly<Record<string, BaseTool>> | undefined> {
}): Promise<(Readonly<Record<string, BaseTool>> & ShortCircuitControl) | undefined> {
return;
}

Expand Down Expand Up @@ -399,7 +411,7 @@ export abstract class BasePlugin {
toolArgs: Record<string, unknown>;
toolContext: Context;
result: Record<string, unknown>;
}): Promise<Record<string, unknown> | undefined> {
}): Promise<(Record<string, unknown> & ShortCircuitControl) | undefined> {
return;
}

Expand Down
90 changes: 71 additions & 19 deletions core/src/plugins/plugin_manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ import {BasePlugin, ContextCompactionTrigger} from './base_plugin.js';
* for that specific event is halted, and the returned value is propagated up
* the call stack. This allows plugins to short-circuit operations like agent
* runs, tool calls, or model requests.
*
* A plugin can opt out of the early exit by returning a payload that includes
* a literal `shortCircuit: false` field; in that case the remaining fields
* (with `shortCircuit` stripped) are forwarded to the next plugin as the new
* value of the transformed input field, and execution continues. Returning a
* payload with `shortCircuit: true` (or omitting the field entirely) preserves
* the legacy early-exit behavior.
*/
export class PluginManager {
private readonly plugins: Set<BasePlugin> = new Set();
Expand Down Expand Up @@ -83,35 +90,63 @@ export class PluginManager {
* Runs the same callback for all plugins. This is a utility method to reduce
* duplication below.
*
* Plugins receive the latest transformed value (via `latest`); the caller is
* responsible for splicing it into the right field of the plugin params. If
* no prior plugin produced a value, `latest` is `undefined` and the caller
* should fall back to the original input.
*
* A plugin return value of `undefined` continues the chain unchanged. A
* return value with `shortCircuit: false` strips the flag and forwards the
* remaining fields to subsequent plugins. Any other return value short-
* circuits the chain; a `shortCircuit` flag, if present, is stripped before
* the value is returned.
*
* @param plugins The set of plugins to run
* @param callback A closure containing the callback method to run on each
* plugin
* plugin. Receives the latest transformed value, or `undefined`.
* @param callbackName The name of the function being called in the closure
* above. Used for logging purposes.
* @returns A promise containing the plugin method result. Must be casted to
* the proper type for the plugin method.
*/
private async runCallbacks(
plugins: Set<BasePlugin>,
callback: (plugin: BasePlugin) => Promise<unknown>,
callback: (plugin: BasePlugin, latest: unknown) => Promise<unknown>,
callbackName: string,
): Promise<unknown> {
let current: unknown = undefined;
for (const plugin of plugins) {
try {
const result = await callback(plugin);
if (result !== undefined) {
logger.debug(
`Plugin '${plugin.name}' returned a value for callback '${callbackName}', exiting early.`,
);
return result;
const result = await callback(plugin, current);
if (result === undefined) continue;

let value: unknown = result;
if (
typeof result === 'object' &&
result !== null &&
'shortCircuit' in result
) {
const {shortCircuit = true, ...data} = result as Record<
string,
unknown
>;
value = data;
if (shortCircuit === false) {
current = data;
continue;
}
}
logger.debug(
`Plugin '${plugin.name}' returned a value for callback '${callbackName}', exiting early.`,
);
return value;
} catch (e) {
const errorMessage = `Error in plugin '${plugin.name}' during '${callbackName}' callback: ${e}`;
logger.error(errorMessage);
throw new Error(errorMessage);
}
}
return undefined;
return current;
}

/**
Expand All @@ -126,8 +161,11 @@ export class PluginManager {
}): Promise<Content | undefined> {
return (await this.runCallbacks(
this.plugins,
(plugin: BasePlugin) =>
plugin.onUserMessageCallback({userMessage, invocationContext}),
(plugin: BasePlugin, latest: unknown) =>
plugin.onUserMessageCallback({
userMessage: (latest as Content) ?? userMessage,
invocationContext,
}),
'onUserMessageCallback',
)) as Content | undefined;
}
Expand Down Expand Up @@ -174,8 +212,11 @@ export class PluginManager {
}): Promise<Event | undefined> {
return (await this.runCallbacks(
this.plugins,
(plugin: BasePlugin) =>
plugin.onEventCallback({invocationContext, event}),
(plugin: BasePlugin, latest: unknown) =>
plugin.onEventCallback({
invocationContext,
event: (latest as Event) ?? event,
}),
'onEventCallback',
)) as Event | undefined;
}
Expand Down Expand Up @@ -228,8 +269,11 @@ export class PluginManager {
}): Promise<Readonly<Record<string, BaseTool>> | undefined> {
return (await this.runCallbacks(
this.plugins,
(plugin: BasePlugin) =>
plugin.beforeToolSelection({callbackContext, tools}),
(plugin: BasePlugin, latest: unknown) =>
plugin.beforeToolSelection({
callbackContext,
tools: (latest as Readonly<Record<string, BaseTool>>) ?? tools,
}),
'beforeToolSelection',
)) as Readonly<Record<string, BaseTool>> | undefined;
}
Expand Down Expand Up @@ -306,8 +350,13 @@ export class PluginManager {
}): Promise<Record<string, unknown> | undefined> {
return (await this.runCallbacks(
this.plugins,
(plugin: BasePlugin) =>
plugin.afterToolCallback({tool, toolArgs, toolContext, result}),
(plugin: BasePlugin, latest: unknown) =>
plugin.afterToolCallback({
tool,
toolArgs,
toolContext,
result: (latest as Record<string, unknown>) ?? result,
}),
'afterToolCallback',
)) as Record<string, unknown> | undefined;
}
Expand Down Expand Up @@ -362,8 +411,11 @@ export class PluginManager {
}): Promise<LlmResponse | undefined> {
return (await this.runCallbacks(
this.plugins,
(plugin: BasePlugin) =>
plugin.afterModelCallback({callbackContext, llmResponse}),
(plugin: BasePlugin, latest: unknown) =>
plugin.afterModelCallback({
callbackContext,
llmResponse: (latest as LlmResponse) ?? llmResponse,
}),
'afterModelCallback',
)) as LlmResponse | undefined;
}
Expand Down
92 changes: 92 additions & 0 deletions core/test/plugins/plugin_manager_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,98 @@ describe('PluginManager', () => {
expect(plugin2.callLog).not.toContain('beforeRunCallback');
});

it('should continue when a plugin returns { shortCircuit: false, ... }', async () => {
const transformed = {message: 'translated'} as unknown as Record<
string,
unknown
>;
plugin1.returnValues['afterToolCallback'] = {
shortCircuit: false,
...transformed,
};
service.registerPlugin(plugin1);
service.registerPlugin(plugin2);

const result = await service.runAfterToolCallback({
tool: {} as BaseTool,
toolArgs: {},
toolContext: {} as Context,
result: {message: 'original'},
});

expect(plugin1.callLog).toContain('afterToolCallback');
expect(plugin2.callLog).toContain('afterToolCallback');
expect(result).toEqual(transformed);
expect((result as Record<string, unknown>).shortCircuit).toBeUndefined();
});

it('should pipe transformed value into the next plugin', async () => {
let plugin2Received: Record<string, unknown> | undefined;
plugin1.returnValues['afterToolCallback'] = {
shortCircuit: false,
message: 'translated',
};
plugin2 = new (class extends TestPlugin {
override async afterToolCallback(params: {
tool: BaseTool;
toolArgs: Record<string, unknown>;
toolContext: Context;
result: Record<string, unknown>;
}) {
plugin2Received = params.result;
return undefined;
}
})('plugin2');

service.registerPlugin(plugin1);
service.registerPlugin(plugin2);

await service.runAfterToolCallback({
tool: {} as BaseTool,
toolArgs: {},
toolContext: {} as Context,
result: {message: 'original'},
});

expect(plugin2Received).toEqual({message: 'translated'});
});

it('strips the shortCircuit flag from a short-circuiting return value', async () => {
plugin1.returnValues['afterToolCallback'] = {
shortCircuit: true,
message: 'replaced',
};
service.registerPlugin(plugin1);
service.registerPlugin(plugin2);

const result = await service.runAfterToolCallback({
tool: mockTool,
toolArgs: {},
toolContext: mockToolContext,
result: {message: 'original'},
});

expect(result).toEqual({message: 'replaced'});
expect((result as Record<string, unknown>).shortCircuit).toBeUndefined();
expect(plugin2.callLog).not.toContain('afterToolCallback');
});

it('short-circuits when shortCircuit is omitted from the return value', async () => {
plugin1.returnValues['afterToolCallback'] = {message: 'replaced'};
service.registerPlugin(plugin1);
service.registerPlugin(plugin2);

const result = await service.runAfterToolCallback({
tool: mockTool,
toolArgs: {},
toolContext: mockToolContext,
result: {message: 'original'},
});

expect(result).toEqual({message: 'replaced'});
expect(plugin2.callLog).not.toContain('afterToolCallback');
});

it('should call all plugins if no plugin returns a value', async () => {
service.registerPlugin(plugin1);
service.registerPlugin(plugin2);
Expand Down