diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index f4be593..67ff79a 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -183,6 +183,16 @@ import Foundation /// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit") /// ``` public struct MLXLanguageModel: LanguageModel { + /// Custom generation options for MLX models. + public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions { + /// Additional key-value pairs injected into the chat template rendering context. + public var additionalContext: [String: MLXLMCommon.JSONValue]? + + public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) { + self.additionalContext = additionalContext + } + } + /// The reason the model is unavailable. public enum UnavailableReason: Sendable, Equatable, Hashable { /// The model has not been loaded into memory yet. @@ -292,6 +302,11 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) + // Extract additional context from custom options + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + // Build chat history from full transcript var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) @@ -305,6 +320,7 @@ import Foundation chat: chat, processing: .init(resize: .init(width: 512, height: 512)), tools: toolSpecs, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) @@ -407,10 +423,15 @@ import Foundation let generateParameters = toGenerateParameters(options) let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + tools: nil, + additionalContext: additionalContext ) let lmInput = try await context.processor.prepare(input: userInput) @@ -876,10 +897,16 @@ import Foundation let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description) let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt) + + let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self] + .flatMap { $0.additionalContext } + .map { $0.mapValues { $0.toSendable() } } + let userInput = MLXLMCommon.UserInput( chat: chat, processing: .init(resize: .init(width: 512, height: 512)), - tools: nil + tools: nil, + additionalContext: additionalContext, ) let lmInput = try await context.processor.prepare(input: userInput) @@ -1120,4 +1147,18 @@ import Foundation return sampledToken.item(Int.self) } } + extension MLXLMCommon.JSONValue { + /// Recursively converts a `JSONValue` to its primitive Swift equivalent. + func toSendable() -> any Sendable { + switch self { + case .string(let s): return s + case .int(let i): return i + case .double(let d): return d + case .bool(let b): return b + case .null: return NSNull() + case .array(let arr): return arr.map { $0.toSendable() } + case .object(let obj): return obj.mapValues { $0.toSendable() } + } + } + } #endif // MLX diff --git a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift index 82c5f85..aae11b0 100644 --- a/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift +++ b/Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift @@ -220,6 +220,28 @@ import Testing #expect([Priority.low, Priority.medium, Priority.high].contains(response.content)) } + @Test func withAdditionalContext() async throws { + let session = LanguageModelSession(model: model) + + var options = GenerationOptions( + temperature: 0.7, + maximumResponseTokens: 32 + ) + options[custom: MLXLanguageModel.self] = .init( + additionalContext: [ + "user_name": .string("Alice"), + "turn_count": .int(3), + "verbose": .bool(true), + ] + ) + + let response = try await session.respond( + to: "Say hello", + options: options + ) + #expect(!response.content.isEmpty) + } + @Test func unavailableForNonexistentModel() async { let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test") await model.removeFromCache()