From a0a011e26b5faebeb81e87aabf70858f5b47bc22 Mon Sep 17 00:00:00 2001 From: Sai Nageswar S Date: Mon, 9 Feb 2026 13:35:59 +0000 Subject: [PATCH] Adding support for Jina AI late chunking --- embed/embedding.go | 34 +++++++++++++++++--------- embed/gemini_embedding_test.go | 20 +++++++-------- embed/jina_ai_embedding.go | 16 +++++++----- embed/jina_ai_embedding_test.go | 43 +++++++++++++++++++++++++++++++++ 4 files changed, 85 insertions(+), 28 deletions(-) diff --git a/embed/embedding.go b/embed/embedding.go index f25097b..d1e33d1 100644 --- a/embed/embedding.go +++ b/embed/embedding.go @@ -18,9 +18,11 @@ type Embedder interface { } type settings struct { - model string // common - keepAlive time.Duration // ollama - taskName string // Jina & Gemini + model string // common + keepAlive time.Duration // ollama + taskName string // Jina & Gemini + lateChunking bool // Jina + returnMultivector bool // Jina } // Task name constants (using Jina AI naming convention as standard) @@ -55,6 +57,14 @@ func WithKeepAlive(d time.Duration) EmbedOption { return func(s *settings) { s.keepAlive = d } } +func WithLateChunking(enabled bool) EmbedOption { + return func(s *settings) { s.lateChunking = enabled } +} + +func WithReturnMultivector(enabled bool) EmbedOption { + return func(s *settings) { s.returnMultivector = enabled } +} + // ---- Jina & Gemini task helpers ---- func WithTask(name string) EmbedOption { // Supported task names: @@ -66,38 +76,38 @@ func WithTask(name string) EmbedOption { } // Task-specific helper functions for common use cases -func WithRetrievalQuery() EmbedOption { +func WithRetrievalQueryTask() EmbedOption { return WithTask(TaskRetrievalQuery) } -func WithRetrievalPassage() EmbedOption { +func WithRetrievalPassageTask() EmbedOption { return WithTask(TaskRetrievalPassage) } -func WithCodeQuery() EmbedOption { +func WithCodeQueryTask() EmbedOption { return WithTask(TaskCodeQuery) } -func WithCodePassage() EmbedOption { +func WithCodePassageTask() EmbedOption { return WithTask(TaskCodePassage) } -func WithTextMatching() EmbedOption { +func WithTextMatchingTask() EmbedOption { return WithTask(TaskTextMatching) } -func WithClassification() EmbedOption { +func WithClassificationTask() EmbedOption { return WithTask(TaskClassification) } -func WithClustering() EmbedOption { +func WithClusteringTask() EmbedOption { return WithTask(TaskClustering) } -func WithQuestionAnswering() EmbedOption { +func WithQuestionAnsweringTask() EmbedOption { return WithTask(TaskQuestionAnswering) } -func WithFactVerification() EmbedOption { +func WithFactVerificationTask() EmbedOption { return WithTask(TaskFactVerification) } diff --git a/embed/gemini_embedding_test.go b/embed/gemini_embedding_test.go index 0c40479..0e67e1f 100644 --- a/embed/gemini_embedding_test.go +++ b/embed/gemini_embedding_test.go @@ -112,7 +112,7 @@ func TestGeminiGetEmbedding_WithTaskType(t *testing.T) { } ctx := context.Background() - result, err := async.Await(client.GetEmbedding(ctx, "test text", WithRetrievalQuery())) + result, err := async.Await(client.GetEmbedding(ctx, "test text", WithRetrievalQueryTask())) assert.NoError(t, err) assert.Equal(t, []float32{0.1, 0.2, 0.3}, result) @@ -185,15 +185,15 @@ func TestGeminiTaskMapping(t *testing.T) { option EmbedOption geminiTask string }{ - {"WithRetrievalQuery", WithRetrievalQuery(), "RETRIEVAL_QUERY"}, - {"WithRetrievalPassage", WithRetrievalPassage(), "RETRIEVAL_DOCUMENT"}, - {"WithCodeQuery", WithCodeQuery(), "CODE_RETRIEVAL_QUERY"}, - {"WithCodePassage", WithCodePassage(), "RETRIEVAL_DOCUMENT"}, - {"WithTextMatching", WithTextMatching(), "SEMANTIC_SIMILARITY"}, - {"WithClassification", WithClassification(), "CLASSIFICATION"}, - {"WithClustering", WithClustering(), "CLUSTERING"}, - {"WithQuestionAnswering", WithQuestionAnswering(), "QUESTION_ANSWERING"}, - {"WithFactVerification", WithFactVerification(), "FACT_VERIFICATION"}, + {"WithRetrievalQuery", WithRetrievalQueryTask(), "RETRIEVAL_QUERY"}, + {"WithRetrievalPassage", WithRetrievalPassageTask(), "RETRIEVAL_DOCUMENT"}, + {"WithCodeQuery", WithCodeQueryTask(), "CODE_RETRIEVAL_QUERY"}, + {"WithCodePassage", WithCodePassageTask(), "RETRIEVAL_DOCUMENT"}, + {"WithTextMatching", WithTextMatchingTask(), "SEMANTIC_SIMILARITY"}, + {"WithClassification", WithClassificationTask(), "CLASSIFICATION"}, + {"WithClustering", WithClusteringTask(), "CLUSTERING"}, + {"WithQuestionAnswering", WithQuestionAnsweringTask(), "QUESTION_ANSWERING"}, + {"WithFactVerification", WithFactVerificationTask(), "FACT_VERIFICATION"}, } for _, tc := range helperFunctionTests { diff --git a/embed/jina_ai_embedding.go b/embed/jina_ai_embedding.go index 968dfb4..1508499 100644 --- a/embed/jina_ai_embedding.go +++ b/embed/jina_ai_embedding.go @@ -43,9 +43,11 @@ func (c *JinaAIEmbeddingClient) GetEmbedding(ctx context.Context, text string, o } req := jinaAIEmbeddingRequest{ - Model: cfg.model, - Task: cfg.taskName, - Input: []string{text}, + Model: cfg.model, + Task: cfg.taskName, + Input: []string{text}, + LateChunking: cfg.lateChunking, + ReturnMultivector: cfg.returnMultivector, } jsonData, err := json.Marshal(req) @@ -90,7 +92,9 @@ func (c *JinaAIEmbeddingClient) GetEmbedding(ctx context.Context, text string, o } type jinaAIEmbeddingRequest struct { - Model string `json:"model"` // jina-embeddings-v3 - Task string `json:"task"` // retrieval.passage or retrieval.query - Input []string `json:"input"` + Model string `json:"model"` // jina-embeddings-v3 + Task string `json:"task"` // retrieval.passage or retrieval.query + Input []string `json:"input"` + LateChunking bool `json:"late_chunking,omitempty"` // Optional: for long inputs, let Jina handle chunking + ReturnMultivector bool `json:"return_multivector,omitempty"` // Optional: if true, returns multiple vectors for long inputs } diff --git a/embed/jina_ai_embedding_test.go b/embed/jina_ai_embedding_test.go index 89a4e19..97f4ea2 100644 --- a/embed/jina_ai_embedding_test.go +++ b/embed/jina_ai_embedding_test.go @@ -2,6 +2,7 @@ package embed import ( "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -116,3 +117,45 @@ func TestGetEmbedding_EmptyData(t *testing.T) { assert.Error(t, result.Err) assert.EqualError(t, result.Err, "no embedding data found") } + +func TestGetEmbedding_WithAllJinaAISettings(t *testing.T) { + var capturedRequest jinaAIEmbeddingRequest + mockResponse := `{ + "data": [ + {"embedding": [0.1, 0.2, 0.3]} + ] + }` + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Capture the request body + body, _ := io.ReadAll(r.Body) + json.Unmarshal(body, &capturedRequest) + + w.Header().Set("Content-Type", "application/json") + io.WriteString(w, mockResponse) + })) + defer server.Close() + + client := &JinaAIEmbeddingClient{ + apiKey: "test-key", + httpClient: server.Client(), + url: server.URL, + } + + ctx := context.Background() + result, err := async.Await(client.GetEmbedding(ctx, "test text", + WithModel("jina-embeddings-v3"), + WithTask(TaskRetrievalQuery), + WithLateChunking(true), + WithReturnMultivector(true))) + + assert.NoError(t, err) + assert.Equal(t, []float32{0.1, 0.2, 0.3}, result) + + // Assert all Jina AI settings are properly sent in the request + assert.Equal(t, "jina-embeddings-v3", capturedRequest.Model) + assert.Equal(t, TaskRetrievalQuery, capturedRequest.Task) + assert.True(t, capturedRequest.LateChunking) + assert.True(t, capturedRequest.ReturnMultivector) + assert.Equal(t, []string{"test text"}, capturedRequest.Input) +}