Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions examples/refrag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# REFRAG Preprocessing Example
#
# This example shows how to use TavilyRefragClient to retrieve web passages
# via Tavily, chunk them into k-token blocks, encode them with a pluggable
# encoder, and apply an expansion policy -- producing a RefragContext that
# is ready to be fed into a REFRAG-compatible decoder.
#
# Replace the dummy encoder and policy below with your own REFRAG-trained
# models (e.g. RoBERTa encoder + RL expansion policy from the REFRAG paper).

import os
import math
from tavily import TavilyRefragClient


# -- Pluggable encoder -------------------------------------------------------
# In a real setup, this would call a lightweight encoder like RoBERTa to
# produce chunk embeddings. Here we use a trivial placeholder.

def my_encoder(token_lists):
"""Encode each token chunk into a fixed-size embedding vector."""
embeddings = []
for tokens in token_lists:
dim = 8
emb = [float(t % 100) / 100.0 for t in tokens[:dim]]
emb += [0.0] * (dim - len(emb))
embeddings.append(emb)
return embeddings


# -- Pluggable expansion policy -----------------------------------------------
# In a real setup, this would be an RL-trained policy that decides which
# chunks need full token-level attention in the decoder. Here we expand
# chunks whose embedding norm exceeds a threshold.

def my_expansion_policy(chunk_embeddings, query_embedding=None):
"""Expand chunks whose L2 norm is above the median (heuristic demo)."""
norms = [math.sqrt(sum(x * x for x in emb)) for emb in chunk_embeddings]
if not norms:
return []
threshold = sorted(norms)[len(norms) // 2]
return [n > threshold for n in norms]


# -- Build the client ---------------------------------------------------------

client = TavilyRefragClient(
api_key=os.environ["TAVILY_API_KEY"],
chunk_size=16, # k=16 as in the paper
encoder_function=my_encoder,
expansion_policy=my_expansion_policy,
)


# -- Option 1: Full pipeline in one call --------------------------------------

ctx = client.prepare_context(
query="What are the latest advances in retrieval augmented generation?",
max_results=5,
search_depth="advanced",
)

print(f"Query: {ctx.query}")
print(f"Total chunks: {len(ctx.chunks)}")
print(f"Compressed (feed as embeddings): {len(ctx.compressed_chunks)}")
print(f"Expanded (feed as full tokens): {len(ctx.expanded_chunks)}")
print(f"Metadata: {ctx.metadata}")

for i, chunk in enumerate(ctx.chunks[:5]):
status = "EXPAND" if chunk.expand else "COMPRESS"
print(f" [{status}] chunk {i}: {chunk.text[:60]}...")


# -- Option 2: Step-by-step (useful when you already have passages) -----------

response = client.tavily.search("REFRAG paper Meta 2025", max_results=3)
passages = response["results"]

chunks = client.chunk_passages(passages, chunk_size=8)
chunks = client.encode_chunks(chunks)
chunks = client.apply_expansion_policy(chunks, query="REFRAG paper Meta 2025")

print(f"\nStep-by-step: {len(chunks)} chunks from {len(passages)} passages")
for c in chunks[:3]:
print(f" expand={c.expand} url={c.source_url} tokens={len(c.tokens)}")
3 changes: 2 additions & 1 deletion tavily/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .async_tavily import AsyncTavilyClient
from .tavily import Client, TavilyClient
from .errors import InvalidAPIKeyError, UsageLimitExceededError, MissingAPIKeyError, BadRequestError
from .hybrid_rag import TavilyHybridClient
from .hybrid_rag import TavilyHybridClient
from .refrag import TavilyRefragClient, RefragChunk, RefragContext
2 changes: 2 additions & 0 deletions tavily/refrag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .refrag import TavilyRefragClient
from .models import RefragChunk, RefragContext
53 changes: 53 additions & 0 deletions tavily/refrag/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any


@dataclass
class RefragChunk:
"""A single k-token chunk of a retrieved passage, following the REFRAG paper's
chunk representation (Section 2: C_i = {x_{q+k*i}, ..., x_{q+k*i+k-1}}).

Attributes:
text: Raw text of the chunk before tokenization.
tokens: Token IDs produced by the tokenizer.
embedding: Chunk embedding from the encoder (None until encode_chunks is called).
expand: Whether the expansion policy selected this chunk for full-token decoding.
source_url: Origin URL from the Tavily search result.
source_score: Relevance score from the Tavily search result.
"""
text: str
tokens: List[int]
embedding: Optional[List[float]] = None
expand: bool = False
source_url: Optional[str] = None
source_score: Optional[float] = None


@dataclass
class RefragContext:
"""Preprocessed REFRAG-compatible context ready for decoder consumption.

The decoder receives compressed chunk embeddings for most passages and
full token embeddings only for chunks flagged by the expansion policy
(see REFRAG paper, Section 2 & Figure 1).

Attributes:
query: Original query text.
chunks: All processed chunks in passage order.
chunk_size: The k value (tokens per chunk) used during chunking.
metadata: Tavily response metadata (response_time, images, etc.).
"""
query: str
chunks: List[RefragChunk]
chunk_size: int
metadata: Dict[str, Any] = field(default_factory=dict)

@property
def compressed_chunks(self) -> List[RefragChunk]:
"""Chunks to feed as embeddings (expand=False)."""
return [c for c in self.chunks if not c.expand]

@property
def expanded_chunks(self) -> List[RefragChunk]:
"""Chunks to feed as full tokens (expand=True)."""
return [c for c in self.chunks if c.expand]
246 changes: 246 additions & 0 deletions tavily/refrag/refrag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
import tiktoken
from typing import Optional, List, Dict, Any, Callable, Sequence

import requests
from tavily import TavilyClient
from tavily.config import DEFAULT_MODEL_ENCODING
from .models import RefragChunk, RefragContext


def _default_tokenize(text: str, _encoder=[None]) -> List[int]:
if _encoder[0] is None:
_encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING)
return _encoder[0].encode(text)


