diff --git a/core/src/plugins/base_plugin.ts b/core/src/plugins/base_plugin.ts index 3dbea1fc..315f633e 100644 --- a/core/src/plugins/base_plugin.ts +++ b/core/src/plugins/base_plugin.ts @@ -23,6 +23,18 @@ export enum ContextCompactionTrigger { Auto = 'Auto', } +/** + * Optional control flag a plugin callback may merge into its return value to + * decide whether to short-circuit the callback chain. Omitting the flag, or + * setting it to `true`, keeps the default behavior: the returned value + * short-circuits the remaining plugins and is propagated up. Setting it to + * `false` lets the chain continue, forwarding the rest of the payload to the + * next plugin. + */ +type ShortCircuitControl = { + shortCircuit?: boolean; +}; + /** * Base class for creating plugins. * @@ -135,7 +147,7 @@ export abstract class BasePlugin { async onUserMessageCallback(params: { invocationContext: InvocationContext; userMessage: Content; - }): Promise { + }): Promise<(Content & ShortCircuitControl) | undefined> { return; } @@ -154,7 +166,7 @@ export abstract class BasePlugin { // eslint-disable-next-line @typescript-eslint/no-unused-vars async beforeRunCallback(params: { invocationContext: InvocationContext; - }): Promise { + }): Promise<(Content & ShortCircuitControl) | undefined> { return; } @@ -174,7 +186,7 @@ export abstract class BasePlugin { async onEventCallback(params: { invocationContext: InvocationContext; event: Event; - }): Promise { + }): Promise<(Event & ShortCircuitControl) | undefined> { return; } @@ -210,7 +222,7 @@ export abstract class BasePlugin { async beforeAgentCallback(params: { agent: BaseAgent; callbackContext: Context; - }): Promise { + }): Promise<(Content & ShortCircuitControl) | undefined> { return; } @@ -230,7 +242,7 @@ export abstract class BasePlugin { async afterAgentCallback(params: { agent: BaseAgent; callbackContext: Context; - }): Promise { + }): Promise<(Content & ShortCircuitControl) | undefined> { return; } @@ -251,7 +263,7 @@ export abstract class BasePlugin { async beforeModelCallback(params: { callbackContext: Context; llmRequest: LlmRequest; - }): Promise { + }): Promise<(LlmResponse & ShortCircuitControl) | undefined> { return; } @@ -271,7 +283,7 @@ export abstract class BasePlugin { async afterModelCallback(params: { callbackContext: Context; llmResponse: LlmResponse; - }): Promise { + }): Promise<(LlmResponse & ShortCircuitControl) | undefined> { return; } @@ -294,7 +306,7 @@ export abstract class BasePlugin { callbackContext: Context; llmRequest: LlmRequest; error: Error; - }): Promise { + }): Promise<(LlmResponse & ShortCircuitControl) | undefined> { return; } @@ -315,7 +327,7 @@ export abstract class BasePlugin { async beforeToolSelection(params: { callbackContext: Context; tools: Readonly>; - }): Promise> | undefined> { + }): Promise<(Readonly> & ShortCircuitControl) | undefined> { return; } @@ -399,7 +411,7 @@ export abstract class BasePlugin { toolArgs: Record; toolContext: Context; result: Record; - }): Promise | undefined> { + }): Promise<(Record & ShortCircuitControl) | undefined> { return; } diff --git a/core/src/plugins/plugin_manager.ts b/core/src/plugins/plugin_manager.ts index 7daee7a1..051d4b43 100644 --- a/core/src/plugins/plugin_manager.ts +++ b/core/src/plugins/plugin_manager.ts @@ -30,6 +30,13 @@ import {BasePlugin, ContextCompactionTrigger} from './base_plugin.js'; * for that specific event is halted, and the returned value is propagated up * the call stack. This allows plugins to short-circuit operations like agent * runs, tool calls, or model requests. + * + * A plugin can opt out of the early exit by returning a payload that includes + * a literal `shortCircuit: false` field; in that case the remaining fields + * (with `shortCircuit` stripped) are forwarded to the next plugin as the new + * value of the transformed input field, and execution continues. Returning a + * payload with `shortCircuit: true` (or omitting the field entirely) preserves + * the legacy early-exit behavior. */ export class PluginManager { private readonly plugins: Set = new Set(); @@ -83,9 +90,20 @@ export class PluginManager { * Runs the same callback for all plugins. This is a utility method to reduce * duplication below. * + * Plugins receive the latest transformed value (via `latest`); the caller is + * responsible for splicing it into the right field of the plugin params. If + * no prior plugin produced a value, `latest` is `undefined` and the caller + * should fall back to the original input. + * + * A plugin return value of `undefined` continues the chain unchanged. A + * return value with `shortCircuit: false` strips the flag and forwards the + * remaining fields to subsequent plugins. Any other return value short- + * circuits the chain; a `shortCircuit` flag, if present, is stripped before + * the value is returned. + * * @param plugins The set of plugins to run * @param callback A closure containing the callback method to run on each - * plugin + * plugin. Receives the latest transformed value, or `undefined`. * @param callbackName The name of the function being called in the closure * above. Used for logging purposes. * @returns A promise containing the plugin method result. Must be casted to @@ -93,25 +111,42 @@ export class PluginManager { */ private async runCallbacks( plugins: Set, - callback: (plugin: BasePlugin) => Promise, + callback: (plugin: BasePlugin, latest: unknown) => Promise, callbackName: string, ): Promise { + let current: unknown = undefined; for (const plugin of plugins) { try { - const result = await callback(plugin); - if (result !== undefined) { - logger.debug( - `Plugin '${plugin.name}' returned a value for callback '${callbackName}', exiting early.`, - ); - return result; + const result = await callback(plugin, current); + if (result === undefined) continue; + + let value: unknown = result; + if ( + typeof result === 'object' && + result !== null && + 'shortCircuit' in result + ) { + const {shortCircuit = true, ...data} = result as Record< + string, + unknown + >; + value = data; + if (shortCircuit === false) { + current = data; + continue; + } } + logger.debug( + `Plugin '${plugin.name}' returned a value for callback '${callbackName}', exiting early.`, + ); + return value; } catch (e) { const errorMessage = `Error in plugin '${plugin.name}' during '${callbackName}' callback: ${e}`; logger.error(errorMessage); throw new Error(errorMessage); } } - return undefined; + return current; } /** @@ -126,8 +161,11 @@ export class PluginManager { }): Promise { return (await this.runCallbacks( this.plugins, - (plugin: BasePlugin) => - plugin.onUserMessageCallback({userMessage, invocationContext}), + (plugin: BasePlugin, latest: unknown) => + plugin.onUserMessageCallback({ + userMessage: (latest as Content) ?? userMessage, + invocationContext, + }), 'onUserMessageCallback', )) as Content | undefined; } @@ -174,8 +212,11 @@ export class PluginManager { }): Promise { return (await this.runCallbacks( this.plugins, - (plugin: BasePlugin) => - plugin.onEventCallback({invocationContext, event}), + (plugin: BasePlugin, latest: unknown) => + plugin.onEventCallback({ + invocationContext, + event: (latest as Event) ?? event, + }), 'onEventCallback', )) as Event | undefined; } @@ -228,8 +269,11 @@ export class PluginManager { }): Promise> | undefined> { return (await this.runCallbacks( this.plugins, - (plugin: BasePlugin) => - plugin.beforeToolSelection({callbackContext, tools}), + (plugin: BasePlugin, latest: unknown) => + plugin.beforeToolSelection({ + callbackContext, + tools: (latest as Readonly>) ?? tools, + }), 'beforeToolSelection', )) as Readonly> | undefined; } @@ -306,8 +350,13 @@ export class PluginManager { }): Promise | undefined> { return (await this.runCallbacks( this.plugins, - (plugin: BasePlugin) => - plugin.afterToolCallback({tool, toolArgs, toolContext, result}), + (plugin: BasePlugin, latest: unknown) => + plugin.afterToolCallback({ + tool, + toolArgs, + toolContext, + result: (latest as Record) ?? result, + }), 'afterToolCallback', )) as Record | undefined; } @@ -362,8 +411,11 @@ export class PluginManager { }): Promise { return (await this.runCallbacks( this.plugins, - (plugin: BasePlugin) => - plugin.afterModelCallback({callbackContext, llmResponse}), + (plugin: BasePlugin, latest: unknown) => + plugin.afterModelCallback({ + callbackContext, + llmResponse: (latest as LlmResponse) ?? llmResponse, + }), 'afterModelCallback', )) as LlmResponse | undefined; } diff --git a/core/test/plugins/plugin_manager_test.ts b/core/test/plugins/plugin_manager_test.ts index a2e76a9d..4cce3485 100644 --- a/core/test/plugins/plugin_manager_test.ts +++ b/core/test/plugins/plugin_manager_test.ts @@ -219,6 +219,98 @@ describe('PluginManager', () => { expect(plugin2.callLog).not.toContain('beforeRunCallback'); }); + it('should continue when a plugin returns { shortCircuit: false, ... }', async () => { + const transformed = {message: 'translated'} as unknown as Record< + string, + unknown + >; + plugin1.returnValues['afterToolCallback'] = { + shortCircuit: false, + ...transformed, + }; + service.registerPlugin(plugin1); + service.registerPlugin(plugin2); + + const result = await service.runAfterToolCallback({ + tool: {} as BaseTool, + toolArgs: {}, + toolContext: {} as Context, + result: {message: 'original'}, + }); + + expect(plugin1.callLog).toContain('afterToolCallback'); + expect(plugin2.callLog).toContain('afterToolCallback'); + expect(result).toEqual(transformed); + expect((result as Record).shortCircuit).toBeUndefined(); + }); + + it('should pipe transformed value into the next plugin', async () => { + let plugin2Received: Record | undefined; + plugin1.returnValues['afterToolCallback'] = { + shortCircuit: false, + message: 'translated', + }; + plugin2 = new (class extends TestPlugin { + override async afterToolCallback(params: { + tool: BaseTool; + toolArgs: Record; + toolContext: Context; + result: Record; + }) { + plugin2Received = params.result; + return undefined; + } + })('plugin2'); + + service.registerPlugin(plugin1); + service.registerPlugin(plugin2); + + await service.runAfterToolCallback({ + tool: {} as BaseTool, + toolArgs: {}, + toolContext: {} as Context, + result: {message: 'original'}, + }); + + expect(plugin2Received).toEqual({message: 'translated'}); + }); + + it('strips the shortCircuit flag from a short-circuiting return value', async () => { + plugin1.returnValues['afterToolCallback'] = { + shortCircuit: true, + message: 'replaced', + }; + service.registerPlugin(plugin1); + service.registerPlugin(plugin2); + + const result = await service.runAfterToolCallback({ + tool: mockTool, + toolArgs: {}, + toolContext: mockToolContext, + result: {message: 'original'}, + }); + + expect(result).toEqual({message: 'replaced'}); + expect((result as Record).shortCircuit).toBeUndefined(); + expect(plugin2.callLog).not.toContain('afterToolCallback'); + }); + + it('short-circuits when shortCircuit is omitted from the return value', async () => { + plugin1.returnValues['afterToolCallback'] = {message: 'replaced'}; + service.registerPlugin(plugin1); + service.registerPlugin(plugin2); + + const result = await service.runAfterToolCallback({ + tool: mockTool, + toolArgs: {}, + toolContext: mockToolContext, + result: {message: 'original'}, + }); + + expect(result).toEqual({message: 'replaced'}); + expect(plugin2.callLog).not.toContain('afterToolCallback'); + }); + it('should call all plugins if no plugin returns a value', async () => { service.registerPlugin(plugin1); service.registerPlugin(plugin2);