diff --git a/src/cli-core.ts b/src/cli-core.ts index 666fb77..eb79fdf 100644 --- a/src/cli-core.ts +++ b/src/cli-core.ts @@ -290,6 +290,11 @@ async function handlePrompt( promptRetries: globalFlags.promptRetries, verbose: globalFlags.verbose, waitForCompletion: flags.wait !== false, + sessionOptions: { + model: globalFlags.model, + allowedTools: globalFlags.allowedTools, + maxTurns: globalFlags.maxTurns, + }, }); if ("queued" in result) { diff --git a/src/client.ts b/src/client.ts index 015fa8a..db38990 100644 --- a/src/client.ts +++ b/src/client.ts @@ -946,6 +946,21 @@ export class AcpClient { this.eventHandlers = {}; } + updateSessionOptions(update: Partial>): void { + if (!this.options.sessionOptions) { + this.options.sessionOptions = {}; + } + if (update.model !== undefined) { + this.options.sessionOptions.model = update.model; + } + if (update.allowedTools !== undefined) { + this.options.sessionOptions.allowedTools = update.allowedTools; + } + if (update.maxTurns !== undefined) { + this.options.sessionOptions.maxTurns = update.maxTurns; + } + } + updateRuntimeOptions(options: { permissionMode?: PermissionMode; nonInteractivePermissions?: NonInteractivePermissionPolicy; @@ -1249,6 +1264,7 @@ export class AcpClient { sessionId, cwd: asAbsoluteCwd(cwd), mcpServers: this.options.mcpServers ?? [], + _meta: buildClaudeCodeOptionsMeta(this.options.sessionOptions), }), ); @@ -1336,6 +1352,43 @@ export class AcpClient { value: string, ): Promise { const connection = this.getConnection(); + // For model changes: prefer session/set_config_option (supports alias resolution + // via resolveModelPreference on Claude adapter) then fall back to session/set_model + // for adapters like Droid that only support the dedicated method. + if (configId === "model") { + try { + return await this.runConnectionRequest(() => + connection.setSessionConfigOption({ + sessionId, + configId, + value, + }), + ); + } catch (error) { + const acp = extractAcpError(error); + if (acp && isLikelySessionControlUnsupportedError(acp)) { + // Adapter doesn't support set_config_option — fall back to set_model + try { + await this.runConnectionRequest(() => + connection.unstable_setSessionModel({ sessionId, modelId: value }), + ); + return { configOptions: [] }; + } catch (fallbackError) { + throw maybeWrapSessionControlError( + "session/set_model", + fallbackError, + `for model="${value}"`, + ); + } + } + throw maybeWrapSessionControlError( + "session/set_config_option", + error, + `for "model"="${value}"`, + ); + } + } + try { return await this.runConnectionRequest(() => connection.setSessionConfigOption({ diff --git a/src/queue-owner-env.ts b/src/queue-owner-env.ts index b013f8d..da11c11 100644 --- a/src/queue-owner-env.ts +++ b/src/queue-owner-env.ts @@ -76,6 +76,22 @@ export function parseQueueOwnerPayload(raw: string): QueueOwnerRuntimeOptions { options.promptRetries = Math.max(0, Math.round(record.promptRetries)); } + const sessionOpts = asRecord(record.sessionOptions); + if (sessionOpts) { + options.sessionOptions = {}; + if (typeof sessionOpts.model === "string" && sessionOpts.model.trim().length > 0) { + options.sessionOptions.model = sessionOpts.model; + } + if (Array.isArray(sessionOpts.allowedTools)) { + options.sessionOptions.allowedTools = sessionOpts.allowedTools.filter( + (t): t is string => typeof t === "string", + ); + } + if (typeof sessionOpts.maxTurns === "number" && Number.isFinite(sessionOpts.maxTurns)) { + options.sessionOptions.maxTurns = Math.max(1, Math.round(sessionOpts.maxTurns)); + } + } + return options; } diff --git a/src/session-runtime.ts b/src/session-runtime.ts index ad9d639..30c4a36 100644 --- a/src/session-runtime.ts +++ b/src/session-runtime.ts @@ -140,6 +140,25 @@ function sessionOptionsFromRecord(record: SessionRecord): SessionAgentOptions | return Object.keys(sessionOptions).length > 0 ? sessionOptions : undefined; } +function mergeSessionOptions( + preferred: SessionAgentOptions | undefined, + fallback: SessionAgentOptions | undefined, +): SessionAgentOptions | undefined { + const merged: SessionAgentOptions = { + ...fallback, + }; + if (preferred?.model !== undefined) { + merged.model = preferred.model; + } + if (preferred?.allowedTools !== undefined) { + merged.allowedTools = preferred.allowedTools; + } + if (preferred?.maxTurns !== undefined) { + merged.maxTurns = preferred.maxTurns; + } + return Object.keys(merged).length > 0 ? merged : undefined; +} + function persistSessionOptions( record: SessionRecord, options: SessionAgentOptions | undefined, @@ -228,6 +247,7 @@ export type SessionSendOptions = { maxQueueDepth?: number; client?: AcpClient; promptRetries?: number; + sessionOptions?: SessionAgentOptions; } & TimedRunOptions; export type SessionEnsureOptions = { @@ -338,6 +358,7 @@ type RunSessionPromptOptions = { suppressSdkConsoleErrors?: boolean; verbose?: boolean; promptRetries?: number; + sessionOptions?: SessionAgentOptions; onClientAvailable?: (controller: ActiveSessionController) => void; onClientClosed?: () => void; onPromptActive?: () => Promise | void; @@ -541,6 +562,7 @@ async function runQueuedTask( authPolicy?: AuthPolicy; suppressSdkConsoleErrors?: boolean; promptRetries?: number; + sessionOptions?: SessionAgentOptions; onClientAvailable?: (controller: ActiveSessionController) => void; onClientClosed?: () => void; onPromptActive?: () => Promise | void; @@ -566,12 +588,22 @@ async function runQueuedTask( suppressSdkConsoleErrors: task.suppressSdkConsoleErrors ?? options.suppressSdkConsoleErrors, verbose: options.verbose, promptRetries: options.promptRetries, + sessionOptions: options.sessionOptions, onClientAvailable: options.onClientAvailable, onClientClosed: options.onClientClosed, onPromptActive: options.onPromptActive, client: options.sharedClient, }); + // Persist the record after each prompt so that control commands + // (e.g. set_config_option for model changes) see the updated acpSessionId + // when the session was recreated due to a loadSession fallback. + if (result.record) { + await writeSessionRecord(result.record).catch(() => { + // best effort — control commands will still work via sharedClient + }); + } + if (task.waitForCompletion) { task.send({ type: "result", @@ -665,7 +697,7 @@ async function runSessionPrompt(options: RunSessionPromptOptions): Promise { + // If the sharedClient has a reusable session, route through it directly + // instead of creating a temporary connection that doesn't affect the live session. + const currentRecord = await resolveSessionRecord(options.sessionId); + if (sharedClient.hasReusableSession(currentRecord.acpSessionId)) { + const response = await withTimeout( + (async () => + await sharedClient.setSessionConfigOption( + currentRecord.acpSessionId, + configId, + value, + ))(), + timeoutMs, + ); + if (configId === "model") { + sharedClient.updateSessionOptions({ model: value }); + } + return response; + } + const result = await runSessionSetConfigOptionDirect({ sessionRecordId: options.sessionId, configId, @@ -1302,6 +1364,10 @@ export async function runSessionQueueOwner(options: QueueOwnerRuntimeOptions): P timeoutMs, verbose: options.verbose, }); + // Update sharedClient's sessionOptions so future reconnections use the new model. + if (configId === "model") { + sharedClient.updateSessionOptions({ model: value }); + } return result.response; }, }); @@ -1405,6 +1471,7 @@ export async function runSessionQueueOwner(options: QueueOwnerRuntimeOptions): P authPolicy: options.authPolicy, suppressSdkConsoleErrors: options.suppressSdkConsoleErrors, promptRetries: options.promptRetries, + sessionOptions: options.sessionOptions, onClientAvailable: setActiveController, onClientClosed: clearActiveController, onPromptActive: async () => { @@ -1446,6 +1513,24 @@ export async function runSessionQueueOwner(options: QueueOwnerRuntimeOptions): P export async function sendSession(options: SessionSendOptions): Promise { const waitForCompletion = options.waitForCompletion !== false; + // If a model is requested and an owner is already running, update the model + // BEFORE submitting the prompt (fresh sessions are handled via trySetModel in createSession). + if (options.sessionOptions?.model) { + await trySetConfigOptionOnRunningOwner( + options.sessionId, + "model", + options.sessionOptions.model, + options.timeoutMs, + options.verbose, + ).catch((err: unknown) => { + if (options.verbose) { + process.stderr.write( + `[acpx] warning: failed to pre-set model on running owner: ${formatErrorMessage(err)}\n`, + ); + } + }); + } + const queuedToOwner = await submitToRunningOwner(options, waitForCompletion); if (queuedToOwner) { return queuedToOwner; diff --git a/src/session-runtime/prompt-runner.ts b/src/session-runtime/prompt-runner.ts index 9967a87..042c33b 100644 --- a/src/session-runtime/prompt-runner.ts +++ b/src/session-runtime/prompt-runner.ts @@ -12,6 +12,7 @@ import { writeSessionRecord, } from "../session-persistence.js"; import { withInterrupt, withTimeout } from "../session-runtime-helpers.js"; +import type { SessionAgentOptions } from "../session-runtime.js"; import type { AuthPolicy, McpServer, @@ -58,6 +59,25 @@ function sessionOptionsFromRecord(record: SessionRecord): return Object.keys(sessionOptions).length > 0 ? sessionOptions : undefined; } +function mergeSessionOptions( + preferred: SessionAgentOptions | undefined, + fallback: SessionAgentOptions | undefined, +): SessionAgentOptions | undefined { + const merged: SessionAgentOptions = { + ...fallback, + }; + if (preferred?.model !== undefined) { + merged.model = preferred.model; + } + if (preferred?.allowedTools !== undefined) { + merged.allowedTools = preferred.allowedTools; + } + if (preferred?.maxTurns !== undefined) { + merged.maxTurns = preferred.maxTurns; + } + return Object.keys(merged).length > 0 ? merged : undefined; +} + type WithConnectedSessionOptions = { sessionRecordId: string; mcpServers?: McpServer[]; @@ -67,6 +87,7 @@ type WithConnectedSessionOptions = { authPolicy?: AuthPolicy; timeoutMs?: number; verbose?: boolean; + sessionOptions?: SessionAgentOptions; onClientAvailable?: (controller: ActiveSessionController) => void; onClientClosed?: () => void; run: (client: AcpClient, sessionId: string, record: SessionRecord) => Promise; @@ -92,7 +113,7 @@ async function withConnectedSession( authCredentials: options.authCredentials, authPolicy: options.authPolicy, verbose: options.verbose, - sessionOptions: sessionOptionsFromRecord(record), + sessionOptions: mergeSessionOptions(options.sessionOptions, sessionOptionsFromRecord(record)), }); let activeSessionIdForControl = record.acpSessionId; let notifiedClientAvailable = false; diff --git a/src/session-runtime/queue-owner-process.ts b/src/session-runtime/queue-owner-process.ts index 43e5642..5d22295 100644 --- a/src/session-runtime/queue-owner-process.ts +++ b/src/session-runtime/queue-owner-process.ts @@ -1,5 +1,6 @@ import { spawn } from "node:child_process"; import { realpathSync } from "node:fs"; +import type { SessionAgentOptions } from "../session-runtime.js"; import type { AuthPolicy, McpServer, @@ -19,6 +20,7 @@ export type QueueOwnerRuntimeOptions = { ttlMs?: number; maxQueueDepth?: number; promptRetries?: number; + sessionOptions?: SessionAgentOptions; }; type SessionSendLike = { @@ -33,6 +35,7 @@ type SessionSendLike = { ttlMs?: number; maxQueueDepth?: number; promptRetries?: number; + sessionOptions?: SessionAgentOptions; }; export function sanitizeQueueOwnerExecArgv( @@ -131,6 +134,7 @@ export function queueOwnerRuntimeOptionsFromSend( ttlMs: options.ttlMs, maxQueueDepth: options.maxQueueDepth, promptRetries: options.promptRetries, + sessionOptions: options.sessionOptions, }; } diff --git a/test/integration.test.ts b/test/integration.test.ts index 87d1017..7d63b55 100644 --- a/test/integration.test.ts +++ b/test/integration.test.ts @@ -868,6 +868,48 @@ test("integration: exec --model skips session/set_model when agent does not adve }); }); +test("integration: prompt --model updates existing session model before prompt", async () => { + await withTempHome(async (homeDir) => { + const cwd = await fs.mkdtemp(path.join(os.tmpdir(), "acpx-integration-cwd-")); + + try { + const ensured = await runCli( + [...baseAgentArgs(cwd), "sessions", "ensure", "--name", "model-prompt-session"], + homeDir, + ); + assert.equal(ensured.code, 0, ensured.stderr); + + const result = await runCli( + [ + ...baseAgentArgs(cwd), + "--format", + "json", + "--model", + "haiku", + "prompt", + "-s", + "model-prompt-session", + "echo hello", + ], + homeDir, + ); + assert.equal(result.code, 0, result.stderr); + + const payloads = parseJsonRpcOutputLines(result.stdout); + const setConfigRequest = payloads.find( + (payload) => + payload.method === "session/set_config_option" && + (payload.params as { configId?: string; value?: string } | undefined)?.configId === + "model" && + (payload.params as { configId?: string; value?: string } | undefined)?.value === "haiku", + ); + assert(setConfigRequest, "expected session/set_config_option for model=haiku"); + } finally { + await fs.rm(cwd, { recursive: true, force: true }); + } + }); +}); + test("integration: exec --model fails when session/set_model fails", async () => { await withTempHome(async (homeDir) => { const cwd = await fs.mkdtemp(path.join(os.tmpdir(), "acpx-integration-cwd-"));