Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@ class GeminiModelService @Inject constructor(
private val workerManager: AiWorkerManager
) {

/**
* Fetches available Gemini models using the provided API key.
* Returns a list of model names that are available for the user.
*/
suspend fun fetchAvailableModels(apiKey: String): Result<List<GeminiModel>> {
return withContext(Dispatchers.IO) {
try {
if (apiKey.isBlank()) {
return@withContext Result.failure(Exception("API Key is required"))
}

// Use a lightweight model to test the API key and fetch available models
// We'll make a request to list models using the Gemini API
val response = makeModelsListRequest(apiKey)

Result.success(response)
} catch (e: Exception) {
Timber.e(e, "Error fetching Gemini models")
Expand All @@ -47,7 +39,6 @@ class GeminiModelService @Inject constructor(
private suspend fun makeModelsListRequest(apiKey: String): List<GeminiModel> {
return withContext(Dispatchers.IO) {
try {
// Make HTTP request to Google's Gemini API to list models
val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$apiKey"
val connection = java.net.URL(url).openConnection() as java.net.HttpURLConnection

Expand All @@ -56,70 +47,61 @@ class GeminiModelService @Inject constructor(
connection.readTimeout = 10000

val responseCode = connection.responseCode
if (responseCode == 200) {
val apiModels = if (responseCode == 200) {
val response = connection.inputStream.bufferedReader().use { it.readText() }
parseModelsResponse(response)
} else {
val errorMessage = connection.errorStream?.bufferedReader()?.use { it.readText() }
Timber.e("Failed to fetch models: $responseCode - $errorMessage")
// Return default models if API call fails
getDefaultModels()
}
} else emptyList()

val defaults = getDefaultModels()
(apiModels + defaults).distinctBy { it.name }.sortedWith(
compareBy<GeminiModel> { model ->
val preferred = defaults.map { it.name }
preferred.indexOf(model.name).takeIf { it >= 0 } ?: Int.MAX_VALUE
}.thenBy { it.displayName.lowercase() }
)
} catch (e: Exception) {
Timber.e(e, "Exception fetching models, returning defaults")
getDefaultModels()
}
}
}

private fun parseModelsResponse(jsonResponse: String): List<GeminiModel> {
try {
// Parse the JSON response to extract model names
// Expected format: {"models": [{"name": "models/gemini-...", ...}, ...]}
val models = mutableListOf<GeminiModel>()

// Simple JSON parsing - extract model names
val modelPattern = """"name":\s*"(models/[^"]+)"""".toRegex()
val matches = modelPattern.findAll(jsonResponse)

val blacklist = listOf("-2.0", "-2.5", "-preview", "customtools", "search", "tuning", "-001", "-002")
val whitelist = listOf("gemini-3.1-pro-preview")

for (match in matches) {
val fullName = match.groupValues[1]
val modelName = fullName.removePrefix("models/")

// Keep the UI focused on the cheapest/fastest Gemini family: Flash only.
if (isSupportedFlashModel(modelName)) {
val isWhitelisted = whitelist.any { modelName == it }
val hasForbiddenSuffix = blacklist.any { modelName.contains(it) }
val isBlacklisted = hasForbiddenSuffix && !isWhitelisted

if (!isBlacklisted &&
(modelName.startsWith("gemini", ignoreCase = true) ||
modelName.startsWith("gemma", ignoreCase = true)) &&
!modelName.contains("embedding", ignoreCase = true)) {
models.add(GeminiModel(
name = modelName,
displayName = formatDisplayName(modelName)
))
}
}

return if (models.isNotEmpty()) {
sortFlashModels(models)
} else {
getDefaultModels()
}
return models
} catch (e: Exception) {
Timber.e(e, "Error parsing models response")
return getDefaultModels()
return emptyList()
}
}

/**
* Estimates the token count for a piece of text.
* Uses a conservative 4 chars per token rule for non-Gemini providers,
* but we recommend using the specific countTokens method on AiClient for accuracy.
*/
fun estimateTokens(text: String): Int {
return (text.length / 4).coerceAtLeast(1)
}

/**
* High-level method to perform an AI operation.
* Starts a background worker if [runInBackground] is true,
* otherwise executes immediately and returns the result.
*/
suspend fun performAiTask(
prompt: String,
type: AiSystemPromptType,
Expand All @@ -146,24 +128,7 @@ class GeminiModelService @Inject constructor(
}
}

private fun isSupportedFlashModel(modelName: String): Boolean {
return modelName.startsWith("gemini", ignoreCase = true) &&
modelName.contains("flash", ignoreCase = true) &&
!modelName.contains("embedding", ignoreCase = true) &&
!modelName.contains("image", ignoreCase = true)
}

private fun sortFlashModels(models: List<GeminiModel>): List<GeminiModel> {
val preferred = getDefaultModels().map { it.name }
return models.distinctBy { it.name }.sortedWith(
compareBy<GeminiModel> { model ->
preferred.indexOf(model.name).takeIf { it >= 0 } ?: Int.MAX_VALUE
}.thenBy { it.displayName.lowercase() }
)
}

private fun formatDisplayName(modelName: String): String {
// Convert "gemini-2.5-flash" to "Gemini 2.5 Flash"
return modelName
.split("-")
.joinToString(" ") { word ->
Expand All @@ -173,12 +138,13 @@ class GeminiModelService @Inject constructor(

private fun getDefaultModels(): List<GeminiModel> {
return listOf(
GeminiModel("gemini-2.5-flash-lite", "Gemini 2.5 Flash Lite (Cheapest/Fastest Default)"),
GeminiModel("gemini-2.5-flash", "Gemini 2.5 Flash"),
GeminiModel("gemini-3.1-flash-lite", "Gemini 3.1 Flash Lite (Recommended Default)"),
GeminiModel("gemini-3.5-flash", "Gemini 3.5 Flash"),
GeminiModel("gemini-3.1-pro-preview", "Gemini 3.1 Pro (Preview)"),
GeminiModel("gemini-flash-lite-latest", "Gemini Flash Lite Latest"),
GeminiModel("gemini-flash-latest", "Gemini Flash Latest"),
GeminiModel("gemini-2.0-flash-lite", "Gemini 2.0 Flash Lite"),
GeminiModel("gemini-2.0-flash", "Gemini 2.0 Flash")
GeminiModel("gemma-4-31b-it", "Gemma 4 31B IT"),
GeminiModel("gemma-4-26b-a4b-it", "Gemma 4 26B MoE")
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@ import com.google.ai.client.generativeai.type.generationConfig
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext

/**
* Gemini AI provider implementation using the official Android SDK
*/
class GeminiAiClient(private val apiKey: String) : AiClient {

companion object {
// Cheapest/fastest Gemini family default. Keep Gemini choices Flash-only.
private const val DEFAULT_GEMINI_MODEL = "gemini-2.5-flash-lite"
private const val DEFAULT_GEMINI_MODEL = "gemini-3.1-flash-lite"
}

private fun createModel(modelName: String, systemPrompt: String, temp: Float = 0.7f): GenerativeModel {
Expand Down Expand Up @@ -61,18 +57,15 @@ class GeminiAiClient(private val apiKey: String) : AiClient {
return withContext(Dispatchers.IO) {
try {
val generativeModel = createModel(model, systemPrompt)
// Combine system instruction if possible, or just estimate
val response = generativeModel.countTokens(prompt)
response.totalTokens
} catch (e: Exception) {
// Return estimation if SDK fails
(prompt.length / 4) + (systemPrompt.length / 4)
}
}
}

override suspend fun getAvailableModels(apiKey: String): List<String> {
// Models are usually fetched via HTTP as the SDK doesn't expose a listing method
return withContext(Dispatchers.IO) {
try {
val url = "https://generativelanguage.googleapis.com/v1beta/models?key=$apiKey"
Expand All @@ -98,7 +91,6 @@ class GeminiAiClient(private val apiKey: String) : AiClient {
override suspend fun validateApiKey(apiKey: String): Boolean {
return withContext(Dispatchers.IO) {
try {
// Use the stable model for validation
val generativeModel = GenerativeModel(
modelName = DEFAULT_GEMINI_MODEL,
apiKey = apiKey
Expand All @@ -119,45 +111,38 @@ class GeminiAiClient(private val apiKey: String) : AiClient {
val modelPattern = """"name":\s*"(models/[^"]+)"""".toRegex()
val matches = modelPattern.findAll(jsonResponse)

val blacklist = listOf("-2.0", "-2.5", "-preview", "customtools", "search", "tuning", "-001", "-002")
val whitelist = listOf("gemini-3.1-pro-preview")

for (match in matches) {
val fullName = match.groupValues[1]
val modelName = fullName.removePrefix("models/")

if (isSupportedFlashModel(modelName)) {
val isWhitelisted = whitelist.any { modelName == it }
val hasForbiddenSuffix = blacklist.any { modelName.contains(it) }
val isBlacklisted = hasForbiddenSuffix && !isWhitelisted

if (!isBlacklisted &&
(modelName.startsWith("gemini", ignoreCase = true) ||
modelName.startsWith("gemma", ignoreCase = true)) &&
!modelName.contains("embedding", ignoreCase = true)) {
models.add(modelName)
}
}

return if (models.isNotEmpty()) sortFlashModels(models) else getDefaultModels()
val defaults = getDefaultModels()
return (models + defaults).distinct().sorted()
} catch (e: Exception) {
return getDefaultModels()
}
}

private fun isSupportedFlashModel(modelName: String): Boolean {
return modelName.startsWith("gemini", ignoreCase = true) &&
modelName.contains("flash", ignoreCase = true) &&
!modelName.contains("embedding", ignoreCase = true) &&
!modelName.contains("image", ignoreCase = true)
}

private fun sortFlashModels(models: List<String>): List<String> {
val preferred = getDefaultModels()
return models.distinct().sortedWith(
compareBy<String> { model ->
preferred.indexOf(model).takeIf { it >= 0 } ?: Int.MAX_VALUE
}.thenBy { it.lowercase() }
)
}

private fun getDefaultModels(): List<String> {
return listOf(
"gemini-2.5-flash-lite",
"gemini-2.5-flash",
"gemini-flash-lite-latest",
"gemini-flash-latest",
"gemini-2.0-flash-lite",
"gemini-2.0-flash"
"gemini-3.1-flash-lite",
"gemini-3.5-flash",
"gemini-3.1-pro-preview",
"gemini-flash-latest"
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ class SettingsViewModel @Inject constructor(
private val colorSchemeProcessor: ColorSchemeProcessor,
private val syncManager: SyncManager,
private val aiClientFactory: AiClientFactory,
private val geminiModelService: com.theveloper.pixelplay.data.ai.GeminiModelService,
private val aiUsageDao: AiUsageDao,
private val lyricsRepository: LyricsRepository,
private val musicRepository: MusicRepository,
Expand Down Expand Up @@ -1125,13 +1126,16 @@ class SettingsViewModel @Inject constructor(
_uiState.update { it.copy(isLoadingModels = true, modelsFetchError = null) }
try {
val provider = AiProvider.fromString(providerName)
val aiClient = aiClientFactory.createClient(provider, apiKey)
val modelStrings = aiClient.getAvailableModels(apiKey)
val models = modelStrings
.map { it.trim() }
.filter { it.isNotBlank() }
.distinct()
.map { com.theveloper.pixelplay.data.ai.GeminiModel(it, formatModelDisplayName(it)) }
val models = if (provider == AiProvider.GEMINI) {
geminiModelService.fetchAvailableModels(apiKey).getOrThrow()
} else {
val aiClient = aiClientFactory.createClient(provider, apiKey)
aiClient.getAvailableModels(apiKey)
.map { it.trim() }
.filter { it.isNotBlank() }
.distinct()
.map { com.theveloper.pixelplay.data.ai.GeminiModel(it, formatModelDisplayName(it)) }
}

_uiState.update {
it.copy(
Expand Down Expand Up @@ -1165,7 +1169,6 @@ class SettingsViewModel @Inject constructor(
.replace('-', ' ')
.replace('_', ' ')
.split(' ')
.filter { it.isNotBlank() }
.joinToString(" ") { token ->
token.lowercase().replaceFirstChar { if (it.isLowerCase()) it.titlecase() else it.toString() }
}
Expand Down