Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import logging
import os
import threading
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -293,6 +294,7 @@ def __init__(
self.output_fields = output_fields or ['$embeddings']
self.device = device
self.kwargs = kwargs
self._encode_lock = threading.Lock()

# Check if we should proxy to server
server_addr = get_model_server_address()
Expand Down Expand Up @@ -375,29 +377,30 @@ def _encode_local(self, sentences: List[str], batch_size: int, **kwargs) -> np.n
t_gpu = 0.0
t_post = 0.0

# Process in batches
for i in range(0, len(sentences), batch_size):
batch = sentences[i : i + batch_size]

# Preprocess phase
t0 = time.perf_counter()
preprocessed = SentenceTransformerLoader.preprocess(self._model, batch, self._metadata)
t_pre += (time.perf_counter() - t0) * 1000

# GPU inference phase
t0 = time.perf_counter()
raw_output = SentenceTransformerLoader.inference(self._model, preprocessed, self._metadata)
t_gpu += (time.perf_counter() - t0) * 1000

# Postprocess phase
t0 = time.perf_counter()
results = SentenceTransformerLoader.postprocess(self._model, raw_output, len(batch), self.output_fields)
t_post += (time.perf_counter() - t0) * 1000

# Extract embeddings from results (handles both 'embeddings' and '$embeddings')
for result in results:
emb = result.get('$embeddings') or result.get('embeddings') or result
all_embeddings.append(emb)
with self._encode_lock:
# Process in batches
for i in range(0, len(sentences), batch_size):
batch = sentences[i : i + batch_size]

# Preprocess phase
t0 = time.perf_counter()
preprocessed = SentenceTransformerLoader.preprocess(self._model, batch, self._metadata)
t_pre += (time.perf_counter() - t0) * 1000

# GPU inference phase
t0 = time.perf_counter()
raw_output = SentenceTransformerLoader.inference(self._model, preprocessed, self._metadata)
t_gpu += (time.perf_counter() - t0) * 1000

# Postprocess phase
t0 = time.perf_counter()
results = SentenceTransformerLoader.postprocess(self._model, raw_output, len(batch), self.output_fields)
t_post += (time.perf_counter() - t0) * 1000

# Extract embeddings from results (handles both 'embeddings' and '$embeddings')
for result in results:
emb = result.get('$embeddings') or result.get('embeddings') or result
all_embeddings.append(emb)

# Report all perf counters — same shape as model server response
metrics.add_time(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Tests for ai.common.models.transformers.*
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""
Manual reproducer for shared-model SentenceTransformer failures.

This script is intended for GPU environments where the production model is available.
It runs the same shared model instance in sequential and/or concurrent modes and logs
the token lengths observed per batch.

Example:
PYTHONPATH=engine/packages/ai/src python \
engine/packages/ai/tests/ai/common/models/transformers/reproduce_sentence_transformer_origin.py \
--model nomic-ai/nomic-embed-text-v1.5 --device cuda:0 --threads 3 --iterations 30 --mode both
"""

from __future__ import annotations

import argparse
from concurrent.futures import ThreadPoolExecutor
import logging
import threading
from typing import List

from ai.common.models import SentenceTransformer


LOGGER = logging.getLogger('sentence_transformer_reproducer')


def _build_text(prefix: str, token_hint: int, label: str) -> str:
# token_hint is intentionally large to create variable sequence lengths.
payload = ' '.join([f'{label}_token'] * token_hint)
return f'{prefix}{payload}'


def _build_batches(document_prefix: str) -> List[List[str]]:
return [
[
_build_text(document_prefix, 409, 'sample_a'),
_build_text(document_prefix, 756, 'sample_b'),
],
[
_build_text(document_prefix, 947, 'sample_c'),
_build_text(document_prefix, 756, 'sample_d'),
],
[
_build_text(document_prefix, 512, 'sample_e'),
_build_text(document_prefix, 409, 'sample_f'),
],
]


def _token_lengths(model: SentenceTransformer, sentences: List[str]) -> List[int]:
tokenizer = model._model.tokenizer
encoded = tokenizer(
sentences,
padding=True,
truncation=True,
return_tensors='pt',
max_length=model.max_seq_length,
)
attention_mask = encoded['attention_mask']
lengths = attention_mask.sum(dim=1).tolist()
return [int(length) for length in lengths]


def _run_single_batch(model: SentenceTransformer, batch: List[str], mode: str, round_idx: int, worker_id: int) -> None:
lengths = _token_lengths(model, batch)
LOGGER.info(
'encoding_start mode=%s round=%d worker=%d thread=%s token_lengths=%s',
mode,
round_idx,
worker_id,
threading.get_ident(),
lengths,
)
model.encode(batch, batch_size=len(batch), show_progress_bar=False)
LOGGER.info(
'encoding_done mode=%s round=%d worker=%d thread=%s',
mode,
round_idx,
worker_id,
threading.get_ident(),
)


def _run_sequential(model: SentenceTransformer, batches: List[List[str]], iterations: int) -> None:
for round_idx in range(iterations):
batch = batches[round_idx % len(batches)]
_run_single_batch(model, batch, mode='sequential', round_idx=round_idx, worker_id=0)


def _run_concurrent(model: SentenceTransformer, batches: List[List[str]], iterations: int, threads: int) -> None:
with ThreadPoolExecutor(max_workers=threads) as executor:
for round_idx in range(iterations):
futures = []
for worker_id in range(threads):
batch = batches[(round_idx + worker_id) % len(batches)]
futures.append(
executor.submit(
_run_single_batch,
model,
batch,
'concurrent',
round_idx,
worker_id,
)
)
for future in futures:
future.result()


def main() -> None:
parser = argparse.ArgumentParser(description='Reproduce shared-model sentence transformer failures.')
parser.add_argument('--model', default='nomic-ai/nomic-embed-text-v1.5')
parser.add_argument('--device', default='cuda:0')
parser.add_argument('--threads', type=int, default=3)
parser.add_argument('--iterations', type=int, default=20)
parser.add_argument('--mode', choices=['sequential', 'concurrent', 'both'], default='both')
parser.add_argument('--truncate-dim', type=int, default=768)
parser.add_argument('--document-prefix', default='search_document: ')
args = parser.parse_args()

logging.basicConfig(
level=logging.INFO,
format='%(asctime)s %(levelname)s %(name)s %(message)s',
)

LOGGER.info(
'loading_model model=%s device=%s threads=%d iterations=%d truncate_dim=%d',
args.model,
args.device,
args.threads,
args.iterations,
args.truncate_dim,
)
model = SentenceTransformer(
model_name_or_path=args.model,
device=args.device,
trust_remote_code=True,
truncate_dim=args.truncate_dim,
)
LOGGER.info(
'model_loaded model=%s max_seq_length=%d embedding_dim=%d proxy_mode=%s',
args.model,
model.max_seq_length,
model.get_sentence_embedding_dimension(),
model._proxy_mode,
)

batches = _build_batches(args.document_prefix)

if args.mode in ('sequential', 'both'):
LOGGER.info('start_mode mode=sequential')
_run_sequential(model, batches, args.iterations)

if args.mode in ('concurrent', 'both'):
LOGGER.info('start_mode mode=concurrent')
_run_concurrent(model, batches, args.iterations, args.threads)

LOGGER.info('reproducer_complete mode=%s', args.mode)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""
Unit tests for ai.common.models.transformers.sentence_transformers.

Focus areas:
- Concurrency serialization in local encode path
"""

from concurrent.futures import ThreadPoolExecutor
import threading
import time

import numpy as np

import ai.common.models.transformers.sentence_transformers as sentence_transformers_module


def test_encode_local_serializes_concurrent_inference(monkeypatch):
"""Local SentenceTransformer.encode serializes shared model access."""
active_calls = 0
max_active_calls = 0
state_lock = threading.Lock()

def fake_get_model_server_address():
return None

def fake_load(model_name, device=None, allocate_gpu=None, exclude_gpus=None, **kwargs):
metadata = {
'embedding_dimension': 1,
'max_seq_length': 128,
'device': device or 'cpu',
'model_name': model_name,
'loader': 'sentence_transformer',
'estimated_memory_gb': 0.0,
}
return object(), metadata, -1

def fake_preprocess(model, inputs, metadata=None):
return {
'encoded': {'input_ids': inputs},
'batch_size': len(inputs),
}

def fake_inference(model, preprocessed, metadata=None, stream=None):
nonlocal active_calls, max_active_calls
with state_lock:
active_calls += 1
if active_calls > max_active_calls:
max_active_calls = active_calls

# Simulate GPU call overlap window.
time.sleep(0.02)

with state_lock:
active_calls -= 1

return [[0.25] for _ in range(preprocessed['batch_size'])]

def fake_postprocess(model, raw_output, batch_size, output_fields, **kwargs):
return [{'$embeddings': row} for row in raw_output]

monkeypatch.setattr(sentence_transformers_module, 'get_model_server_address', fake_get_model_server_address)
monkeypatch.setattr(
sentence_transformers_module.SentenceTransformerLoader,
'load',
staticmethod(fake_load),
)
monkeypatch.setattr(
sentence_transformers_module.SentenceTransformerLoader,
'preprocess',
staticmethod(fake_preprocess),
)
monkeypatch.setattr(
sentence_transformers_module.SentenceTransformerLoader,
'inference',
staticmethod(fake_inference),
)
monkeypatch.setattr(
sentence_transformers_module.SentenceTransformerLoader,
'postprocess',
staticmethod(fake_postprocess),
)

model = sentence_transformers_module.SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', device='cpu')

def run_encode(worker_idx):
sentences = [f'search_document: worker-{worker_idx}-item-{i}' for i in range(4)]
return model.encode(sentences, batch_size=2)

with ThreadPoolExecutor(max_workers=4) as executor:
results = list(executor.map(run_encode, range(8)))

assert max_active_calls == 1
for result in results:
assert isinstance(result, np.ndarray)
assert result.shape == (4, 1)
Loading