def _default_detokenize(tokens: List[int], _encoder=[None]) -> str:
if _encoder[0] is None:
_encoder[0] = tiktoken.encoding_for_model(DEFAULT_MODEL_ENCODING)
return _encoder[0].decode(tokens)


def _default_expansion_policy(
chunk_embeddings: List[List[float]],
query_embedding: Optional[List[float]] = None,
) -> List[bool]:
"""Default policy: compress everything (expand nothing)."""
return [False] * len(chunk_embeddings)


class TavilyRefragClient:
"""Preprocessing adapter that converts Tavily search results into
REFRAG-compatible chunked and encoded context.

This client handles the retrieval-to-REFRAG-input pipeline:
query -> Tavily search -> passage chunking -> encoder -> expansion policy
-> RefragContext ready for an external REFRAG decoder.

Each pipeline step is available independently (chunk_passages,
encode_chunks, apply_expansion_policy) or as a single call
(prepare_context).

Parameters:
api_key: Tavily API key (falls back to TAVILY_API_KEY env var).
chunk_size: Number of tokens per chunk (k in the REFRAG paper).
Paper evaluates k=8, 16, 32; default is 16.
tokenizer_function: ``(text: str) -> list[int]``.
Defaults to tiktoken with the gpt-3.5-turbo encoding.
detokenizer_function: ``(tokens: list[int]) -> str``.
Defaults to tiktoken with the gpt-3.5-turbo encoding.
encoder_function: ``(chunks: list[list[int]]) -> list[list[float]]``.
Takes a list of token-ID lists, returns embedding vectors.
No default — raises if encode_chunks is called without one.
expansion_policy: ``(chunk_embeddings, query_embedding) -> list[bool]``.
Returns a boolean mask indicating which chunks to expand.
Defaults to compressing all chunks (expand nothing).
api_base_url: Override the Tavily base URL.
session: Pre-configured requests.Session for HTTP calls.
**tavily_kwargs: Extra keyword arguments forwarded to TavilyClient.
"""

def __init__(
self,
api_key: Optional[str] = None,
chunk_size: int = 16,
tokenizer_function: Optional[Callable[[str], List[int]]] = None,
detokenizer_function: Optional[Callable[[List[int]], str]] = None,
encoder_function: Optional[Callable[[List[List[int]]], List[List[float]]]] = None,
expansion_policy: Optional[Callable] = None,
api_base_url: Optional[str] = None,
session: Optional[requests.Session] = None,
**tavily_kwargs,
):
if chunk_size < 1:
raise ValueError("chunk_size must be a positive integer.")

self.tavily = TavilyClient(
api_key=api_key,
api_base_url=api_base_url,
session=session,
**tavily_kwargs,
)
self.chunk_size = chunk_size
self.tokenizer_function = tokenizer_function or _default_tokenize
self.detokenizer_function = detokenizer_function or _default_detokenize
self.encoder_function = encoder_function
self.expansion_policy = expansion_policy or _default_expansion_policy

