diff --git a/__tests__/integration/models/activeModelService.test.ts b/__tests__/integration/models/activeModelService.test.ts index c9358b66..828e0c0d 100644 --- a/__tests__/integration/models/activeModelService.test.ts +++ b/__tests__/integration/models/activeModelService.test.ts @@ -1009,8 +1009,45 @@ describe('ActiveModelService Integration', () => { }); }); - describe('checkMemoryForModel critical with other models message', () => { - it('includes other models in critical message', async () => { + describe('checkMemoryForModel critical message and resolvableByUnload', () => { + it('marks block resolvable and mentions other models when the model fits alone but the total is over budget', async () => { + // 8GB device -> 4.8GB budget. Text 2GB*1.5=3.0GB fits alone (<=4.8), + // image 2GB*1.5=3.0GB already loaded -> total 6.0GB > 4.8GB budget. + const textModel = createDownloadedModel({ + id: 'fits-alone-text', + fileSize: 2 * 1024 * 1024 * 1024, + }); + const imageModel = createONNXImageModel({ + id: 'img-already', + size: 2 * 1024 * 1024 * 1024, + }); + useAppStore.setState({ + downloadedModels: [textModel], + downloadedImageModels: [imageModel], + settings: { imageThreads: 4 } as any, + }); + + // Load image model so it counts as "other loaded" memory + mockLocalDreamService.isModelLoaded.mockResolvedValue(true); + await activeModelService.loadImageModel('img-already'); + + mockHardwareService.getDeviceInfo.mockResolvedValue( + createDeviceInfo({ totalMemory: 8 * 1024 * 1024 * 1024 }) + ); + + const result = await activeModelService.checkMemoryForModel('fits-alone-text', 'text'); + + expect(result.severity).toBe('critical'); + expect(result.canLoad).toBe(false); + // Unloading the loaded image model would free enough room, so the UI can + // offer a one-tap recovery instead of a dead-end alert. + expect(result.resolvableByUnload).toBe(true); + expect(result.message).toContain('other models are loaded'); + }); + + it('marks block NOT resolvable when the model is too large even on its own', async () => { + // 8GB device -> 4.8GB budget. Text 6GB*1.5=9GB exceeds budget by itself, + // so unloading other models cannot help. const textModel = createDownloadedModel({ id: 'huge-text', fileSize: 6 * 1024 * 1024 * 1024, @@ -1025,11 +1062,9 @@ describe('ActiveModelService Integration', () => { settings: { imageThreads: 4 } as any, }); - // Load image model mockLocalDreamService.isModelLoaded.mockResolvedValue(true); await activeModelService.loadImageModel('img-already'); - // 8GB device - 6GB text * 1.5 = 9GB + image model memory = way over budget mockHardwareService.getDeviceInfo.mockResolvedValue( createDeviceInfo({ totalMemory: 8 * 1024 * 1024 * 1024 }) ); @@ -1038,7 +1073,8 @@ describe('ActiveModelService Integration', () => { expect(result.severity).toBe('critical'); expect(result.canLoad).toBe(false); - expect(result.message).toContain('other models are loaded'); + expect(result.resolvableByUnload).toBe(false); + expect(result.message).toContain('too large for your device'); }); }); diff --git a/__tests__/rntl/components/ModelSelectorModal.test.tsx b/__tests__/rntl/components/ModelSelectorModal.test.tsx index a7fac1b4..029f3b84 100644 --- a/__tests__/rntl/components/ModelSelectorModal.test.tsx +++ b/__tests__/rntl/components/ModelSelectorModal.test.tsx @@ -48,6 +48,10 @@ jest.mock('../../../src/services', () => ({ loadImageModel: jest.fn().mockResolvedValue(undefined), unloadImageModel: jest.fn().mockResolvedValue(undefined), unloadTextModel: jest.fn().mockResolvedValue(undefined), + unloadAllModels: jest.fn().mockResolvedValue({ textUnloaded: true, imageUnloaded: true }), + checkMemoryForModel: jest + .fn() + .mockResolvedValue({ canLoad: true, severity: 'safe', message: '', resolvableByUnload: false }), }, llmService: { isModelLoaded: jest.fn(() => false), @@ -78,6 +82,15 @@ describe('ModelSelectorModal', () => { beforeEach(() => { jest.clearAllMocks(); + // clearAllMocks wipes implementations; restore the defaults the component relies on. + activeModelService.checkMemoryForModel.mockResolvedValue({ + canLoad: true, + severity: 'safe', + message: '', + resolvableByUnload: false, + }); + activeModelService.loadImageModel.mockResolvedValue(undefined); + activeModelService.unloadAllModels.mockResolvedValue({ textUnloaded: true, imageUnloaded: true }); mockUseAppStore.mockReturnValue({ downloadedModels: [ { @@ -652,6 +665,70 @@ describe('ModelSelectorModal', () => { expect(activeModelService.loadImageModel).toHaveBeenCalledWith('img1'); }); + // Shared setup for the blocked-load recovery tests below. + const blockedImageStore = () => + mockUseAppStore.mockReturnValue({ + downloadedModels: [], + downloadedImageModels: [ + { id: 'img1', name: 'SD Model', size: 2000000000, style: 'Creative' }, + ], + activeImageModelId: null, + }); + const mockMemoryCheck = (resolvableByUnload: boolean, message: string) => + activeModelService.checkMemoryForModel.mockResolvedValue({ + canLoad: false, + severity: 'critical', + resolvableByUnload, + message, + }); + const renderImageTab = () => + render(); + + it('offers an unload-and-load recovery instead of loading when blocked by other models', async () => { + mockMemoryCheck(true, 'Cannot load SD Model while other models are loaded.'); + blockedImageStore(); + + const { getByText } = renderImageTab(); + await act(async () => { + fireEvent.press(getByText('SD Model')); + }); + + // The model is NOT loaded yet — the user is shown a recovery action. + expect(activeModelService.loadImageModel).not.toHaveBeenCalled(); + expect(getByText('Unload others & load')).toBeTruthy(); + }); + + it('unloads other models then loads the image model when recovery action is pressed', async () => { + mockMemoryCheck(true, 'Cannot load SD Model while other models are loaded.'); + blockedImageStore(); + + const { getByText } = renderImageTab(); + await act(async () => { + fireEvent.press(getByText('SD Model')); + }); + await act(async () => { + fireEvent.press(getByText('Unload others & load')); + }); + + expect(activeModelService.unloadAllModels).toHaveBeenCalled(); + expect(activeModelService.loadImageModel).toHaveBeenCalledWith('img1'); + }); + + it('shows an informational alert with no recovery when the model is too large to load at all', async () => { + mockMemoryCheck(false, 'SD Model is too large for your device. Choose a smaller model.'); + blockedImageStore(); + + const { getByText, queryByText } = renderImageTab(); + await act(async () => { + fireEvent.press(getByText('SD Model')); + }); + + expect(activeModelService.loadImageModel).not.toHaveBeenCalled(); + expect(activeModelService.unloadAllModels).not.toHaveBeenCalled(); + expect(queryByText('Unload others & load')).toBeNull(); + expect(getByText('Model too large')).toBeTruthy(); + }); + it('does not call loadImageModel when pressing the currently active image model', async () => { mockUseAppStore.mockReturnValue({ downloadedModels: [], diff --git a/__tests__/unit/hooks/useChatModelActions.test.ts b/__tests__/unit/hooks/useChatModelActions.test.ts index 34383b47..0843d0fa 100644 --- a/__tests__/unit/hooks/useChatModelActions.test.ts +++ b/__tests__/unit/hooks/useChatModelActions.test.ts @@ -19,6 +19,7 @@ jest.mock('../../../src/services/activeModelService', () => ({ activeModelService: { loadTextModel: jest.fn(), unloadTextModel: jest.fn(), + unloadAllModels: jest.fn(), checkMemoryForModel: jest.fn(), getActiveModels: jest.fn(), }, @@ -39,6 +40,7 @@ const { llmService } = require('../../../src/services/llm'); const mockLoadTextModel = activeModelService.loadTextModel as jest.Mock; const mockUnloadTextModel = activeModelService.unloadTextModel as jest.Mock; +const mockUnloadAllModels = activeModelService.unloadAllModels as jest.Mock; const mockCheckMemoryForModel = activeModelService.checkMemoryForModel as jest.Mock; const mockGetActiveModels = activeModelService.getActiveModels as jest.Mock; const mockGetMultimodalSupport = llmService.getMultimodalSupport as jest.Mock; @@ -104,6 +106,36 @@ function makeDeps(overrides: Partial = {}) { }; } +/** Stubs a critical memory block that unloading other models would resolve. */ +function mockResolvableBlock() { + mockCheckMemoryForModel.mockResolvedValueOnce({ + canLoad: false, + severity: 'critical', + message: 'other models are loaded', + resolvableByUnload: true, + }); + mockGetMultimodalSupport.mockReturnValueOnce({ vision: false }); +} + +/** Finds an alert button by label from the first setAlertState call. */ +function alertButton(deps: any, text: string) { + return deps.setAlertState.mock.calls[0][0].buttons.find((b: any) => b.text === text); +} + +/** + * Flushes the waitForRenderFrame timer and the unload→load promise chain. + * The two entry points order the timer and the unload promise differently + * (startLoad-then-unload vs unload-then-proceed), so interleave timer advances + * with microtask flushes to drain whichever ordering applies. + */ +async function flushRecoveryChain() { + for (let i = 0; i < 4; i++) { + jest.advanceTimersByTime(400); + await Promise.resolve(); + await Promise.resolve(); + } +} + // ───────────────────────────────────────────── // initiateModelLoad // ───────────────────────────────────────────── @@ -260,6 +292,47 @@ describe('initiateModelLoad — Load Anyway button', () => { jest.useRealTimers(); }); + it('offers "Unload others & load" (not "Load Anyway") when the block is resolvable by unloading', async () => { + jest.useFakeTimers(); + mockResolvableBlock(); + mockUnloadAllModels.mockResolvedValueOnce({ textUnloaded: true, imageUnloaded: false }); + mockLoadTextModel.mockResolvedValueOnce(undefined); + + const deps = makeDeps(); + await initiateModelLoad(deps, false); + + const unloadBtn = alertButton(deps, 'Unload others & load'); + expect(unloadBtn).toBeDefined(); + // No "Load Anyway" override on a resolvable block — recovery is the safe default. + expect(alertButton(deps, 'Load Anyway')).toBeUndefined(); + + unloadBtn.onPress(); + expect(deps.setIsModelLoading).toHaveBeenCalledWith(true); + await flushRecoveryChain(); + + expect(mockUnloadAllModels).toHaveBeenCalled(); + expect(mockLoadTextModel).toHaveBeenCalled(); + jest.useRealTimers(); + }); + + it('still loads (does not leave UI stuck) when unloadAllModels rejects during recovery', async () => { + jest.useFakeTimers(); + mockResolvableBlock(); + mockUnloadAllModels.mockRejectedValueOnce(new Error('unload failed')); + mockLoadTextModel.mockResolvedValueOnce(undefined); + + const deps = makeDeps(); + await initiateModelLoad(deps, false); + + alertButton(deps, 'Unload others & load').onPress(); + await flushRecoveryChain(); + + // The rejection is caught and logged; the load proceeds rather than hanging. + expect(mockUnloadAllModels).toHaveBeenCalled(); + expect(mockLoadTextModel).toHaveBeenCalled(); + jest.useRealTimers(); + }); + it('doLoadTextModel does not post system message when showGenerationDetails=false', async () => { jest.useFakeTimers(); mockCheckMemoryForModel.mockResolvedValueOnce({ canLoad: false, message: 'OOM', severity: 'critical' }); @@ -329,6 +402,48 @@ describe('handleModelSelectFn — Load Anyway button', () => { expect(deps.setIsModelLoading).toHaveBeenCalled(); }); + it('offers "Unload others & load" and unloads then loads when block is resolvable', async () => { + jest.useFakeTimers(); + mockResolvableBlock(); + mockUnloadAllModels.mockResolvedValueOnce({ textUnloaded: true, imageUnloaded: false }); + mockLoadTextModel.mockResolvedValueOnce(undefined); + + const deps = makeDeps(); + await handleModelSelectFn(deps, createDownloadedModel({ id: 'model-z' })); + + const unloadBtn = alertButton(deps, 'Unload others & load'); + expect(unloadBtn).toBeDefined(); + expect(alertButton(deps, 'Load Anyway')).toBeUndefined(); + + unloadBtn.onPress(); + // Loading indicator is shown before unloading begins so the UI doesn't + // look frozen during the (potentially multi-second) unload. + expect(deps.setIsModelLoading).toHaveBeenCalledWith(true); + await flushRecoveryChain(); + + expect(mockUnloadAllModels).toHaveBeenCalled(); + expect(mockLoadTextModel).toHaveBeenCalledWith('model-z'); + jest.useRealTimers(); + }); + + it('still loads (does not leave UI stuck) when unloadAllModels rejects', async () => { + jest.useFakeTimers(); + mockResolvableBlock(); + mockUnloadAllModels.mockRejectedValueOnce(new Error('unload failed')); + mockLoadTextModel.mockResolvedValueOnce(undefined); + + const deps = makeDeps(); + await handleModelSelectFn(deps, createDownloadedModel({ id: 'model-z' })); + + alertButton(deps, 'Unload others & load').onPress(); + await flushRecoveryChain(); + + // Rejection is caught and logged; load still proceeds rather than hanging. + expect(mockUnloadAllModels).toHaveBeenCalled(); + expect(mockLoadTextModel).toHaveBeenCalledWith('model-z'); + jest.useRealTimers(); + }); + it('executes Load Anyway callback in low memory warning', async () => { mockCheckMemoryForModel.mockResolvedValueOnce({ canLoad: true, severity: 'warning', message: 'Low memory' }); mockLoadTextModel.mockResolvedValueOnce(undefined); diff --git a/src/components/ModelSelectorModal/index.tsx b/src/components/ModelSelectorModal/index.tsx index a60c5b84..f0c4bc5d 100644 --- a/src/components/ModelSelectorModal/index.tsx +++ b/src/components/ModelSelectorModal/index.tsx @@ -12,7 +12,7 @@ import { useTheme, useThemedStyles } from '../../theme'; import { useAppStore, useRemoteServerStore } from '../../stores'; import { DownloadedModel, ONNXImageModel, RemoteModel } from '../../types'; import { activeModelService, llmService, remoteServerManager } from '../../services'; -import { CustomAlert, AlertState, initialAlertState, showAlert } from '../CustomAlert'; +import { CustomAlert, AlertState, initialAlertState, showAlert, showLoadingAlert } from '../CustomAlert'; import { createAllStyles } from './styles'; import { TextTab } from './TextTab'; import { ImageTab } from './ImageTab'; @@ -97,8 +97,9 @@ export const ModelSelectorModal: React.FC = ({ // Vision-language models (supportsVision) are text models and belong in the text tab. const remoteVisionModels = useMemo(() => [], []); - const handleSelectImageModel = async (model: ONNXImageModel) => { - if (activeImageModelId === model.id) return; + // Performs the actual image-model load and post-load wiring. Shared by the + // direct path and the "unload others & load" recovery path. + const loadImageModelNow = async (model: ONNXImageModel) => { setIsLoadingImage(true); try { await activeModelService.loadImageModel(model.id); @@ -114,6 +115,52 @@ export const ModelSelectorModal: React.FC = ({ } }; + // Unloads whatever is currently loaded (text and/or image) to free RAM, then + // loads the requested image model. Used by the recovery action so the user is + // never left at a dead end when another model is occupying memory. + const unloadOthersAndLoadImageModel = async (model: ONNXImageModel) => { + setAlertState(showLoadingAlert('Freeing memory', 'Unloading other models…')); + try { + await activeModelService.unloadAllModels(); + } catch (error) { + logger.error('Failed to unload models before image load:', error); + } finally { + setAlertState(initialAlertState); + } + await loadImageModelNow(model); + }; + + const handleSelectImageModel = async (model: ONNXImageModel) => { + if (activeImageModelId === model.id) return; + + // Check memory up front so we can offer a one-tap recovery instead of a + // dead-end "Failed to Load" alert when another model is occupying RAM. + const memoryCheck = await activeModelService.checkMemoryForModel(model.id, 'image'); + if (!memoryCheck.canLoad) { + if (memoryCheck.resolvableByUnload) { + setAlertState( + showAlert('Not enough memory', memoryCheck.message, [ + { text: 'Cancel', style: 'cancel' }, + { + text: 'Unload others & load', + style: 'default', + onPress: () => { + setAlertState(initialAlertState); + unloadOthersAndLoadImageModel(model); + }, + }, + ]), + ); + } else { + // Too large for this device even on its own — unloading cannot help. + setAlertState(showAlert('Model too large', memoryCheck.message)); + } + return; + } + + await loadImageModelNow(model); + }; + const handleUnloadImageModel = async () => { setIsLoadingImage(true); try { diff --git a/src/screens/ChatScreen/useChatModelActions.ts b/src/screens/ChatScreen/useChatModelActions.ts index a004b895..9e9d357e 100644 --- a/src/screens/ChatScreen/useChatModelActions.ts +++ b/src/screens/ChatScreen/useChatModelActions.ts @@ -88,21 +88,39 @@ export async function initiateModelLoad( if (!alreadyLoading) { const memoryCheck = await activeModelService.checkMemoryForModel(activeModelId, 'text'); if (!memoryCheck.canLoad) { + const startLoad = () => { + deps.setAlertState(hideAlert()); + deps.setIsModelLoading(true); + deps.setLoadingModel(activeModel); + deps.modelLoadStartTimeRef.current = Date.now(); + return waitForRenderFrame(); + }; + // When the block is caused by other loaded models, offer a one-tap + // "unload others & load" recovery so the user is never left stuck. + const buttons = memoryCheck.resolvableByUnload + ? [ + { text: 'Cancel', style: 'cancel' as const }, + { + text: 'Unload others & load', style: 'default' as const, onPress: () => { + startLoad() + .then(() => activeModelService.unloadAllModels()) + .catch(err => logger.error('Failed to unload models before load:', err)) + .then(() => doLoadTextModel(deps)); + }, + }, + ] + : [ + { text: 'Cancel', style: 'cancel' as const }, + { + text: 'Load Anyway', style: 'destructive' as const, onPress: () => { + startLoad().then(() => doLoadTextModel(deps)); + }, + }, + ]; deps.setAlertState(showAlert( 'Insufficient Memory', - `Cannot load ${activeModel.name}. ${memoryCheck.message}\n\nTry unloading other models from the Home screen.`, - [ - { text: 'Cancel', style: 'cancel' }, - { - text: 'Load Anyway', style: 'destructive', onPress: () => { - deps.setAlertState(hideAlert()); - deps.setIsModelLoading(true); - deps.setLoadingModel(activeModel); - deps.modelLoadStartTimeRef.current = Date.now(); - waitForRenderFrame().then(() => doLoadTextModel(deps)); - } - }, - ], + `Cannot load ${activeModel.name}. ${memoryCheck.message}`, + buttons, )); return; } @@ -199,15 +217,37 @@ export async function handleModelSelectFn( } const memoryCheck = await activeModelService.checkMemoryForModel(model.id, 'text'); if (!memoryCheck.canLoad) { - deps.setAlertState(showAlert('Insufficient Memory', memoryCheck.message, [ - { text: 'Cancel', style: 'cancel' }, - { - text: 'Load Anyway', style: 'destructive', onPress: () => { - deps.setAlertState(hideAlert()); - proceedWithModelLoadFn(deps, model); - } - }, - ])); + // When other loaded models are the cause, offer a one-tap "unload others & + // load" recovery so the user is never left at a dead end. Otherwise the + // model is too large on its own and only "Load Anyway" makes sense. + const buttons = memoryCheck.resolvableByUnload + ? [ + { text: 'Cancel', style: 'cancel' as const }, + { + text: 'Unload others & load', style: 'default' as const, onPress: () => { + deps.setAlertState(hideAlert()); + // Show the loading state before unloading — unloading can take a + // few seconds and would otherwise leave the UI looking frozen. + deps.setIsModelLoading(true); + deps.setLoadingModel(model); + deps.modelLoadStartTimeRef.current = Date.now(); + waitForRenderFrame() + .then(() => activeModelService.unloadAllModels()) + .catch(err => logger.error('Failed to unload models before load:', err)) + .then(() => proceedWithModelLoadFn(deps, model)); + }, + }, + ] + : [ + { text: 'Cancel', style: 'cancel' as const }, + { + text: 'Load Anyway', style: 'destructive' as const, onPress: () => { + deps.setAlertState(hideAlert()); + proceedWithModelLoadFn(deps, model); + }, + }, + ]; + deps.setAlertState(showAlert('Insufficient Memory', memoryCheck.message, buttons)); return; } if (memoryCheck.severity === 'warning') { diff --git a/src/services/activeModelService/memory.ts b/src/services/activeModelService/memory.ts index 0fc29abd..d1f3f0ef 100644 --- a/src/services/activeModelService/memory.ts +++ b/src/services/activeModelService/memory.ts @@ -143,6 +143,7 @@ export async function checkMemoryForModel( totalRequiredMemoryGB: 0, remainingAfterLoadGB: 0, message: 'Model not found', + resolvableByUnload: false, }; } @@ -159,17 +160,21 @@ export async function checkMemoryForModel( let severity: MemoryCheckSeverity; let canLoad: boolean; let message: string; + // Resolvable when the model fits on its own but other loaded models push the + // total over budget — unloading them frees enough room. Set only on a block. + let resolvableByUnload = false; if (totalRequiredMemoryGB > memoryBudgetGB) { severity = 'critical'; canLoad = false; - message = - currentlyLoadedMemoryGB > 0 - ? `Cannot load ${modelName} (~${requiredStr} GB) while other models are loaded. ` + - `Total would be ~${totalStr} GB, exceeding your device's ~${budgetStr} GB safe limit (60% of RAM). ` + - `Unload the other model first, or choose a smaller model.` - : `${modelName} requires ~${requiredStr} GB which exceeds your device's ~${budgetStr} GB safe limit (60% of RAM). ` + - `This model is too large for your device. Choose a smaller model.`; + resolvableByUnload = + currentlyLoadedMemoryGB > 0 && requiredMemoryGB <= memoryBudgetGB; + message = resolvableByUnload + ? `Cannot load ${modelName} (~${requiredStr} GB) while other models are loaded. ` + + `Total would be ~${totalStr} GB, exceeding your device's ~${budgetStr} GB safe limit (60% of RAM). ` + + `Unload the other model first, or choose a smaller model.` + : `${modelName} requires ~${requiredStr} GB which exceeds your device's ~${budgetStr} GB safe limit (60% of RAM). ` + + `This model is too large for your device. Choose a smaller model.`; } else if (totalRequiredMemoryGB > warningThresholdGB) { severity = 'warning'; canLoad = true; @@ -192,6 +197,7 @@ export async function checkMemoryForModel( totalRequiredMemoryGB, remainingAfterLoadGB: remainingBudgetGB, message, + resolvableByUnload, }; } @@ -267,5 +273,6 @@ export async function checkMemoryForDualModel( totalRequiredMemoryGB: totalRequiredGB, remainingAfterLoadGB: remainingBudgetGB, message, + resolvableByUnload: false, }; } diff --git a/src/services/activeModelService/types.ts b/src/services/activeModelService/types.ts index 7e9cada0..189da7ea 100644 --- a/src/services/activeModelService/types.ts +++ b/src/services/activeModelService/types.ts @@ -14,6 +14,14 @@ export interface MemoryCheckResult { totalRequiredMemoryGB: number; remainingAfterLoadGB: number; message: string; + /** + * True when a critical block is caused by OTHER models already being loaded, + * i.e. unloading them would free enough room for this one. When false on a + * critical result, the model is too large for the device even on its own and + * unloading cannot help. The UI uses this to decide whether to offer an + * "unload others & load" recovery action instead of a dead-end alert. + */ + resolvableByUnload: boolean; } export interface ActiveModelInfo {