From 7955c653ba45c97707afedc5c091ce3925296bd6 Mon Sep 17 00:00:00 2001 From: AnasAmchaar Date: Sun, 22 Mar 2026 03:24:30 +0100 Subject: [PATCH 1/3] feat: add refrag client and related classes to the module This update introduces the TavilyRefragClient, RefragChunk, and RefragContext classes to the tavily module, enhancing its functionality for handling refrag operations. --- tavily/__init__.py | 3 +- tavily/refrag/__init__.py | 2 + tavily/refrag/models.py | 53 ++++++ tavily/refrag/refrag.py | 246 ++++++++++++++++++++++++++++ tests/test_refrag.py | 330 ++++++++++++++++++++++++++++++++++++++ 5 files changed, 633 insertions(+), 1 deletion(-) create mode 100644 tavily/refrag/__init__.py create mode 100644 tavily/refrag/models.py create mode 100644 tavily/refrag/refrag.py create mode 100644 tests/test_refrag.py diff --git a/tavily/__init__.py b/tavily/__init__.py index 4a2ea54..f85a3e2 100644 --- a/tavily/__init__.py +++ b/tavily/__init__.py @@ -1,4 +1,5 @@ from .async_tavily import AsyncTavilyClient from .tavily import Client, TavilyClient from .errors import InvalidAPIKeyError, UsageLimitExceededError, MissingAPIKeyError, BadRequestError -from .hybrid_rag import TavilyHybridClient \ No newline at end of file +from .hybrid_rag import TavilyHybridClient +from .refrag import TavilyRefragClient, RefragChunk, RefragContext \ No newline at end of file diff --git a/tavily/refrag/__init__.py b/tavily/refrag/__init__.py new file mode 100644 index 0000000..08105da --- /dev/null +++ b/tavily/refrag/__init__.py @@ -0,0 +1,2 @@ +from .refrag import TavilyRefragClient +from .models import RefragChunk, RefragContext diff --git a/tavily/refrag/models.py b/tavily/refrag/models.py new file mode 100644 index 0000000..d2f7fa8 --- /dev/null +++ b/tavily/refrag/models.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any + + +@dataclass +class RefragChunk: + """A single k-token chunk of a retrieved passage, following the REFRAG paper's + chunk representation (Section 2: C_i = {x_{q+k*i}, ..., x_{q+k*i+k-1}}). + + Attributes: + text: Raw text of the chunk before tokenization. + tokens: Token IDs produced by the tokenizer. + embedding: Chunk embedding from the encoder (None until encode_chunks is called). + expand: Whether the expansion policy selected this chunk for full-token decoding. + source_url: Origin URL from the Tavily search result. + source_score: Relevance score from the Tavily search result. + """ + text: str + tokens: List[int] + embedding: Optional[List[float]] = None + expand: bool = False + source_url: Optional[str] = None + source_score: Optional[float] = None + + +@dataclass +class RefragContext: + """Preprocessed REFRAG-compatible context ready for decoder consumption. + + The decoder receives compressed chunk embeddings for most passages and + full token embeddings only for chunks flagged by the expansion policy + (see REFRAG paper, Section 2 & Figure 1). + + Attributes: + query: Original query text. + chunks: All processed chunks in passage order. + chunk_size: The k value (tokens per chunk) used during chunking. + metadata: Tavily response metadata (response_time, images, etc.). + """ + query: str + chunks: List[RefragChunk] + chunk_size: int + metadata: Dict[str, Any] = field(default_factory=dict) + + @property + def compressed_chunks(self) -> List[RefragChunk]: + """Chunks to feed as embeddings (expand=False).""" + return [c for c in self.chunks if not c.expand] + + @property + def expanded_chunks(self) -> List[RefragChunk]: + """Chunks to feed as full tokens (expand=True).""" + return [c for c in self.chunks if c.expand] diff --git a/tavily/refrag/refrag.py b/tavily/refrag/refrag.py new file mode 100644 index 0000000..23eed59 --- /dev/null +++ b/tavily/refrag/refrag.py @@ -0,0 +1,246 @@ +import tiktoken +from typing import Optional, List, Dict, Any, Callable, Sequence + +import requests +from tavily import TavilyClient +from tavily.config import DEFAULT_MODEL_ENCODING +from .models import RefragChunk, RefragContext + + +def _default_tokenize(text: str, _encoder=[None]) -> List[int]: + if _encoder[0] is None: + _encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING) + return _encoder[0].encode(text) + + +def _default_detokenize(tokens: List[int], _encoder=[None]) -> str: + if _encoder[0] is None: + _encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING) + return _encoder[0].decode(tokens) + + +def _default_expansion_policy( + chunk_embeddings: List[List[float]], + query_embedding: Optional[List[float]] = None, +) -> List[bool]: + """Default policy: compress everything (expand nothing).""" + return [False] * len(chunk_embeddings) + + +class TavilyRefragClient: + """Preprocessing adapter that converts Tavily search results into + REFRAG-compatible chunked and encoded context. + + This client handles the retrieval-to-REFRAG-input pipeline: + query -> Tavily search -> passage chunking -> encoder -> expansion policy + -> RefragContext ready for an external REFRAG decoder. + + Each pipeline step is available independently (chunk_passages, + encode_chunks, apply_expansion_policy) or as a single call + (prepare_context). + + Parameters: + api_key: Tavily API key (falls back to TAVILY_API_KEY env var). + chunk_size: Number of tokens per chunk (k in the REFRAG paper). + Paper evaluates k=8, 16, 32; default is 16. + tokenizer_function: ``(text: str) -> list[int]``. + Defaults to tiktoken with the gpt-3.5-turbo encoding. + detokenizer_function: ``(tokens: list[int]) -> str``. + Defaults to tiktoken with the gpt-3.5-turbo encoding. + encoder_function: ``(chunks: list[list[int]]) -> list[list[float]]``. + Takes a list of token-ID lists, returns embedding vectors. + No default — raises if encode_chunks is called without one. + expansion_policy: ``(chunk_embeddings, query_embedding) -> list[bool]``. + Returns a boolean mask indicating which chunks to expand. + Defaults to compressing all chunks (expand nothing). + api_base_url: Override the Tavily base URL. + session: Pre-configured requests.Session for HTTP calls. + **tavily_kwargs: Extra keyword arguments forwarded to TavilyClient. + """ + + def __init__( + self, + api_key: Optional[str] = None, + chunk_size: int = 16, + tokenizer_function: Optional[Callable[[str], List[int]]] = None, + detokenizer_function: Optional[Callable[[List[int]], str]] = None, + encoder_function: Optional[Callable[[List[List[int]]], List[List[float]]]] = None, + expansion_policy: Optional[Callable] = None, + api_base_url: Optional[str] = None, + session: Optional[requests.Session] = None, + **tavily_kwargs, + ): + if chunk_size < 1: + raise ValueError("chunk_size must be a positive integer.") + + self.tavily = TavilyClient( + api_key=api_key, + api_base_url=api_base_url, + session=session, + **tavily_kwargs, + ) + self.chunk_size = chunk_size + self.tokenizer_function = tokenizer_function or _default_tokenize + self.detokenizer_function = detokenizer_function or _default_detokenize + self.encoder_function = encoder_function + self.expansion_policy = expansion_policy or _default_expansion_policy + + def chunk_passages( + self, + passages: List[Dict[str, Any]], + chunk_size: Optional[int] = None, + ) -> List[RefragChunk]: + """Split retrieved passages into fixed-size token chunks. + + Each Tavily result dict is expected to have at least a ``content`` + key. The text is tokenized and then split into non-overlapping + chunks of ``chunk_size`` tokens. Leftover tokens shorter than + ``chunk_size`` are kept as a final smaller chunk. + + Args: + passages: List of Tavily result dicts (must contain ``content``). + chunk_size: Override the instance-level chunk_size for this call. + + Returns: + Ordered list of RefragChunk objects. + """ + k = chunk_size or self.chunk_size + chunks: List[RefragChunk] = [] + + for passage in passages: + content = passage.get("content", "") + if not content: + continue + url = passage.get("url") + score = passage.get("score") + tokens = self.tokenizer_function(content) + + for start in range(0, len(tokens), k): + chunk_tokens = tokens[start : start + k] + chunk_text = self.detokenizer_function(chunk_tokens) + chunks.append( + RefragChunk( + text=chunk_text, + tokens=chunk_tokens, + source_url=url, + source_score=score, + ) + ) + + return chunks + + def encode_chunks(self, chunks: List[RefragChunk]) -> List[RefragChunk]: + """Apply the encoder function to produce chunk embeddings. + + Requires ``encoder_function`` to have been set during construction. + + Args: + chunks: List of RefragChunk objects (tokens must be populated). + + Returns: + The same list with ``embedding`` fields populated. + + Raises: + RuntimeError: If no encoder_function was provided. + """ + if self.encoder_function is None: + raise RuntimeError( + "encoder_function must be provided to encode chunks. " + "Pass it to the TavilyRefragClient constructor." + ) + + token_lists = [c.tokens for c in chunks] + embeddings = self.encoder_function(token_lists) + + for chunk, emb in zip(chunks, embeddings): + chunk.embedding = emb + + return chunks + + def apply_expansion_policy( + self, + chunks: List[RefragChunk], + query: Optional[str] = None, + ) -> List[RefragChunk]: + """Run the expansion policy to decide which chunks to expand. + + The policy receives chunk embeddings and an optional query + embedding and returns a boolean mask. + + Args: + chunks: Chunks with embeddings already populated. + query: Optional query text (tokenized and passed to encoder + to produce a query embedding when encoder_function + is available). + + Returns: + The same list with ``expand`` flags set. + + Raises: + ValueError: If any chunk lacks an embedding. + """ + for c in chunks: + if c.embedding is None: + raise ValueError( + "All chunks must have embeddings before applying the " + "expansion policy. Call encode_chunks first." + ) + + chunk_embeddings = [c.embedding for c in chunks] + + query_embedding = None + if query is not None and self.encoder_function is not None: + query_tokens = self.tokenizer_function(query) + query_embedding = self.encoder_function([query_tokens])[0] + + expand_mask = self.expansion_policy(chunk_embeddings, query_embedding) + + for chunk, should_expand in zip(chunks, expand_mask): + chunk.expand = should_expand + + return chunks + + def prepare_context( + self, + query: str, + max_results: int = 10, + encode: bool = True, + apply_policy: bool = True, + **search_kwargs, + ) -> RefragContext: + """Full pipeline: search -> chunk -> encode -> policy -> RefragContext. + + Args: + query: The search query. + max_results: Maximum number of Tavily search results. + encode: Whether to run the encoder on chunks. Requires + encoder_function to be set. + apply_policy: Whether to run the expansion policy. Only + effective when encode=True. + **search_kwargs: Extra arguments forwarded to TavilyClient.search. + + Returns: + A RefragContext ready for decoder consumption. + """ + response = self.tavily.search( + query, max_results=max_results, **search_kwargs + ) + + passages = response.get("results", []) + metadata = { + k: v for k, v in response.items() if k != "results" + } + + chunks = self.chunk_passages(passages) + + if encode and self.encoder_function is not None: + self.encode_chunks(chunks) + if apply_policy: + self.apply_expansion_policy(chunks, query=query) + + return RefragContext( + query=query, + chunks=chunks, + chunk_size=self.chunk_size, + metadata=metadata, + ) diff --git a/tests/test_refrag.py b/tests/test_refrag.py new file mode 100644 index 0000000..742ffa0 --- /dev/null +++ b/tests/test_refrag.py @@ -0,0 +1,330 @@ +import pytest +from tests.request_intercept import intercept_requests, clear_interceptor +import tavily.tavily as sync_tavily +from tavily.refrag import TavilyRefragClient, RefragChunk, RefragContext + + +dummy_search_response = { + "query": "What is REFRAG?", + "answer": None, + "images": [], + "results": [ + { + "title": "REFRAG Paper", + "url": "https://arxiv.org/abs/2509.01092", + "content": "REFRAG is an efficient decoding framework that compresses senses and expands", + "score": 0.95, + "raw_content": None, + }, + { + "title": "RAG Overview", + "url": "https://example.com/rag", + "content": "Retrieval augmented generation improves LLM accuracy", + "score": 0.88, + "raw_content": None, + }, + ], + "response_time": 0.8, +} + + +@pytest.fixture +def interceptor(): + yield intercept_requests(sync_tavily) + clear_interceptor(sync_tavily) + + +def _dummy_encoder(token_lists): + """Return a fixed-length embedding per chunk (dimension = 4).""" + return [[float(len(tl)), 0.1, 0.2, 0.3] for tl in token_lists] + + +def _dummy_expansion_policy(chunk_embeddings, query_embedding=None): + """Expand the first chunk, compress the rest.""" + return [i == 0 for i in range(len(chunk_embeddings))] + + +# --------------------------------------------------------------------------- +# Constructor tests +# --------------------------------------------------------------------------- + +class TestConstructor: + def test_default_chunk_size(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test") + assert client.chunk_size == 16 + + def test_custom_chunk_size(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=32) + assert client.chunk_size == 32 + + def test_invalid_chunk_size_raises(self, interceptor): + with pytest.raises(ValueError, match="chunk_size must be a positive integer"): + TavilyRefragClient(api_key="tvly-test", chunk_size=0) + + def test_no_encoder_by_default(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test") + assert client.encoder_function is None + + +# --------------------------------------------------------------------------- +# chunk_passages tests +# --------------------------------------------------------------------------- + +class TestChunkPassages: + def test_basic_chunking(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + passages = [ + {"content": "one two three four five six seven eight", "url": "https://a.com", "score": 0.9} + ] + chunks = client.chunk_passages(passages) + + assert len(chunks) >= 2 + assert all(isinstance(c, RefragChunk) for c in chunks) + for c in chunks: + assert len(c.tokens) <= 4 + assert c.source_url == "https://a.com" + assert c.source_score == 0.9 + assert c.embedding is None + assert c.expand is False + + def test_empty_content_skipped(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + passages = [{"content": "", "url": "https://a.com", "score": 0.5}] + chunks = client.chunk_passages(passages) + assert len(chunks) == 0 + + def test_multiple_passages(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + passages = [ + {"content": "alpha beta gamma delta", "url": "https://a.com", "score": 0.9}, + {"content": "epsilon zeta", "url": "https://b.com", "score": 0.7}, + ] + chunks = client.chunk_passages(passages) + urls = [c.source_url for c in chunks] + assert "https://a.com" in urls + assert "https://b.com" in urls + + def test_chunk_size_override(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=100) + passages = [{"content": "a b c d e f g h", "url": "https://a.com", "score": 0.5}] + chunks_big = client.chunk_passages(passages, chunk_size=100) + chunks_small = client.chunk_passages(passages, chunk_size=2) + assert len(chunks_small) > len(chunks_big) + + def test_missing_content_key(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + passages = [{"url": "https://a.com"}] + chunks = client.chunk_passages(passages) + assert len(chunks) == 0 + + +# --------------------------------------------------------------------------- +# encode_chunks tests +# --------------------------------------------------------------------------- + +class TestEncodeChunks: + def test_encode_populates_embeddings(self, interceptor): + client = TavilyRefragClient( + api_key="tvly-test", chunk_size=4, encoder_function=_dummy_encoder + ) + passages = [{"content": "one two three four five six", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + assert all(c.embedding is None for c in chunks) + + client.encode_chunks(chunks) + assert all(c.embedding is not None for c in chunks) + assert len(chunks[0].embedding) == 4 + + def test_encode_without_encoder_raises(self, interceptor): + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + passages = [{"content": "hello world", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + + with pytest.raises(RuntimeError, match="encoder_function must be provided"): + client.encode_chunks(chunks) + + def test_encoder_receives_correct_tokens(self, interceptor): + received = [] + + def tracking_encoder(token_lists): + received.extend(token_lists) + return [[0.0] * 4 for _ in token_lists] + + client = TavilyRefragClient( + api_key="tvly-test", chunk_size=4, encoder_function=tracking_encoder + ) + passages = [{"content": "one two three four five six", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + client.encode_chunks(chunks) + + assert len(received) == len(chunks) + for token_list, chunk in zip(received, chunks): + assert token_list == chunk.tokens + + +# --------------------------------------------------------------------------- +# apply_expansion_policy tests +# --------------------------------------------------------------------------- + +class TestExpansionPolicy: + def test_default_policy_compresses_all(self, interceptor): + client = TavilyRefragClient( + api_key="tvly-test", chunk_size=4, encoder_function=_dummy_encoder + ) + passages = [{"content": "one two three four five six seven eight", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + client.encode_chunks(chunks) + client.apply_expansion_policy(chunks) + + assert all(c.expand is False for c in chunks) + + def test_custom_policy_applied(self, interceptor): + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + expansion_policy=_dummy_expansion_policy, + ) + passages = [{"content": "one two three four five six seven eight", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + client.encode_chunks(chunks) + client.apply_expansion_policy(chunks) + + assert chunks[0].expand is True + assert all(c.expand is False for c in chunks[1:]) + + def test_policy_without_embeddings_raises(self, interceptor): + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + expansion_policy=_dummy_expansion_policy, + ) + passages = [{"content": "hello world tokens here", "url": "https://a.com", "score": 0.9}] + chunks = client.chunk_passages(passages) + + with pytest.raises(ValueError, match="must have embeddings"): + client.apply_expansion_policy(chunks) + + +# --------------------------------------------------------------------------- +# RefragContext dataclass tests +# --------------------------------------------------------------------------- + +class TestRefragContext: + def test_compressed_and_expanded_properties(self): + chunks = [ + RefragChunk(text="a", tokens=[1], embedding=[0.1], expand=False), + RefragChunk(text="b", tokens=[2], embedding=[0.2], expand=True), + RefragChunk(text="c", tokens=[3], embedding=[0.3], expand=False), + ] + ctx = RefragContext(query="test", chunks=chunks, chunk_size=1) + + assert len(ctx.compressed_chunks) == 2 + assert len(ctx.expanded_chunks) == 1 + assert ctx.expanded_chunks[0].text == "b" + + def test_empty_chunks(self): + ctx = RefragContext(query="test", chunks=[], chunk_size=16) + assert ctx.compressed_chunks == [] + assert ctx.expanded_chunks == [] + + +# --------------------------------------------------------------------------- +# prepare_context (end-to-end) tests +# --------------------------------------------------------------------------- + +class TestPrepareContext: + def test_full_pipeline(self, interceptor): + interceptor.set_response(200, json=dummy_search_response) + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + expansion_policy=_dummy_expansion_policy, + ) + + ctx = client.prepare_context("What is REFRAG?", max_results=5) + + assert isinstance(ctx, RefragContext) + assert ctx.query == "What is REFRAG?" + assert ctx.chunk_size == 4 + assert len(ctx.chunks) > 0 + assert all(c.embedding is not None for c in ctx.chunks) + assert ctx.chunks[0].expand is True + assert ctx.metadata.get("response_time") == 0.8 + + request = interceptor.get_request() + assert request.method == "POST" + assert request.url == "https://api.tavily.com/search" + assert request.json()["query"] == "What is REFRAG?" + + def test_pipeline_without_encoder(self, interceptor): + interceptor.set_response(200, json=dummy_search_response) + client = TavilyRefragClient(api_key="tvly-test", chunk_size=4) + + ctx = client.prepare_context("What is REFRAG?", encode=True) + + assert len(ctx.chunks) > 0 + assert all(c.embedding is None for c in ctx.chunks) + assert all(c.expand is False for c in ctx.chunks) + + def test_pipeline_skip_encode(self, interceptor): + interceptor.set_response(200, json=dummy_search_response) + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + ) + + ctx = client.prepare_context("What is REFRAG?", encode=False) + + assert all(c.embedding is None for c in ctx.chunks) + + def test_pipeline_encode_but_skip_policy(self, interceptor): + interceptor.set_response(200, json=dummy_search_response) + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + expansion_policy=_dummy_expansion_policy, + ) + + ctx = client.prepare_context("What is REFRAG?", apply_policy=False) + + assert all(c.embedding is not None for c in ctx.chunks) + assert all(c.expand is False for c in ctx.chunks) + + def test_search_kwargs_forwarded(self, interceptor): + interceptor.set_response(200, json=dummy_search_response) + client = TavilyRefragClient( + api_key="tvly-test", + chunk_size=4, + encoder_function=_dummy_encoder, + ) + + client.prepare_context( + "What is REFRAG?", + max_results=3, + search_depth="advanced", + topic="general", + ) + + request = interceptor.get_request() + body = request.json() + assert body["search_depth"] == "advanced" + assert body["topic"] == "general" + + def test_empty_results(self, interceptor): + interceptor.set_response(200, json={ + "query": "nothing", + "results": [], + "response_time": 0.1, + }) + client = TavilyRefragClient( + api_key="tvly-test", chunk_size=4, encoder_function=_dummy_encoder + ) + ctx = client.prepare_context("nothing") + assert len(ctx.chunks) == 0 + assert ctx.compressed_chunks == [] + assert ctx.expanded_chunks == [] From d2b9da9f873dd14635ebd68644896a0b75c10173 Mon Sep 17 00:00:00 2001 From: AnasAmchaar Date: Sun, 22 Mar 2026 03:30:49 +0100 Subject: [PATCH 2/3] feat: add REFRAG preprocessing example with pluggable encoder and expansion policy Made-with: Cursor --- examples/refrag.py | 85 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 examples/refrag.py diff --git a/examples/refrag.py b/examples/refrag.py new file mode 100644 index 0000000..0696fc2 --- /dev/null +++ b/examples/refrag.py @@ -0,0 +1,85 @@ +# REFRAG Preprocessing Example +# +# This example shows how to use TavilyRefragClient to retrieve web passages +# via Tavily, chunk them into k-token blocks, encode them with a pluggable +# encoder, and apply an expansion policy -- producing a RefragContext that +# is ready to be fed into a REFRAG-compatible decoder. +# +# Replace the dummy encoder and policy below with your own REFRAG-trained +# models (e.g. RoBERTa encoder + RL expansion policy from the REFRAG paper). + +import os +import math +from tavily import TavilyRefragClient + + +# -- Pluggable encoder ------------------------------------------------------- +# In a real setup, this would call a lightweight encoder like RoBERTa to +# produce chunk embeddings. Here we use a trivial placeholder. + +def my_encoder(token_lists): + """Encode each token chunk into a fixed-size embedding vector.""" + embeddings = [] + for tokens in token_lists: + dim = 8 + emb = [float(t % 100) / 100.0 for t in tokens[:dim]] + emb += [0.0] * (dim - len(emb)) + embeddings.append(emb) + return embeddings + + +# -- Pluggable expansion policy ----------------------------------------------- +# In a real setup, this would be an RL-trained policy that decides which +# chunks need full token-level attention in the decoder. Here we expand +# chunks whose embedding norm exceeds a threshold. + +def my_expansion_policy(chunk_embeddings, query_embedding=None): + """Expand chunks whose L2 norm is above the median (heuristic demo).""" + norms = [math.sqrt(sum(x * x for x in emb)) for emb in chunk_embeddings] + if not norms: + return [] + threshold = sorted(norms)[len(norms) // 2] + return [n > threshold for n in norms] + + +# -- Build the client --------------------------------------------------------- + +client = TavilyRefragClient( + api_key=os.environ["TAVILY_API_KEY"], + chunk_size=16, # k=16 as in the paper + encoder_function=my_encoder, + expansion_policy=my_expansion_policy, +) + + +# -- Option 1: Full pipeline in one call -------------------------------------- + +ctx = client.prepare_context( + query="What are the latest advances in retrieval augmented generation?", + max_results=5, + search_depth="advanced", +) + +print(f"Query: {ctx.query}") +print(f"Total chunks: {len(ctx.chunks)}") +print(f"Compressed (feed as embeddings): {len(ctx.compressed_chunks)}") +print(f"Expanded (feed as full tokens): {len(ctx.expanded_chunks)}") +print(f"Metadata: {ctx.metadata}") + +for i, chunk in enumerate(ctx.chunks[:5]): + status = "EXPAND" if chunk.expand else "COMPRESS" + print(f" [{status}] chunk {i}: {chunk.text[:60]}...") + + +# -- Option 2: Step-by-step (useful when you already have passages) ----------- + +response = client.tavily.search("REFRAG paper Meta 2025", max_results=3) +passages = response["results"] + +chunks = client.chunk_passages(passages, chunk_size=8) +chunks = client.encode_chunks(chunks) +chunks = client.apply_expansion_policy(chunks, query="REFRAG paper Meta 2025") + +print(f"\nStep-by-step: {len(chunks)} chunks from {len(passages)} passages") +for c in chunks[:3]: + print(f" expand={c.expand} url={c.source_url} tokens={len(c.tokens)}") From 41b4b9a5565345440e254962a7c60628003b5096 Mon Sep 17 00:00:00 2001 From: AnasAmchaar Date: Sun, 22 Mar 2026 12:39:45 +0100 Subject: [PATCH 3/3] fix: handle None chunk_size in TavilyRefragClient Updated the chunk size assignment to use a conditional expression, ensuring that the default chunk size is used only when chunk_size is explicitly set to None. --- tavily/refrag/refrag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tavily/refrag/refrag.py b/tavily/refrag/refrag.py index 23eed59..e4051e5 100644 --- a/tavily/refrag/refrag.py +++ b/tavily/refrag/refrag.py @@ -104,7 +104,7 @@ def chunk_passages( Returns: Ordered list of RefragChunk objects. """ - k = chunk_size or self.chunk_size + k = chunk_size if chunk_size is not None else self.chunk_size chunks: List[RefragChunk] = [] for passage in passages: