Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tests/contract/test_turboquant_estimate_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import numpy as np
import pytest

from semafold import EncodeMetric, EncodeObjective, VectorEncodeRequest
from semafold.turboquant import (
TurboQuantMSEConfig,
TurboQuantMSEVectorCodec,
TurboQuantProdConfig,
TurboQuantProdVectorCodec,
)


def _normalized_rows(*, seed: int, shape: tuple[int, int], dtype: type[np.generic]) -> np.ndarray:
rng = np.random.default_rng(seed)
rows = rng.normal(size=shape).astype(np.float32)
norms = np.linalg.norm(rows.astype(np.float64), axis=1, keepdims=True).astype(np.float32)
norms = np.where(norms == 0.0, np.float32(1.0), norms)
return np.asarray(rows / norms, dtype=dtype)


@pytest.mark.parametrize(
("codec", "encode_request"),
[
(
TurboQuantMSEVectorCodec(
config=TurboQuantMSEConfig(default_bits_per_scalar=3, default_rotation_seed=7)
),
VectorEncodeRequest(
data=np.random.default_rng(17).normal(size=(8, 32)).astype(np.float32),
objective=EncodeObjective.RECONSTRUCTION,
metric=EncodeMetric.MSE,
role="embedding",
seed=19,
),
),
(
TurboQuantProdVectorCodec(
config=TurboQuantProdConfig(total_bits_per_scalar=4, default_rotation_seed=7, default_qjl_seed=11)
),
VectorEncodeRequest(
data=_normalized_rows(seed=23, shape=(8, 32), dtype=np.float32),
objective=EncodeObjective.INNER_PRODUCT_ESTIMATION,
metric=EncodeMetric.DOT_PRODUCT_ERROR,
role="embedding",
seed=29,
),
),
],
)
def test_turboquant_estimate_contract_exposes_exact_accounting_fields(
codec,
encode_request: VectorEncodeRequest,
) -> None:
estimate = codec.estimate(encode_request)
encoding = codec.encode(encode_request)

assert estimate.baseline_bytes == int(encode_request.data.nbytes)
assert estimate.estimated_payload_bytes is not None
assert estimate.estimated_metadata_bytes is not None
assert estimate.estimated_sidecar_bytes is not None
assert estimate.estimated_protected_passthrough_bytes == 0
assert estimate.estimated_decoder_state_bytes == 0
assert estimate.estimated_total_bytes is not None
assert estimate.estimated_compression_ratio is not None

assert estimate.estimated_total_bytes == (
estimate.estimated_payload_bytes
+ estimate.estimated_metadata_bytes
+ estimate.estimated_sidecar_bytes
+ estimate.estimated_protected_passthrough_bytes
+ estimate.estimated_decoder_state_bytes
)
assert estimate.estimated_compression_ratio == pytest.approx(
float(estimate.baseline_bytes) / float(estimate.estimated_total_bytes)
)

assert encoding.footprint.payload_bytes == estimate.estimated_payload_bytes
assert encoding.footprint.metadata_bytes == estimate.estimated_metadata_bytes
assert encoding.footprint.sidecar_bytes == estimate.estimated_sidecar_bytes
assert encoding.footprint.protected_passthrough_bytes == estimate.estimated_protected_passthrough_bytes
assert encoding.footprint.decoder_state_bytes == estimate.estimated_decoder_state_bytes
assert encoding.footprint.total_bytes == estimate.estimated_total_bytes
assert encoding.footprint.compression_ratio == pytest.approx(estimate.estimated_compression_ratio)
105 changes: 105 additions & 0 deletions tests/integration/test_turboquant_estimate_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from __future__ import annotations

import numpy as np
import pytest

from semafold import VectorEncodeRequest
from semafold.turboquant import (
TurboQuantMSEConfig,
TurboQuantMSEVectorCodec,
TurboQuantProdConfig,
TurboQuantProdVectorCodec,
)
from semafold.vector.models import EncodeMetric, EncodeObjective


