diff --git a/CHANGELOG.md b/CHANGELOG.md index 2531fc0b..41d586f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Fixed + +- Fix AI chat hanging the app during streaming, schema fetch, and conversation loading (#735) + ## [0.31.4] - 2026-04-14 ### Added diff --git a/TablePro/Core/AI/AIProviderFactory.swift b/TablePro/Core/AI/AIProviderFactory.swift index 5b6a6a9e..88994a95 100644 --- a/TablePro/Core/AI/AIProviderFactory.swift +++ b/TablePro/Core/AI/AIProviderFactory.swift @@ -6,6 +6,7 @@ // import Foundation +import os /// Factory for creating AI provider instances enum AIProviderFactory { @@ -16,46 +17,50 @@ enum AIProviderFactory { let config: AIProviderConfig } - private static var cachedProviders: [UUID: (apiKey: String?, provider: AIProvider)] = [:] + private static let cacheLock = OSAllocatedUnfairLock( + initialState: [UUID: (apiKey: String?, provider: AIProvider)]() + ) /// Create or return a cached AI provider for the given configuration static func createProvider( for config: AIProviderConfig, apiKey: String? ) -> AIProvider { - if let cached = cachedProviders[config.id], cached.apiKey == apiKey { - return cached.provider - } + cacheLock.withLock { cache in + if let cached = cache[config.id], cached.apiKey == apiKey { + return cached.provider + } - let provider: AIProvider - switch config.type { - case .claude: - provider = AnthropicProvider( - endpoint: config.endpoint, - apiKey: apiKey ?? "" - ) - case .gemini: - provider = GeminiProvider( - endpoint: config.endpoint, - apiKey: apiKey ?? "" - ) - case .openAI, .openRouter, .ollama, .custom: - provider = OpenAICompatibleProvider( - endpoint: config.endpoint, - apiKey: apiKey, - providerType: config.type - ) + let provider: AIProvider + switch config.type { + case .claude: + provider = AnthropicProvider( + endpoint: config.endpoint, + apiKey: apiKey ?? "" + ) + case .gemini: + provider = GeminiProvider( + endpoint: config.endpoint, + apiKey: apiKey ?? "" + ) + case .openAI, .openRouter, .ollama, .custom: + provider = OpenAICompatibleProvider( + endpoint: config.endpoint, + apiKey: apiKey, + providerType: config.type + ) + } + cache[config.id] = (apiKey, provider) + return provider } - cachedProviders[config.id] = (apiKey, provider) - return provider } static func invalidateCache() { - cachedProviders.removeAll() + cacheLock.withLock { $0.removeAll() } } static func invalidateCache(for configID: UUID) { - cachedProviders.removeValue(forKey: configID) + cacheLock.withLock { $0.removeValue(forKey: configID) } } static func resolveProvider( diff --git a/TablePro/Core/AI/AISchemaContext.swift b/TablePro/Core/AI/AISchemaContext.swift index 15acb2c2..fd48ab68 100644 --- a/TablePro/Core/AI/AISchemaContext.swift +++ b/TablePro/Core/AI/AISchemaContext.swift @@ -6,20 +6,14 @@ // import Foundation -import os import TableProPluginKit /// Builds schema context for AI system prompts struct AISchemaContext { - private static let logger = Logger( - subsystem: "com.TablePro", - category: "AISchemaContext" - ) - // MARK: - Public /// Build a system prompt including database context - @MainActor static func buildSystemPrompt( + static func buildSystemPrompt( databaseType: DatabaseType, databaseName: String, tables: [TableInfo], @@ -28,7 +22,9 @@ struct AISchemaContext { currentQuery: String?, queryResults: String?, settings: AISettings, - identifierQuote: String = "\"" + identifierQuote: String = "\"", + editorLanguage: EditorLanguage, + queryLanguageName: String ) -> String { var parts: [String] = [] @@ -56,7 +52,7 @@ struct AISchemaContext { if settings.includeCurrentQuery, let query = currentQuery, !query.isEmpty { - let lang = PluginManager.shared.editorLanguage(for: databaseType).codeBlockTag + let lang = editorLanguage.codeBlockTag parts.append("\n## Current Query\n```\(lang)\n\(query)\n```") } @@ -66,11 +62,9 @@ struct AISchemaContext { parts.append("\n## Recent Query Results\n\(results)") } - let editorLang = PluginManager.shared.editorLanguage(for: databaseType) - let langName = PluginManager.shared.queryLanguageName(for: databaseType) - let langTag = editorLang.codeBlockTag + let langTag = editorLanguage.codeBlockTag - switch editorLang { + switch editorLanguage { case .sql: parts.append( "\nProvide SQL queries appropriate for" @@ -82,10 +76,10 @@ struct AISchemaContext { ) default: parts.append( - "\nProvide \(langName) queries using `\(langTag)` fenced code blocks." + "\nProvide \(queryLanguageName) queries using `\(langTag)` fenced code blocks." ) parts.append( - "Use \(langName) syntax, not SQL." + "Use \(queryLanguageName) syntax, not SQL." ) } diff --git a/TablePro/Core/AI/AnthropicProvider.swift b/TablePro/Core/AI/AnthropicProvider.swift index 8c95c398..6ef1f9ae 100644 --- a/TablePro/Core/AI/AnthropicProvider.swift +++ b/TablePro/Core/AI/AnthropicProvider.swift @@ -60,16 +60,36 @@ final class AnthropicProvider: AIProvider { guard line.hasPrefix("data: ") else { continue } let jsonString = String(line.dropFirst(6)) - guard jsonString != "[DONE]" else { break } - - if let text = parseContentBlockDelta(jsonString) { - continuation.yield(.text(text)) - } - if let tokens = parseInputTokens(jsonString) { - inputTokens = tokens - } - if let tokens = parseOutputTokens(jsonString) { - outputTokens = tokens + guard jsonString != "[DONE]", + let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], + let type = json["type"] as? String + else { continue } + + switch type { + case "content_block_delta": + if let delta = json["delta"] as? [String: Any], + let text = delta["text"] as? String { + continuation.yield(.text(text)) + } + case "message_start": + if let message = json["message"] as? [String: Any], + let usage = message["usage"] as? [String: Any], + let tokens = usage["input_tokens"] as? Int { + inputTokens = tokens + } + case "message_delta": + if let usage = json["usage"] as? [String: Any], + let tokens = usage["output_tokens"] as? Int { + outputTokens = tokens + } + case "error": + if let errorObj = json["error"] as? [String: Any], + let message = errorObj["message"] as? String { + throw AIProviderError.streamingFailed(message) + } + default: + break } } @@ -196,44 +216,4 @@ final class AnthropicProvider: AIProvider { request.httpBody = try JSONSerialization.data(withJSONObject: body) return request } - - private func parseContentBlockDelta(_ jsonString: String) -> String? { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let type = json["type"] as? String, - type == "content_block_delta", - let delta = json["delta"] as? [String: Any], - let text = delta["text"] as? String - else { - return nil - } - return text - } - - private func parseInputTokens(_ jsonString: String) -> Int? { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let type = json["type"] as? String, - type == "message_start", - let message = json["message"] as? [String: Any], - let usage = message["usage"] as? [String: Any], - let inputTokens = usage["input_tokens"] as? Int - else { - return nil - } - return inputTokens - } - - private func parseOutputTokens(_ jsonString: String) -> Int? { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - let type = json["type"] as? String, - type == "message_delta", - let usage = json["usage"] as? [String: Any], - let outputTokens = usage["output_tokens"] as? Int - else { - return nil - } - return outputTokens - } } diff --git a/TablePro/Core/AI/OpenAICompatibleProvider.swift b/TablePro/Core/AI/OpenAICompatibleProvider.swift index f31b1192..907645e0 100644 --- a/TablePro/Core/AI/OpenAICompatibleProvider.swift +++ b/TablePro/Core/AI/OpenAICompatibleProvider.swift @@ -63,38 +63,52 @@ final class OpenAICompatibleProvider: AIProvider { for try await line in bytes.lines { if Task.isCancelled { break } + let jsonString: String if self.providerType == .ollama { // Ollama: raw newline-delimited JSON (no SSE "data: " prefix) guard !line.isEmpty else { continue } Self.logger.debug("Ollama stream line: \(line.prefix(200), privacy: .public)") - - if let text = self.parseChatCompletionDelta(line) { - continuation.yield(.text(text)) - } - if let usage = self.parseUsageFromChunk(line) { - inputTokens = usage.inputTokens - outputTokens = usage.outputTokens - } - // Ollama signals completion with "done":true - if let data = line.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any], - json["done"] as? Bool == true - { - break - } + jsonString = line } else { // OpenAI/OpenRouter/Custom: SSE with "data: " prefix guard line.hasPrefix("data: ") else { continue } - let jsonString = String(line.dropFirst(6)) - guard jsonString != "[DONE]" else { break } - - if let text = self.parseChatCompletionDelta(jsonString) { - continuation.yield(.text(text)) - } - if let usage = self.parseUsageFromChunk(jsonString) { - inputTokens = usage.inputTokens - outputTokens = usage.outputTokens - } + let payload = String(line.dropFirst(6)) + guard payload != "[DONE]" else { break } + jsonString = payload + } + + // Single JSON parse per SSE line + guard let data = jsonString.data(using: .utf8), + let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] + else { continue } + + // Text extraction + if let choices = json["choices"] as? [[String: Any]], + let delta = choices.first?["delta"] as? [String: Any], + let content = delta["content"] as? String { + continuation.yield(.text(content)) + } else if let message = json["message"] as? [String: Any], + let content = message["content"] as? String, + !content.isEmpty { + continuation.yield(.text(content)) + } + + // Usage extraction + if let usage = json["usage"] as? [String: Any], + let promptTokens = usage["prompt_tokens"] as? Int, + let completionTokens = usage["completion_tokens"] as? Int { + inputTokens = promptTokens + outputTokens = completionTokens + } else if let done = json["done"] as? Bool, done, + let promptEval = json["prompt_eval_count"] as? Int, + let evalCount = json["eval_count"] as? Int { + inputTokens = promptEval + outputTokens = evalCount + } + + // Ollama signals completion with "done":true + if json["done"] as? Bool == true { + break } } @@ -250,57 +264,6 @@ final class OpenAICompatibleProvider: AIProvider { return request } - // MARK: - Response Parsing - - private func parseChatCompletionDelta(_ jsonString: String) -> String? { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) - as? [String: Any] - else { - return nil - } - - // OpenAI/OpenRouter format - if let choices = json["choices"] as? [[String: Any]], - let delta = choices.first?["delta"] as? [String: Any], - let content = delta["content"] as? String { - return content - } - - // Ollama format - if let message = json["message"] as? [String: Any], - let content = message["content"] as? String, - !content.isEmpty { - return content - } - - return nil - } - - private func parseUsageFromChunk(_ jsonString: String) -> AITokenUsage? { - guard let data = jsonString.data(using: .utf8), - let json = try? JSONSerialization.jsonObject(with: data) as? [String: Any] - else { - return nil - } - - // OpenAI/OpenRouter format: usage object in the chunk - if let usage = json["usage"] as? [String: Any], - let promptTokens = usage["prompt_tokens"] as? Int, - let completionTokens = usage["completion_tokens"] as? Int { - return AITokenUsage(inputTokens: promptTokens, outputTokens: completionTokens) - } - - // Ollama format: done=true with eval counts - if let done = json["done"] as? Bool, done, - let promptEval = json["prompt_eval_count"] as? Int, - let evalCount = json["eval_count"] as? Int { - return AITokenUsage(inputTokens: promptEval, outputTokens: evalCount) - } - - return nil - } - // MARK: - Model Fetching private func fetchOpenAIModels() async throws -> [String] { diff --git a/TablePro/Core/Autocomplete/SQLSchemaProvider.swift b/TablePro/Core/Autocomplete/SQLSchemaProvider.swift index 55ce1ff3..dce5ff97 100644 --- a/TablePro/Core/Autocomplete/SQLSchemaProvider.swift +++ b/TablePro/Core/Autocomplete/SQLSchemaProvider.swift @@ -165,23 +165,26 @@ actor SQLSchemaProvider { let dbType = connection.type let dbName = connection.database let capturedTables = tables - let idQuote = await MainActor.run { - PluginManager.shared.sqlDialect(for: dbType)?.identifierQuote ?? "\"" + let (idQuote, editorLanguage, queryLanguageName) = await MainActor.run { + let quote = PluginManager.shared.sqlDialect(for: dbType)?.identifierQuote ?? "\"" + let lang = PluginManager.shared.editorLanguage(for: dbType) + let langName = PluginManager.shared.queryLanguageName(for: dbType) + return (quote, lang, langName) } - return await MainActor.run { - AISchemaContext.buildSystemPrompt( - databaseType: dbType, - databaseName: dbName, - tables: capturedTables, - columnsByTable: columnsByTable, - foreignKeys: [:], - currentQuery: nil, - queryResults: nil, - settings: settings, - identifierQuote: idQuote - ) - } + return AISchemaContext.buildSystemPrompt( + databaseType: dbType, + databaseName: dbName, + tables: capturedTables, + columnsByTable: columnsByTable, + foreignKeys: [:], + currentQuery: nil, + queryResults: nil, + settings: settings, + identifierQuote: idQuote, + editorLanguage: editorLanguage, + queryLanguageName: queryLanguageName + ) } // MARK: - Completion Items diff --git a/TablePro/ViewModels/AIChatViewModel.swift b/TablePro/ViewModels/AIChatViewModel.swift index 6331023f..755e9a23 100644 --- a/TablePro/ViewModels/AIChatViewModel.swift +++ b/TablePro/ViewModels/AIChatViewModel.swift @@ -93,6 +93,7 @@ final class AIChatViewModel { /// nonisolated(unsafe) is required because deinit is not @MainActor-isolated, /// so accessing a @MainActor property from deinit requires opting out of isolation. @ObservationIgnored nonisolated(unsafe) private var streamingTask: Task? + @ObservationIgnored private var schemaFetchTask: Task? private var streamingAssistantID: UUID? private var lastUsedFeature: AIFeature = .chat private let chatStorage = AIChatStorage.shared @@ -142,8 +143,10 @@ final class AIChatViewModel { isStreaming = false // Remove empty assistant placeholder left by cancelled stream - if let last = messages.last, last.role == .assistant, last.content.isEmpty { - messages.removeLast() + if let assistantID = streamingAssistantID, + let idx = messages.firstIndex(where: { $0.id == assistantID }), + messages[idx].content.isEmpty { + messages.remove(at: idx) } streamingAssistantID = nil persistCurrentConversation() @@ -211,12 +214,16 @@ final class AIChatViewModel { /// Load saved conversations from disk func loadConversations() { - Task { - let loaded = await chatStorage.loadAll() - conversations = loaded - if let mostRecent = loaded.first { - activeConversationID = mostRecent.id - messages = mostRecent.messages + let storage = chatStorage + Task.detached(priority: .utility) { [weak self] in + let loaded = await storage.loadAll() + await MainActor.run { + guard let self else { return } + self.conversations = loaded + if let mostRecent = loaded.first { + self.activeConversationID = mostRecent.id + self.messages = mostRecent.messages + } } } } @@ -245,6 +252,8 @@ final class AIChatViewModel { func clearSessionData() { streamingTask?.cancel() streamingTask = nil + schemaFetchTask?.cancel() + schemaFetchTask = nil schemaProvider = nil connection = nil tables = [] @@ -315,6 +324,19 @@ final class AIChatViewModel { } private func startStreaming(feature: AIFeature) { + // Cancel any in-flight stream before starting a new one + if streamingTask != nil { + streamingTask?.cancel() + streamingTask = nil + if let id = streamingAssistantID, + let idx = messages.firstIndex(where: { $0.id == id }), + messages[idx].content.isEmpty { + messages.remove(at: idx) + } + streamingAssistantID = nil + isStreaming = false + } + lastUsedFeature = feature lastMessageFailed = false @@ -354,12 +376,11 @@ final class AIChatViewModel { isStreaming = true - streamingTask = Task { [weak self] in - guard let self else { return } + // Capture value types on main actor before detaching + let chatMessages = Array(messages.dropLast()) + streamingTask = Task.detached(priority: .userInitiated) { [weak self] in do { - // Exclude the empty assistant placeholder from sent messages - let chatMessages = Array(self.messages.dropLast()) let stream = resolved.provider.streamChat( messages: chatMessages, model: resolved.model, @@ -367,36 +388,46 @@ final class AIChatViewModel { ) for try await event in stream { - guard !Task.isCancelled, - let idx = self.messages.firstIndex(where: { $0.id == assistantID }) - else { break } - switch event { - case .text(let token): - self.messages[idx].content += token - case .usage(let usage): - self.messages[idx].usage = usage + guard !Task.isCancelled else { break } + await MainActor.run { [weak self] in + guard let self, + let idx = self.messages.firstIndex(where: { $0.id == assistantID }) + else { return } + switch event { + case .text(let token): + self.messages[idx].content += token + case .usage(let usage): + self.messages[idx].usage = usage + } } } - self.isStreaming = false - self.streamingTask = nil - self.streamingAssistantID = nil - self.persistCurrentConversation() + guard !Task.isCancelled else { return } + await MainActor.run { [weak self] in + guard let self else { return } + self.isStreaming = false + self.streamingTask = nil + self.streamingAssistantID = nil + self.persistCurrentConversation() + } } catch { - if !Task.isCancelled { - Self.logger.error("Streaming failed: \(error.localizedDescription)") - self.lastMessageFailed = true - self.errorMessage = error.localizedDescription - - // Remove empty assistant message on error - if let idx = self.messages.firstIndex(where: { $0.id == assistantID }), - self.messages[idx].content.isEmpty { - self.messages.remove(at: idx) + await MainActor.run { [weak self] in + guard let self else { return } + if !Task.isCancelled { + Self.logger.error("Streaming failed: \(error.localizedDescription)") + self.lastMessageFailed = true + self.errorMessage = error.localizedDescription + + // Remove empty assistant message on error + if let idx = self.messages.firstIndex(where: { $0.id == assistantID }), + self.messages[idx].content.isEmpty { + self.messages.remove(at: idx) + } } + self.isStreaming = false + self.streamingTask = nil + self.streamingAssistantID = nil } - self.isStreaming = false - self.streamingTask = nil - self.streamingAssistantID = nil } } } @@ -419,6 +450,8 @@ final class AIChatViewModel { guard let connection else { return nil } let idQuote = PluginManager.shared.sqlDialect(for: connection.type)?.identifierQuote ?? "\"" + let editorLanguage = PluginManager.shared.editorLanguage(for: connection.type) + let queryLanguageName = PluginManager.shared.queryLanguageName(for: connection.type) return AISchemaContext.buildSystemPrompt( databaseType: connection.type, databaseName: connection.database, @@ -428,58 +461,81 @@ final class AIChatViewModel { currentQuery: settings.includeCurrentQuery ? currentQuery : nil, queryResults: settings.includeQueryResults ? queryResults : nil, settings: settings, - identifierQuote: idQuote + identifierQuote: idQuote, + editorLanguage: editorLanguage, + queryLanguageName: queryLanguageName ) } // MARK: - Schema Context - func fetchSchemaContext() async { + func fetchSchemaContext() { let settings = AppSettingsManager.shared.ai guard settings.includeSchema, let connection, let driver = DatabaseManager.shared.driver(for: connection.id) else { return } + schemaFetchTask?.cancel() + let tablesToFetch = Array(tables.prefix(settings.maxSchemaTables)) - var columns: [String: [ColumnInfo]] = [:] - var foreignKeys: [String: [ForeignKeyInfo]] = [:] - - await withTaskGroup(of: (String, [ColumnInfo]).self) { group in - for table in tablesToFetch { - group.addTask { [schemaProvider] in - if let schemaProvider { - let cached = await schemaProvider.getColumns(for: table.name) - if !cached.isEmpty { - return (table.name, cached) - } + let capturedProvider = schemaProvider + + schemaFetchTask = Task.detached(priority: .userInitiated) { [weak self] in + var columns: [String: [ColumnInfo]] = [:] + var foreignKeys: [String: [ForeignKeyInfo]] = [:] + + let fetchColumns: @Sendable (TableInfo) async -> (String, [ColumnInfo]) = { table in + if let provider = capturedProvider { + let cached = await provider.getColumns(for: table.name) + if !cached.isEmpty { + return (table.name, cached) } - do { - let cols = try await driver.fetchColumns(table: table.name) - return (table.name, cols) - } catch { - Self.logger.debug("Schema context: failed to fetch columns for '\(table.name)'") - return (table.name, []) + } + do { + let cols = try await driver.fetchColumns(table: table.name) + return (table.name, cols) + } catch { + return (table.name, []) + } + } + + let concurrencyLimit = 4 + await withTaskGroup(of: (String, [ColumnInfo]).self) { group in + var pending = tablesToFetch.makeIterator() + + // Seed initial batch + for _ in 0..= 0.1 else { return } + lastAutoScrollTime = now + scrollToBottom(proxy: proxy) + } + .onChange(of: viewModel.isStreaming) { _, newValue in + if !newValue, !isUserScrolledUp { + scrollToBottom(proxy: proxy) + } + } } if isUserScrolledUp, let proxy = scrollProxy { @@ -314,10 +321,6 @@ struct AIChatPanelView: View { } } - // MARK: - Schema Context - - private static let logger = Logger(subsystem: "com.TablePro", category: "AIChatPanelView") - // MARK: - Helpers private func scrollToBottom(proxy: ScrollViewProxy) {