Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions __tests__/integration/models/activeModelService.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1009,8 +1009,45 @@ describe('ActiveModelService Integration', () => {
});
});

describe('checkMemoryForModel critical with other models message', () => {
it('includes other models in critical message', async () => {
describe('checkMemoryForModel critical message and resolvableByUnload', () => {
it('marks block resolvable and mentions other models when the model fits alone but the total is over budget', async () => {
// 8GB device -> 4.8GB budget. Text 2GB*1.5=3.0GB fits alone (<=4.8),
// image 2GB*1.5=3.0GB already loaded -> total 6.0GB > 4.8GB budget.
const textModel = createDownloadedModel({
id: 'fits-alone-text',
fileSize: 2 * 1024 * 1024 * 1024,
});
const imageModel = createONNXImageModel({
id: 'img-already',
size: 2 * 1024 * 1024 * 1024,
});
useAppStore.setState({
downloadedModels: [textModel],
downloadedImageModels: [imageModel],
settings: { imageThreads: 4 } as any,
});

// Load image model so it counts as "other loaded" memory
mockLocalDreamService.isModelLoaded.mockResolvedValue(true);
await activeModelService.loadImageModel('img-already');

mockHardwareService.getDeviceInfo.mockResolvedValue(
createDeviceInfo({ totalMemory: 8 * 1024 * 1024 * 1024 })
);

const result = await activeModelService.checkMemoryForModel('fits-alone-text', 'text');

expect(result.severity).toBe('critical');
expect(result.canLoad).toBe(false);
// Unloading the loaded image model would free enough room, so the UI can
// offer a one-tap recovery instead of a dead-end alert.
expect(result.resolvableByUnload).toBe(true);
expect(result.message).toContain('other models are loaded');
});

it('marks block NOT resolvable when the model is too large even on its own', async () => {
// 8GB device -> 4.8GB budget. Text 6GB*1.5=9GB exceeds budget by itself,
// so unloading other models cannot help.
const textModel = createDownloadedModel({
id: 'huge-text',
fileSize: 6 * 1024 * 1024 * 1024,
Expand All @@ -1025,11 +1062,9 @@ describe('ActiveModelService Integration', () => {
settings: { imageThreads: 4 } as any,
});

// Load image model
mockLocalDreamService.isModelLoaded.mockResolvedValue(true);
await activeModelService.loadImageModel('img-already');

// 8GB device - 6GB text * 1.5 = 9GB + image model memory = way over budget
mockHardwareService.getDeviceInfo.mockResolvedValue(
createDeviceInfo({ totalMemory: 8 * 1024 * 1024 * 1024 })
);
Expand All @@ -1038,7 +1073,8 @@ describe('ActiveModelService Integration', () => {

expect(result.severity).toBe('critical');
expect(result.canLoad).toBe(false);
expect(result.message).toContain('other models are loaded');
expect(result.resolvableByUnload).toBe(false);
expect(result.message).toContain('too large for your device');
});
});

Expand Down
77 changes: 77 additions & 0 deletions __tests__/rntl/components/ModelSelectorModal.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ jest.mock('../../../src/services', () => ({
loadImageModel: jest.fn().mockResolvedValue(undefined),
unloadImageModel: jest.fn().mockResolvedValue(undefined),
unloadTextModel: jest.fn().mockResolvedValue(undefined),
unloadAllModels: jest.fn().mockResolvedValue({ textUnloaded: true, imageUnloaded: true }),
checkMemoryForModel: jest
.fn()
.mockResolvedValue({ canLoad: true, severity: 'safe', message: '', resolvableByUnload: false }),
},
llmService: {
isModelLoaded: jest.fn(() => false),
Expand Down Expand Up @@ -78,6 +82,15 @@ describe('ModelSelectorModal', () => {

beforeEach(() => {
jest.clearAllMocks();
// clearAllMocks wipes implementations; restore the defaults the component relies on.
activeModelService.checkMemoryForModel.mockResolvedValue({
canLoad: true,
severity: 'safe',
message: '',
resolvableByUnload: false,
});
activeModelService.loadImageModel.mockResolvedValue(undefined);
activeModelService.unloadAllModels.mockResolvedValue({ textUnloaded: true, imageUnloaded: true });
mockUseAppStore.mockReturnValue({
downloadedModels: [
{
Expand Down Expand Up @@ -652,6 +665,70 @@ describe('ModelSelectorModal', () => {
expect(activeModelService.loadImageModel).toHaveBeenCalledWith('img1');
});

// Shared setup for the blocked-load recovery tests below.
const blockedImageStore = () =>
mockUseAppStore.mockReturnValue({
downloadedModels: [],
downloadedImageModels: [
{ id: 'img1', name: 'SD Model', size: 2000000000, style: 'Creative' },
],
activeImageModelId: null,
});
const mockMemoryCheck = (resolvableByUnload: boolean, message: string) =>
activeModelService.checkMemoryForModel.mockResolvedValue({
canLoad: false,
severity: 'critical',
resolvableByUnload,
message,
});
const renderImageTab = () =>
render(<ModelSelectorModal {...defaultProps} initialTab="image" />);

it('offers an unload-and-load recovery instead of loading when blocked by other models', async () => {
mockMemoryCheck(true, 'Cannot load SD Model while other models are loaded.');
blockedImageStore();

const { getByText } = renderImageTab();
await act(async () => {
fireEvent.press(getByText('SD Model'));
});

// The model is NOT loaded yet — the user is shown a recovery action.
expect(activeModelService.loadImageModel).not.toHaveBeenCalled();
expect(getByText('Unload others & load')).toBeTruthy();
});

it('unloads other models then loads the image model when recovery action is pressed', async () => {
mockMemoryCheck(true, 'Cannot load SD Model while other models are loaded.');
blockedImageStore();

const { getByText } = renderImageTab();
await act(async () => {
fireEvent.press(getByText('SD Model'));
});
await act(async () => {
fireEvent.press(getByText('Unload others & load'));
});

expect(activeModelService.unloadAllModels).toHaveBeenCalled();
expect(activeModelService.loadImageModel).toHaveBeenCalledWith('img1');
});

it('shows an informational alert with no recovery when the model is too large to load at all', async () => {
mockMemoryCheck(false, 'SD Model is too large for your device. Choose a smaller model.');
blockedImageStore();

const { getByText, queryByText } = renderImageTab();
await act(async () => {
fireEvent.press(getByText('SD Model'));
});

expect(activeModelService.loadImageModel).not.toHaveBeenCalled();
expect(activeModelService.unloadAllModels).not.toHaveBeenCalled();
expect(queryByText('Unload others & load')).toBeNull();
expect(getByText('Model too large')).toBeTruthy();
});

it('does not call loadImageModel when pressing the currently active image model', async () => {
mockUseAppStore.mockReturnValue({
downloadedModels: [],
Expand Down
115 changes: 115 additions & 0 deletions __tests__/unit/hooks/useChatModelActions.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jest.mock('../../../src/services/activeModelService', () => ({
activeModelService: {
loadTextModel: jest.fn(),
unloadTextModel: jest.fn(),
unloadAllModels: jest.fn(),
checkMemoryForModel: jest.fn(),
getActiveModels: jest.fn(),
},
Expand All @@ -39,6 +40,7 @@ const { llmService } = require('../../../src/services/llm');

const mockLoadTextModel = activeModelService.loadTextModel as jest.Mock;
const mockUnloadTextModel = activeModelService.unloadTextModel as jest.Mock;
const mockUnloadAllModels = activeModelService.unloadAllModels as jest.Mock;
const mockCheckMemoryForModel = activeModelService.checkMemoryForModel as jest.Mock;
const mockGetActiveModels = activeModelService.getActiveModels as jest.Mock;
const mockGetMultimodalSupport = llmService.getMultimodalSupport as jest.Mock;
Expand Down Expand Up @@ -104,6 +106,36 @@ function makeDeps(overrides: Partial<any> = {}) {
};
}

/** Stubs a critical memory block that unloading other models would resolve. */
function mockResolvableBlock() {
mockCheckMemoryForModel.mockResolvedValueOnce({
canLoad: false,
severity: 'critical',
message: 'other models are loaded',
resolvableByUnload: true,
});
mockGetMultimodalSupport.mockReturnValueOnce({ vision: false });
}

/** Finds an alert button by label from the first setAlertState call. */
function alertButton(deps: any, text: string) {
return deps.setAlertState.mock.calls[0][0].buttons.find((b: any) => b.text === text);
}

/**
* Flushes the waitForRenderFrame timer and the unload→load promise chain.
* The two entry points order the timer and the unload promise differently
* (startLoad-then-unload vs unload-then-proceed), so interleave timer advances
* with microtask flushes to drain whichever ordering applies.
*/
async function flushRecoveryChain() {
for (let i = 0; i < 4; i++) {
jest.advanceTimersByTime(400);
await Promise.resolve();
await Promise.resolve();
}
}

// ─────────────────────────────────────────────
// initiateModelLoad
// ─────────────────────────────────────────────
Expand Down Expand Up @@ -260,6 +292,47 @@ describe('initiateModelLoad — Load Anyway button', () => {
jest.useRealTimers();
});

it('offers "Unload others & load" (not "Load Anyway") when the block is resolvable by unloading', async () => {
jest.useFakeTimers();
mockResolvableBlock();
mockUnloadAllModels.mockResolvedValueOnce({ textUnloaded: true, imageUnloaded: false });
mockLoadTextModel.mockResolvedValueOnce(undefined);

const deps = makeDeps();
await initiateModelLoad(deps, false);

const unloadBtn = alertButton(deps, 'Unload others & load');
expect(unloadBtn).toBeDefined();
// No "Load Anyway" override on a resolvable block — recovery is the safe default.
expect(alertButton(deps, 'Load Anyway')).toBeUndefined();

unloadBtn.onPress();
expect(deps.setIsModelLoading).toHaveBeenCalledWith(true);
await flushRecoveryChain();

expect(mockUnloadAllModels).toHaveBeenCalled();
expect(mockLoadTextModel).toHaveBeenCalled();
jest.useRealTimers();
});

it('still loads (does not leave UI stuck) when unloadAllModels rejects during recovery', async () => {
jest.useFakeTimers();
mockResolvableBlock();
mockUnloadAllModels.mockRejectedValueOnce(new Error('unload failed'));
mockLoadTextModel.mockResolvedValueOnce(undefined);

const deps = makeDeps();
await initiateModelLoad(deps, false);

alertButton(deps, 'Unload others & load').onPress();
await flushRecoveryChain();

// The rejection is caught and logged; the load proceeds rather than hanging.
expect(mockUnloadAllModels).toHaveBeenCalled();
expect(mockLoadTextModel).toHaveBeenCalled();
jest.useRealTimers();
});

it('doLoadTextModel does not post system message when showGenerationDetails=false', async () => {
jest.useFakeTimers();
mockCheckMemoryForModel.mockResolvedValueOnce({ canLoad: false, message: 'OOM', severity: 'critical' });
Expand Down Expand Up @@ -329,6 +402,48 @@ describe('handleModelSelectFn — Load Anyway button', () => {
expect(deps.setIsModelLoading).toHaveBeenCalled();
});

it('offers "Unload others & load" and unloads then loads when block is resolvable', async () => {
jest.useFakeTimers();
mockResolvableBlock();
mockUnloadAllModels.mockResolvedValueOnce({ textUnloaded: true, imageUnloaded: false });
mockLoadTextModel.mockResolvedValueOnce(undefined);

const deps = makeDeps();
await handleModelSelectFn(deps, createDownloadedModel({ id: 'model-z' }));

const unloadBtn = alertButton(deps, 'Unload others & load');
expect(unloadBtn).toBeDefined();
expect(alertButton(deps, 'Load Anyway')).toBeUndefined();

unloadBtn.onPress();
// Loading indicator is shown before unloading begins so the UI doesn't
// look frozen during the (potentially multi-second) unload.
expect(deps.setIsModelLoading).toHaveBeenCalledWith(true);
await flushRecoveryChain();

expect(mockUnloadAllModels).toHaveBeenCalled();
expect(mockLoadTextModel).toHaveBeenCalledWith('model-z');
jest.useRealTimers();
});

it('still loads (does not leave UI stuck) when unloadAllModels rejects', async () => {
jest.useFakeTimers();
mockResolvableBlock();
mockUnloadAllModels.mockRejectedValueOnce(new Error('unload failed'));
mockLoadTextModel.mockResolvedValueOnce(undefined);

const deps = makeDeps();
await handleModelSelectFn(deps, createDownloadedModel({ id: 'model-z' }));

alertButton(deps, 'Unload others & load').onPress();
await flushRecoveryChain();

// Rejection is caught and logged; load still proceeds rather than hanging.
expect(mockUnloadAllModels).toHaveBeenCalled();
expect(mockLoadTextModel).toHaveBeenCalledWith('model-z');
jest.useRealTimers();
});

it('executes Load Anyway callback in low memory warning', async () => {
mockCheckMemoryForModel.mockResolvedValueOnce({ canLoad: true, severity: 'warning', message: 'Low memory' });
mockLoadTextModel.mockResolvedValueOnce(undefined);
Expand Down
53 changes: 50 additions & 3 deletions src/components/ModelSelectorModal/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { useTheme, useThemedStyles } from '../../theme';
import { useAppStore, useRemoteServerStore } from '../../stores';
import { DownloadedModel, ONNXImageModel, RemoteModel } from '../../types';
import { activeModelService, llmService, remoteServerManager } from '../../services';
import { CustomAlert, AlertState, initialAlertState, showAlert } from '../CustomAlert';
import { CustomAlert, AlertState, initialAlertState, showAlert, showLoadingAlert } from '../CustomAlert';
import { createAllStyles } from './styles';
import { TextTab } from './TextTab';
import { ImageTab } from './ImageTab';
Expand Down Expand Up @@ -97,8 +97,9 @@ export const ModelSelectorModal: React.FC<ModelSelectorModalProps> = ({
// Vision-language models (supportsVision) are text models and belong in the text tab.
const remoteVisionModels = useMemo(() => [], []);

const handleSelectImageModel = async (model: ONNXImageModel) => {
if (activeImageModelId === model.id) return;
// Performs the actual image-model load and post-load wiring. Shared by the
// direct path and the "unload others & load" recovery path.
const loadImageModelNow = async (model: ONNXImageModel) => {
setIsLoadingImage(true);
try {
await activeModelService.loadImageModel(model.id);
Expand All @@ -114,6 +115,52 @@ export const ModelSelectorModal: React.FC<ModelSelectorModalProps> = ({
}
};

// Unloads whatever is currently loaded (text and/or image) to free RAM, then
// loads the requested image model. Used by the recovery action so the user is
// never left at a dead end when another model is occupying memory.
const unloadOthersAndLoadImageModel = async (model: ONNXImageModel) => {
setAlertState(showLoadingAlert('Freeing memory', 'Unloading other models…'));
try {
await activeModelService.unloadAllModels();
} catch (error) {
logger.error('Failed to unload models before image load:', error);
} finally {
setAlertState(initialAlertState);
}
await loadImageModelNow(model);
};

const handleSelectImageModel = async (model: ONNXImageModel) => {
if (activeImageModelId === model.id) return;

// Check memory up front so we can offer a one-tap recovery instead of a
// dead-end "Failed to Load" alert when another model is occupying RAM.
const memoryCheck = await activeModelService.checkMemoryForModel(model.id, 'image');
if (!memoryCheck.canLoad) {
if (memoryCheck.resolvableByUnload) {
setAlertState(
showAlert('Not enough memory', memoryCheck.message, [
{ text: 'Cancel', style: 'cancel' },
{
text: 'Unload others & load',
style: 'default',
onPress: () => {
setAlertState(initialAlertState);
unloadOthersAndLoadImageModel(model);
},
},
]),
);
} else {
// Too large for this device even on its own — unloading cannot help.
setAlertState(showAlert('Model too large', memoryCheck.message));
}
return;
}

await loadImageModelNow(model);
};

const handleUnloadImageModel = async () => {
setIsLoadingImage(true);
try {
Expand Down
Loading
Loading