Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ jobs:
- uses: astral-sh/setup-uv@v5
- run: uv sync
- run: uv run ruff check .
- run: uv run mypy search/

test:
runs-on: ubuntu-latest
Expand Down
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,22 @@ Install dependencies:
uv sync
```

Run from the command line. On first run, the Wikipedia dataset (~20GB) will be downloaded from [Hugging Face](https://huggingface.co/datasets/wikimedia/wikipedia) and cached automatically:
Run the full-text search from the command line. On first run, the Wikipedia dataset (~20GB) will be downloaded from [Hugging Face](https://huggingface.co/datasets/wikimedia/wikipedia) and cached automatically:

```bash
uv run python run.py
# loads of log output
index_documents took 1714.3159050941467 seconds
Index contains 6407814 documents
search took 0.3170650005340576 seconds
search took 4.130218982696533 seconds
search took 0.005632877349853516 seconds
search took 17.051696300506592 seconds
```

Run the semantic (vector) search:

```bash
uv run python run_semantic.py
```

On first run this builds a vector index by embedding all 6.4M documents. Embeddings are checkpointed to `data/checkpoints/` so you can resume if interrupted. The finished index is saved to `data/vector_index.*` and memory-mapped on subsequent runs.

To skip the multi-hour encoding step, download the pre-computed embeddings from [Hugging Face](https://huggingface.co/datasets/bartdegoede/wikipedia-semantic-search), place the JSON and `.npy` files in `data/checkpoints/`, and run `uv run python run_semantic.py`.

If you'd like to download the dataset separately (e.g. before a demo):

```bash
Expand All @@ -50,10 +53,11 @@ In [2]: index.search('python programming language', rank=True)[:5]

## Development

Lint with ruff:
Lint and type check:

```bash
uv run ruff check .
uv run mypy search/
```

Run tests:
Expand Down
Empty file added data/checkpoints/.keep
Empty file.
35 changes: 23 additions & 12 deletions load.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
from collections.abc import Generator

from datasets import load_dataset
from tqdm import tqdm

from search.documents import Abstract

DATASET = "wikimedia/wikipedia"
DATASET_CONFIG = "20231101.en"
DATASET: str = "wikimedia/wikipedia"
DATASET_CONFIG: str = "20231101.en"


def load_documents() -> tuple[int, Generator[Abstract, None, None]]:
"""Load Wikipedia abstracts from HuggingFace.

def load_documents():
Returns (total, iterator) so callers can create fixed-size structures
(like a memmap) without materializing all documents into memory.
The HF Dataset is Arrow-backed and memory-mapped, so iterating over it
doesn't load the full dataset into RAM.
"""
ds = load_dataset(DATASET, DATASET_CONFIG, split="train")
for doc_id, row in enumerate(tqdm(ds, desc="Loading documents")):
title = row["title"]
url = row["url"]
# extract first paragraph as abstract
text = row["text"]
abstract = text.split("\n\n")[0] if text else ""

yield Abstract(ID=doc_id, title=title, url=url, abstract=abstract)

def _generate() -> Generator[Abstract, None, None]:
for doc_id, row in enumerate(ds):
title: str = row["title"]
url: str = row["url"]
# extract first paragraph as abstract
text: str = row["text"]
abstract = text.split("\n\n")[0] if text else ""
yield Abstract(ID=doc_id, title=title, url=url, abstract=abstract)

return len(ds), _generate()
14 changes: 13 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
[project]
name = "python-searchengine"
version = "0.1.0"
version = "0.2.0"
description = "Simple full-text search engine implementation in Python"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"datasets",
"PyStemmer",
"numpy",
"sentence-transformers",
]

[dependency-groups]
dev = [
"mypy>=1.19.1",
"pytest",
"ruff",
]
faiss = [
"faiss-cpu",
]
openai = [
"openai",
]

[tool.ruff]
line-length = 120
Expand All @@ -23,3 +32,6 @@ select = ["E", "F", "W", "I"]

[tool.pytest.ini_options]
testpaths = ["tests"]

[tool.mypy]
ignore_missing_imports = true
14 changes: 8 additions & 6 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from search.timing import timing

logging.basicConfig(level=logging.INFO)
# set httpx logging to WARNING to reduce noise from API calls
logging.getLogger("httpx").setLevel(logging.WARNING)


@timing
Expand All @@ -14,11 +16,11 @@ def index_documents(documents, index):
return index


if __name__ == '__main__':
if __name__ == "__main__":
index = index_documents(load_documents(), Index())
print(f'Index contains {len(index.documents)} documents')
print(f"Index contains {len(index.documents)} documents")

index.search('London Beer Flood', search_type='AND')
index.search('London Beer Flood', search_type='OR')
index.search('London Beer Flood', search_type='AND', rank=True)
index.search('London Beer Flood', search_type='OR', rank=True)
index.search("London Beer Flood", search_type="AND")
index.search("London Beer Flood", search_type="OR")
index.search("London Beer Flood", search_type="AND", rank=True)
index.search("London Beer Flood", search_type="OR", rank=True)
139 changes: 139 additions & 0 deletions run_semantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import itertools
import json
import logging
import tempfile
import time
from pathlib import Path

import numpy as np

from load import load_documents
from search.embeddings import embed_batch, embed_text, get_embedding_model
from search.timing import timing
from search.vector_index import VectorIndex

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
# set httpx logging to WARNING to reduce noise from API calls
logging.getLogger("httpx").setLevel(logging.WARNING)

BATCH_SIZE = 256
CHECKPOINT_SIZE = 10_000
INDEX_PATH = "data/vector_index"
CHECKPOINT_DIR = Path("data/checkpoints")