def _normalized_data(*, seed: int, shape: tuple[int, ...], dtype: type[np.generic]) -> np.ndarray:
rng = np.random.default_rng(seed)
data = rng.normal(size=shape).astype(np.float32)
if len(shape) == 2:
norms = np.linalg.norm(data.astype(np.float64), axis=1, keepdims=True).astype(np.float32)
norms = np.where(norms == 0.0, np.float32(1.0), norms)
data = np.asarray(data / norms, dtype=np.float32)
return data.astype(dtype)


@pytest.mark.parametrize(
("codec_factory", "request_factory", "seed"),
[
(
lambda: TurboQuantMSEVectorCodec(
config=TurboQuantMSEConfig(default_bits_per_scalar=1, default_rotation_seed=5)
),
lambda seed: VectorEncodeRequest(
data=_normalized_data(seed=seed, shape=(16,), dtype=np.float32),
objective=EncodeObjective.RECONSTRUCTION,
metric=EncodeMetric.MSE,
role="embedding",
seed=11,
),
101,
),
(
lambda: TurboQuantMSEVectorCodec(
config=TurboQuantMSEConfig(default_bits_per_scalar=4, default_rotation_seed=5)
),
lambda seed: VectorEncodeRequest(
data=_normalized_data(seed=seed, shape=(6, 32), dtype=np.float64),
objective=EncodeObjective.RECONSTRUCTION,
metric=EncodeMetric.MSE,
role="embedding",
seed=13,
),
202,
),
(
lambda: TurboQuantProdVectorCodec(
config=TurboQuantProdConfig(total_bits_per_scalar=2, default_rotation_seed=7, default_qjl_seed=17)
),
lambda seed: VectorEncodeRequest(
data=_normalized_data(seed=seed, shape=(8, 32), dtype=np.float32),
objective=EncodeObjective.INNER_PRODUCT_ESTIMATION,
metric=EncodeMetric.DOT_PRODUCT_ERROR,
role="embedding",
seed=19,
),
303,
),
(
lambda: TurboQuantProdVectorCodec(
config=TurboQuantProdConfig(total_bits_per_scalar=5, default_rotation_seed=7, default_qjl_seed=17)
),
lambda seed: VectorEncodeRequest(
data=_normalized_data(seed=seed, shape=(4, 64), dtype=np.float16),
objective=EncodeObjective.INNER_PRODUCT_ESTIMATION,
metric=EncodeMetric.DOT_PRODUCT_ERROR,
role="embedding",
seed=23,
),
404,
),
],
)
def test_turboquant_estimate_matches_encode_across_supported_shapes_and_precisions(
codec_factory,
request_factory,
seed: int,
) -> None:
codec = codec_factory()
encode_request = request_factory(seed)

estimate = codec.estimate(encode_request)
encoding = codec.encode(encode_request)

assert estimate.estimated_total_bytes is not None
assert estimate.estimated_payload_bytes is not None
assert estimate.estimated_metadata_bytes is not None
assert estimate.estimated_sidecar_bytes is not None
assert estimate.estimated_compression_ratio is not None

assert encoding.footprint.total_bytes == estimate.estimated_total_bytes
assert encoding.footprint.payload_bytes == estimate.estimated_payload_bytes
assert encoding.footprint.metadata_bytes == estimate.estimated_metadata_bytes
assert encoding.footprint.sidecar_bytes == estimate.estimated_sidecar_bytes
assert encoding.footprint.compression_ratio == pytest.approx(estimate.estimated_compression_ratio)
assert encoding.footprint.baseline_bytes == estimate.baseline_bytes
181 changes: 181 additions & 0 deletions tests/integration/test_turboquant_kv_rate_distortion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from __future__ import annotations

import numpy as np

from semafold.turboquant.kv import TurboQuantKVConfig, TurboQuantKVPreviewCodec


def _normalize_last_axis(array: np.ndarray) -> np.ndarray:
norms = np.linalg.norm(array.astype(np.float64), axis=-1, keepdims=True).astype(np.float32)
norms = np.where(norms == 0.0, np.float32(1.0), norms)
return np.asarray(array / norms, dtype=np.float32)


def _softmax(array: np.ndarray, *, axis: int = -1) -> np.ndarray:
shifted = array - np.max(array, axis=axis, keepdims=True)
exp = np.exp(shifted)
return exp / np.sum(exp, axis=axis, keepdims=True)