def chunk_passages(
self,
passages: List[Dict[str, Any]],
chunk_size: Optional[int] = None,
) -> List[RefragChunk]:
"""Split retrieved passages into fixed-size token chunks.

Each Tavily result dict is expected to have at least a ``content``
key. The text is tokenized and then split into non-overlapping
chunks of ``chunk_size`` tokens. Leftover tokens shorter than
``chunk_size`` are kept as a final smaller chunk.

Args:
passages: List of Tavily result dicts (must contain ``content``).
chunk_size: Override the instance-level chunk_size for this call.

Returns:
Ordered list of RefragChunk objects.
"""
k = chunk_size if chunk_size is not None else self.chunk_size
chunks: List[RefragChunk] = []

for passage in passages:
content = passage.get("content", "")
if not content:
continue
url = passage.get("url")
score = passage.get("score")
tokens = self.tokenizer_function(content)

for start in range(0, len(tokens), k):
chunk_tokens = tokens[start : start + k]
chunk_text = self.detokenizer_function(chunk_tokens)
chunks.append(
RefragChunk(
text=chunk_text,
tokens=chunk_tokens,
source_url=url,
source_score=score,
)
)

return chunks

def encode_chunks(self, chunks: List[RefragChunk]) -> List[RefragChunk]:
"""Apply the encoder function to produce chunk embeddings.

Requires ``encoder_function`` to have been set during construction.

Args:
chunks: List of RefragChunk objects (tokens must be populated).

Returns:
The same list with ``embedding`` fields populated.

Raises:
RuntimeError: If no encoder_function was provided.
"""
if self.encoder_function is None:
raise RuntimeError(
"encoder_function must be provided to encode chunks. "
"Pass it to the TavilyRefragClient constructor."
)

token_lists = [c.tokens for c in chunks]
embeddings = self.encoder_function(token_lists)

for chunk, emb in zip(chunks, embeddings):
chunk.embedding = emb

return chunks

def apply_expansion_policy(
self,
chunks: List[RefragChunk],
query: Optional[str] = None,
) -> List[RefragChunk]:
"""Run the expansion policy to decide which chunks to expand.

The policy receives chunk embeddings and an optional query
embedding and returns a boolean mask.

Args:
chunks: Chunks with embeddings already populated.
query: Optional query text (tokenized and passed to encoder
to produce a query embedding when encoder_function
is available).

Returns:
The same list with ``expand`` flags set.

Raises:
ValueError: If any chunk lacks an embedding.
"""
for c in chunks:
if c.embedding is None:
raise ValueError(
"All chunks must have embeddings before applying the "
"expansion policy. Call encode_chunks first."
)

chunk_embeddings = [c.embedding for c in chunks]

query_embedding = None
if query is not None and self.encoder_function is not None:
query_tokens = self.tokenizer_function(query)
query_embedding = self.encoder_function([query_tokens])[0]

expand_mask = self.expansion_policy(chunk_embeddings, query_embedding)

for chunk, should_expand in zip(chunks, expand_mask):
chunk.expand = should_expand

return chunks

def prepare_context(
self,
query: str,
max_results: int = 10,
encode: bool = True,
apply_policy: bool = True,
**search_kwargs,
) -> RefragContext:
"""Full pipeline: search -> chunk -> encode -> policy -> RefragContext.

Args:
query: The search query.
max_results: Maximum number of Tavily search results.
encode: Whether to run the encoder on chunks. Requires
encoder_function to be set.
apply_policy: Whether to run the expansion policy. Only
effective when encode=True.
**search_kwargs: Extra arguments forwarded to TavilyClient.search.

Returns:
A RefragContext ready for decoder consumption.
"""
response = self.tavily.search(
query, max_results=max_results, **search_kwargs
)

passages = response.get("results", [])
metadata = {
k: v for k, v in response.items() if k != "results"
}

chunks = self.chunk_passages(passages)

if encode and self.encoder_function is not None:
self.encode_chunks(chunks)
if apply_policy:
self.apply_expansion_policy(chunks, query=query)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent skip of encoding contradicts docstring and encode_chunks

Medium Severity

When encode=True (the default) but encoder_function is None, prepare_context silently skips both encoding and the expansion policy. This contradicts the docstring which states encode "Requires encoder_function to be set" and is inconsistent with encode_chunks, which raises a RuntimeError when called without an encoder. A user who explicitly passes encode=True signals intent to encode — silently producing a RefragContext with no embeddings and no expansion flags creates a subtle failure mode where downstream REFRAG decoders receive incomplete data without any error being surfaced.

Additional Locations (1)
Fix in Cursor Fix in Web


return RefragContext(
query=query,
chunks=chunks,
chunk_size=self.chunk_size,
metadata=metadata,
)
Loading