@timing
def build_vector_index(documents, total, model):
logger.info(f"Building index for {total} documents...")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
Path(INDEX_PATH).parent.mkdir(parents=True, exist_ok=True)

num_chunks = (total + CHECKPOINT_SIZE - 1) // CHECKPOINT_SIZE
matrix = None
build_start = time.perf_counter()

for chunk_num, i in enumerate(range(0, total, CHECKPOINT_SIZE), 1):
chunk_path = CHECKPOINT_DIR / f"chunk_{i}.npy"
docs_path = CHECKPOINT_DIR / f"chunk_{i}.json"
end = min(i + CHECKPOINT_SIZE, total)
chunk_size = end - i

# Consume exactly one chunk from the generator. We only hold ~10k
# Abstract objects at a time instead of all 6.4M.
chunk_docs = list(itertools.islice(documents, chunk_size))

# Save doc metadata per-chunk so we never need all docs in memory.
# On resume these are read back from disk to assemble the final JSON.
# Write to a temp file then rename so a crash mid-write can't leave
# a corrupt file that blocks resume.
if not docs_path.exists():
chunk_docs_data = {
str(i + j): {"ID": d.ID, "title": d.title, "abstract": d.abstract, "url": d.url}
for j, d in enumerate(chunk_docs)
}
with tempfile.NamedTemporaryFile("w", dir=CHECKPOINT_DIR, suffix=".json", delete=False) as f:
json.dump(chunk_docs_data, f)
tmp_path = Path(f.name)
tmp_path.rename(docs_path)

t0 = time.perf_counter()
if chunk_path.exists():
chunk_vectors = np.load(chunk_path)
elapsed = time.perf_counter() - t0
logger.info(f" Chunk {chunk_num}/{num_chunks}: loaded checkpoint ({end}/{total} docs) in {elapsed:.1f}s")
else:
logger.info(f" Chunk {chunk_num}/{num_chunks}: embedding docs {i:,}–{end:,} of {total:,}")
texts = [d.fulltext for d in chunk_docs]
chunk_vectors = embed_batch(model, texts, batch_size=BATCH_SIZE, show_progress=True)
elapsed = time.perf_counter() - t0
logger.info(f" Chunk {chunk_num}/{num_chunks}: embedded in {elapsed:.1f}s")
# Save raw embeddings via temp file + rename for crash safety
with tempfile.NamedTemporaryFile(dir=CHECKPOINT_DIR, suffix=".npy", delete=False) as f:
np.save(f, chunk_vectors)
tmp_path = Path(f.name)
tmp_path.rename(chunk_path)

total_elapsed = time.perf_counter() - build_start
logger.info(f" Total elapsed: {total_elapsed:.1f}s")

# We can only create the memmap once we know the embedding dimensions
# from the first chunk (e.g. 384 for all-MiniLM-L6-v2).
if matrix is None:
matrix = np.lib.format.open_memmap(
f"{INDEX_PATH}.npy", mode="w+", dtype=np.float16,
shape=(total, chunk_vectors.shape[1]),
)

# Normalize in float32 for numerical stability, then downcast to float16
# to halve disk/memory usage. The precision loss is negligible for ranking.
norms = np.linalg.norm(chunk_vectors, axis=1, keepdims=True)
norms[norms == 0] = 1
matrix[i:end] = (chunk_vectors / norms).astype(np.float16)

if matrix is not None:
matrix.flush()
del matrix

# Assemble final document metadata from per-chunk JSON files
all_docs_data = {}
for i in range(0, total, CHECKPOINT_SIZE):
with open(CHECKPOINT_DIR / f"chunk_{i}.json") as f:
all_docs_data.update(json.load(f))
with open(f"{INDEX_PATH}.json", "w") as f:
json.dump(all_docs_data, f)

# Load the finished index using memory-mapped I/O — the matrix stays on disk
# and the OS pages in data as needed during search.
index = VectorIndex()
index.load(INDEX_PATH)
return index


if __name__ == "__main__":
model = get_embedding_model()

# try loading a saved index first
try:
index = VectorIndex()
index.load(INDEX_PATH)
logger.info(f"Loaded vector index with {len(index.documents)} documents from disk")
except FileNotFoundError:
logger.info("No saved index found, building from scratch...")
total, documents = load_documents()
index = build_vector_index(documents, total, model)

logger.info(f"Index contains {len(index.documents)} documents")

queries = [
"London Beer Flood",
"alcoholic beverage disaster in England",
"python programming language",
"large constricting reptiles",
]
for query in queries:
print(f'\n--- Query: "{query}" ---')
query_vector = embed_text(model, query)
results = index.search(query_vector, k=5)
for doc, score in results:
print(f" {score:.4f} | {doc.title}")
21 changes: 21 additions & 0 deletions search/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import numpy as np
from sentence_transformers import SentenceTransformer

DEFAULT_MODEL = "all-MiniLM-L6-v2"


def get_embedding_model(model_name=DEFAULT_MODEL):
"""Load a sentence-transformers model."""
return SentenceTransformer(model_name)


def embed_text(model, text):
"""Embed a single text string. Returns a float32 numpy array."""
return model.encode(text, convert_to_numpy=True).astype(np.float32)


def embed_batch(model, texts, batch_size=256, show_progress=False):
"""Embed a list of texts in batches. Returns a (n, dims) float32 numpy array."""
return model.encode(
texts, batch_size=batch_size, show_progress_bar=show_progress, convert_to_numpy=True
).astype(np.float32)
Loading