def _attention_output(queries: np.ndarray, keys: np.ndarray, values: np.ndarray) -> np.ndarray:
scale = float(np.sqrt(keys.shape[-1], dtype=np.float32))
scores = np.einsum("bhqd,bhkd->bhqk", queries.astype(np.float64), keys.astype(np.float64)) / scale
weights = _softmax(scores, axis=-1)
return np.einsum("bhqk,bhkd->bhqd", weights, values.astype(np.float64))


def _sample_attention_inputs(*, seed: int = 123) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
rng = np.random.default_rng(seed)
queries = _normalize_last_axis(rng.standard_normal((2, 2, 5, 16), dtype=np.float32))
keys = _normalize_last_axis(rng.standard_normal((2, 2, 7, 16), dtype=np.float32))
values = rng.standard_normal((2, 2, 7, 16), dtype=np.float32)
return queries, keys, values


def _attention_quality(
*,
queries: np.ndarray,
keys: np.ndarray,
values: np.ndarray,
codec: TurboQuantKVPreviewCodec,
) -> tuple[dict[str, float | int], float, float]:
artifact = codec.compress(keys, values)
restored_keys, restored_values = codec.decompress(artifact)
exact_output = _attention_output(queries, keys, values)
approx_output = _attention_output(queries, restored_keys, restored_values)
mse = float(np.mean(np.square(exact_output - approx_output)))
cosine_similarity = float(
np.sum(exact_output * approx_output)
/ ((np.linalg.norm(exact_output) + 1e-12) * (np.linalg.norm(approx_output) + 1e-12))
)
return codec.memory_stats(artifact), mse, cosine_similarity


def test_turboquant_kv_rate_distortion_tradeoff_is_visible_in_memory_stats_and_attention_quality() -> None:
queries, keys, values = _sample_attention_inputs()

low_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=2,
value_bits_per_scalar=1,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)
high_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=5,
value_bits_per_scalar=4,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)

low_stats, low_mse, low_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=low_codec,
)
high_stats, high_mse, high_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=high_codec,
)

assert int(low_stats["combined_bytes"]) < int(high_stats["combined_bytes"])
assert int(low_stats["key_bytes"]) < int(high_stats["key_bytes"])
assert int(low_stats["value_bytes"]) < int(high_stats["value_bytes"])
assert float(low_stats["combined_compression_ratio"]) > float(high_stats["combined_compression_ratio"])

assert high_mse < low_mse
assert high_cosine > low_cosine


def test_turboquant_kv_key_bits_mainly_move_key_memory_and_attention_quality() -> None:
queries, keys, values = _sample_attention_inputs()

low_key_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=2,
value_bits_per_scalar=3,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)
high_key_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=5,
value_bits_per_scalar=3,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)

low_stats, low_mse, low_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=low_key_codec,
)
high_stats, high_mse, high_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=high_key_codec,
)

assert int(low_stats["key_bytes"]) < int(high_stats["key_bytes"])
assert int(low_stats["combined_bytes"]) < int(high_stats["combined_bytes"])
assert abs(int(low_stats["value_bytes"]) - int(high_stats["value_bytes"])) <= 16

assert high_mse < low_mse
assert high_cosine > low_cosine


def test_turboquant_kv_value_bits_mainly_move_value_memory_and_attention_quality() -> None:
queries, keys, values = _sample_attention_inputs()

low_value_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=3,
value_bits_per_scalar=1,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)
high_value_codec = TurboQuantKVPreviewCodec(
config=TurboQuantKVConfig(
key_total_bits_per_scalar=3,
value_bits_per_scalar=4,
default_key_rotation_seed=7,
default_key_qjl_seed=11,
default_value_rotation_seed=17,
)
)

low_stats, low_mse, low_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=low_value_codec,
)
high_stats, high_mse, high_cosine = _attention_quality(
queries=queries,
keys=keys,
values=values,
codec=high_value_codec,
)

assert int(low_stats["value_bytes"]) < int(high_stats["value_bytes"])
assert int(low_stats["combined_bytes"]) < int(high_stats["combined_bytes"])
assert abs(int(low_stats["key_bytes"]) - int(high_stats["key_bytes"])) <= 16

assert high_mse < low_mse
assert high_cosine > low_cosine
Loading