From 981112418eca1021cd561fb290265bad3f579334 Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 5 Jun 2026 07:24:18 +0000 Subject: [PATCH] feat(embedding): EmbeddingGemma on-device RAG embedder (ONNX, 256-dim) with USE-Lite fallback Engine (lib): - Task-aware EmbeddingEngine (QUERY/DOCUMENT + dimensions); OnnxEmbeddingEngine (EmbeddingGemma 300M via ONNX Runtime: instruction prompts, mean-pool, Matryoshka-256, L2-norm) + HfGemmaTokenizer (tokenizer.json). - ModelFormat.ONNX_EMBEDDER and ModelDescriptor.companions (tokenizer download). App (sample-app): - GemmaChunkEntity (256-dim HNSW) alongside the 100-dim USE store; repository routes vector ops by dimension and returns neutral Chunk/ScoredChunk DTOs. - RagHolder selects the active embedder (Gemma on capable devices, USE fallback), downloads model+companion, and re-indexes legacy chunks (migrateToGemma). - Ingest/retrieve are task-aware; per-embedder distance gate. - Backups round-trip both stores (schema v2, per-chunk dim); catalog descriptor. Note: not compiled in this sandbox (no Android SDK); first build regenerates the ObjectBox model. Tokenizer runtime + ONNX output names need on-device verification. --- CHANGELOG.md | 3 + docs/EMBEDDING_GEMMA_PLAN.md | 220 ++++++++++++++++++ gradle/libs.versions.toml | 5 + lib/build.gradle.kts | 3 + lib/consumer-rules.pro | 11 + .../kotlin/com/sagar/aicore/GemmaTokenizer.kt | 49 ++++ .../sagar/aicore/MediaPipeEmbeddingEngine.kt | 9 +- .../com/sagar/aicore/OnnxEmbeddingEngine.kt | 156 +++++++++++++ .../com/sagar/aicore/EmbeddingEngine.kt | 31 ++- .../kotlin/com/sagar/aicore/ModelCatalog.kt | 22 ++ .../nativelm/app/data/backup/BackupManager.kt | 59 +++-- .../nativelm/app/data/backup/BackupModels.kt | 13 +- .../app/data/db/DocumentRepository.kt | 70 ++++-- .../java/com/nativelm/app/data/db/Entities.kt | 35 +++ .../data/db/ObjectBoxDocumentRepository.kt | 190 ++++++++++----- .../nativelm/app/llm/NativeLmModelCatalog.kt | 28 +++ .../com/nativelm/app/llm/NativeLmViewModel.kt | 59 +++-- .../java/com/nativelm/app/llm/RagHolder.kt | 127 ++++++++-- .../app/rag/DefaultDocumentIngestor.kt | 28 +-- .../app/rag/DefaultDocumentRetriever.kt | 38 +-- .../app/rag/DefaultDocumentRetrieverTest.kt | 21 +- .../app/rag/RagContextFormatterTest.kt | 8 +- 22 files changed, 1013 insertions(+), 172 deletions(-) create mode 100644 docs/EMBEDDING_GEMMA_PLAN.md create mode 100644 lib/src/androidMain/kotlin/com/sagar/aicore/GemmaTokenizer.kt create mode 100644 lib/src/androidMain/kotlin/com/sagar/aicore/OnnxEmbeddingEngine.kt diff --git a/CHANGELOG.md b/CHANGELOG.md index 9589192..7970a2d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ and the project loosely follows [Semantic Versioning](https://semver.org/). ## [Unreleased] +### Added +- **EmbeddingGemma RAG embedder** _(app + engine)_ — higher-quality on-device retrieval via EmbeddingGemma 300M (ONNX Runtime, Matryoshka-256, query/document task prompts), default on capable devices with the 6 MB USE-Lite embedder as the friction-free fallback. Existing sources are re-indexed from their stored text on upgrade; backups round-trip both embedders. Engine adds a task-aware `EmbeddingEngine` and `ModelFormat.ONNX_EMBEDDER` + companion-file (tokenizer) downloads. See `docs/EMBEDDING_GEMMA_PLAN.md`. + ## [0.8.0] — 2026-06-03 ### Added diff --git a/docs/EMBEDDING_GEMMA_PLAN.md b/docs/EMBEDDING_GEMMA_PLAN.md new file mode 100644 index 0000000..d592edf --- /dev/null +++ b/docs/EMBEDDING_GEMMA_PLAN.md @@ -0,0 +1,220 @@ +# NativeLM — EmbeddingGemma on-device RAG embedder (implementation plan) + +_Branch: `claude/analysis-KV4t2`. Goal: replace the 2018-era Universal Sentence +Encoder (USE-Lite, 100-dim) with **EmbeddingGemma 300M** as the default RAG +embedder, lifting retrieval quality for both chat answers and every Studio +artifact. USE-Lite stays as the low-end / no-download fallback._ + +## Locked decisions + +| Decision | Choice | Why | +|---|---|---| +| **Runtime** | **ONNX Runtime (Android)** + a SentencePiece/HF tokenizer | Self-contained; full control of task-prompts, pooling, normalization, Matryoshka. **No Google telemetry deps** (protects the zero-telemetry stance — commit `d5b5fa9`). KMP/iOS-friendly. Avoids the MediaPipe `TextEmbedder` path that broke before. | +| **Dimension** | **256** (Matryoshka truncation of the 768-native vector) | Best quality/size/speed balance on-device; the migration path already reserved in `DocumentChunkEntity`. ~2.5× storage vs the current 100-dim but far better retrieval; half the index cost of 512. | +| **Rollout** | **Default on capable devices; USE-Lite stays as fallback** | Friction-free first run preserved. Low-end / no-download installs keep working on USE-Lite. Two HNSW indexes coexist; the active embedder selects which one. | + +--- + +## Why the earlier attempt failed (recorded so we don't repeat it) + +EmbeddingGemma is a 300M transformer, **not** a TFLite *Task* model. Three +independent landmines, any one of which sinks a naive swap: + +1. **Wrong loader.** `MediaPipeEmbeddingEngine` calls + `TextEmbedder.createFromFile()`, which only accepts TFLite Task models with + baked-in tokenizer metadata (USE/BERT-style). EmbeddingGemma won't load there + (or returns garbage). **→ This plan introduces a separate ONNX engine; it does + not touch the MediaPipe path.** +2. **Dimension lock.** `DocumentChunkEntity.@HnswIndex(dimensions = 100L)` is an + annotation **literal**. EmbeddingGemma emits 768/512/256 → ObjectBox throws the + moment a longer vector is inserted/queried. **→ This plan adds a new 256-dim + entity rather than editing the 100-dim one.** +3. **Missing task prompts.** EmbeddingGemma *requires* instruction prefixes; + `EmbeddingEngine.embed(text)` is symmetric and `DefaultDocumentRetriever` calls + `embed(query)` with no role. Even if it loaded, retrieval would look "broken." + **→ This plan makes the interface task-aware.** + +--- + +## Architecture + +``` + ┌─ EmbeddingTask.QUERY → "task: search result | query: {q}" +query / chunk text ──► │ │ + └─ EmbeddingTask.DOCUMENT → "title: {t|none} | text: {c}"│ + ▼ + ┌──────────────────────────────────────┐ + │ OnnxEmbeddingEngine (androidMain) │ + │ tokenize (SentencePiece/HF) │ + │ → ORT run → last_hidden_state │ + │ → mean-pool over attention mask │ + │ → truncate to 256 (Matryoshka) │ + │ → L2 normalize │ + └──────────────────────────────────────┘ + │ 256-dim FloatArray + ▼ + active embedder selects index ──► GemmaChunkEntity (256-dim HNSW) [default] + DocumentChunkEntity (100-dim HNSW) [USE fallback / legacy] +``` + +The **active embedder** is an install-level property (which EMBEDDING model is +downloaded + chosen). It determines (a) which engine `embed*` routes to and +(b) which HNSW entity ingestion/retrieval use. Switching embedders triggers a +**re-index from stored chunk text** (no re-extraction needed). + +--- + +## Interface contract (lands first — Module 0) + +```kotlin +// lib/commonMain — EmbeddingEngine.kt (BREAKING: task-aware) +enum class EmbeddingTask { QUERY, DOCUMENT } + +interface EmbeddingEngine { + /** Output dimension of this embedder (USE-Lite = 100, EmbeddingGemma = 256). */ + val dimensions: Int + suspend fun initialize(modelPath: String) + /** [title] is only used for DOCUMENT task on prompt-instructed models; ignored otherwise. */ + suspend fun embed(text: String, task: EmbeddingTask, title: String? = null): FloatArray +} +``` + +```kotlin +// lib/commonMain — ModelCatalog.kt +enum class ModelFormat { LITERTLM, MEDIAPIPE_TEXT_EMBEDDER, WHISPER_GGML, ONNX_EMBEDDER } + +// ModelDescriptor gains companion-file support so the tokenizer ships with the model: +data class ModelDescriptor( + /* …existing… */ + val companions: List = emptyList(), // NEW — e.g. tokenizer.json +) +data class CompanionFile(val url: String, val fileName: String, val sizeBytes: Long, val sha256: String? = null) +``` + +```kotlin +// sample-app data.db — new 256-dim chunk entity (parallel to DocumentChunkEntity) +@Entity +class GemmaChunkEntity { + @Id var id: Long = 0 + @Index var documentId: Long = 0 + @Index var projectId: Long = 0 + var text: String = ""; var pageNumber: Int = 0; var chunkIndex: Int = 0 + @HnswIndex(dimensions = 256L, distanceType = VectorDistanceType.COSINE, + neighborsPerNode = 48, indexingSearchCount = 200) + var embedding: FloatArray? = null + companion object { const val EMBEDDING_DIM = 256 } +} +``` + +The `DocumentRepository` ingestion/retrieval methods route to the entity matching +the active embedder; `ScoredChunk` stays the common return shape so +`DefaultDocumentRetriever` / `RagContextFormatter` are largely unchanged. + +--- + +## Module breakdown + +| Mod | Scope | Key files | Depends on | +|----|-------|-----------|-----------| +| **0** | Contracts: task-aware `EmbeddingEngine`, `ONNX_EMBEDDER` format, `companions` on `ModelDescriptor`, `GemmaChunkEntity` (regen `objectbox-models/default.json`) | `EmbeddingEngine.kt`, `ModelCatalog.kt`, `Entities.kt`, version catalog | — | +| **A** | ONNX engine: ORT session, tokenizer, mean-pool + Matryoshka-256 + L2-norm, task prompts | `OnnxEmbeddingEngine.kt` (androidMain), DI wiring in `AndroidAiEngineComponent.kt` | 0 | +| **B** | Catalog + download: EmbeddingGemma descriptor (`requiresAuth = true`), tokenizer companion download, sha256 pins | `NativeLmModelCatalog.kt`, model-download path | 0 | +| **C** | Repository routing: `GemmaChunkEntity` CRUD + HNSW search; active-embedder selector | `ObjectBoxDocumentRepository.kt`, `RagHolder.kt` | 0 | +| **D** | Ingest/retrieve task wiring: `embedDocument` on ingest, `embedQuery` on retrieve; re-tune distance gate | `DefaultDocumentIngestor.kt`, `DefaultDocumentRetriever.kt` | A,C | +| **E** | Migration: background re-index USE→Gemma from stored text, with progress + resume | new `EmbeddingMigrator.kt`, `NativeLmViewModel.kt` | C,D | +| **F** | UI + gating: embedder shown in Models screen (Recommended/Advanced, Gemma terms), device gating, re-index progress | `ModelManagementScreen.kt`, onboarding terms gate | B,E | +| **G** | Backup/sync compatibility: carry embedder tag; re-index on mismatched import | `BackupManager.kt`, `BackupModels.kt`, sync transport | E | + +--- + +## Migration / re-index plan + +Embeddings are **derived data**; `DocumentChunkEntity.text` is already persisted, +so re-indexing never needs the original PDFs. + +1. On first run after EmbeddingGemma is downloaded + selected, kick a background + `EmbeddingMigrator` (resumable, idempotent — skip docs already in `GemmaChunkEntity`). +2. Stream chunks per project → `embed(text, DOCUMENT, title)` → write to + `GemmaChunkEntity`. Reuse the `IngestState.Embedding(done,total)` progress UI. +3. Until a project is migrated, retrieval **falls back to the 100-dim index** so + chat keeps working. +4. After a project migrates, delete its old 100-dim chunks to reclaim storage + (tx-split delete — see gotchas). +5. Low-end devices that never download EmbeddingGemma stay entirely on USE-Lite. + +--- + +## Gotchas (bake these in) + +- **HNSW tx-split (carried over):** chunk deletes and parent-doc deletes go in + **separate transactions**, or HNSW commit deadlocks. Applies to the re-index + cleanup too. +- **Distance gate is USE-tuned.** `DefaultDocumentRetriever.RELEVANCE_MAX_DISTANCE + = 0.75` was tuned for USE-Lite's distribution. EmbeddingGemma's cosine spread + differs — **re-tune per active embedder** (likely a separate constant), or + off-topic queries will over/under-ground. +- **Task prompts are mandatory.** Query = `task: search result | query: …`; + Document = `title: {title or "none"} | text: …`. Wrong/missing prompts quietly + tank recall. +- **Matryoshka order: truncate *then* re-normalize.** Take the first 256 dims of + the pooled vector, *then* L2-normalize — not the reverse. +- **Tokenizer is the fiddly bit.** Ship `tokenizer.json` as a `companions` file + (or app asset) and run it via `onnxruntime-extensions` (in-graph) or the HF + `tokenizers` Android binding. Cap `max_seq_len` (~512) — chunks are ~500 chars + so this is safe and bounds latency/memory. +- **Latency.** A 300M transformer per chunk is far slower than USE's 6MB model; + a big PDF can go from seconds to minutes. Mitigate: quantized (INT8/QAT) ONNX, + XNNPACK threads, batch tokenization, and run ingestion/migration off the main + thread (ties into the deferred foreground-service download/ingest work in + `PLAY_STORE.md §9`). +- **Memory coexistence.** Don't embed and generate simultaneously — the LLM is the + big RAM tenant. Sequence ingestion/migration vs active chat generation. +- **Gemma licensing.** EmbeddingGemma is Gemma-licensed → `requiresAuth = true`, + `Authorization: Bearer `, surfaced under the **Advanced — Hugging Face + account** section and the onboarding **terms gate** already built in PR #22. +- **Backup/sync dimension mismatch.** A backup/synced DB may carry vectors from a + different embedder/dimension. Tag exports with the embedder id; on import with a + mismatch, **re-index from the included chunk text** rather than trusting vectors. +- **APK/download budget.** ORT Android AAR (~10–20 MB, arm64-only to match the + existing `abiFilters`) + the quantized ONNX model (~100–200 MB, downloaded, not + bundled) + tokenizer. Confirm against the size budget. + +--- + +## Dependencies to add + +- `com.microsoft.onnxruntime:onnxruntime-android` (full build for op coverage; + revisit ORT-format + `onnxruntime-mobile` later for size). +- `com.microsoft.onnxruntime:onnxruntime-extensions-android` (in-graph tokenizer), + **or** the HF `tokenizers` Android binding as fallback. +- arm64-v8a only, consistent with `libwhisper.so` and the LiteRT-LM footprint. + +--- + +## Testing / verify (the ship bar) + +On-device on **CPH2723 (release build)**: +1. Fresh ingest of a real PDF → confirm chunks land in `GemmaChunkEntity` (256-dim). +2. **Retrieval quality A/B**: a fixed query set, USE-Lite vs EmbeddingGemma — the + win must be visible (recall on names/concepts, fewer off-topic citations). +3. **Latency/memory**: per-chunk embed time + peak RAM during ingest; confirm a + multi-page PDF completes acceptably and coexists with chat. +4. **Migration**: upgrade an install with existing USE-Lite docs → re-index runs, + shows progress, retrieval keeps working throughout, old vectors reclaimed after. +5. **Low-end fallback**: a device that declines the download stays on USE-Lite and + functions unchanged. +6. **Distance-gate tuning**: verify off-topic questions still return + `RetrievedContext.EMPTY` with the re-tuned threshold. + +**Done when:** EmbeddingGemma is the default embedder on a capable device, an +existing install migrates cleanly, retrieval quality is visibly better, and the +USE-Lite fallback path still works end-to-end. + +--- + +## Out of scope (follow-ups) + +- **Reranker** second stage (cross-encoder) — separate plan; complements this. +- **ORT-format / mobile build** size optimization — after correctness is proven. +- **iOS embedder** — the ONNX engine is KMP-portable; wire `iosMain` later. +- **Token-aware chunking** — still char-based (500/50); revisit independently. diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 4025c45..96b52a7 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -13,6 +13,8 @@ napier = "2.7.1" kotlin-inject = "0.9.0" mediapipe = "0.10.35" litertlm = "0.11.0" +onnxruntime = "1.20.0" +djl-tokenizers = "0.30.0" mlkit-text-recognition = "16.0.1" androidx-core-ktx = "1.15.0" # sample-app only @@ -43,6 +45,9 @@ kotlin-inject-runtime = { module = "me.tatarka.inject:kotlin-inject-runtime", ve kotlin-inject-compiler = { module = "me.tatarka.inject:kotlin-inject-compiler-ksp", version.ref = "kotlin-inject" } mediapipe-tasks-text = { module = "com.google.mediapipe:tasks-text", version.ref = "mediapipe" } litertlm-android = { module = "com.google.ai.edge.litertlm:litertlm-android", version.ref = "litertlm" } +# EmbeddingGemma on-device: ONNX Runtime + a HuggingFace tokenizer (reads tokenizer.json). +onnxruntime-android = { module = "com.microsoft.onnxruntime:onnxruntime-android", version.ref = "onnxruntime" } +djl-tokenizers = { module = "ai.djl.huggingface:tokenizers", version.ref = "djl-tokenizers" } androidx-core-ktx = { group = "androidx.core", name = "core-ktx", version.ref = "androidx-core-ktx" } kotlinx-coroutines-android = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-android", version.ref = "kotlinx-coroutines" } # sample-app only diff --git a/lib/build.gradle.kts b/lib/build.gradle.kts index 6c091d0..c0498df 100644 --- a/lib/build.gradle.kts +++ b/lib/build.gradle.kts @@ -51,6 +51,9 @@ kotlin { implementation(libs.litertlm.android) implementation(libs.androidx.core.ktx) implementation(libs.ktor.client.okhttp) + // EmbeddingGemma on-device: ONNX Runtime + HuggingFace tokenizer. + implementation(libs.onnxruntime.android) + implementation(libs.djl.tokenizers) } iosMain.dependencies { implementation(libs.ktor.client.darwin) diff --git a/lib/consumer-rules.pro b/lib/consumer-rules.pro index 1955480..b8b85f1 100644 --- a/lib/consumer-rules.pro +++ b/lib/consumer-rules.pro @@ -33,6 +33,17 @@ -keep class com.google.common.flogger.** { *; } -dontwarn com.google.common.flogger.** +# ---- ONNX Runtime (com.microsoft.onnxruntime) — EmbeddingGemma ---- +# JNI bridge; the native keep above covers the bindings. Keep the API +# surface and silence optional references. +-keep class ai.onnxruntime.** { *; } +-dontwarn ai.onnxruntime.** + +# ---- HuggingFace tokenizers (ai.djl.huggingface) — EmbeddingGemma tokenizer ---- +# Loads a native lib + reflects over JNI types when reading tokenizer.json. +-keep class ai.djl.** { *; } +-dontwarn ai.djl.** + # ---- kotlinx.serialization ---- # Keep generated serializers + the synthetic serializer() accessor. -keepclassmembers class **$$serializer { *; } diff --git a/lib/src/androidMain/kotlin/com/sagar/aicore/GemmaTokenizer.kt b/lib/src/androidMain/kotlin/com/sagar/aicore/GemmaTokenizer.kt new file mode 100644 index 0000000..945427a --- /dev/null +++ b/lib/src/androidMain/kotlin/com/sagar/aicore/GemmaTokenizer.kt @@ -0,0 +1,49 @@ +/* + * Copyright (C) 2026 Sagar Gupta + * SPDX-License-Identifier: AGPL-3.0-or-later + */ +package com.sagar.aicore + +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer +import java.nio.file.Paths + +/** Tokenized input for the ONNX embedder: parallel token-id and attention-mask arrays. */ +data class TokenizedInput(val ids: LongArray, val attentionMask: LongArray) { + val length: Int get() = ids.size +} + +/** + * Turns text into model input ids + attention mask for [OnnxEmbeddingEngine]. + * EmbeddingGemma uses the Gemma SentencePiece vocabulary; we load it from the + * `tokenizer.json` companion that ships next to the model. + */ +interface GemmaTokenizer { + fun encode(text: String): TokenizedInput +} + +/** + * [GemmaTokenizer] backed by the HuggingFace tokenizers runtime (DJL binding), + * reading the model's `tokenizer.json`. Truncates to [maxLength] tokens — chunks + * are ~500 chars so this bounds latency/memory without losing content. + * + * Note: the exact padding/truncation knobs depend on the shipped `tokenizer.json`; + * verify ids/mask shapes against the chosen EmbeddingGemma ONNX export on-device. + */ +class HfGemmaTokenizer( + tokenizerJsonPath: String, + private val maxLength: Int = 512, +) : GemmaTokenizer { + + private val tokenizer: HuggingFaceTokenizer = + HuggingFaceTokenizer.builder() + .optTokenizerPath(Paths.get(tokenizerJsonPath)) + .optAddSpecialTokens(true) + .optTruncation(true) + .optMaxLength(maxLength) + .build() + + override fun encode(text: String): TokenizedInput { + val enc = tokenizer.encode(text) + return TokenizedInput(ids = enc.ids, attentionMask = enc.attentionMask) + } +} diff --git a/lib/src/androidMain/kotlin/com/sagar/aicore/MediaPipeEmbeddingEngine.kt b/lib/src/androidMain/kotlin/com/sagar/aicore/MediaPipeEmbeddingEngine.kt index f086255..b63e941 100644 --- a/lib/src/androidMain/kotlin/com/sagar/aicore/MediaPipeEmbeddingEngine.kt +++ b/lib/src/androidMain/kotlin/com/sagar/aicore/MediaPipeEmbeddingEngine.kt @@ -31,6 +31,9 @@ class MediaPipeEmbeddingEngine( private var textEmbedder: TextEmbedder? = null private val mutex = Mutex() + /** USE-Lite is a fixed 100-dim embedder; [task]/title are ignored (it is symmetric). */ + override val dimensions: Int = 100 + /** * Initializes the embedder with a model path. */ @@ -56,7 +59,11 @@ class MediaPipeEmbeddingEngine( } } - override suspend fun embed(text: String): FloatArray = withContext(Dispatchers.IO) { + override suspend fun embed( + text: String, + task: EmbeddingTask, + title: String?, + ): FloatArray = withContext(Dispatchers.IO) { Napier.d(tag = "EmbeddingEngine") { "embed START hash=${System.identityHashCode(this@MediaPipeEmbeddingEngine)} text_len=${text.length} first50=${text.take(50)}" } val embedder = mutex.withLock { textEmbedder } ?: run { Napier.e(tag = "EmbeddingEngine") { "Embedding model not loaded! hash=${System.identityHashCode(this@MediaPipeEmbeddingEngine)}" } diff --git a/lib/src/androidMain/kotlin/com/sagar/aicore/OnnxEmbeddingEngine.kt b/lib/src/androidMain/kotlin/com/sagar/aicore/OnnxEmbeddingEngine.kt new file mode 100644 index 0000000..c335e57 --- /dev/null +++ b/lib/src/androidMain/kotlin/com/sagar/aicore/OnnxEmbeddingEngine.kt @@ -0,0 +1,156 @@ +/* + * Copyright (C) 2026 Sagar Gupta + * SPDX-License-Identifier: AGPL-3.0-or-later + */ +package com.sagar.aicore + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import io.github.aakira.napier.Napier +import kotlinx.coroutines.Dispatchers +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext +import java.io.File +import java.nio.LongBuffer +import kotlin.math.sqrt + +/** + * On-device [EmbeddingEngine] for **EmbeddingGemma 300M** via ONNX Runtime. + * + * Pipeline per call: apply the task-specific instruction prefix → tokenize → + * run the ONNX graph → mean-pool the token embeddings over the attention mask + * (unless the graph already emits a pooled `sentence_embedding`) → **truncate to + * [outputDim]** (Matryoshka) → **L2-normalize**. Everything stays on-device. + * + * The instruction prefixes are mandatory for EmbeddingGemma retrieval quality: + * - query → `task: search result | query: {text}` + * - document→ `title: {title or "none"} | text: {text}` + * + * The [tokenizer] is created from the `tokenizer.json` companion downloaded next + * to the model. Construct one instance with an Application-scoped lifetime. + */ +class OnnxEmbeddingEngine( + /** Builds the tokenizer from the companion path resolved at [initialize] time. */ + private val tokenizerFactory: (modelDir: String) -> GemmaTokenizer, + /** Matryoshka output dimension; must match the vector store's index dimension. */ + override val dimensions: Int = 256, + private val tokenizerFileName: String = "tokenizer.json", +) : EmbeddingEngine { + + private val outputDim: Int get() = dimensions + + private var env: OrtEnvironment? = null + private var session: OrtSession? = null + private var tokenizer: GemmaTokenizer? = null + private val mutex = Mutex() + + override suspend fun initialize(modelPath: String): Unit = withContext(Dispatchers.IO) { + mutex.withLock { + if (session != null) return@withContext + val modelFile = File(modelPath) + require(modelFile.exists()) { "ONNX model not found: $modelPath" } + try { + val environment = OrtEnvironment.getEnvironment() + val opts = OrtSession.SessionOptions().apply { + setIntraOpNumThreads(Runtime.getRuntime().availableProcessors().coerceAtMost(4)) + setOptimizationLevel(OrtSession.SessionOptions.OptLevel.ALL_OPT) + // XNNPACK gives a solid CPU speedup; fall back silently if unavailable. + runCatching { addXnnpack(mapOf()) } + } + env = environment + session = environment.createSession(modelPath, opts) + tokenizer = tokenizerFactory(modelFile.parentFile?.absolutePath ?: ".") + Napier.d(tag = TAG) { "EmbeddingGemma ONNX session ready (dim=$outputDim)" } + } catch (e: Exception) { + Napier.e(tag = TAG, throwable = e) { "Failed to load ONNX embedder" } + throw HardwareFault.DelegateFailure("Failed to load ONNX embedder from $modelPath: ${e.message}") + } + } + } + + override suspend fun embed( + text: String, + task: EmbeddingTask, + title: String?, + ): FloatArray = withContext(Dispatchers.IO) { + val ortSession = session ?: throw HardwareFault.ModelNotLoaded("ONNX embedder not loaded. Call initialize() first.") + val tok = tokenizer ?: throw HardwareFault.ModelNotLoaded("Tokenizer not loaded.") + val environment = env ?: throw HardwareFault.ModelNotLoaded("ORT environment not loaded.") + + val prompt = instruction(text, task, title) + val encoded = tok.encode(prompt) + val seqLen = encoded.length + val shape = longArrayOf(1, seqLen.toLong()) + + val idsTensor = OnnxTensor.createTensor(environment, LongBuffer.wrap(encoded.ids), shape) + val maskTensor = OnnxTensor.createTensor(environment, LongBuffer.wrap(encoded.attentionMask), shape) + try { + val inputs = HashMap(2).apply { + put("input_ids", idsTensor) + put("attention_mask", maskTensor) + } + ortSession.run(inputs).use { result -> + val pooled = pool(result, encoded.attentionMask) + matryoshkaNormalize(pooled) + } + } finally { + idsTensor.close() + maskTensor.close() + } + } + + /** Prefer a pre-pooled `sentence_embedding`; otherwise mean-pool token embeddings. */ + private fun pool(result: OrtSession.Result, mask: LongArray): FloatArray { + // Some sentence-transformers ONNX exports emit a pooled output directly. + result.get("sentence_embedding").orElse(null)?.let { out -> + val arr = (out as OnnxTensor).floatBuffer + val v = FloatArray(arr.remaining()) + arr.get(v) + return v + } + // Otherwise mean-pool last_hidden_state / token_embeddings: [1, seq, hidden]. + val tokenOut = (result.get("last_hidden_state").orElse(null) + ?: result.get("token_embeddings").orElse(null) + ?: result.get(0)) as OnnxTensor + @Suppress("UNCHECKED_CAST") + val hidden = (tokenOut.value as Array>)[0] // [seq][hidden] + val dim = hidden.first().size + val sum = FloatArray(dim) + var count = 0f + for (t in hidden.indices) { + if (mask[t] == 0L) continue + val row = hidden[t] + for (d in 0 until dim) sum[d] += row[d] + count += 1f + } + if (count > 0f) for (d in 0 until dim) sum[d] /= count + return sum + } + + /** Matryoshka: truncate to [outputDim] first, then L2-normalize. */ + private fun matryoshkaNormalize(full: FloatArray): FloatArray { + val n = minOf(outputDim, full.size) + val v = full.copyOf(n) + var norm = 0.0 + for (x in v) norm += (x * x).toDouble() + norm = sqrt(norm) + if (norm > 0.0) for (i in v.indices) v[i] = (v[i] / norm).toFloat() + return v + } + + private fun instruction(text: String, task: EmbeddingTask, title: String?): String = when (task) { + EmbeddingTask.QUERY -> "task: search result | query: $text" + EmbeddingTask.DOCUMENT -> "title: ${title?.takeIf { it.isNotBlank() } ?: "none"} | text: $text" + } + + fun close() { + runCatching { session?.close() } + session = null + } + + private companion object { + const val TAG = "OnnxEmbeddingEngine" + } +} diff --git a/lib/src/commonMain/kotlin/com/sagar/aicore/EmbeddingEngine.kt b/lib/src/commonMain/kotlin/com/sagar/aicore/EmbeddingEngine.kt index 65ed876..2aeacec 100644 --- a/lib/src/commonMain/kotlin/com/sagar/aicore/EmbeddingEngine.kt +++ b/lib/src/commonMain/kotlin/com/sagar/aicore/EmbeddingEngine.kt @@ -4,17 +4,38 @@ */ package com.sagar.aicore +/** + * Whether a piece of text is being embedded as a **search query** or as a stored + * **document** chunk. Prompt-instructed embedders (e.g. EmbeddingGemma) emit + * different vectors for each role and *require* the distinction for good + * retrieval; symmetric embedders (e.g. USE-Lite) ignore it. + */ +enum class EmbeddingTask { QUERY, DOCUMENT } + /** * Interface for generating vector embeddings from text. + * + * [embed] is task-aware: callers state whether the text is a query or a document + * so prompt-instructed models can apply the correct instruction prefix. Symmetric + * models ignore [task]/[title]. [dimensions] is the length of the returned vector + * and must match the vector store's index dimension for the active embedder. */ interface EmbeddingEngine { - /** - * Initializes the embedding engine with the model. - */ + /** Length of the vectors this engine returns (e.g. USE-Lite = 100, EmbeddingGemma = 256). */ + val dimensions: Int + + /** Initializes the embedding engine with the model. */ suspend fun initialize(modelPath: String) /** - * Converts a string into a float array representing its vector embedding. + * Converts [text] into a [dimensions]-length embedding. [task] selects the + * query/document instruction on prompt-instructed models; [title] is an optional + * document title used only for [EmbeddingTask.DOCUMENT] on those models. Both are + * ignored by symmetric models. */ - suspend fun embed(text: String): FloatArray + suspend fun embed( + text: String, + task: EmbeddingTask = EmbeddingTask.DOCUMENT, + title: String? = null, + ): FloatArray } diff --git a/lib/src/commonMain/kotlin/com/sagar/aicore/ModelCatalog.kt b/lib/src/commonMain/kotlin/com/sagar/aicore/ModelCatalog.kt index 2018a3a..bd49fe7 100644 --- a/lib/src/commonMain/kotlin/com/sagar/aicore/ModelCatalog.kt +++ b/lib/src/commonMain/kotlin/com/sagar/aicore/ModelCatalog.kt @@ -37,6 +37,21 @@ data class ModelDescriptor( * skipping it on text-only models also frees memory. Defaults to false. */ val supportsVision: Boolean = false, + /** + * Extra files that must be downloaded **alongside** the main model for it to + * run — e.g. an ONNX embedder's `tokenizer.json`. Consumers download each + * companion to the same model directory before initializing the engine. + * Empty for self-contained models. + */ + val companions: List = emptyList(), +) + +/** A sidecar file downloaded next to a [ModelDescriptor]'s main model (e.g. a tokenizer). */ +data class CompanionFile( + val url: String, + val fileName: String, + val sizeBytes: Long, + val sha256: String? = null, ) enum class ModelFormat { @@ -46,6 +61,13 @@ enum class ModelFormat { MEDIAPIPE_TEXT_EMBEDDER, /** Whisper GGML/GGUF weights for on-device speech-to-text (whisper.cpp). */ WHISPER_GGML, + + /** + * ONNX transformer embedder run via ONNX Runtime (e.g. EmbeddingGemma 300M). + * Ships a `tokenizer.json` companion; the engine handles tokenization, + * mean-pooling, Matryoshka truncation, and L2 normalization itself. + */ + ONNX_EMBEDDER, } enum class ModelRole { diff --git a/sample-app/src/main/java/com/nativelm/app/data/backup/BackupManager.kt b/sample-app/src/main/java/com/nativelm/app/data/backup/BackupManager.kt index b65b0f2..186afa0 100644 --- a/sample-app/src/main/java/com/nativelm/app/data/backup/BackupManager.kt +++ b/sample-app/src/main/java/com/nativelm/app/data/backup/BackupManager.kt @@ -11,6 +11,7 @@ import com.nativelm.app.data.ThemeMode import com.nativelm.app.data.db.ConversationEntity import com.nativelm.app.data.db.DocumentChunkEntity import com.nativelm.app.data.db.DocumentEntity +import com.nativelm.app.data.db.GemmaChunkEntity import com.nativelm.app.data.db.MessageEntity import com.nativelm.app.data.db.ObjectBox import com.nativelm.app.data.db.ProjectEntity @@ -62,7 +63,8 @@ class BackupManager(private val context: Context) { val conversations = store.boxFor(ConversationEntity::class.java).all val messages = store.boxFor(MessageEntity::class.java).all val documents = store.boxFor(DocumentEntity::class.java).all - val chunks = store.boxFor(DocumentChunkEntity::class.java).all + val useChunks = store.boxFor(DocumentChunkEntity::class.java).all + val gemmaChunks = store.boxFor(GemmaChunkEntity::class.java).all val artifacts = store.boxFor(StudioArtifactEntity::class.java).all val prefsStore = AppPreferences(context) @@ -93,10 +95,15 @@ class BackupManager(private val context: Context) { createdAt = doc.createdAt, ) }, - chunks = chunks.map { + chunks = useChunks.map { ChunkDto( it.id, it.documentId, it.projectId, s(it.text), it.pageNumber, it.chunkIndex, - encodeEmbedding(it.embedding), + encodeEmbedding(it.embedding), dim = DocumentChunkEntity.EMBEDDING_DIM, + ) + } + gemmaChunks.map { + ChunkDto( + it.id, it.documentId, it.projectId, s(it.text), it.pageNumber, it.chunkIndex, + encodeEmbedding(it.embedding), dim = GemmaChunkEntity.EMBEDDING_DIM, ) }, artifacts = artifacts.map { @@ -224,9 +231,9 @@ class BackupManager(private val context: Context) { if (m.schemaVersion > BACKUP_SCHEMA_VERSION) { throw BackupException("This backup was made by a newer version of NativeLM. Update the app and try again.") } - if (m.embeddingDim != BACKUP_EMBEDDING_DIM) { - throw BackupException("This backup uses an incompatible embedding format and can't be restored by this version.") - } + // No embedding-dim gate: each chunk carries its own dim and is routed to the + // matching store on import. Chunks from a non-active embedder are re-indexed + // from their text on first use (see RagHolder.migrateToGemma). } /** A document awaiting its source-file bytes from a `files/` zip entry. */ @@ -238,7 +245,8 @@ class BackupManager(private val context: Context) { val convBox = store.boxFor(ConversationEntity::class.java) val msgBox = store.boxFor(MessageEntity::class.java) val docBox = store.boxFor(DocumentEntity::class.java) - val chunkBox = store.boxFor(DocumentChunkEntity::class.java) + val useChunkBox = store.boxFor(DocumentChunkEntity::class.java) + val gemmaChunkBox = store.boxFor(GemmaChunkEntity::class.java) val artBox = store.boxFor(StudioArtifactEntity::class.java) val projMap = HashMap() @@ -290,17 +298,32 @@ class BackupManager(private val context: Context) { } msgBox.put(newMessages) - val newChunks = payload.chunks.map { dto -> - DocumentChunkEntity().apply { - documentId = docMap[dto.documentId] ?: 0L - projectId = projMap[dto.projectId] ?: 0L - text = dto.text - pageNumber = dto.pageNumber - chunkIndex = dto.chunkIndex - embedding = dto.embeddingB64?.let { decodeEmbedding(it) } - } - } - chunkBox.put(newChunks) + // Route each chunk to the store matching its embedding dimension. + val (gemmaDtos, useDtos) = payload.chunks.partition { it.dim == GemmaChunkEntity.EMBEDDING_DIM } + useChunkBox.put( + useDtos.map { dto -> + DocumentChunkEntity().apply { + documentId = docMap[dto.documentId] ?: 0L + projectId = projMap[dto.projectId] ?: 0L + text = dto.text + pageNumber = dto.pageNumber + chunkIndex = dto.chunkIndex + embedding = dto.embeddingB64?.let { decodeEmbedding(it) } + } + }, + ) + gemmaChunkBox.put( + gemmaDtos.map { dto -> + GemmaChunkEntity().apply { + documentId = docMap[dto.documentId] ?: 0L + projectId = projMap[dto.projectId] ?: 0L + text = dto.text + pageNumber = dto.pageNumber + chunkIndex = dto.chunkIndex + embedding = dto.embeddingB64?.let { decodeEmbedding(it) } + } + }, + ) val newArtifacts = payload.artifacts.map { dto -> StudioArtifactEntity().apply { diff --git a/sample-app/src/main/java/com/nativelm/app/data/backup/BackupModels.kt b/sample-app/src/main/java/com/nativelm/app/data/backup/BackupModels.kt index f4a0e5e..0ed15e7 100644 --- a/sample-app/src/main/java/com/nativelm/app/data/backup/BackupModels.kt +++ b/sample-app/src/main/java/com/nativelm/app/data/backup/BackupModels.kt @@ -6,10 +6,15 @@ package com.nativelm.app.data.backup import kotlinx.serialization.Serializable -/** Bump when the on-disk backup format changes incompatibly. Import rejects newer. */ -const val BACKUP_SCHEMA_VERSION = 1 +/** + * Bump when the on-disk backup format changes incompatibly. Import rejects newer. + * v2: each [ChunkDto] carries its own embedding [ChunkDto.dim] (100 USE-Lite / 256 + * EmbeddingGemma) so both stores round-trip; older apps must reject v2 (they'd try to + * load 256-dim vectors into the 100-dim HNSW index and crash). + */ +const val BACKUP_SCHEMA_VERSION = 2 -/** Embedding dimensionality this build can restore (USE-Lite, 100-dim). */ +/** Informational: the legacy/primary embedding dim recorded in the manifest. */ const val BACKUP_EMBEDDING_DIM = 100 /** Suggested file extension / container name for a backup. */ @@ -112,6 +117,8 @@ data class ChunkDto( val chunkIndex: Int, /** Base64 of little-endian float32 embedding bytes; null for un-embedded chunks. */ val embeddingB64: String? = null, + /** Embedding dimension / store this chunk belongs to: 100 (USE-Lite) or 256 (EmbeddingGemma). */ + val dim: Int = 100, ) @Serializable diff --git a/sample-app/src/main/java/com/nativelm/app/data/db/DocumentRepository.kt b/sample-app/src/main/java/com/nativelm/app/data/db/DocumentRepository.kt index 5768df7..17bbb8e 100644 --- a/sample-app/src/main/java/com/nativelm/app/data/db/DocumentRepository.kt +++ b/sample-app/src/main/java/com/nativelm/app/data/db/DocumentRepository.kt @@ -9,6 +9,13 @@ package com.nativelm.app.data.db * embedded chunks. Every source belongs to a [ProjectEntity]; retrieval is scoped * to one project so a notebook only answers from its own sources. * + * Chunk vectors live in one of two HNSW stores selected by the active embedder's + * dimension: 100-dim (USE-Lite, [DocumentChunkEntity]) or 256-dim (EmbeddingGemma, + * [GemmaChunkEntity]). The [dim] parameter routes vector ops to the right store; + * [findSimilarChunks] infers it from the query vector's length. Results are returned + * as the neutral [Chunk]/[ScoredChunk] DTOs so callers don't depend on which store + * is active. + * * Heavy ops are `suspend` (HNSW index / disk). Cascade-delete is explicit and * **tx-split**: a document's chunks are deleted in one transaction, the document * in another (combining them deadlocks the HNSW commit). @@ -28,10 +35,13 @@ interface DocumentRepository { /** Fetch a single document by id, or null if it no longer exists. */ suspend fun getDocument(documentId: Long): DocumentEntity? - /** Persist embedded chunks for [documentId] (stamped with [projectId]) and bump chunkCount. */ - suspend fun addChunks(documentId: Long, projectId: Long, chunks: List) + /** Persist embedded [chunks] for [documentId] into the [dim]-dim store and bump chunkCount. */ + suspend fun addChunks(documentId: Long, projectId: Long, dim: Int, chunks: List) - /** Top-[k] chunks by cosine similarity to [queryEmbedding], within [projectId]. */ + /** + * Top-[k] chunks by cosine similarity to [queryEmbedding], within [projectId]. + * The store is selected by `queryEmbedding.size` (100 → USE-Lite, 256 → Gemma). + */ suspend fun findSimilarChunks( queryEmbedding: FloatArray, k: Int, @@ -39,37 +49,63 @@ interface DocumentRepository { ): List /** - * Chunks in [projectId] whose text contains at least one of [terms] - * (case-insensitive), capped at [limit]. The lexical-candidate set for the - * keyword arm of hybrid retrieval — only chunks that could plausibly match are - * loaded, so this stays cheap regardless of corpus size. + * Chunks in [projectId] (in the [dim]-dim store) whose text contains at least one + * of [terms] (case-insensitive), capped at [limit] — the lexical-candidate set for + * the keyword arm of hybrid retrieval. */ suspend fun keywordCandidates( projectId: Long, terms: List, limit: Int, - ): List + dim: Int, + ): List /** Sources in [projectId], newest first. */ suspend fun listDocuments(projectId: Long): List /** - * Every chunk in [projectId] (or just [documentId] when > 0), ordered by - * document then [DocumentChunkEntity.chunkIndex] so a source reads top-to-bottom. - * Backs Studio's map-reduce, which must see the *whole* source set rather than a - * top-k retrieval slice. Embeddings are not needed here. + * Every chunk in [projectId] (or just [documentId] when > 0) from the [dim]-dim + * store, ordered by document then chunkIndex. Backs Studio's map-reduce. Embeddings + * are not included. */ - suspend fun chunksForProject(projectId: Long, documentId: Long = 0): List + suspend fun chunksForProject(projectId: Long, documentId: Long = 0, dim: Int): List + + /** Count of chunks for [projectId] in the [dim]-dim store (0 when not yet indexed there). */ + suspend fun chunkCount(projectId: Long, dim: Int): Long + + /** Distinct document ids that still have chunks in the [dim]-dim store — drives migration. */ + suspend fun documentIdsWithChunks(dim: Int): List - /** Delete a document and all its chunks (tx-split). */ + /** Remove a document's chunks from the [dim]-dim store only (used after re-embedding). */ + suspend fun clearChunksOfDocument(documentId: Long, dim: Int) + + /** Delete a document and all its chunks from **both** stores (tx-split). */ suspend fun deleteDocument(documentId: Long) - /** Delete every source (and chunks) of [projectId] — used when deleting a project. */ + /** Delete every source (and chunks, both stores) of [projectId]. */ suspend fun deleteDocumentsOfProject(projectId: Long) } -/** A retrieved chunk with its cosine *distance* [score] (lower = closer; ordered closest-first). */ +/** A chunk, decoupled from which HNSW store it came from. */ +data class Chunk( + val id: Long, + val documentId: Long, + val projectId: Long, + val text: String, + val pageNumber: Int, + val chunkIndex: Int, +) + +/** A chunk to insert: text + metadata + its embedding vector. */ +data class ChunkInput( + val text: String, + val pageNumber: Int, + val chunkIndex: Int, + val embedding: FloatArray, +) + +/** A retrieved [Chunk] with its cosine *distance* [score] (lower = closer; ordered closest-first). */ data class ScoredChunk( - val chunk: DocumentChunkEntity, + val chunk: Chunk, val score: Double, ) diff --git a/sample-app/src/main/java/com/nativelm/app/data/db/Entities.kt b/sample-app/src/main/java/com/nativelm/app/data/db/Entities.kt index e294614..0a88f91 100644 --- a/sample-app/src/main/java/com/nativelm/app/data/db/Entities.kt +++ b/sample-app/src/main/java/com/nativelm/app/data/db/Entities.kt @@ -138,6 +138,41 @@ class DocumentChunkEntity { } } +/** + * One embedded chunk stored with its **EmbeddingGemma** vector (256-dim, Matryoshka). + * Parallel to [DocumentChunkEntity] (the 100-dim USE-Lite store): a chunk lives in + * exactly one of the two, selected by the install's active embedder. Keeping them as + * separate entities lets the old index stay readable during migration (re-embed from + * [text] into here, then drop the old rows) and sidesteps an in-place HNSW dimension + * change. See `docs/EMBEDDING_GEMMA_PLAN.md`. + */ +@Entity +class GemmaChunkEntity { + @Id var id: Long = 0 + + @Index + var documentId: Long = 0 + + @Index + var projectId: Long = 0 + var text: String = "" + var pageNumber: Int = 0 + var chunkIndex: Int = 0 + + // dimensions must be a literal (annotation constant); keep in lockstep with EMBEDDING_DIM. + @HnswIndex( + dimensions = 256L, + distanceType = VectorDistanceType.COSINE, + neighborsPerNode = 48, + indexingSearchCount = 200, + ) + var embedding: FloatArray? = null + + companion object { + const val EMBEDDING_DIM: Int = 256 + } +} + /** * A Studio artifact generated *from* a project's sources (a Briefing, FAQ, etc.) — * the output of map-reduce over the source set, not a chat answer. Persisted per diff --git a/sample-app/src/main/java/com/nativelm/app/data/db/ObjectBoxDocumentRepository.kt b/sample-app/src/main/java/com/nativelm/app/data/db/ObjectBoxDocumentRepository.kt index a734213..46fd059 100644 --- a/sample-app/src/main/java/com/nativelm/app/data/db/ObjectBoxDocumentRepository.kt +++ b/sample-app/src/main/java/com/nativelm/app/data/db/ObjectBoxDocumentRepository.kt @@ -10,14 +10,17 @@ import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.withContext /** - * ObjectBox-backed [DocumentRepository]. Vector retrieval uses the HNSW index on - * [DocumentChunkEntity.embedding], post-filtered to a project. All ops run on - * [Dispatchers.IO]; the generated `*_` query metadata lives in this package. + * ObjectBox-backed [DocumentRepository]. Holds two parallel HNSW chunk stores — + * 100-dim ([DocumentChunkEntity], USE-Lite) and 256-dim ([GemmaChunkEntity], + * EmbeddingGemma) — and routes vector ops to the one matching the active embedder's + * `dim`. Document metadata ([DocumentEntity]) is shared. All ops run on + * [Dispatchers.IO]; generated `*_` query metadata lives in this package. */ class ObjectBoxDocumentRepository : DocumentRepository { private val documents = ObjectBox.store.boxFor(DocumentEntity::class.java) - private val chunks = ObjectBox.store.boxFor(DocumentChunkEntity::class.java) + private val useChunks = ObjectBox.store.boxFor(DocumentChunkEntity::class.java) + private val gemmaChunks = ObjectBox.store.boxFor(GemmaChunkEntity::class.java) override suspend fun createDocument( projectId: Long, @@ -47,67 +50,106 @@ class ObjectBoxDocumentRepository : DocumentRepository { override suspend fun addChunks( documentId: Long, projectId: Long, - chunks: List, + dim: Int, + chunks: List, ): Unit = withContext(Dispatchers.IO) { - chunks.forEach { - it.id = 0 - it.documentId = documentId - it.projectId = projectId + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.put( + chunks.map { c -> + GemmaChunkEntity().apply { + this.documentId = documentId + this.projectId = projectId + text = c.text + pageNumber = c.pageNumber + chunkIndex = c.chunkIndex + embedding = c.embedding + } + }, + ) + } else { + useChunks.put( + chunks.map { c -> + DocumentChunkEntity().apply { + this.documentId = documentId + this.projectId = projectId + text = c.text + pageNumber = c.pageNumber + chunkIndex = c.chunkIndex + embedding = c.embedding + } + }, + ) } - this@ObjectBoxDocumentRepository.chunks.put(chunks) documents.get(documentId)?.let { doc -> - doc.chunkCount = this@ObjectBoxDocumentRepository.chunks - .query().equal(DocumentChunkEntity_.documentId, documentId).build() - .use { it.count().toInt() } + doc.chunkCount = countForDocument(documentId, dim).toInt() documents.put(doc) } } + private fun countForDocument(documentId: Long, dim: Int): Long = + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.query().equal(GemmaChunkEntity_.documentId, documentId).build() + .use { it.count() } + } else { + useChunks.query().equal(DocumentChunkEntity_.documentId, documentId).build() + .use { it.count() } + } + override suspend fun findSimilarChunks( queryEmbedding: FloatArray, k: Int, projectId: Long, ): List = withContext(Dispatchers.IO) { - require(queryEmbedding.size == DocumentChunkEntity.EMBEDDING_DIM) { - "Query embedding dim ${queryEmbedding.size} != ${DocumentChunkEntity.EMBEDDING_DIM}" - } // ObjectBox applies the projectId condition AFTER the HNSW k-NN, not during - // it. So asking for just `k` neighbors globally can return zero rows for - // this project when closer chunks from OTHER projects fill the k slots — - // silently breaking project-scoped grounding. Over-fetch a wide candidate - // set, then filter to the project and keep the k closest. searchK is bounded - // by the index search ef (indexingSearchCount = 200) for recall. + // it. Over-fetch a wide candidate set, then filter to the project and keep the + // k closest (searchK bounded by indexingSearchCount = 200 for recall). val searchK = maxOf(k * 30, 150) - chunks.query() - .nearestNeighbors(DocumentChunkEntity_.embedding, queryEmbedding, searchK) - .equal(DocumentChunkEntity_.projectId, projectId) - .build() - .use { query -> - query.findWithScores().asSequence() - .map { ScoredChunk(it.get(), it.score) } - .take(k) - .toList() + if (queryEmbedding.size == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.query() + .nearestNeighbors(GemmaChunkEntity_.embedding, queryEmbedding, searchK) + .equal(GemmaChunkEntity_.projectId, projectId) + .build() + .use { q -> + q.findWithScores().asSequence() + .map { ScoredChunk(it.get().toChunk(), it.score) } + .take(k).toList() + } + } else { + require(queryEmbedding.size == DocumentChunkEntity.EMBEDDING_DIM) { + "Query embedding dim ${queryEmbedding.size} matches neither store" } + useChunks.query() + .nearestNeighbors(DocumentChunkEntity_.embedding, queryEmbedding, searchK) + .equal(DocumentChunkEntity_.projectId, projectId) + .build() + .use { q -> + q.findWithScores().asSequence() + .map { ScoredChunk(it.get().toChunk(), it.score) } + .take(k).toList() + } + } } override suspend fun keywordCandidates( projectId: Long, terms: List, limit: Int, - ): List = withContext(Dispatchers.IO) { + dim: Int, + ): List = withContext(Dispatchers.IO) { if (terms.isEmpty()) return@withContext emptyList() - // OR of case-insensitive "text contains term" across the query terms. - // Seed the fold as QueryCondition so .or() (which widens from - // PropertyQueryCondition) type-checks. - fun contains(term: String) = - DocumentChunkEntity_.text.contains(term, StringOrder.CASE_INSENSITIVE) - val anyTerm: QueryCondition = - terms.drop(1).fold>(contains(terms.first())) { acc, t -> - acc.or(contains(t)) - } - chunks.query(DocumentChunkEntity_.projectId.equal(projectId).and(anyTerm)) - .build() - .use { it.find(0L, limit.toLong()) } + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + fun contains(t: String) = GemmaChunkEntity_.text.contains(t, StringOrder.CASE_INSENSITIVE) + val any: QueryCondition = + terms.drop(1).fold>(contains(terms.first())) { acc, t -> acc.or(contains(t)) } + gemmaChunks.query(GemmaChunkEntity_.projectId.equal(projectId).and(any)) + .build().use { it.find(0L, limit.toLong()) }.map { it.toChunk() } + } else { + fun contains(t: String) = DocumentChunkEntity_.text.contains(t, StringOrder.CASE_INSENSITIVE) + val any: QueryCondition = + terms.drop(1).fold>(contains(terms.first())) { acc, t -> acc.or(contains(t)) } + useChunks.query(DocumentChunkEntity_.projectId.equal(projectId).and(any)) + .build().use { it.find(0L, limit.toLong()) }.map { it.toChunk() } + } } override suspend fun listDocuments(projectId: Long): List = @@ -122,27 +164,65 @@ class ObjectBoxDocumentRepository : DocumentRepository { override suspend fun chunksForProject( projectId: Long, documentId: Long, - ): List = withContext(Dispatchers.IO) { - val condition = if (documentId > 0L) { - DocumentChunkEntity_.documentId.equal(documentId) + dim: Int, + ): List = withContext(Dispatchers.IO) { + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + val cond = if (documentId > 0L) GemmaChunkEntity_.documentId.equal(documentId) + else GemmaChunkEntity_.projectId.equal(projectId) + gemmaChunks.query(cond) + .order(GemmaChunkEntity_.documentId).order(GemmaChunkEntity_.chunkIndex) + .build().use { it.find() }.map { it.toChunk() } + } else { + val cond = if (documentId > 0L) DocumentChunkEntity_.documentId.equal(documentId) + else DocumentChunkEntity_.projectId.equal(projectId) + useChunks.query(cond) + .order(DocumentChunkEntity_.documentId).order(DocumentChunkEntity_.chunkIndex) + .build().use { it.find() }.map { it.toChunk() } + } + } + + override suspend fun chunkCount(projectId: Long, dim: Int): Long = withContext(Dispatchers.IO) { + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.query().equal(GemmaChunkEntity_.projectId, projectId).build().use { it.count() } + } else { + useChunks.query().equal(DocumentChunkEntity_.projectId, projectId).build().use { it.count() } + } + } + + override suspend fun documentIdsWithChunks(dim: Int): List = withContext(Dispatchers.IO) { + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.query().build() + .use { it.property(GemmaChunkEntity_.documentId).distinct().findLongs() }.toList() } else { - DocumentChunkEntity_.projectId.equal(projectId) + useChunks.query().build() + .use { it.property(DocumentChunkEntity_.documentId).distinct().findLongs() }.toList() + } + } + + override suspend fun clearChunksOfDocument(documentId: Long, dim: Int): Unit = withContext(Dispatchers.IO) { + if (dim == GemmaChunkEntity.EMBEDDING_DIM) { + gemmaChunks.query().equal(GemmaChunkEntity_.documentId, documentId).build().use { it.remove() } + } else { + useChunks.query().equal(DocumentChunkEntity_.documentId, documentId).build().use { it.remove() } } - chunks.query(condition) - .order(DocumentChunkEntity_.documentId) - .order(DocumentChunkEntity_.chunkIndex) - .build() - .use { it.find() } } override suspend fun deleteDocument(documentId: Long): Unit = withContext(Dispatchers.IO) { - // tx-split (landmine #3): remove HNSW-indexed chunks first, then the parent. - chunks.query().equal(DocumentChunkEntity_.documentId, documentId).build().use { it.remove() } + // tx-split (landmine #3): remove HNSW-indexed chunks (both stores) first, then the parent. + useChunks.query().equal(DocumentChunkEntity_.documentId, documentId).build().use { it.remove() } + gemmaChunks.query().equal(GemmaChunkEntity_.documentId, documentId).build().use { it.remove() } documents.remove(documentId) } override suspend fun deleteDocumentsOfProject(projectId: Long): Unit = withContext(Dispatchers.IO) { - chunks.query().equal(DocumentChunkEntity_.projectId, projectId).build().use { it.remove() } + useChunks.query().equal(DocumentChunkEntity_.projectId, projectId).build().use { it.remove() } + gemmaChunks.query().equal(GemmaChunkEntity_.projectId, projectId).build().use { it.remove() } documents.query().equal(DocumentEntity_.projectId, projectId).build().use { it.remove() } } + + private fun DocumentChunkEntity.toChunk() = + Chunk(id, documentId, projectId, text, pageNumber, chunkIndex) + + private fun GemmaChunkEntity.toChunk() = + Chunk(id, documentId, projectId, text, pageNumber, chunkIndex) } diff --git a/sample-app/src/main/java/com/nativelm/app/llm/NativeLmModelCatalog.kt b/sample-app/src/main/java/com/nativelm/app/llm/NativeLmModelCatalog.kt index 9c42983..668b58a 100644 --- a/sample-app/src/main/java/com/nativelm/app/llm/NativeLmModelCatalog.kt +++ b/sample-app/src/main/java/com/nativelm/app/llm/NativeLmModelCatalog.kt @@ -4,6 +4,7 @@ */ package com.nativelm.app.llm +import com.sagar.aicore.CompanionFile import com.sagar.aicore.ModelCatalog import com.sagar.aicore.ModelDescriptor import com.sagar.aicore.ModelFormat @@ -117,6 +118,33 @@ class NativeLmModelCatalog : ModelCatalog { requiresAuth = false, supportsVision = false, ), + // ── Embedding (preferred) — EmbeddingGemma 300M, ONNX, quantized. The + // higher-quality RAG embedder; default on devices that clear the RAM/storage + // floor, with USE-Lite as the friction-free fallback below. Ships a + // tokenizer.json companion. Output is Matryoshka-truncated to 256-dim. + // + // NOTE: url/sizeBytes/sha256 are pinned to a specific EmbeddingGemma ONNX + // export — VERIFY against the chosen Hugging Face revision before release + // (and prefer a QAT/INT8 build for size + quality). Gemma-derived weights: + // surface the Gemma terms in the onboarding terms gate even though the + // onnx-community mirror downloads token-free. + ModelDescriptor( + id = "embeddinggemma-300m-onnx", + url = "https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/onnx/model_quantized.onnx?download=true", + fileName = "embeddinggemma-300m-q.onnx", + sizeBytes = 200_000_000L, // approx — verify against the pinned revision + format = ModelFormat.ONNX_EMBEDDER, + role = ModelRole.EMBEDDING, + minDeviceRamMb = 4000, + requiresAuth = false, + companions = listOf( + CompanionFile( + url = "https://huggingface.co/onnx-community/embeddinggemma-300m-ONNX/resolve/main/tokenizer.json?download=true", + fileName = "embeddinggemma-tokenizer.json", + sizeBytes = 17_000_000L, // approx — verify + ), + ), + ), ModelDescriptor( id = "universal-sentence-encoder", // Public MediaPipe model — no auth required. diff --git a/sample-app/src/main/java/com/nativelm/app/llm/NativeLmViewModel.kt b/sample-app/src/main/java/com/nativelm/app/llm/NativeLmViewModel.kt index 76431d2..60e4b8a 100644 --- a/sample-app/src/main/java/com/nativelm/app/llm/NativeLmViewModel.kt +++ b/sample-app/src/main/java/com/nativelm/app/llm/NativeLmViewModel.kt @@ -210,6 +210,9 @@ class NativeLmViewModel(app: Application) : ViewModel() { private val appContext: Application = app private val engineHolder = EngineHolder(app) private val ragHolder = RagHolder(app, engineHolder) + + /** Guards the one-time legacy→active embedder re-index per ViewModel lifetime. */ + @Volatile private var embeddingMigrationKicked = false private val prefs = AppPreferences(app) private val secureStore = SecureStore(app) @@ -991,30 +994,60 @@ class NativeLmViewModel(app: Application) : ViewModel() { } } - /** Make the embedder usable: download the small USE model if needed, then init. */ + /** + * Make the embedder usable: download the preferred embedder (EmbeddingGemma on + * capable devices, else USE-Lite) if needed, init it, and kick a one-time + * re-index of any legacy chunks into the active store. + */ private suspend fun prepareEmbedding(): Boolean { - if (ragHolder.ensureEmbeddingReady()) return true + if (ragHolder.ensureEmbeddingReady()) { + maybeMigrateEmbeddings() + return true + } if (!downloadEmbeddingModel()) return false - return ragHolder.ensureEmbeddingReady() + val ready = ragHolder.ensureEmbeddingReady() + if (ready) maybeMigrateEmbeddings() + return ready } + /** Download the preferred embedder's model file + any companions (e.g. tokenizer.json). */ private suspend fun downloadEmbeddingModel(): Boolean { - val descriptor = catalog.byId(RagHolder.USE_MODEL_ID) ?: return false - if (engineHolder.isModelDownloaded(descriptor.fileName)) return true - var success = false - engineHolder.downloadModel(descriptor.url, descriptor.fileName, descriptor.sha256) - .collect { state -> + val d = catalog.byId(ragHolder.preferredModelId) ?: return false + val files = buildList { + add(Triple(d.url, d.fileName, d.sha256)) + d.companions.forEach { add(Triple(it.url, it.fileName, it.sha256)) } + } + for ((url, name, sha) in files) { + if (engineHolder.isModelDownloaded(name)) continue + var ok = false + engineHolder.downloadModel(url, name, sha).collect { state -> when (state) { - is DownloadState.Success -> success = true + is DownloadState.Success -> ok = true is DownloadState.Error -> { - Napier.e(tag = TAG) { "USE model download failed: ${state.message}" } - success = false + Napier.e(tag = TAG) { "Embedding download failed: ${state.message}" } + ok = false } else -> Unit } } - if (success) refreshModels() - return success + if (!ok) return false + } + refreshModels() + return true + } + + /** Re-index legacy USE-Lite chunks into the EmbeddingGemma store once, in the background. */ + private fun maybeMigrateEmbeddings() { + if (embeddingMigrationKicked) return + embeddingMigrationKicked = true + viewModelScope.launch { + runCatching { + if (ragHolder.needsMigration()) { + ragHolder.migrateToGemma() + refreshDocuments() + } + }.onFailure { Napier.e(tag = TAG, throwable = it) { "Embedding migration failed" } } + } } // ---- Settings ---- diff --git a/sample-app/src/main/java/com/nativelm/app/llm/RagHolder.kt b/sample-app/src/main/java/com/nativelm/app/llm/RagHolder.kt index 7e68fec..c39b0e1 100644 --- a/sample-app/src/main/java/com/nativelm/app/llm/RagHolder.kt +++ b/sample-app/src/main/java/com/nativelm/app/llm/RagHolder.kt @@ -5,7 +5,9 @@ package com.nativelm.app.llm import android.app.Application +import com.nativelm.app.data.db.ChunkInput import com.nativelm.app.data.db.DocumentEntity +import com.nativelm.app.data.db.GemmaChunkEntity import com.nativelm.app.data.db.ObjectBoxDocumentRepository import com.nativelm.app.rag.DefaultDocumentIngestor import com.nativelm.app.rag.DefaultDocumentRetriever @@ -17,38 +19,88 @@ import com.nativelm.app.rag.extract.AndroidDocumentFileStore import com.nativelm.app.rag.extract.AndroidTextExtractor import com.nativelm.app.rag.extract.MlKitOcrEngine import com.nativelm.app.rag.extract.TextChunker +import com.sagar.aicore.EmbeddingEngine +import com.sagar.aicore.EmbeddingTask +import com.sagar.aicore.HfGemmaTokenizer import com.sagar.aicore.MediaPipeEmbeddingEngine +import com.sagar.aicore.OnnxEmbeddingEngine +import io.github.aakira.napier.Napier import kotlinx.coroutines.flow.Flow +import java.io.File /** * Wires the document-RAG stack — embedding engine, vector store, ingestion and * retrieval — in the same manual-DI style as [EngineHolder]. All sources are - * project-scoped. Reuses EngineHolder's ModelManager to locate the USE-Lite model. + * project-scoped. + * + * Two embedders coexist: **EmbeddingGemma** (256-dim ONNX — preferred on capable + * devices) and **USE-Lite** (100-dim — the friction-free fallback). The active one + * is chosen by [useGemma]; a [SwitchableEmbeddingEngine] forwards ingest/retrieve to + * it so the pipeline always reads the current embedder's `dimensions`, routing + * vectors to the matching HNSW store. */ class RagHolder(app: Application, private val engineHolder: EngineHolder) { - private val embeddingEngine = MediaPipeEmbeddingEngine(app) + private val useEngine = MediaPipeEmbeddingEngine(app) + private val gemmaEngine = OnnxEmbeddingEngine( + tokenizerFactory = { dir -> HfGemmaTokenizer(File(dir, GEMMA_TOKENIZER_FILE).absolutePath) }, + dimensions = GemmaChunkEntity.EMBEDDING_DIM, + ) + + /** Delegates to whichever embedder is currently active (so `dimensions` stays correct). */ + private val active: EmbeddingEngine = object : EmbeddingEngine { + override val dimensions: Int get() = activeEngine().dimensions + override suspend fun initialize(modelPath: String) {} + override suspend fun embed(text: String, task: EmbeddingTask, title: String?) = + activeEngine().embed(text, task, title) + } + private val repository = ObjectBoxDocumentRepository() private val extractor = AndroidTextExtractor(app, MlKitOcrEngine()) private val fileStore = AndroidDocumentFileStore(app) private val ingestor: DocumentIngestor = - DefaultDocumentIngestor(extractor, TextChunker(), embeddingEngine, repository, fileStore) + DefaultDocumentIngestor(extractor, TextChunker(), active, repository, fileStore) private val retriever: DocumentRetriever = - DefaultDocumentRetriever(embeddingEngine, repository) + DefaultDocumentRetriever(active, repository) + + @Volatile private var useReady = false + @Volatile private var gemmaReady = false + + /** Dimension of the active embedder — Studio needs it to read the right chunk store. */ + val activeDim: Int get() = if (useGemma()) GemmaChunkEntity.EMBEDDING_DIM else 100 + + private fun gemmaFilesPresent(): Boolean = + engineHolder.isModelDownloaded(GEMMA_FILE_NAME) && engineHolder.isModelDownloaded(GEMMA_TOKENIZER_FILE) + + private fun deviceCapableForGemma(): Boolean = engineHolder.deviceRamMb >= GEMMA_MIN_RAM_MB + + /** Prefer EmbeddingGemma when its files are present and the device clears the RAM floor. */ + fun useGemma(): Boolean = gemmaFilesPresent() && deviceCapableForGemma() + + private fun activeEngine(): EmbeddingEngine = if (useGemma()) gemmaEngine else useEngine - @Volatile - private var embeddingReady = false + /** The embedder model needed for first-run download: Gemma on capable devices, else USE. */ + val preferredModelId: String + get() = if (deviceCapableForGemma()) GEMMA_MODEL_ID else USE_MODEL_ID val isEmbeddingModelDownloaded: Boolean - get() = engineHolder.isModelDownloaded(USE_FILE_NAME) + get() = useGemma() || engineHolder.isModelDownloaded(USE_FILE_NAME) - /** Initialize the embedder if its model is on disk. Returns true once ready. */ + /** Initialize the active embedder if its model is on disk. Returns true once ready. */ suspend fun ensureEmbeddingReady(): Boolean { - if (embeddingReady) return true - if (!isEmbeddingModelDownloaded) return false - embeddingEngine.initialize(engineHolder.modelPath(USE_FILE_NAME)) - embeddingReady = true + if (useGemma()) { + if (!gemmaReady) { + gemmaEngine.initialize(engineHolder.modelPath(GEMMA_FILE_NAME)) + gemmaReady = true + } + return true + } + if (!engineHolder.isModelDownloaded(USE_FILE_NAME)) return false + if (!useReady) { + useEngine.initialize(engineHolder.modelPath(USE_FILE_NAME)) + useReady = true + } return true } @@ -63,15 +115,10 @@ class RagHolder(app: Application, private val engineHolder: EngineHolder) { suspend fun documents(projectId: Long): List = repository.listDocuments(projectId) - /** - * Every chunk of a project (or one source when [documentId] > 0), in reading - * order, for Studio's whole-source-set map-reduce. See - * [com.nativelm.app.data.db.DocumentRepository.chunksForProject]. - */ - suspend fun chunksForProject(projectId: Long, documentId: Long = 0): List = - repository.chunksForProject(projectId, documentId) + /** Every chunk of a project (or one source), in reading order, from the active store. */ + suspend fun chunksForProject(projectId: Long, documentId: Long = 0) = + repository.chunksForProject(projectId, documentId, activeDim) - /** A single source's metadata (incl. [DocumentEntity.localPath]) for the viewer. */ suspend fun document(id: Long): DocumentEntity? = repository.getDocument(id) suspend fun deleteDocument(id: Long) { @@ -84,11 +131,49 @@ class RagHolder(app: Application, private val engineHolder: EngineHolder) { repository.deleteDocumentsOfProject(projectId) } - /** Wipe every stored source copy (used by Settings → clear all data). */ suspend fun deleteAllSourceFiles() = fileStore.deleteAll() + /** True when Gemma is active but legacy 100-dim chunks still need re-embedding. */ + suspend fun needsMigration(): Boolean = + useGemma() && repository.documentIdsWithChunks(100).isNotEmpty() + + /** + * Re-embed every legacy USE-Lite (100-dim) document into the Gemma (256-dim) store, + * then drop its old chunks. Document-scoped and idempotent: a re-run skips documents + * already migrated (they no longer have 100-dim chunks). [onProgress] reports + * (doneDocs, totalDocs). Returns the number of documents migrated. + */ + suspend fun migrateToGemma(onProgress: (done: Int, total: Int) -> Unit = { _, _ -> }): Int { + if (!useGemma()) return 0 + ensureEmbeddingReady() + val docIds = repository.documentIdsWithChunks(100) + docIds.forEachIndexed { i, docId -> + val doc = repository.getDocument(docId) ?: return@forEachIndexed + val chunks = repository.chunksForProject(doc.projectId, docId, dim = 100) + val inputs = chunks.map { c -> + ChunkInput( + text = c.text, + pageNumber = c.pageNumber, + chunkIndex = c.chunkIndex, + embedding = gemmaEngine.embed(c.text, EmbeddingTask.DOCUMENT, doc.title), + ) + } + if (inputs.isNotEmpty()) { + repository.addChunks(docId, doc.projectId, GemmaChunkEntity.EMBEDDING_DIM, inputs) + repository.clearChunksOfDocument(docId, 100) + } + onProgress(i + 1, docIds.size) + } + if (docIds.isNotEmpty()) Napier.d(tag = "RagHolder") { "Migrated ${docIds.size} docs to EmbeddingGemma" } + return docIds.size + } + companion object { const val USE_MODEL_ID = "universal-sentence-encoder" const val USE_FILE_NAME = "universal_sentence_encoder.tflite" + const val GEMMA_MODEL_ID = "embeddinggemma-300m-onnx" + const val GEMMA_FILE_NAME = "embeddinggemma-300m-q.onnx" + const val GEMMA_TOKENIZER_FILE = "embeddinggemma-tokenizer.json" + const val GEMMA_MIN_RAM_MB = 4000L } } diff --git a/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentIngestor.kt b/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentIngestor.kt index 824eb4b..13cf9e1 100644 --- a/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentIngestor.kt +++ b/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentIngestor.kt @@ -4,12 +4,13 @@ */ package com.nativelm.app.rag -import com.nativelm.app.data.db.DocumentChunkEntity +import com.nativelm.app.data.db.ChunkInput import com.nativelm.app.data.db.DocumentRepository import com.nativelm.app.rag.extract.DocumentFileStore import com.nativelm.app.rag.extract.TextChunker import com.nativelm.app.rag.extract.TextExtractor import com.sagar.aicore.EmbeddingEngine +import com.sagar.aicore.EmbeddingTask import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.FlowCollector @@ -69,22 +70,23 @@ class DefaultDocumentIngestor( emit(IngestState.Chunking(chunks.size)) val documentId = repository.createDocument(projectId, title, uri, localPath, mime, pageCount) - val entities = ArrayList(chunks.size) + val dim = embeddingEngine.dimensions + val inputs = ArrayList(chunks.size) chunks.forEachIndexed { i, chunk -> emit(IngestState.Embedding(done = i, total = chunks.size)) - val vector = embeddingEngine.embed(chunk.text) - entities += DocumentChunkEntity().apply { - this.documentId = documentId - this.projectId = projectId - this.text = chunk.text - pageNumber = chunk.pageNumber - chunkIndex = chunk.index - embedding = vector - } + // DOCUMENT task: prompt-instructed embedders (EmbeddingGemma) need the + // document role + title; symmetric embedders (USE-Lite) ignore both. + val vector = embeddingEngine.embed(chunk.text, EmbeddingTask.DOCUMENT, title) + inputs += ChunkInput( + text = chunk.text, + pageNumber = chunk.pageNumber, + chunkIndex = chunk.index, + embedding = vector, + ) } - repository.addChunks(documentId, projectId, entities) + repository.addChunks(documentId, projectId, dim, inputs) emit(IngestState.Embedding(done = chunks.size, total = chunks.size)) - emit(IngestState.Done(documentId, entities.size)) + emit(IngestState.Done(documentId, inputs.size)) } private suspend fun FlowCollector.emitFailure(t: Throwable) { diff --git a/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentRetriever.kt b/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentRetriever.kt index b1956db..11bf94f 100644 --- a/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentRetriever.kt +++ b/sample-app/src/main/java/com/nativelm/app/rag/DefaultDocumentRetriever.kt @@ -4,10 +4,11 @@ */ package com.nativelm.app.rag -import com.nativelm.app.data.db.DocumentChunkEntity +import com.nativelm.app.data.db.Chunk import com.nativelm.app.data.db.DocumentRepository import com.nativelm.app.data.db.ScoredChunk import com.sagar.aicore.EmbeddingEngine +import com.sagar.aicore.EmbeddingTask /** * Hybrid retriever: blends semantic (vector) and lexical (keyword/BM25) search. @@ -27,16 +28,21 @@ import com.sagar.aicore.EmbeddingEngine class DefaultDocumentRetriever( private val embeddingEngine: EmbeddingEngine, private val repository: DocumentRepository, - private val maxDistance: Double = RELEVANCE_MAX_DISTANCE, + /** Override the cosine-distance gate; when null it's chosen per active embedder. */ + private val maxDistance: Double? = null, ) : DocumentRetriever { override suspend fun retrieve(projectId: Long, query: String, k: Int): RetrievedContext { if (query.isBlank() || projectId <= 0L) return RetrievedContext.EMPTY + val dim = embeddingEngine.dimensions + val gate = maxDistance ?: gateFor(dim) + // ── Vector arm: nearest neighbors, gated to genuine semantic matches. ── - val queryVector = embeddingEngine.embed(query) + // QUERY task: prompt-instructed embedders (EmbeddingGemma) need the query role. + val queryVector = embeddingEngine.embed(query, EmbeddingTask.QUERY) val vectorHits = repository.findSimilarChunks(queryVector, VECTOR_POOL, projectId) - .filter { it.score <= maxDistance } + .filter { it.score <= gate } val vectorRanking = vectorHits.map { it.chunk.id } // ── Keyword arm: BM25 over chunks that contain a query term. ── @@ -44,7 +50,7 @@ class DefaultDocumentRetriever( val keywordCandidates = if (terms.isEmpty()) { emptyList() } else { - repository.keywordCandidates(projectId, terms, KEYWORD_POOL) + repository.keywordCandidates(projectId, terms, KEYWORD_POOL, dim) } val keywordRanking = KeywordSearch.rank( query, @@ -58,7 +64,7 @@ class DefaultDocumentRetriever( .reciprocalRankFusion(listOf(vectorRanking, keywordRanking)) .take(k) - val byId: Map = + val byId: Map = (vectorHits.map { it.chunk } + keywordCandidates).associateBy { it.id } val ordered = fusedIds.mapNotNull { id -> byId[id] }.map { ScoredChunk(it, 0.0) } if (ordered.isEmpty()) return RetrievedContext.EMPTY @@ -67,14 +73,20 @@ class DefaultDocumentRetriever( return RagContextFormatter.format(ordered) { id -> titles[id] ?: "Source" } } + /** + * Cosine-distance gate per embedder (0 = identical direction … up to 2 = opposite). + * USE-Lite (100-dim) and EmbeddingGemma (256-dim) have different distance + * distributions, so the cutoff differs. Both are deliberately loose — they only + * drop clearly-unrelated hits. **Tune against real corpora on-device.** + */ + private fun gateFor(dim: Int): Double = + if (dim == 256) RELEVANCE_MAX_DISTANCE_GEMMA else RELEVANCE_MAX_DISTANCE_USE + companion object { - /** - * Cosine-distance cutoff (0 = identical direction … up to 2 = opposite). - * Vector hits farther than this are dropped. A real USE-Lite match sits well - * below this; the value is deliberately loose so it only filters clearly - * unrelated hits — tune against real corpora. - */ - const val RELEVANCE_MAX_DISTANCE: Double = 0.75 + const val RELEVANCE_MAX_DISTANCE_USE: Double = 0.75 + + /** Provisional — EmbeddingGemma vectors are L2-normalized; retune on-device. */ + const val RELEVANCE_MAX_DISTANCE_GEMMA: Double = 0.55 /** Candidate-pool sizes per arm before fusion (final result is top-k). */ private const val VECTOR_POOL = 30 diff --git a/sample-app/src/test/java/com/nativelm/app/rag/DefaultDocumentRetrieverTest.kt b/sample-app/src/test/java/com/nativelm/app/rag/DefaultDocumentRetrieverTest.kt index 79b76f4..345a233 100644 --- a/sample-app/src/test/java/com/nativelm/app/rag/DefaultDocumentRetrieverTest.kt +++ b/sample-app/src/test/java/com/nativelm/app/rag/DefaultDocumentRetrieverTest.kt @@ -4,11 +4,13 @@ */ package com.nativelm.app.rag -import com.nativelm.app.data.db.DocumentChunkEntity +import com.nativelm.app.data.db.Chunk +import com.nativelm.app.data.db.ChunkInput import com.nativelm.app.data.db.DocumentEntity import com.nativelm.app.data.db.DocumentRepository import com.nativelm.app.data.db.ScoredChunk import com.sagar.aicore.EmbeddingEngine +import com.sagar.aicore.EmbeddingTask import kotlinx.coroutines.test.runTest import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -19,8 +21,9 @@ class DefaultDocumentRetrieverTest { private val embedder = object : EmbeddingEngine { var calls = 0 + override val dimensions: Int = 100 override suspend fun initialize(modelPath: String) {} - override suspend fun embed(text: String): FloatArray { + override suspend fun embed(text: String, task: EmbeddingTask, title: String?): FloatArray { calls++ return FloatArray(100) } @@ -28,23 +31,27 @@ class DefaultDocumentRetrieverTest { private class FakeRepo( private val hits: List, - private val keyword: List = emptyList(), + private val keyword: List = emptyList(), ) : DocumentRepository { override suspend fun createDocument(projectId: Long, title: String, uri: String, localPath: String, mime: String, pageCount: Int): Long = 0 override suspend fun getDocument(documentId: Long): DocumentEntity? = null - override suspend fun addChunks(documentId: Long, projectId: Long, chunks: List) {} + override suspend fun addChunks(documentId: Long, projectId: Long, dim: Int, chunks: List) {} override suspend fun findSimilarChunks(queryEmbedding: FloatArray, k: Int, projectId: Long): List = hits - override suspend fun keywordCandidates(projectId: Long, terms: List, limit: Int): List = keyword + override suspend fun keywordCandidates(projectId: Long, terms: List, limit: Int, dim: Int): List = keyword override suspend fun listDocuments(projectId: Long): List = emptyList() + override suspend fun chunksForProject(projectId: Long, documentId: Long, dim: Int): List = emptyList() + override suspend fun chunkCount(projectId: Long, dim: Int): Long = 0 + override suspend fun documentIdsWithChunks(dim: Int): List = emptyList() + override suspend fun clearChunksOfDocument(documentId: Long, dim: Int) {} override suspend fun deleteDocument(documentId: Long) {} override suspend fun deleteDocumentsOfProject(projectId: Long) {} } private fun scored(score: Double, id: Long = 0, text: String = "fact") = - ScoredChunk(DocumentChunkEntity().apply { this.id = id; documentId = 1; this.text = text }, score) + ScoredChunk(Chunk(id = id, documentId = 1, projectId = 1, text = text, pageNumber = 0, chunkIndex = 0), score) private fun chunk(id: Long, text: String) = - DocumentChunkEntity().apply { this.id = id; documentId = 1; this.text = text } + Chunk(id = id, documentId = 1, projectId = 1, text = text, pageNumber = 0, chunkIndex = 0) @Test fun blankQueryReturnsEmptyWithoutEmbedding() = runTest { val r = DefaultDocumentRetriever(embedder, FakeRepo(listOf(scored(0.1)))) diff --git a/sample-app/src/test/java/com/nativelm/app/rag/RagContextFormatterTest.kt b/sample-app/src/test/java/com/nativelm/app/rag/RagContextFormatterTest.kt index 8956307..3b07d7b 100644 --- a/sample-app/src/test/java/com/nativelm/app/rag/RagContextFormatterTest.kt +++ b/sample-app/src/test/java/com/nativelm/app/rag/RagContextFormatterTest.kt @@ -4,7 +4,7 @@ */ package com.nativelm.app.rag -import com.nativelm.app.data.db.DocumentChunkEntity +import com.nativelm.app.data.db.Chunk import com.nativelm.app.data.db.ScoredChunk import org.junit.Assert.assertEquals import org.junit.Assert.assertFalse @@ -15,11 +15,7 @@ class RagContextFormatterTest { private fun scored(docId: Long, text: String, page: Int = 0, score: Double = 0.1) = ScoredChunk( - DocumentChunkEntity().apply { - documentId = docId - this.text = text - pageNumber = page - }, + Chunk(id = 0, documentId = docId, projectId = 1, text = text, pageNumber = page, chunkIndex = 0), score, )