Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 43 additions & 2 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,16 @@ import Foundation
/// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit")
/// ```
public struct MLXLanguageModel: LanguageModel {
/// Custom generation options for MLX models.
public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions {
/// Additional key-value pairs injected into the chat template rendering context.
public var additionalContext: [String: MLXLMCommon.JSONValue]?

public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) {
self.additionalContext = additionalContext
}
}

/// The reason the model is unavailable.
public enum UnavailableReason: Sendable, Equatable, Hashable {
/// The model has not been loaded into memory yet.
Expand Down Expand Up @@ -292,6 +302,11 @@ import Foundation
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
let generateParameters = toGenerateParameters(options)

// Extract additional context from custom options
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
.flatMap { $0.additionalContext }
.map { $0.mapValues { $0.toSendable() } }

// Build chat history from full transcript
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

Expand All @@ -305,6 +320,7 @@ import Foundation
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: toolSpecs,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -407,10 +423,15 @@ import Foundation
let generateParameters = toGenerateParameters(options)
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)

let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
.flatMap { $0.additionalContext }
.map { $0.mapValues { $0.toSendable() } }

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
tools: nil,
additionalContext: additionalContext
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -876,10 +897,16 @@ import Foundation
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)

let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
.flatMap { $0.additionalContext }
.map { $0.mapValues { $0.toSendable() } }

let userInput = MLXLMCommon.UserInput(
chat: chat,
processing: .init(resize: .init(width: 512, height: 512)),
tools: nil
tools: nil,
additionalContext: additionalContext,
)
let lmInput = try await context.processor.prepare(input: userInput)

Expand Down Expand Up @@ -1120,4 +1147,18 @@ import Foundation
return sampledToken.item(Int.self)
}
}
extension MLXLMCommon.JSONValue {
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
func toSendable() -> any Sendable {
switch self {
case .string(let s): return s
case .int(let i): return i
case .double(let d): return d
case .bool(let b): return b
case .null: return NSNull()
case .array(let arr): return arr.map { $0.toSendable() }
case .object(let obj): return obj.mapValues { $0.toSendable() }
}
}
}
#endif // MLX
22 changes: 22 additions & 0 deletions Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,28 @@ import Testing
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
}

@Test func withAdditionalContext() async throws {
let session = LanguageModelSession(model: model)

var options = GenerationOptions(
temperature: 0.7,
maximumResponseTokens: 32
)
options[custom: MLXLanguageModel.self] = .init(
additionalContext: [
"user_name": .string("Alice"),
"turn_count": .int(3),
"verbose": .bool(true),
]
)

let response = try await session.respond(
to: "Say hello",
options: options
)
#expect(!response.content.isEmpty)
}

@Test func unavailableForNonexistentModel() async {
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
await model.removeFromCache()
Expand Down
Loading