diff --git a/app/src/main/java/com/theveloper/pixelplay/data/ai/GeminiModelService.kt b/app/src/main/java/com/theveloper/pixelplay/data/ai/GeminiModelService.kt index 6669918fd..5e96bc870 100644 --- a/app/src/main/java/com/theveloper/pixelplay/data/ai/GeminiModelService.kt +++ b/app/src/main/java/com/theveloper/pixelplay/data/ai/GeminiModelService.kt @@ -21,21 +21,13 @@ class GeminiModelService @Inject constructor( private val workerManager: AiWorkerManager ) { - /** - * Fetches available Gemini models using the provided API key. - * Returns a list of model names that are available for the user. - */ suspend fun fetchAvailableModels(apiKey: String): Result> { return withContext(Dispatchers.IO) { try { if (apiKey.isBlank()) { return@withContext Result.failure(Exception("API Key is required")) } - - // Use a lightweight model to test the API key and fetch available models - // We'll make a request to list models using the Gemini API val response = makeModelsListRequest(apiKey) - Result.success(response) } catch (e: Exception) { Timber.e(e, "Error fetching Gemini models") @@ -47,7 +39,6 @@ class GeminiModelService @Inject constructor( private suspend fun makeModelsListRequest(apiKey: String): List { return withContext(Dispatchers.IO) { try { - // Make HTTP request to Google's Gemini API to list models val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$apiKey" val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection @@ -56,17 +47,19 @@ class GeminiModelService @Inject constructor( connection.readTimeout = 10000 val responseCode = connection.responseCode - if (responseCode == 200) { + val apiModels = if (responseCode == 200) { val response = connection.inputStream.bufferedReader().use { it.readText() } parseModelsResponse(response) - } else { - val errorMessage = connection.errorStream?.bufferedReader()?.use { it.readText() } - Timber.e("Failed to fetch models: $responseCode - $errorMessage") - // Return default models if API call fails - getDefaultModels() - } + } else emptyList() + + val defaults = getDefaultModels() + (apiModels + defaults).distinctBy { it.name }.sortedWith( + compareBy { model -> + val preferred = defaults.map { it.name } + preferred.indexOf(model.name).takeIf { it >= 0 } ?: Int.MAX_VALUE + }.thenBy { it.displayName.lowercase() } + ) } catch (e: Exception) { - Timber.e(e, "Exception fetching models, returning defaults") getDefaultModels() } } @@ -74,52 +67,41 @@ class GeminiModelService @Inject constructor( private fun parseModelsResponse(jsonResponse: String): List { try { - // Parse the JSON response to extract model names - // Expected format: {"models": [{"name": "models/gemini-...", ...}, ...]} val models = mutableListOf() - - // Simple JSON parsing - extract model names val modelPattern = """"name":\s*"(models/[^"]+)"""".toRegex() val matches = modelPattern.findAll(jsonResponse) + val blacklist = listOf("-2.0", "-2.5", "-preview", "customtools", "search", "tuning", "-001", "-002") + val whitelist = listOf("gemini-3.1-pro-preview") + for (match in matches) { val fullName = match.groupValues[1] val modelName = fullName.removePrefix("models/") - // Keep the UI focused on the cheapest/fastest Gemini family: Flash only. - if (isSupportedFlashModel(modelName)) { + val isWhitelisted = whitelist.any { modelName == it } + val hasForbiddenSuffix = blacklist.any { modelName.contains(it) } + val isBlacklisted = hasForbiddenSuffix && !isWhitelisted + + if (!isBlacklisted && + (modelName.startsWith("gemini", ignoreCase = true) || + modelName.startsWith("gemma", ignoreCase = true)) && + !modelName.contains("embedding", ignoreCase = true)) { models.add(GeminiModel( name = modelName, displayName = formatDisplayName(modelName) )) } } - - return if (models.isNotEmpty()) { - sortFlashModels(models) - } else { - getDefaultModels() - } + return models } catch (e: Exception) { - Timber.e(e, "Error parsing models response") - return getDefaultModels() + return emptyList() } } - /** - * Estimates the token count for a piece of text. - * Uses a conservative 4 chars per token rule for non-Gemini providers, - * but we recommend using the specific countTokens method on AiClient for accuracy. - */ fun estimateTokens(text: String): Int { return (text.length / 4).coerceAtLeast(1) } - /** - * High-level method to perform an AI operation. - * Starts a background worker if [runInBackground] is true, - * otherwise executes immediately and returns the result. - */ suspend fun performAiTask( prompt: String, type: AiSystemPromptType, @@ -146,24 +128,7 @@ class GeminiModelService @Inject constructor( } } - private fun isSupportedFlashModel(modelName: String): Boolean { - return modelName.startsWith("gemini", ignoreCase = true) && - modelName.contains("flash", ignoreCase = true) && - !modelName.contains("embedding", ignoreCase = true) && - !modelName.contains("image", ignoreCase = true) - } - - private fun sortFlashModels(models: List): List { - val preferred = getDefaultModels().map { it.name } - return models.distinctBy { it.name }.sortedWith( - compareBy { model -> - preferred.indexOf(model.name).takeIf { it >= 0 } ?: Int.MAX_VALUE - }.thenBy { it.displayName.lowercase() } - ) - } - private fun formatDisplayName(modelName: String): String { - // Convert "gemini-2.5-flash" to "Gemini 2.5 Flash" return modelName .split("-") .joinToString(" ") { word -> @@ -173,12 +138,13 @@ class GeminiModelService @Inject constructor( private fun getDefaultModels(): List { return listOf( - GeminiModel("gemini-2.5-flash-lite", "Gemini 2.5 Flash Lite (Cheapest/Fastest Default)"), - GeminiModel("gemini-2.5-flash", "Gemini 2.5 Flash"), + GeminiModel("gemini-3.1-flash-lite", "Gemini 3.1 Flash Lite (Recommended Default)"), + GeminiModel("gemini-3.5-flash", "Gemini 3.5 Flash"), + GeminiModel("gemini-3.1-pro-preview", "Gemini 3.1 Pro (Preview)"), GeminiModel("gemini-flash-lite-latest", "Gemini Flash Lite Latest"), GeminiModel("gemini-flash-latest", "Gemini Flash Latest"), - GeminiModel("gemini-2.0-flash-lite", "Gemini 2.0 Flash Lite"), - GeminiModel("gemini-2.0-flash", "Gemini 2.0 Flash") + GeminiModel("gemma-4-31b-it", "Gemma 4 31B IT"), + GeminiModel("gemma-4-26b-a4b-it", "Gemma 4 26B MoE") ) } } diff --git a/app/src/main/java/com/theveloper/pixelplay/data/ai/provider/GeminiAiClient.kt b/app/src/main/java/com/theveloper/pixelplay/data/ai/provider/GeminiAiClient.kt index 094665f40..1181bb70f 100644 --- a/app/src/main/java/com/theveloper/pixelplay/data/ai/provider/GeminiAiClient.kt +++ b/app/src/main/java/com/theveloper/pixelplay/data/ai/provider/GeminiAiClient.kt @@ -5,14 +5,10 @@ import com.google.ai.client.generativeai.type.generationConfig import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext -/** - * Gemini AI provider implementation using the official Android SDK - */ class GeminiAiClient(private val apiKey: String) : AiClient { companion object { - // Cheapest/fastest Gemini family default. Keep Gemini choices Flash-only. - private const val DEFAULT_GEMINI_MODEL = "gemini-2.5-flash-lite" + private const val DEFAULT_GEMINI_MODEL = "gemini-3.1-flash-lite" } private fun createModel(modelName: String, systemPrompt: String, temp: Float = 0.7f): GenerativeModel { @@ -61,18 +57,15 @@ class GeminiAiClient(private val apiKey: String) : AiClient { return withContext(Dispatchers.IO) { try { val generativeModel = createModel(model, systemPrompt) - // Combine system instruction if possible, or just estimate val response = generativeModel.countTokens(prompt) response.totalTokens } catch (e: Exception) { - // Return estimation if SDK fails (prompt.length / 4) + (systemPrompt.length / 4) } } } override suspend fun getAvailableModels(apiKey: String): List { - // Models are usually fetched via HTTP as the SDK doesn't expose a listing method return withContext(Dispatchers.IO) { try { val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$apiKey" @@ -98,7 +91,6 @@ class GeminiAiClient(private val apiKey: String) : AiClient { override suspend fun validateApiKey(apiKey: String): Boolean { return withContext(Dispatchers.IO) { try { - // Use the stable model for validation val generativeModel = GenerativeModel( modelName = DEFAULT_GEMINI_MODEL, apiKey = apiKey @@ -119,45 +111,38 @@ class GeminiAiClient(private val apiKey: String) : AiClient { val modelPattern = """"name":\s*"(models/[^"]+)"""".toRegex() val matches = modelPattern.findAll(jsonResponse) + val blacklist = listOf("-2.0", "-2.5", "-preview", "customtools", "search", "tuning", "-001", "-002") + val whitelist = listOf("gemini-3.1-pro-preview") + for (match in matches) { val fullName = match.groupValues[1] val modelName = fullName.removePrefix("models/") - if (isSupportedFlashModel(modelName)) { + val isWhitelisted = whitelist.any { modelName == it } + val hasForbiddenSuffix = blacklist.any { modelName.contains(it) } + val isBlacklisted = hasForbiddenSuffix && !isWhitelisted + + if (!isBlacklisted && + (modelName.startsWith("gemini", ignoreCase = true) || + modelName.startsWith("gemma", ignoreCase = true)) && + !modelName.contains("embedding", ignoreCase = true)) { models.add(modelName) } } - return if (models.isNotEmpty()) sortFlashModels(models) else getDefaultModels() + val defaults = getDefaultModels() + return (models + defaults).distinct().sorted() } catch (e: Exception) { return getDefaultModels() } } - private fun isSupportedFlashModel(modelName: String): Boolean { - return modelName.startsWith("gemini", ignoreCase = true) && - modelName.contains("flash", ignoreCase = true) && - !modelName.contains("embedding", ignoreCase = true) && - !modelName.contains("image", ignoreCase = true) - } - - private fun sortFlashModels(models: List): List { - val preferred = getDefaultModels() - return models.distinct().sortedWith( - compareBy { model -> - preferred.indexOf(model).takeIf { it >= 0 } ?: Int.MAX_VALUE - }.thenBy { it.lowercase() } - ) - } - private fun getDefaultModels(): List { return listOf( - "gemini-2.5-flash-lite", - "gemini-2.5-flash", - "gemini-flash-lite-latest", - "gemini-flash-latest", - "gemini-2.0-flash-lite", - "gemini-2.0-flash" + "gemini-3.1-flash-lite", + "gemini-3.5-flash", + "gemini-3.1-pro-preview", + "gemini-flash-latest" ) } } diff --git a/app/src/main/java/com/theveloper/pixelplay/presentation/viewmodel/SettingsViewModel.kt b/app/src/main/java/com/theveloper/pixelplay/presentation/viewmodel/SettingsViewModel.kt index c9dbeee9b..fe4251e8c 100644 --- a/app/src/main/java/com/theveloper/pixelplay/presentation/viewmodel/SettingsViewModel.kt +++ b/app/src/main/java/com/theveloper/pixelplay/presentation/viewmodel/SettingsViewModel.kt @@ -179,6 +179,7 @@ class SettingsViewModel @Inject constructor( private val colorSchemeProcessor: ColorSchemeProcessor, private val syncManager: SyncManager, private val aiClientFactory: AiClientFactory, + private val geminiModelService: com.theveloper.pixelplay.data.ai.GeminiModelService, private val aiUsageDao: AiUsageDao, private val lyricsRepository: LyricsRepository, private val musicRepository: MusicRepository, @@ -1125,13 +1126,16 @@ class SettingsViewModel @Inject constructor( _uiState.update { it.copy(isLoadingModels = true, modelsFetchError = null) } try { val provider = AiProvider.fromString(providerName) - val aiClient = aiClientFactory.createClient(provider, apiKey) - val modelStrings = aiClient.getAvailableModels(apiKey) - val models = modelStrings - .map { it.trim() } - .filter { it.isNotBlank() } - .distinct() - .map { com.theveloper.pixelplay.data.ai.GeminiModel(it, formatModelDisplayName(it)) } + val models = if (provider == AiProvider.GEMINI) { + geminiModelService.fetchAvailableModels(apiKey).getOrThrow() + } else { + val aiClient = aiClientFactory.createClient(provider, apiKey) + aiClient.getAvailableModels(apiKey) + .map { it.trim() } + .filter { it.isNotBlank() } + .distinct() + .map { com.theveloper.pixelplay.data.ai.GeminiModel(it, formatModelDisplayName(it)) } + } _uiState.update { it.copy( @@ -1165,7 +1169,6 @@ class SettingsViewModel @Inject constructor( .replace('-', ' ') .replace('_', ' ') .split(' ') - .filter { it.isNotBlank() } .joinToString(" ") { token -> token.lowercase().replaceFirstChar { if (it.isLowerCase()) it.titlecase() else it.toString() } }