-
Notifications
You must be signed in to change notification settings - Fork 157
feat: Add REFRAG preprocessing adapter for next-gen RAG context compression #162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
AnasAmchaar
wants to merge
3
commits into
tavily-ai:master
Choose a base branch
from
AnasAmchaar:feat/refrag-preprocessing-adapter
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| # REFRAG Preprocessing Example | ||
| # | ||
| # This example shows how to use TavilyRefragClient to retrieve web passages | ||
| # via Tavily, chunk them into k-token blocks, encode them with a pluggable | ||
| # encoder, and apply an expansion policy -- producing a RefragContext that | ||
| # is ready to be fed into a REFRAG-compatible decoder. | ||
| # | ||
| # Replace the dummy encoder and policy below with your own REFRAG-trained | ||
| # models (e.g. RoBERTa encoder + RL expansion policy from the REFRAG paper). | ||
|
|
||
| import os | ||
| import math | ||
| from tavily import TavilyRefragClient | ||
|
|
||
|
|
||
| # -- Pluggable encoder ------------------------------------------------------- | ||
| # In a real setup, this would call a lightweight encoder like RoBERTa to | ||
| # produce chunk embeddings. Here we use a trivial placeholder. | ||
|
|
||
| def my_encoder(token_lists): | ||
| """Encode each token chunk into a fixed-size embedding vector.""" | ||
| embeddings = [] | ||
| for tokens in token_lists: | ||
| dim = 8 | ||
| emb = [float(t % 100) / 100.0 for t in tokens[:dim]] | ||
| emb += [0.0] * (dim - len(emb)) | ||
| embeddings.append(emb) | ||
| return embeddings | ||
|
|
||
|
|
||
| # -- Pluggable expansion policy ----------------------------------------------- | ||
| # In a real setup, this would be an RL-trained policy that decides which | ||
| # chunks need full token-level attention in the decoder. Here we expand | ||
| # chunks whose embedding norm exceeds a threshold. | ||
|
|
||
| def my_expansion_policy(chunk_embeddings, query_embedding=None): | ||
| """Expand chunks whose L2 norm is above the median (heuristic demo).""" | ||
| norms = [math.sqrt(sum(x * x for x in emb)) for emb in chunk_embeddings] | ||
| if not norms: | ||
| return [] | ||
| threshold = sorted(norms)[len(norms) // 2] | ||
| return [n > threshold for n in norms] | ||
|
|
||
|
|
||
| # -- Build the client --------------------------------------------------------- | ||
|
|
||
| client = TavilyRefragClient( | ||
| api_key=os.environ["TAVILY_API_KEY"], | ||
| chunk_size=16, # k=16 as in the paper | ||
| encoder_function=my_encoder, | ||
| expansion_policy=my_expansion_policy, | ||
| ) | ||
|
|
||
|
|
||
| # -- Option 1: Full pipeline in one call -------------------------------------- | ||
|
|
||
| ctx = client.prepare_context( | ||
| query="What are the latest advances in retrieval augmented generation?", | ||
| max_results=5, | ||
| search_depth="advanced", | ||
| ) | ||
|
|
||
| print(f"Query: {ctx.query}") | ||
| print(f"Total chunks: {len(ctx.chunks)}") | ||
| print(f"Compressed (feed as embeddings): {len(ctx.compressed_chunks)}") | ||
| print(f"Expanded (feed as full tokens): {len(ctx.expanded_chunks)}") | ||
| print(f"Metadata: {ctx.metadata}") | ||
|
|
||
| for i, chunk in enumerate(ctx.chunks[:5]): | ||
| status = "EXPAND" if chunk.expand else "COMPRESS" | ||
| print(f" [{status}] chunk {i}: {chunk.text[:60]}...") | ||
|
|
||
|
|
||
| # -- Option 2: Step-by-step (useful when you already have passages) ----------- | ||
|
|
||
| response = client.tavily.search("REFRAG paper Meta 2025", max_results=3) | ||
| passages = response["results"] | ||
|
|
||
| chunks = client.chunk_passages(passages, chunk_size=8) | ||
| chunks = client.encode_chunks(chunks) | ||
| chunks = client.apply_expansion_policy(chunks, query="REFRAG paper Meta 2025") | ||
|
|
||
| print(f"\nStep-by-step: {len(chunks)} chunks from {len(passages)} passages") | ||
| for c in chunks[:3]: | ||
| print(f" expand={c.expand} url={c.source_url} tokens={len(c.tokens)}") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from .async_tavily import AsyncTavilyClient | ||
| from .tavily import Client, TavilyClient | ||
| from .errors import InvalidAPIKeyError, UsageLimitExceededError, MissingAPIKeyError, BadRequestError | ||
| from .hybrid_rag import TavilyHybridClient | ||
| from .hybrid_rag import TavilyHybridClient | ||
| from .refrag import TavilyRefragClient, RefragChunk, RefragContext |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| from .refrag import TavilyRefragClient | ||
| from .models import RefragChunk, RefragContext |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| from dataclasses import dataclass, field | ||
| from typing import Optional, List, Dict, Any | ||
|
|
||
|
|
||
| @dataclass | ||
| class RefragChunk: | ||
| """A single k-token chunk of a retrieved passage, following the REFRAG paper's | ||
| chunk representation (Section 2: C_i = {x_{q+k*i}, ..., x_{q+k*i+k-1}}). | ||
|
|
||
| Attributes: | ||
| text: Raw text of the chunk before tokenization. | ||
| tokens: Token IDs produced by the tokenizer. | ||
| embedding: Chunk embedding from the encoder (None until encode_chunks is called). | ||
| expand: Whether the expansion policy selected this chunk for full-token decoding. | ||
| source_url: Origin URL from the Tavily search result. | ||
| source_score: Relevance score from the Tavily search result. | ||
| """ | ||
| text: str | ||
| tokens: List[int] | ||
| embedding: Optional[List[float]] = None | ||
| expand: bool = False | ||
| source_url: Optional[str] = None | ||
| source_score: Optional[float] = None | ||
|
|
||
|
|
||
| @dataclass | ||
| class RefragContext: | ||
| """Preprocessed REFRAG-compatible context ready for decoder consumption. | ||
|
|
||
| The decoder receives compressed chunk embeddings for most passages and | ||
| full token embeddings only for chunks flagged by the expansion policy | ||
| (see REFRAG paper, Section 2 & Figure 1). | ||
|
|
||
| Attributes: | ||
| query: Original query text. | ||
| chunks: All processed chunks in passage order. | ||
| chunk_size: The k value (tokens per chunk) used during chunking. | ||
| metadata: Tavily response metadata (response_time, images, etc.). | ||
| """ | ||
| query: str | ||
| chunks: List[RefragChunk] | ||
| chunk_size: int | ||
| metadata: Dict[str, Any] = field(default_factory=dict) | ||
|
|
||
| @property | ||
| def compressed_chunks(self) -> List[RefragChunk]: | ||
| """Chunks to feed as embeddings (expand=False).""" | ||
| return [c for c in self.chunks if not c.expand] | ||
|
|
||
| @property | ||
| def expanded_chunks(self) -> List[RefragChunk]: | ||
| """Chunks to feed as full tokens (expand=True).""" | ||
| return [c for c in self.chunks if c.expand] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,246 @@ | ||
| import tiktoken | ||
| from typing import Optional, List, Dict, Any, Callable, Sequence | ||
|
|
||
| import requests | ||
| from tavily import TavilyClient | ||
| from tavily.config import DEFAULT_MODEL_ENCODING | ||
| from .models import RefragChunk, RefragContext | ||
|
|
||
|
|
||
| def _default_tokenize(text: str, _encoder=[None]) -> List[int]: | ||
| if _encoder[0] is None: | ||
| _encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING) | ||
| return _encoder[0].encode(text) | ||
|
|
||
|
|
||
| def _default_detokenize(tokens: List[int], _encoder=[None]) -> str: | ||
| if _encoder[0] is None: | ||
| _encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING) | ||
| return _encoder[0].decode(tokens) | ||
|
|
||
|
|
||
| def _default_expansion_policy( | ||
| chunk_embeddings: List[List[float]], | ||
| query_embedding: Optional[List[float]] = None, | ||
| ) -> List[bool]: | ||
| """Default policy: compress everything (expand nothing).""" | ||
| return [False] * len(chunk_embeddings) | ||
|
|
||
|
|
||
| class TavilyRefragClient: | ||
| """Preprocessing adapter that converts Tavily search results into | ||
| REFRAG-compatible chunked and encoded context. | ||
|
|
||
| This client handles the retrieval-to-REFRAG-input pipeline: | ||
| query -> Tavily search -> passage chunking -> encoder -> expansion policy | ||
| -> RefragContext ready for an external REFRAG decoder. | ||
|
|
||
| Each pipeline step is available independently (chunk_passages, | ||
| encode_chunks, apply_expansion_policy) or as a single call | ||
| (prepare_context). | ||
|
|
||
| Parameters: | ||
| api_key: Tavily API key (falls back to TAVILY_API_KEY env var). | ||
| chunk_size: Number of tokens per chunk (k in the REFRAG paper). | ||
| Paper evaluates k=8, 16, 32; default is 16. | ||
| tokenizer_function: ``(text: str) -> list[int]``. | ||
| Defaults to tiktoken with the gpt-3.5-turbo encoding. | ||
| detokenizer_function: ``(tokens: list[int]) -> str``. | ||
| Defaults to tiktoken with the gpt-3.5-turbo encoding. | ||
| encoder_function: ``(chunks: list[list[int]]) -> list[list[float]]``. | ||
| Takes a list of token-ID lists, returns embedding vectors. | ||
| No default — raises if encode_chunks is called without one. | ||
| expansion_policy: ``(chunk_embeddings, query_embedding) -> list[bool]``. | ||
| Returns a boolean mask indicating which chunks to expand. | ||
| Defaults to compressing all chunks (expand nothing). | ||
| api_base_url: Override the Tavily base URL. | ||
| session: Pre-configured requests.Session for HTTP calls. | ||
| **tavily_kwargs: Extra keyword arguments forwarded to TavilyClient. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| api_key: Optional[str] = None, | ||
| chunk_size: int = 16, | ||
| tokenizer_function: Optional[Callable[[str], List[int]]] = None, | ||
| detokenizer_function: Optional[Callable[[List[int]], str]] = None, | ||
| encoder_function: Optional[Callable[[List[List[int]]], List[List[float]]]] = None, | ||
| expansion_policy: Optional[Callable] = None, | ||
| api_base_url: Optional[str] = None, | ||
| session: Optional[requests.Session] = None, | ||
| **tavily_kwargs, | ||
| ): | ||
| if chunk_size < 1: | ||
| raise ValueError("chunk_size must be a positive integer.") | ||
|
|
||
| self.tavily = TavilyClient( | ||
| api_key=api_key, | ||
| api_base_url=api_base_url, | ||
| session=session, | ||
| **tavily_kwargs, | ||
| ) | ||
| self.chunk_size = chunk_size | ||
| self.tokenizer_function = tokenizer_function or _default_tokenize | ||
| self.detokenizer_function = detokenizer_function or _default_detokenize | ||
| self.encoder_function = encoder_function | ||
| self.expansion_policy = expansion_policy or _default_expansion_policy | ||
|
|
||
| def chunk_passages( | ||
| self, | ||
| passages: List[Dict[str, Any]], | ||
| chunk_size: Optional[int] = None, | ||
| ) -> List[RefragChunk]: | ||
| """Split retrieved passages into fixed-size token chunks. | ||
|
|
||
| Each Tavily result dict is expected to have at least a ``content`` | ||
| key. The text is tokenized and then split into non-overlapping | ||
| chunks of ``chunk_size`` tokens. Leftover tokens shorter than | ||
| ``chunk_size`` are kept as a final smaller chunk. | ||
|
|
||
| Args: | ||
| passages: List of Tavily result dicts (must contain ``content``). | ||
| chunk_size: Override the instance-level chunk_size for this call. | ||
|
|
||
| Returns: | ||
| Ordered list of RefragChunk objects. | ||
| """ | ||
| k = chunk_size if chunk_size is not None else self.chunk_size | ||
| chunks: List[RefragChunk] = [] | ||
|
|
||
| for passage in passages: | ||
| content = passage.get("content", "") | ||
| if not content: | ||
| continue | ||
| url = passage.get("url") | ||
| score = passage.get("score") | ||
| tokens = self.tokenizer_function(content) | ||
|
|
||
| for start in range(0, len(tokens), k): | ||
| chunk_tokens = tokens[start : start + k] | ||
| chunk_text = self.detokenizer_function(chunk_tokens) | ||
| chunks.append( | ||
| RefragChunk( | ||
| text=chunk_text, | ||
| tokens=chunk_tokens, | ||
| source_url=url, | ||
| source_score=score, | ||
| ) | ||
| ) | ||
|
|
||
| return chunks | ||
|
|
||
| def encode_chunks(self, chunks: List[RefragChunk]) -> List[RefragChunk]: | ||
| """Apply the encoder function to produce chunk embeddings. | ||
|
|
||
| Requires ``encoder_function`` to have been set during construction. | ||
|
|
||
| Args: | ||
| chunks: List of RefragChunk objects (tokens must be populated). | ||
|
|
||
| Returns: | ||
| The same list with ``embedding`` fields populated. | ||
|
|
||
| Raises: | ||
| RuntimeError: If no encoder_function was provided. | ||
| """ | ||
| if self.encoder_function is None: | ||
| raise RuntimeError( | ||
| "encoder_function must be provided to encode chunks. " | ||
| "Pass it to the TavilyRefragClient constructor." | ||
| ) | ||
|
|
||
| token_lists = [c.tokens for c in chunks] | ||
| embeddings = self.encoder_function(token_lists) | ||
|
|
||
| for chunk, emb in zip(chunks, embeddings): | ||
| chunk.embedding = emb | ||
|
|
||
| return chunks | ||
|
|
||
| def apply_expansion_policy( | ||
| self, | ||
| chunks: List[RefragChunk], | ||
| query: Optional[str] = None, | ||
| ) -> List[RefragChunk]: | ||
| """Run the expansion policy to decide which chunks to expand. | ||
|
|
||
| The policy receives chunk embeddings and an optional query | ||
| embedding and returns a boolean mask. | ||
|
|
||
| Args: | ||
| chunks: Chunks with embeddings already populated. | ||
| query: Optional query text (tokenized and passed to encoder | ||
| to produce a query embedding when encoder_function | ||
| is available). | ||
|
|
||
| Returns: | ||
| The same list with ``expand`` flags set. | ||
|
|
||
| Raises: | ||
| ValueError: If any chunk lacks an embedding. | ||
| """ | ||
| for c in chunks: | ||
| if c.embedding is None: | ||
| raise ValueError( | ||
| "All chunks must have embeddings before applying the " | ||
| "expansion policy. Call encode_chunks first." | ||
| ) | ||
|
|
||
| chunk_embeddings = [c.embedding for c in chunks] | ||
|
|
||
| query_embedding = None | ||
| if query is not None and self.encoder_function is not None: | ||
| query_tokens = self.tokenizer_function(query) | ||
| query_embedding = self.encoder_function([query_tokens])[0] | ||
|
|
||
| expand_mask = self.expansion_policy(chunk_embeddings, query_embedding) | ||
|
|
||
| for chunk, should_expand in zip(chunks, expand_mask): | ||
| chunk.expand = should_expand | ||
|
|
||
| return chunks | ||
|
|
||
| def prepare_context( | ||
| self, | ||
| query: str, | ||
| max_results: int = 10, | ||
| encode: bool = True, | ||
| apply_policy: bool = True, | ||
| **search_kwargs, | ||
| ) -> RefragContext: | ||
| """Full pipeline: search -> chunk -> encode -> policy -> RefragContext. | ||
|
|
||
| Args: | ||
| query: The search query. | ||
| max_results: Maximum number of Tavily search results. | ||
| encode: Whether to run the encoder on chunks. Requires | ||
| encoder_function to be set. | ||
| apply_policy: Whether to run the expansion policy. Only | ||
| effective when encode=True. | ||
| **search_kwargs: Extra arguments forwarded to TavilyClient.search. | ||
|
|
||
| Returns: | ||
| A RefragContext ready for decoder consumption. | ||
| """ | ||
| response = self.tavily.search( | ||
| query, max_results=max_results, **search_kwargs | ||
| ) | ||
|
|
||
| passages = response.get("results", []) | ||
| metadata = { | ||
| k: v for k, v in response.items() if k != "results" | ||
| } | ||
|
|
||
| chunks = self.chunk_passages(passages) | ||
|
|
||
| if encode and self.encoder_function is not None: | ||
| self.encode_chunks(chunks) | ||
| if apply_policy: | ||
| self.apply_expansion_policy(chunks, query=query) | ||
|
|
||
| return RefragContext( | ||
| query=query, | ||
| chunks=chunks, | ||
| chunk_size=self.chunk_size, | ||
| metadata=metadata, | ||
| ) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Silent skip of encoding contradicts docstring and
encode_chunksMedium Severity
When
encode=True(the default) butencoder_functionisNone,prepare_contextsilently skips both encoding and the expansion policy. This contradicts the docstring which statesencode"Requires encoder_function to be set" and is inconsistent withencode_chunks, which raises aRuntimeErrorwhen called without an encoder. A user who explicitly passesencode=Truesignals intent to encode — silently producing aRefragContextwith no embeddings and no expansion flags creates a subtle failure mode where downstream REFRAG decoders receive incomplete data without any error being surfaced.Additional Locations (1)
tavily/refrag/refrag.py#L215-L217