Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions embed/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ type Embedder interface {
}

type settings struct {
model string // common
keepAlive time.Duration // ollama
taskName string // Jina & Gemini
model string // common
keepAlive time.Duration // ollama
taskName string // Jina & Gemini
lateChunking bool // Jina
returnMultivector bool // Jina
}

// Task name constants (using Jina AI naming convention as standard)
Expand Down Expand Up @@ -55,6 +57,14 @@ func WithKeepAlive(d time.Duration) EmbedOption {
return func(s *settings) { s.keepAlive = d }
}

func WithLateChunking(enabled bool) EmbedOption {
return func(s *settings) { s.lateChunking = enabled }
}

func WithReturnMultivector(enabled bool) EmbedOption {
return func(s *settings) { s.returnMultivector = enabled }
}

// ---- Jina & Gemini task helpers ----
func WithTask(name string) EmbedOption {
// Supported task names:
Expand All @@ -66,38 +76,38 @@ func WithTask(name string) EmbedOption {
}

// Task-specific helper functions for common use cases
func WithRetrievalQuery() EmbedOption {
func WithRetrievalQueryTask() EmbedOption {
return WithTask(TaskRetrievalQuery)
}

func WithRetrievalPassage() EmbedOption {
func WithRetrievalPassageTask() EmbedOption {
return WithTask(TaskRetrievalPassage)
}

func WithCodeQuery() EmbedOption {
func WithCodeQueryTask() EmbedOption {
return WithTask(TaskCodeQuery)
}

func WithCodePassage() EmbedOption {
func WithCodePassageTask() EmbedOption {
return WithTask(TaskCodePassage)
}

func WithTextMatching() EmbedOption {
func WithTextMatchingTask() EmbedOption {
return WithTask(TaskTextMatching)
}

func WithClassification() EmbedOption {
func WithClassificationTask() EmbedOption {
return WithTask(TaskClassification)
}

func WithClustering() EmbedOption {
func WithClusteringTask() EmbedOption {
return WithTask(TaskClustering)
}

func WithQuestionAnswering() EmbedOption {
func WithQuestionAnsweringTask() EmbedOption {
return WithTask(TaskQuestionAnswering)
}

func WithFactVerification() EmbedOption {
func WithFactVerificationTask() EmbedOption {
return WithTask(TaskFactVerification)
}
20 changes: 10 additions & 10 deletions embed/gemini_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func TestGeminiGetEmbedding_WithTaskType(t *testing.T) {
}

ctx := context.Background()
result, err := async.Await(client.GetEmbedding(ctx, "test text", WithRetrievalQuery()))
result, err := async.Await(client.GetEmbedding(ctx, "test text", WithRetrievalQueryTask()))

assert.NoError(t, err)
assert.Equal(t, []float32{0.1, 0.2, 0.3}, result)
Expand Down Expand Up @@ -185,15 +185,15 @@ func TestGeminiTaskMapping(t *testing.T) {
option EmbedOption
geminiTask string
}{
{"WithRetrievalQuery", WithRetrievalQuery(), "RETRIEVAL_QUERY"},
{"WithRetrievalPassage", WithRetrievalPassage(), "RETRIEVAL_DOCUMENT"},
{"WithCodeQuery", WithCodeQuery(), "CODE_RETRIEVAL_QUERY"},
{"WithCodePassage", WithCodePassage(), "RETRIEVAL_DOCUMENT"},
{"WithTextMatching", WithTextMatching(), "SEMANTIC_SIMILARITY"},
{"WithClassification", WithClassification(), "CLASSIFICATION"},
{"WithClustering", WithClustering(), "CLUSTERING"},
{"WithQuestionAnswering", WithQuestionAnswering(), "QUESTION_ANSWERING"},
{"WithFactVerification", WithFactVerification(), "FACT_VERIFICATION"},
{"WithRetrievalQuery", WithRetrievalQueryTask(), "RETRIEVAL_QUERY"},
{"WithRetrievalPassage", WithRetrievalPassageTask(), "RETRIEVAL_DOCUMENT"},
{"WithCodeQuery", WithCodeQueryTask(), "CODE_RETRIEVAL_QUERY"},
{"WithCodePassage", WithCodePassageTask(), "RETRIEVAL_DOCUMENT"},
{"WithTextMatching", WithTextMatchingTask(), "SEMANTIC_SIMILARITY"},
{"WithClassification", WithClassificationTask(), "CLASSIFICATION"},
{"WithClustering", WithClusteringTask(), "CLUSTERING"},
{"WithQuestionAnswering", WithQuestionAnsweringTask(), "QUESTION_ANSWERING"},
{"WithFactVerification", WithFactVerificationTask(), "FACT_VERIFICATION"},
}

for _, tc := range helperFunctionTests {
Expand Down
16 changes: 10 additions & 6 deletions embed/jina_ai_embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ func (c *JinaAIEmbeddingClient) GetEmbedding(ctx context.Context, text string, o
}

req := jinaAIEmbeddingRequest{
Model: cfg.model,
Task: cfg.taskName,
Input: []string{text},
Model: cfg.model,
Task: cfg.taskName,
Input: []string{text},
LateChunking: cfg.lateChunking,
ReturnMultivector: cfg.returnMultivector,
}

jsonData, err := json.Marshal(req)
Expand Down Expand Up @@ -90,7 +92,9 @@ func (c *JinaAIEmbeddingClient) GetEmbedding(ctx context.Context, text string, o
}

type jinaAIEmbeddingRequest struct {
Model string `json:"model"` // jina-embeddings-v3
Task string `json:"task"` // retrieval.passage or retrieval.query
Input []string `json:"input"`
Model string `json:"model"` // jina-embeddings-v3
Task string `json:"task"` // retrieval.passage or retrieval.query
Input []string `json:"input"`
LateChunking bool `json:"late_chunking,omitempty"` // Optional: for long inputs, let Jina handle chunking
ReturnMultivector bool `json:"return_multivector,omitempty"` // Optional: if true, returns multiple vectors for long inputs
}
43 changes: 43 additions & 0 deletions embed/jina_ai_embedding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package embed

import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -116,3 +117,45 @@ func TestGetEmbedding_EmptyData(t *testing.T) {
assert.Error(t, result.Err)
assert.EqualError(t, result.Err, "no embedding data found")
}

func TestGetEmbedding_WithAllJinaAISettings(t *testing.T) {
var capturedRequest jinaAIEmbeddingRequest
mockResponse := `{
"data": [
{"embedding": [0.1, 0.2, 0.3]}
]
}`

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Capture the request body
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &capturedRequest)

w.Header().Set("Content-Type", "application/json")
io.WriteString(w, mockResponse)
}))
defer server.Close()

client := &JinaAIEmbeddingClient{
apiKey: "test-key",
httpClient: server.Client(),
url: server.URL,
}

ctx := context.Background()
result, err := async.Await(client.GetEmbedding(ctx, "test text",
WithModel("jina-embeddings-v3"),
WithTask(TaskRetrievalQuery),
WithLateChunking(true),
WithReturnMultivector(true)))

assert.NoError(t, err)
assert.Equal(t, []float32{0.1, 0.2, 0.3}, result)

// Assert all Jina AI settings are properly sent in the request
assert.Equal(t, "jina-embeddings-v3", capturedRequest.Model)
assert.Equal(t, TaskRetrievalQuery, capturedRequest.Task)
assert.True(t, capturedRequest.LateChunking)
assert.True(t, capturedRequest.ReturnMultivector)
assert.Equal(t, []string{"test text"}, capturedRequest.Input)
}