diff --git a/apps/web/src/hooks/useHandleNewThread.ts b/apps/web/src/hooks/useHandleNewThread.ts index e31809cdd..ee7bc13b0 100644 --- a/apps/web/src/hooks/useHandleNewThread.ts +++ b/apps/web/src/hooks/useHandleNewThread.ts @@ -50,6 +50,16 @@ export function useHandleNewThread() { const hasBranchOption = options?.branch !== undefined; const hasWorktreePathOption = options?.worktreePath !== undefined; const hasEnvModeOption = options?.envMode !== undefined; + const applyStickyModel = (threadId: ThreadId, draftModel?: string | null) => { + if (stickyModel && !draftModel) { + setProvider(threadId, inferProviderForModel(stickyModel)); + setModel(threadId, stickyModel); + } + if (!draftModel && Object.keys(stickyModelOptions).length > 0) { + setModelOptions(threadId, stickyModelOptions); + } + }; + const storedDraftThread = getDraftThreadByProjectId(projectId); const latestActiveDraftThread: DraftThreadState | null = routeThreadId ? getDraftThread(routeThreadId) @@ -64,6 +74,9 @@ export function useHandleNewThread() { }); } setProjectDraftThreadId(projectId, storedDraftThread.threadId); + const existingDraft = + useComposerDraftStore.getState().draftsByThreadId[storedDraftThread.threadId]; + applyStickyModel(storedDraftThread.threadId, existingDraft?.model); if (routeThreadId === storedDraftThread.threadId) { return; } @@ -89,6 +102,8 @@ export function useHandleNewThread() { }); } setProjectDraftThreadId(projectId, routeThreadId); + const existingDraft = useComposerDraftStore.getState().draftsByThreadId[routeThreadId]; + applyStickyModel(routeThreadId, existingDraft?.model); return Promise.resolve(); } @@ -102,13 +117,7 @@ export function useHandleNewThread() { envMode: options?.envMode ?? "local", runtimeMode: DEFAULT_RUNTIME_MODE, }); - if (stickyModel) { - setProvider(threadId, inferProviderForModel(stickyModel)); - setModel(threadId, stickyModel); - } - if (Object.keys(stickyModelOptions).length > 0) { - setModelOptions(threadId, stickyModelOptions); - } + applyStickyModel(threadId); await navigate({ to: "/$threadId",