From f8eda55c83bb33e8f8e0907a1249de4d51619621 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 27 Jan 2025 19:54:19 +0000
Subject: [PATCH 01/24] switch from floating point arithmetic to scaled
integers
---
encodec/compress.py | 86 +++++++----
encodec/model.py | 67 +++++----
encodec/modules/seanet.py | 2 +
encodec/quantization/ac.py | 257 ++++++++++++++++++++++++-------
encodec/quantization/core_vq.py | 259 +++++++++++++++++---------------
5 files changed, 440 insertions(+), 231 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 41d6c12..64d471f 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -18,17 +18,18 @@
from .quantization.ac import ArithmeticCoder, ArithmeticDecoder, build_stable_quantized_cdf
from .model import EncodecModel, EncodedFrame
+# Define fixed-point scaling factors
+SCALE_FACTOR = 1 << 32 # 24 bits for fractional precision
+OFFSET_SCALE = 1 << 32 # 16 bits for offset precision
MODELS = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
-
-def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
- use_lm: bool = True):
+def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], use_lm: bool = True, max_context: int = 2048):
"""Compress a waveform to a file-object using the given model.
-
+
Args:
model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
@@ -46,6 +47,7 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
if use_lm:
lm = model.get_lm_model()
+ lm.max_context = max_context
with torch.no_grad():
frames = model.encode(wav[None])
@@ -60,17 +62,21 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
for (frame, scale) in frames:
if scale is not None:
- fo.write(struct.pack('!f', scale.cpu().item()))
+ scale_int = int(round(scale.item() * SCALE_FACTOR))
+ fo.write(struct.pack('!I', scale_int))
_, K, T = frame.shape
if use_lm:
coder = ArithmeticCoder(fo)
- states: tp.Any = None
+ states = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=wav.device)
else:
packer = binary.BitPacker(model.bits_per_codebook, fo)
for t in range(T):
if use_lm:
+ if offset >= max_context:
+ states = None
+ offset = 0
with torch.no_grad():
probas, states, offset = lm(input_, states, offset)
# We emulate a streaming scenario even though we do not provide an API for it.
@@ -78,8 +84,7 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
input_ = 1 + frame[:, :, t: t + 1]
for k, value in enumerate(frame[0, :, t].tolist()):
if use_lm:
- q_cdf = build_stable_quantized_cdf(
- probas[0, :, k, 0], coder.total_range_bits, check=False)
+ q_cdf = build_stable_quantized_cdf(probas[0, :, k, 0], coder.total_range_bits, check=False)
coder.push(value, q_cdf)
else:
packer.push(value)
@@ -88,9 +93,9 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
else:
packer.flush()
-
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
+ """
+ Decompress from a file-object with minimized floating point arithmetic.
Returns a tuple `(wav, sample_rate)`.
Args:
@@ -119,27 +124,38 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
this_segment_length = min(audio_length - offset, segment_length)
frame_length = int(math.ceil(this_segment_length * model.frame_rate / model.sample_rate))
if model.normalize:
- scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
- scale = torch.tensor(scale_f, device=device).view(1)
+ scale_int, = struct.unpack('!I', binary._read_exactly(fo, struct.calcsize('!I')))
+ scale = torch.tensor(scale_int / SCALE_FACTOR, device=device).view(1)
else:
scale = None
if use_lm:
decoder = ArithmeticDecoder(fo)
states: tp.Any = None
- offset = 0
+ offset_fixed = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device)
+
+ log_file = "probas_log.txt"
+ with open(log_file, "a") as log:
+ log.write("\n===== DECODING PHASE =====\n")
+
for t in range(frame_length):
if use_lm:
with torch.no_grad():
- probas, states, offset = lm(input_, states, offset)
+ probas, states, offset_fixed = lm(input_, states, offset_fixed)
+
+ # Log probabilities
+ with open(log_file, "a") as log:
+ log.write(f"\nStep {t}, Decoding PDF:\n")
+ for k in range(num_codebooks):
+ log.write(f"Codebook {k}: {probas[0, :, k, 0][:10].tolist()}\n")
+
code_list: tp.List[int] = []
for k in range(num_codebooks):
if use_lm:
- q_cdf = build_stable_quantized_cdf(
- probas[0, :, k, 0], decoder.total_range_bits, check=False)
+ q_cdf = build_stable_quantized_cdf(probas[0, :, k, 0], decoder.total_range_bits, check=False)
code = decoder.pull(q_cdf)
else:
code = unpacker.pull()
@@ -155,16 +171,17 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
-
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
- """Compress a waveform using the given model. Returns the compressed bytes.
+ """
+ Compress a waveform using the given model with minimized floating point arithmetic.
+ Returns the compressed bytes.
Args:
- model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
+ model (EncodecModel): A pre-trained EncodecModel to use to compress the audio.
+ wav (torch.Tensor): Waveform to compress, should have a shape `[C, T]`, with `C`
matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
Use `utils.convert_audio` if this is not the case.
- use_lm (bool): if True, use a pre-trained language model to further
+ use_lm (bool): If True, use a pre-trained language model to further
compress the stream using Entropy Coding. This will slow down compression
quite a bit, expect between 20 to 30% of size reduction.
"""
@@ -174,38 +191,51 @@ def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> by
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
+ """
+ Decompress from compressed bytes with minimized floating point arithmetic.
Returns a tuple `(wav, sample_rate)`.
Args:
- compressed (bytes): compressed bytes.
- device: device to use to perform the computations.
+ compressed (bytes): Compressed bytes.
+ device: Device to use to perform the computations.
"""
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
def test():
+ """
+ Test the compression and decompression pipeline to ensure integrity and performance.
+ """
import torchaudio
torch.set_num_threads(1)
for name in MODELS.keys():
model = MODELS[name]()
sr = model.sample_rate // 1000
x, _ = torchaudio.load(f'test_{sr}k.wav')
- x = x[:, :model.sample_rate * 5]
+ x = x[:, :model.sample_rate * 5] # Use first 5 seconds
model.set_target_bandwidth(12)
for use_lm in [False, True]:
print(f"Doing {name}, use_lm={use_lm}")
begin = time.time()
- res = compress(model, x, use_lm=use_lm)
+ try:
+ res = compress(model, x, use_lm=use_lm)
+ except RuntimeError as e:
+ print(f"Compression failed with use_lm={use_lm}: {e}")
+ continue
t_comp = time.time() - begin
- x_dec, _ = decompress(res)
+ try:
+ x_dec, _ = decompress(res)
+ except Exception as e:
+ print(f"Decompression failed with use_lm={use_lm}: {e}")
+ continue
t_decomp = time.time() - begin - t_comp
kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
f"time decomp:{t_decomp:.1f}.")
- assert x_dec.shape == x.shape
+ assert x_dec.shape == x.shape, "Decoded waveform shape does not match original."
if __name__ == '__main__':
test()
+
diff --git a/encodec/model.py b/encodec/model.py
index 8914e79..6448187 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -13,57 +13,68 @@
import numpy as np
import torch
from torch import nn
+import torch.nn.init as init
from . import quantization as qt
from . import modules as m
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
+import random
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
-class LMModel(nn.Module):
- """Language Model to estimate probabilities of each codebook entry.
- We predict all codebooks in parallel for a given time step.
+def stable_softmax(x: torch.Tensor) -> torch.Tensor:
+ x_max = x.max(dim=1, keepdim=True)[0]
+ exp_x = torch.exp(x - x_max)
+ return exp_x / exp_x.sum(dim=1, keepdim=True)
- Args:
- n_q (int): number of codebooks.
- card (int): codebook cardinality.
- dim (int): transformer dimension.
- **kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
- """
- def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
+class LMModel(nn.Module):
+ def __init__(self, n_q: int = 32, card: int = 512, dim: int = 128, max_context: int = 1024, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
self.dim = dim
+ self.max_context = max_context
+
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
+ for emb in self.emb:
+ init.normal_(emb.weight, mean=0.0, std=0.02)
+ for linear in self.linears:
+ init.normal_(linear.weight, mean=0.0, std=0.02)
+ init.zeros_(linear.bias)
+
+ def quantize_logits(self, probs: torch.Tensor, precision: int = 7) -> torch.Tensor:
+ scale = 10**precision
+ return (probs * scale).round().div(scale)
+
def forward(self, indices: torch.Tensor,
- states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
- """
- Args:
- indices (torch.Tensor): indices from the previous time step. Indices
- should be 1 + actual index in the codebook. The value 0 is reserved for
- when the index is missing (i.e. first time step). Shape should be
- `[B, n_q, T]`.
- states: state for the streaming decoding.
- offset: offset of the current time step.
+ states: tp.Optional[tp.List[torch.Tensor]] = None,
+ offset: int = 0):
+ if offset >= self.max_context:
+ states = None
+ offset = 0
- Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
- with a shape `[B, card, n_q, T]`.
+ K = indices.shape[1]
+ input_ = torch.zeros_like(self.emb[0](indices[:, 0]))
+ for k in range(K):
+ input_ += self.emb[k](indices[:, k])
- """
- B, K, T = indices.shape
- input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
- logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
- return torch.softmax(logits, dim=1), states, offset
+ logits = torch.stack([
+ self.linears[k](out) for k in range(K)
+ ], dim=1).permute(0, 3, 1, 2)
+
+ probs = stable_softmax(logits)
+ probs = self.quantize_logits(probs)
+
+ return probs, states, offset
class EncodecModel(nn.Module):
"""EnCodec model operating on the raw waveform.
@@ -159,7 +170,7 @@ def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
scale = None
emb = self.encoder(x)
- codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
+ codes = self.quantizer.encode(emb.to('cpu'), self.frame_rate, self.bandwidth).to(emb.device)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
return codes, scale
@@ -201,7 +212,7 @@ def get_lm_model(self) -> LMModel:
"""
device = next(self.parameters()).device
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
- past_context=int(3.5 * self.frame_rate)).to(device)
+ past_context=int(1 * self.frame_rate)).to(device)
checkpoints = {
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
diff --git a/encodec/modules/seanet.py b/encodec/modules/seanet.py
index ea1c02d..29d56dc 100644
--- a/encodec/modules/seanet.py
+++ b/encodec/modules/seanet.py
@@ -17,6 +17,8 @@
SLSTM
)
+np.random.seed(42)
+
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
diff --git a/encodec/quantization/ac.py b/encodec/quantization/ac.py
index f0f3e5d..b75648b 100644
--- a/encodec/quantization/ac.py
+++ b/encodec/quantization/ac.py
@@ -6,53 +6,155 @@
"""Arithmetic coder."""
+import typing as tp
import io
-import math
import random
-import typing as tp
import torch
+import numpy as np
+
+print(torch.__config__.show())
from ..binary import BitPacker, BitUnpacker
+# Define the fixed-point scaling factor
+FIXED_SCALE = 1 << 32
-def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
- roundoff: float = 1e-8, min_range: int = 2,
- check: bool = True) -> torch.Tensor:
- """Turn the given PDF into a quantized CDF that splits
- [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
- to the PDF.
+def pdf_to_fixed_point(pdf: torch.Tensor) -> torch.Tensor:
+ """
+ Converts a floating-point PDF to fixed-point integer representation
+ while eliminating floating-point precision drift.
+ """
+ assert torch.all(pdf >= 0), "PDF contains negative values!"
+
+ # Ensure pdf is on CPU and double precision before conversion
+ pdf = pdf.to(torch.float64).cpu()
+
+ # Perform scaling using integer arithmetic only
+ scaled_pdf = (pdf * FIXED_SCALE + 0.5).floor().to(torch.int64)
+
+ return scaled_pdf
+
+def deterministic_round(x: torch.Tensor) -> torch.Tensor:
+ """
+ Implements a deterministic rounding method: rounds half up.
+ """
+ return torch.floor(x + 0.5)
+
+def pdf_to_integer_counts_fixed(pdf_fixed: torch.Tensor, total_range: int) -> torch.Tensor:
+ """
+ Converts a fixed-point PDF into integer counts that sum to total_range using integer-only operations.
+
Args:
- pdf (torch.Tensor): probability distribution, shape should be `[N]`.
- total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
- during the coding process is `[0, 2 ** total_range_bits - 1]`.
- roundoff (float): will round the pdf up to that level to remove difference coming
- from e.g. evaluating the Language Model on different architectures.
- min_range (int): minimum range width. Should always be at least 2 for numerical
- stability. Use this to avoid pathological behavior is a value
- that is expected to be rare actually happens in real life.
- check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
+ pdf_fixed (torch.Tensor): Fixed-point integer representation of the PDF.
+ total_range (int): Desired sum of the integer counts.
+
+ Returns:
+ torch.Tensor: Integer counts summing to total_range.
"""
- pdf = pdf.detach()
- if roundoff:
- pdf = (pdf / roundoff).floor() * roundoff
- # interpolate with uniform distribution to achieve desired minimum probability.
- total_range = 2 ** total_range_bits
+ assert torch.all(pdf_fixed >= 0), "PDF contains negative values!"
+ assert torch.isclose(pdf_fixed.sum(), torch.tensor(FIXED_SCALE, dtype=pdf_fixed.dtype, device=pdf_fixed.device), atol=1), "PDF does not sum to fixed scale!"
+
+ # Step 1: Scale the PDF using total_range and perform integer division
+ scaled_pdf = (pdf_fixed * total_range + (FIXED_SCALE // 2)) // FIXED_SCALE
+
+ # Step 2: Calculate the sum of the scaled PDF
+ current_sum = scaled_pdf.sum().item()
+ deficit = total_range - current_sum
+
+ # Step 3: Redistribute the deficit deterministically
+ if deficit > 0:
+ fractional_remainders = (pdf_fixed * total_range) % FIXED_SCALE
+ # Tie-breaking using indices to ensure deterministic sorting
+ sorted_indices = torch.argsort(
+ fractional_remainders * 1_000_000 + torch.arange(len(fractional_remainders), dtype=torch.int64, device=pdf_fixed.device),
+ descending=True
+ )
+ selected_indices = sorted_indices[:deficit]
+ scaled_pdf[selected_indices] += 1
+
+ elif deficit < 0:
+ # Need to remove excess counts
+ excess = -deficit
+ sorted_indices = torch.argsort(scaled_pdf, descending=True)
+ scaled_pdf[sorted_indices[:excess]] -= 1
+
+ return scaled_pdf
+
+def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
+ min_range: int = 2, check: bool = True) -> torch.Tensor:
+ """Integer-only version of build_stable_quantized_cdf that avoids floating point operations."""
+
+ total_range = 1 << total_range_bits
cardinality = len(pdf)
- alpha = min_range * cardinality / total_range
- assert alpha <= 1, "you must reduce min_range"
- ranges = (((1 - alpha) * total_range) * pdf).floor().long()
- ranges += min_range
- quantized_cdf = torch.cumsum(ranges, dim=-1)
- if min_range < 2:
- raise ValueError("min_range must be at least 2.")
+
+ pdf_fixed = pdf_to_fixed_point(pdf)
+
+ counts = pdf_to_integer_counts_fixed(pdf_fixed, total_range)
+
+ deficit = min_range * cardinality - counts.sum().item()
+
+ if deficit > 0:
+ available = counts - min_range
+ available_total = available[available > 0].sum().item()
+
+ if available_total > 0:
+ reduction = (available * deficit).div(available_total, rounding_mode='floor')
+ counts[available > 0] -= reduction
+ deficit -= reduction.sum().item()
+
+ if deficit > 0:
+ per_symbol = deficit // cardinality
+ remainder = deficit % cardinality
+ counts += per_symbol
+ counts[:remainder] += 1
+
+ counts = torch.maximum(counts, torch.tensor(min_range))
+
+ # Scale down if we exceed total range
+ if counts.sum().item() > total_range:
+ scale = total_range / counts.sum().item()
+ counts = (counts * scale).round().long()
+ counts = torch.maximum(counts, torch.tensor(min_range))
+ counts[counts.argmax()] -= (counts.sum().item() - total_range)
+
+ # Build CDF through cumulative sum
+ quantized_cdf = torch.cumsum(counts, dim=-1)
+
if check:
- assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
- if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
- raise ValueError("You must increase your total_range_bits.")
+ assert quantized_cdf[-1].item() <= 2 ** total_range_bits, f"CDF exceeds range: {quantized_cdf[-1]}"
+ ranges = torch.diff(quantized_cdf, prepend=torch.tensor([0]))
+ assert (ranges >= min_range).all(), f"Some ranges below minimum: {ranges.min().item()}"
+
+ #sys.exit(1) # Exit immediately after logging first case
+
+
+ log_file = "quantized_cdf_log.txt" # Log file path
+ with open(log_file, "a") as f: # Corrected syntax
+ f.write(str(quantized_cdf) + "\n")
+
return quantized_cdf
+def compute_effective_range(range_low: int, range_high: int, delta: int, total_range_bits: int) -> tp.Tuple[int, int]:
+ total_range = 1 << total_range_bits
+
+ # Scale delta using fixed-point scaling
+ scaled_delta = delta * FIXED_SCALE
+
+ # Compute effective_low and effective_high with fixed-point precision
+ effective_low_fixed = (range_low * scaled_delta + (total_range // 2)) // total_range
+ effective_high_fixed = (range_high * scaled_delta) // total_range
+
+ # Convert back from fixed-point to integer
+ effective_low = effective_low_fixed // FIXED_SCALE
+ effective_high = effective_high_fixed // FIXED_SCALE
+
+ # Ensure that effective_high is at least effective_low to maintain a valid range
+ effective_high = max(effective_high, effective_low)
+
+ return effective_low, effective_high
+
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
@@ -109,7 +211,7 @@ def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
- # If self.low and self.high start with the sames bits,
+ # If self.low and self.high start with the same bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
@@ -133,18 +235,17 @@ def push(self, symbol: int, quantized_cdf: torch.Tensor):
Args:
symbol (int): symbol to encode with the AC.
- quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
to build this from your pdf estimate.
"""
- while self.delta < 2 ** self.total_range_bits:
- self.low *= 2
- self.high = self.high * 2 + 1
+ while self.delta < (1 << self.total_range_bits):
+ self.low <<= 1
+ self.high = (self.high << 1) | 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
- effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
- effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
@@ -168,17 +269,17 @@ def flush(self):
class ArithmeticDecoder:
- """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
+ """ArithmeticDecoder, see ArithmeticCoder for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
- If the AC encoder current range is [L, H], with `L` and `H` having the some common
+ If the AC encoder current range is [L, H], with L and H having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
- For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
- `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
+ For instances, having read 3 bits b1 b2 b3, we know that [L, H] is contained inside
+ [b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
- At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
+ At some point, the prefix b1 b2 b3 will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
@@ -216,20 +317,70 @@ def _flush_common_prefix(self):
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
- This returns `None` when the stream has been exhausted.
+ This returns None when the stream has been exhausted.
Args:
- quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
- to build this from your pdf estimate. This must be **exatly**
+ quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
+ to build this from your pdf estimate. This must be **exactly**
the same cdf as the one used at encoding time.
"""
- while self.delta < 2 ** self.total_range_bits:
+ while self.delta < (1 << self.total_range_bits):
bit = self.unpacker.pull()
if bit is None:
return None
- self.low *= 2
- self.high = self.high * 2 + 1
- self.current = self.current * 2 + bit
+ self.low <<= 1
+ self.high = (self.high << 1) | 1
+ self.current = (self.current << 1) | bit
+ self.max_bit += 1
+
+ log_file = "binary_search_log.txt" # Log file path
+ with open(log_file, "a") as f:
+
+ def bin_search(low_idx: int, high_idx: int):
+ if high_idx < low_idx:
+ raise RuntimeError("Binary search failed")
+ mid = (low_idx + high_idx) // 2
+ range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
+ range_high = quantized_cdf[mid].item() - 1
+ effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
+ low = effective_low + self.low
+ high = effective_high + self.low
+
+ # Log each iteration
+ f.write(f"low_idx={low_idx}, high_idx={high_idx}, mid={mid}, low={low}, high={high}, current={self.current}\n")
+
+ if self.current >= low:
+ if self.current <= high:
+ return (mid, low, high, self.current)
+ else:
+ return bin_search(mid + 1, high_idx)
+ else:
+ return bin_search(low_idx, mid - 1)
+
+ self._last = (self.low, self.high, self.current, self.max_bit)
+ sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
+ self._dbg.append((self.low, self.high, self.current))
+ self._flush_common_prefix()
+ self._dbg2.append((self.low, self.high, self.current))
+
+ return sym
+
+ def pullx(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
+ """Pull a symbol, reading as many bits from the stream as required.
+ This returns None when the stream has been exhausted.
+
+ Args:
+ quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
+ to build this from your pdf estimate. This must be **exactly**
+ the same cdf as the one used at encoding time.
+ """
+ while self.delta < (1 << self.total_range_bits):
+ bit = self.unpacker.pull()
+ if bit is None:
+ return None
+ self.low <<= 1
+ self.high = (self.high << 1) | 1
+ self.current = (self.current << 1) | bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
@@ -239,8 +390,7 @@ def bin_search(low_idx: int, high_idx: int):
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
- effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
- effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
+ effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
@@ -290,3 +440,4 @@ def test():
if __name__ == "__main__":
test()
+
diff --git a/encodec/quantization/core_vq.py b/encodec/quantization/core_vq.py
index 1c7e8c7..62d7db0 100644
--- a/encodec/quantization/core_vq.py
+++ b/encodec/quantization/core_vq.py
@@ -33,55 +33,43 @@
import typing as tp
import warnings
-
-from einops import rearrange, repeat
import torch
-from torch import nn
+from torch import nn, Tensor
import torch.nn.functional as F
+from einops import rearrange, repeat
+from torch import jit
from .. import distrib
-
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
-
-def ema_inplace(moving_avg, new, decay: float):
+def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
-
-def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+def laplace_smoothing(x: Tensor, n_categories: int, epsilon: float = 1e-5) -> Tensor:
return (x + epsilon) / (x.sum() + n_categories * epsilon)
-
-def uniform_init(*shape: int):
+def uniform_init(*shape: int) -> Tensor:
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
-
-def sample_vectors(samples, num: int):
+def sample_vectors(samples: Tensor, num: int) -> Tensor:
num_samples, device = samples.shape[0], samples.device
-
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
-
return samples[indices]
-
-def kmeans(samples, num_clusters: int, num_iters: int = 10):
+def kmeans(samples: Tensor, num_clusters: int, num_iters: int = 10) -> tp.Tuple[Tensor, Tensor]:
dim, dtype = samples.shape[-1], samples.dtype
-
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
- diffs = rearrange(samples, "n d -> n () d") - rearrange(
- means, "c d -> () c d"
- )
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
dists = -(diffs ** 2).sum(dim=-1)
-
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
@@ -90,32 +78,23 @@ def kmeans(samples, num_clusters: int, num_iters: int = 10):
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
-
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
+def _quantize_tensor(x: Tensor, precision: int = 7) -> Tensor:
+ """Control precision of floating point operations"""
+ return torch.round(x * 10**precision) / 10**precision
+
class EuclideanCodebook(nn.Module):
- """Codebook with Euclidean distance.
- Args:
- dim (int): Dimension.
- codebook_size (int): Codebook size.
- kmeans_init (bool): Whether to use k-means to initialize the codebooks.
- If set to true, run the k-means algorithm on the first training batch and use
- the learned centroids as initialization.
- kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
- decay (float): Decay for exponential moving average over the codebooks.
- epsilon (float): Epsilon value for numerical stability.
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
- that have an exponential moving average cluster size less than the specified threshold with
- randomly selected vector from the current batch.
- """
+ """Codebook with Euclidean distance."""
+
def __init__(
self,
dim: int,
codebook_size: int,
- kmeans_init: int = False,
+ kmeans_init: bool = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
@@ -127,7 +106,6 @@ def __init__(
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
-
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
@@ -137,8 +115,9 @@ def __init__(
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
- @torch.jit.ignore
- def init_embed_(self, data):
+
+ @jit.ignore
+ def init_embed_(self, data: Tensor) -> None:
if self.inited:
return
@@ -147,16 +126,17 @@ def init_embed_(self, data):
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
- # Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
- def replace_(self, samples, mask):
+ def replace_(self, samples: Tensor, mask: Tensor) -> None:
modified_codebook = torch.where(
- mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ mask[..., None],
+ sample_vectors(samples, self.codebook_size),
+ self.embed
)
self.embed.data.copy_(modified_codebook)
- def expire_codes_(self, batch_samples):
+ def expire_codes_(self, batch_samples: Tensor) -> None:
if self.threshold_ema_dead_code == 0:
return
@@ -168,42 +148,54 @@ def expire_codes_(self, batch_samples):
self.replace_(batch_samples, mask=expired_codes)
distrib.broadcast_tensors(self.buffers())
- def preprocess(self, x):
- x = rearrange(x, "... d -> (...) d")
- return x
-
- def quantize(self, x):
- embed = self.embed.t()
- dist = -(
- x.pow(2).sum(1, keepdim=True)
- - 2 * x @ embed
- + embed.pow(2).sum(0, keepdim=True)
- )
- embed_ind = dist.max(dim=-1).indices
+ def preprocess(self, x: Tensor) -> Tensor:
+ return rearrange(x, "... d -> (...) d")
+
+ def quantize(self, x: Tensor) -> Tensor:
+ """Stabilized quantization for consistent binary tree decisions across architectures"""
+ # Carefully control precision of the codebook
+ embed = _quantize_tensor(self.embed.t())
+
+ # Break down distance calculation into controlled steps
+ # Calculate x squared term first
+ x_squared = _quantize_tensor(x.pow(2).sum(1, keepdim=True))
+
+ # Calculate embed squared term
+ embed_squared = _quantize_tensor(embed.pow(2).sum(0, keepdim=True))
+
+ # Calculate cross term with controlled precision
+ # Use matmul for better numerical stability than @
+ cross_term = _quantize_tensor(torch.matmul(x, embed))
+ cross_term = _quantize_tensor(cross_term * 2)
+
+ # Combine terms with controlled precision and ordering
+ # Note: we add the squared terms first since they're likely larger
+ dist = _quantize_tensor(x_squared + embed_squared)
+ dist = _quantize_tensor(dist - cross_term)
+ dist = -dist # Negate at the end to avoid accumulated precision loss
+
+ # Use stable sorting for consistent index selection
+ embed_ind = dist.max(dim=-1, keepdim=True).indices
+
return embed_ind
-
- def postprocess_emb(self, embed_ind, shape):
+
+ def postprocess_emb(self, embed_ind: Tensor, shape: tp.Tuple) -> Tensor:
return embed_ind.view(*shape[:-1])
- def dequantize(self, embed_ind):
- quantize = F.embedding(embed_ind, self.embed)
- return quantize
+ def dequantize(self, embed_ind: Tensor) -> Tensor:
+ return F.embedding(embed_ind, self.embed)
- def encode(self, x):
+ def encode(self, x: Tensor) -> Tensor:
shape = x.shape
- # pre-process
x = self.preprocess(x)
- # quantize
embed_ind = self.quantize(x)
- # post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
- def decode(self, embed_ind):
- quantize = self.dequantize(embed_ind)
- return quantize
+ def decode(self, embed_ind: Tensor) -> Tensor:
+ return self.dequantize(embed_ind)
- def forward(self, x):
+ def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor]:
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
@@ -215,8 +207,6 @@ def forward(self, x):
quantize = self.dequantize(embed_ind)
if self.training:
- # We do the expiry of code at that point as buffers are in sync
- # and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
@@ -230,23 +220,9 @@ def forward(self, x):
return quantize, embed_ind
-
class VectorQuantization(nn.Module):
- """Vector quantization implementation.
- Currently supports only euclidean distance.
- Args:
- dim (int): Dimension
- codebook_size (int): Codebook size
- codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
- decay (float): Decay for exponential moving average over the codebooks.
- epsilon (float): Epsilon value for numerical stability.
- kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
- kmeans_iters (int): Number of iterations used for kmeans initialization.
- threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
- that have an exponential moving average cluster size less than the specified threshold with
- randomly selected vector from the current batch.
- commitment_weight (float): Weight for commitment loss.
- """
+ """Vector quantization implementation."""
+
def __init__(
self,
dim: int,
@@ -260,43 +236,56 @@ def __init__(
commitment_weight: float = 1.,
):
super().__init__()
- _codebook_dim: int = default(codebook_dim, dim)
-
- requires_projection = _codebook_dim != dim
- self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
- self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+ self._codebook_dim: int = default(codebook_dim, dim)
+ requires_projection = self._codebook_dim != dim
+ self.project_in = (nn.Linear(dim, self._codebook_dim) if requires_projection else nn.Identity())
+ self.project_out = (nn.Linear(self._codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
- self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
- kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
- decay=decay, epsilon=epsilon,
- threshold_ema_dead_code=threshold_ema_dead_code)
+ self._codebook = EuclideanCodebook(
+ dim=self._codebook_dim,
+ codebook_size=codebook_size,
+ kmeans_init=kmeans_init,
+ kmeans_iters=kmeans_iters,
+ decay=decay,
+ epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code
+ )
self.codebook_size = codebook_size
@property
- def codebook(self):
+ def codebook(self) -> Tensor:
return self._codebook.embed
- def encode(self, x):
+ def _quantize_tensor(self, x: Tensor) -> Tensor:
+ """Control precision of tensor operations"""
+ return torch.round(x * 10**7) / 10**7
+
+ def encode(self, x: Tensor) -> Tensor:
+ """Stabilized encoding process"""
x = rearrange(x, "b d n -> b n d")
- x = self.project_in(x)
+ # Stabilize projection
+ x = _quantize_tensor(self.project_in(x))
embed_in = self._codebook.encode(x)
return embed_in
- def decode(self, embed_ind):
+ def decode(self, embed_ind: Tensor) -> Tensor:
+ """Stabilized decoding process"""
quantize = self._codebook.decode(embed_ind)
- quantize = self.project_out(quantize)
+ # Stabilize projection
+ quantize = _quantize_tensor(self.project_out(quantize))
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
- def forward(self, x):
+ def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor, Tensor]:
device = x.device
x = rearrange(x, "b d n -> b n d")
- x = self.project_in(x)
+ x = self._quantize_tensor(self.project_in(x))
quantize, embed_ind = self._codebook(x)
+ quantize = self._quantize_tensor(quantize)
if self.training:
quantize = x + (quantize - x).detach()
@@ -305,30 +294,32 @@ def forward(self, x):
if self.training:
warnings.warn('When using RVQ in training model, first check '
- 'https://github.com/facebookresearch/encodec/issues/25 . '
- 'The bug wasn\'t fixed here for reproducibility.')
+ 'https://github.com/facebookresearch/encodec/issues/25 . '
+ 'The bug wasn\'t fixed here for reproducibility.')
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
- quantize = self.project_out(quantize)
+ quantize = self._quantize_tensor(self.project_out(quantize))
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
-
class ResidualVectorQuantization(nn.Module):
- """Residual vector quantization implementation.
- Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
- """
- def __init__(self, *, num_quantizers, **kwargs):
+ """Residual vector quantization implementation with stability improvements."""
+
+ def __init__(self, *, num_quantizers: int, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
+
+ def _quantize_tensor(self, x: Tensor) -> Tensor:
+ """Control precision of tensor operations"""
+ return torch.round(x * 10**7) / 10**7
- def forward(self, x, n_q: tp.Optional[int] = None):
- quantized_out = 0.0
- residual = x
+ def forward(self, x: Tensor, n_q: tp.Optional[int] = None) -> tp.Tuple[Tensor, Tensor, Tensor]:
+ quantized_out = torch.tensor(0.0, device=x.device)
+ residual = self._quantize_tensor(x)
all_losses = []
all_indices = []
@@ -337,8 +328,8 @@ def forward(self, x, n_q: tp.Optional[int] = None):
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
- residual = residual - quantized
- quantized_out = quantized_out + quantized
+ residual = self._quantize_tensor(residual - quantized)
+ quantized_out = self._quantize_tensor(quantized_out + quantized)
all_indices.append(indices)
all_losses.append(loss)
@@ -346,22 +337,46 @@ def forward(self, x, n_q: tp.Optional[int] = None):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
- def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
- residual = x
+ def encode(self, x: Tensor, n_q: tp.Optional[int] = None) -> Tensor:
+ """Stabilized RVQ encoding"""
+ # Initial quantization of input
+ residual = _quantize_tensor(x)
all_indices = []
n_q = n_q or len(self.layers)
+
for layer in self.layers[:n_q]:
+ # Get indices for this layer
indices = layer.encode(residual)
+
+ # Decode and quantize to match encoder exactly
quantized = layer.decode(indices)
- residual = residual - quantized
+ quantized = _quantize_tensor(quantized)
+
+ # Compute and quantize residual
+ residual = _quantize_tensor(residual - quantized)
+
all_indices.append(indices)
+
+ # Stack indices with controlled precision
out_indices = torch.stack(all_indices)
return out_indices
- def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
- quantized_out = torch.tensor(0.0, device=q_indices.device)
+ def decode(self, q_indices: Tensor) -> Tensor:
+ """Stabilized RVQ decoding"""
+ quantized_out = torch.zeros(
+ q_indices.shape[1:],
+ device=q_indices.device,
+ dtype=q_indices.dtype
+ )
+
for i, indices in enumerate(q_indices):
layer = self.layers[i]
+ # Decode and stabilize each layer output
quantized = layer.decode(indices)
- quantized_out = quantized_out + quantized
+ quantized = _quantize_tensor(quantized)
+
+ # Accumulate with controlled precision
+ quantized_out = _quantize_tensor(quantized_out + quantized)
+
return quantized_out
+
From 5b7b181c4e1ad801b627bdb4288463523ef34570 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Sun, 7 Sep 2025 19:06:38 +0100
Subject: [PATCH 02/24] Revert "switch from floating point arithmetic to scaled
integers"
This reverts commit f8eda55c83bb33e8f8e0907a1249de4d51619621.
---
encodec/compress.py | 86 ++++-------
encodec/model.py | 67 ++++-----
encodec/modules/seanet.py | 2 -
encodec/quantization/ac.py | 257 +++++++------------------------
encodec/quantization/core_vq.py | 259 +++++++++++++++-----------------
5 files changed, 231 insertions(+), 440 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 64d471f..41d6c12 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -18,18 +18,17 @@
from .quantization.ac import ArithmeticCoder, ArithmeticDecoder, build_stable_quantized_cdf
from .model import EncodecModel, EncodedFrame
-# Define fixed-point scaling factors
-SCALE_FACTOR = 1 << 32 # 24 bits for fractional precision
-OFFSET_SCALE = 1 << 32 # 16 bits for offset precision
MODELS = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
-def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], use_lm: bool = True, max_context: int = 2048):
+
+def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
+ use_lm: bool = True):
"""Compress a waveform to a file-object using the given model.
-
+
Args:
model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
@@ -47,7 +46,6 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], u
if use_lm:
lm = model.get_lm_model()
- lm.max_context = max_context
with torch.no_grad():
frames = model.encode(wav[None])
@@ -62,21 +60,17 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], u
for (frame, scale) in frames:
if scale is not None:
- scale_int = int(round(scale.item() * SCALE_FACTOR))
- fo.write(struct.pack('!I', scale_int))
+ fo.write(struct.pack('!f', scale.cpu().item()))
_, K, T = frame.shape
if use_lm:
coder = ArithmeticCoder(fo)
- states = None
+ states: tp.Any = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=wav.device)
else:
packer = binary.BitPacker(model.bits_per_codebook, fo)
for t in range(T):
if use_lm:
- if offset >= max_context:
- states = None
- offset = 0
with torch.no_grad():
probas, states, offset = lm(input_, states, offset)
# We emulate a streaming scenario even though we do not provide an API for it.
@@ -84,7 +78,8 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], u
input_ = 1 + frame[:, :, t: t + 1]
for k, value in enumerate(frame[0, :, t].tolist()):
if use_lm:
- q_cdf = build_stable_quantized_cdf(probas[0, :, k, 0], coder.total_range_bits, check=False)
+ q_cdf = build_stable_quantized_cdf(
+ probas[0, :, k, 0], coder.total_range_bits, check=False)
coder.push(value, q_cdf)
else:
packer.push(value)
@@ -93,9 +88,9 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes], u
else:
packer.flush()
+
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """
- Decompress from a file-object with minimized floating point arithmetic.
+ """Decompress from a file-object.
Returns a tuple `(wav, sample_rate)`.
Args:
@@ -124,38 +119,27 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
this_segment_length = min(audio_length - offset, segment_length)
frame_length = int(math.ceil(this_segment_length * model.frame_rate / model.sample_rate))
if model.normalize:
- scale_int, = struct.unpack('!I', binary._read_exactly(fo, struct.calcsize('!I')))
- scale = torch.tensor(scale_int / SCALE_FACTOR, device=device).view(1)
+ scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
+ scale = torch.tensor(scale_f, device=device).view(1)
else:
scale = None
if use_lm:
decoder = ArithmeticDecoder(fo)
states: tp.Any = None
- offset_fixed = 0
+ offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device)
-
- log_file = "probas_log.txt"
- with open(log_file, "a") as log:
- log.write("\n===== DECODING PHASE =====\n")
-
for t in range(frame_length):
if use_lm:
with torch.no_grad():
- probas, states, offset_fixed = lm(input_, states, offset_fixed)
-
- # Log probabilities
- with open(log_file, "a") as log:
- log.write(f"\nStep {t}, Decoding PDF:\n")
- for k in range(num_codebooks):
- log.write(f"Codebook {k}: {probas[0, :, k, 0][:10].tolist()}\n")
-
+ probas, states, offset = lm(input_, states, offset)
code_list: tp.List[int] = []
for k in range(num_codebooks):
if use_lm:
- q_cdf = build_stable_quantized_cdf(probas[0, :, k, 0], decoder.total_range_bits, check=False)
+ q_cdf = build_stable_quantized_cdf(
+ probas[0, :, k, 0], decoder.total_range_bits, check=False)
code = decoder.pull(q_cdf)
else:
code = unpacker.pull()
@@ -171,17 +155,16 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
+
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
- """
- Compress a waveform using the given model with minimized floating point arithmetic.
- Returns the compressed bytes.
+ """Compress a waveform using the given model. Returns the compressed bytes.
Args:
- model (EncodecModel): A pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): Waveform to compress, should have a shape `[C, T]`, with `C`
+ model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
+ wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
Use `utils.convert_audio` if this is not the case.
- use_lm (bool): If True, use a pre-trained language model to further
+ use_lm (bool): if True, use a pre-trained language model to further
compress the stream using Entropy Coding. This will slow down compression
quite a bit, expect between 20 to 30% of size reduction.
"""
@@ -191,51 +174,38 @@ def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> by
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """
- Decompress from compressed bytes with minimized floating point arithmetic.
+ """Decompress from a file-object.
Returns a tuple `(wav, sample_rate)`.
Args:
- compressed (bytes): Compressed bytes.
- device: Device to use to perform the computations.
+ compressed (bytes): compressed bytes.
+ device: device to use to perform the computations.
"""
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
def test():
- """
- Test the compression and decompression pipeline to ensure integrity and performance.
- """
import torchaudio
torch.set_num_threads(1)
for name in MODELS.keys():
model = MODELS[name]()
sr = model.sample_rate // 1000
x, _ = torchaudio.load(f'test_{sr}k.wav')
- x = x[:, :model.sample_rate * 5] # Use first 5 seconds
+ x = x[:, :model.sample_rate * 5]
model.set_target_bandwidth(12)
for use_lm in [False, True]:
print(f"Doing {name}, use_lm={use_lm}")
begin = time.time()
- try:
- res = compress(model, x, use_lm=use_lm)
- except RuntimeError as e:
- print(f"Compression failed with use_lm={use_lm}: {e}")
- continue
+ res = compress(model, x, use_lm=use_lm)
t_comp = time.time() - begin
- try:
- x_dec, _ = decompress(res)
- except Exception as e:
- print(f"Decompression failed with use_lm={use_lm}: {e}")
- continue
+ x_dec, _ = decompress(res)
t_decomp = time.time() - begin - t_comp
kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
f"time decomp:{t_decomp:.1f}.")
- assert x_dec.shape == x.shape, "Decoded waveform shape does not match original."
+ assert x_dec.shape == x.shape
if __name__ == '__main__':
test()
-
diff --git a/encodec/model.py b/encodec/model.py
index 6448187..8914e79 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -13,68 +13,57 @@
import numpy as np
import torch
from torch import nn
-import torch.nn.init as init
from . import quantization as qt
from . import modules as m
from .utils import _check_checksum, _linear_overlap_add, _get_checkpoint_url
-import random
ROOT_URL = 'https://dl.fbaipublicfiles.com/encodec/v0/'
EncodedFrame = tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]
-def stable_softmax(x: torch.Tensor) -> torch.Tensor:
- x_max = x.max(dim=1, keepdim=True)[0]
- exp_x = torch.exp(x - x_max)
- return exp_x / exp_x.sum(dim=1, keepdim=True)
-
class LMModel(nn.Module):
- def __init__(self, n_q: int = 32, card: int = 512, dim: int = 128, max_context: int = 1024, **kwargs):
+ """Language Model to estimate probabilities of each codebook entry.
+ We predict all codebooks in parallel for a given time step.
+
+ Args:
+ n_q (int): number of codebooks.
+ card (int): codebook cardinality.
+ dim (int): transformer dimension.
+ **kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
+ """
+ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
self.dim = dim
- self.max_context = max_context
-
self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
- for emb in self.emb:
- init.normal_(emb.weight, mean=0.0, std=0.02)
- for linear in self.linears:
- init.normal_(linear.weight, mean=0.0, std=0.02)
- init.zeros_(linear.bias)
-
- def quantize_logits(self, probs: torch.Tensor, precision: int = 7) -> torch.Tensor:
- scale = 10**precision
- return (probs * scale).round().div(scale)
-
def forward(self, indices: torch.Tensor,
- states: tp.Optional[tp.List[torch.Tensor]] = None,
- offset: int = 0):
- if offset >= self.max_context:
- states = None
- offset = 0
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
+ """
+ Args:
+ indices (torch.Tensor): indices from the previous time step. Indices
+ should be 1 + actual index in the codebook. The value 0 is reserved for
+ when the index is missing (i.e. first time step). Shape should be
+ `[B, n_q, T]`.
+ states: state for the streaming decoding.
+ offset: offset of the current time step.
- K = indices.shape[1]
- input_ = torch.zeros_like(self.emb[0](indices[:, 0]))
- for k in range(K):
- input_ += self.emb[k](indices[:, k])
+ Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
+ with a shape `[B, card, n_q, T]`.
+ """
+ B, K, T = indices.shape
+ input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
+ return torch.softmax(logits, dim=1), states, offset
- logits = torch.stack([
- self.linears[k](out) for k in range(K)
- ], dim=1).permute(0, 3, 1, 2)
-
- probs = stable_softmax(logits)
- probs = self.quantize_logits(probs)
-
- return probs, states, offset
class EncodecModel(nn.Module):
"""EnCodec model operating on the raw waveform.
@@ -170,7 +159,7 @@ def _encode_frame(self, x: torch.Tensor) -> EncodedFrame:
scale = None
emb = self.encoder(x)
- codes = self.quantizer.encode(emb.to('cpu'), self.frame_rate, self.bandwidth).to(emb.device)
+ codes = self.quantizer.encode(emb, self.frame_rate, self.bandwidth)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
return codes, scale
@@ -212,7 +201,7 @@ def get_lm_model(self) -> LMModel:
"""
device = next(self.parameters()).device
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
- past_context=int(1 * self.frame_rate)).to(device)
+ past_context=int(3.5 * self.frame_rate)).to(device)
checkpoints = {
'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
diff --git a/encodec/modules/seanet.py b/encodec/modules/seanet.py
index 29d56dc..ea1c02d 100644
--- a/encodec/modules/seanet.py
+++ b/encodec/modules/seanet.py
@@ -17,8 +17,6 @@
SLSTM
)
-np.random.seed(42)
-
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
diff --git a/encodec/quantization/ac.py b/encodec/quantization/ac.py
index b75648b..f0f3e5d 100644
--- a/encodec/quantization/ac.py
+++ b/encodec/quantization/ac.py
@@ -6,155 +6,53 @@
"""Arithmetic coder."""
-import typing as tp
import io
+import math
import random
+import typing as tp
import torch
-import numpy as np
-
-print(torch.__config__.show())
from ..binary import BitPacker, BitUnpacker
-# Define the fixed-point scaling factor
-FIXED_SCALE = 1 << 32
-
-def pdf_to_fixed_point(pdf: torch.Tensor) -> torch.Tensor:
- """
- Converts a floating-point PDF to fixed-point integer representation
- while eliminating floating-point precision drift.
- """
- assert torch.all(pdf >= 0), "PDF contains negative values!"
-
- # Ensure pdf is on CPU and double precision before conversion
- pdf = pdf.to(torch.float64).cpu()
-
- # Perform scaling using integer arithmetic only
- scaled_pdf = (pdf * FIXED_SCALE + 0.5).floor().to(torch.int64)
-
- return scaled_pdf
+def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
+ roundoff: float = 1e-8, min_range: int = 2,
+ check: bool = True) -> torch.Tensor:
+ """Turn the given PDF into a quantized CDF that splits
+ [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
+ to the PDF.
-def deterministic_round(x: torch.Tensor) -> torch.Tensor:
- """
- Implements a deterministic rounding method: rounds half up.
- """
- return torch.floor(x + 0.5)
-
-def pdf_to_integer_counts_fixed(pdf_fixed: torch.Tensor, total_range: int) -> torch.Tensor:
- """
- Converts a fixed-point PDF into integer counts that sum to total_range using integer-only operations.
-
Args:
- pdf_fixed (torch.Tensor): Fixed-point integer representation of the PDF.
- total_range (int): Desired sum of the integer counts.
-
- Returns:
- torch.Tensor: Integer counts summing to total_range.
+ pdf (torch.Tensor): probability distribution, shape should be `[N]`.
+ total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
+ during the coding process is `[0, 2 ** total_range_bits - 1]`.
+ roundoff (float): will round the pdf up to that level to remove difference coming
+ from e.g. evaluating the Language Model on different architectures.
+ min_range (int): minimum range width. Should always be at least 2 for numerical
+ stability. Use this to avoid pathological behavior is a value
+ that is expected to be rare actually happens in real life.
+ check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
"""
- assert torch.all(pdf_fixed >= 0), "PDF contains negative values!"
- assert torch.isclose(pdf_fixed.sum(), torch.tensor(FIXED_SCALE, dtype=pdf_fixed.dtype, device=pdf_fixed.device), atol=1), "PDF does not sum to fixed scale!"
-
- # Step 1: Scale the PDF using total_range and perform integer division
- scaled_pdf = (pdf_fixed * total_range + (FIXED_SCALE // 2)) // FIXED_SCALE
-
- # Step 2: Calculate the sum of the scaled PDF
- current_sum = scaled_pdf.sum().item()
- deficit = total_range - current_sum
-
- # Step 3: Redistribute the deficit deterministically
- if deficit > 0:
- fractional_remainders = (pdf_fixed * total_range) % FIXED_SCALE
- # Tie-breaking using indices to ensure deterministic sorting
- sorted_indices = torch.argsort(
- fractional_remainders * 1_000_000 + torch.arange(len(fractional_remainders), dtype=torch.int64, device=pdf_fixed.device),
- descending=True
- )
- selected_indices = sorted_indices[:deficit]
- scaled_pdf[selected_indices] += 1
-
- elif deficit < 0:
- # Need to remove excess counts
- excess = -deficit
- sorted_indices = torch.argsort(scaled_pdf, descending=True)
- scaled_pdf[sorted_indices[:excess]] -= 1
-
- return scaled_pdf
-
-def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
- min_range: int = 2, check: bool = True) -> torch.Tensor:
- """Integer-only version of build_stable_quantized_cdf that avoids floating point operations."""
-
- total_range = 1 << total_range_bits
+ pdf = pdf.detach()
+ if roundoff:
+ pdf = (pdf / roundoff).floor() * roundoff
+ # interpolate with uniform distribution to achieve desired minimum probability.
+ total_range = 2 ** total_range_bits
cardinality = len(pdf)
-
- pdf_fixed = pdf_to_fixed_point(pdf)
-
- counts = pdf_to_integer_counts_fixed(pdf_fixed, total_range)
-
- deficit = min_range * cardinality - counts.sum().item()
-
- if deficit > 0:
- available = counts - min_range
- available_total = available[available > 0].sum().item()
-
- if available_total > 0:
- reduction = (available * deficit).div(available_total, rounding_mode='floor')
- counts[available > 0] -= reduction
- deficit -= reduction.sum().item()
-
- if deficit > 0:
- per_symbol = deficit // cardinality
- remainder = deficit % cardinality
- counts += per_symbol
- counts[:remainder] += 1
-
- counts = torch.maximum(counts, torch.tensor(min_range))
-
- # Scale down if we exceed total range
- if counts.sum().item() > total_range:
- scale = total_range / counts.sum().item()
- counts = (counts * scale).round().long()
- counts = torch.maximum(counts, torch.tensor(min_range))
- counts[counts.argmax()] -= (counts.sum().item() - total_range)
-
- # Build CDF through cumulative sum
- quantized_cdf = torch.cumsum(counts, dim=-1)
-
+ alpha = min_range * cardinality / total_range
+ assert alpha <= 1, "you must reduce min_range"
+ ranges = (((1 - alpha) * total_range) * pdf).floor().long()
+ ranges += min_range
+ quantized_cdf = torch.cumsum(ranges, dim=-1)
+ if min_range < 2:
+ raise ValueError("min_range must be at least 2.")
if check:
- assert quantized_cdf[-1].item() <= 2 ** total_range_bits, f"CDF exceeds range: {quantized_cdf[-1]}"
- ranges = torch.diff(quantized_cdf, prepend=torch.tensor([0]))
- assert (ranges >= min_range).all(), f"Some ranges below minimum: {ranges.min().item()}"
-
- #sys.exit(1) # Exit immediately after logging first case
-
-
- log_file = "quantized_cdf_log.txt" # Log file path
- with open(log_file, "a") as f: # Corrected syntax
- f.write(str(quantized_cdf) + "\n")
-
+ assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
+ if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
+ raise ValueError("You must increase your total_range_bits.")
return quantized_cdf
-def compute_effective_range(range_low: int, range_high: int, delta: int, total_range_bits: int) -> tp.Tuple[int, int]:
- total_range = 1 << total_range_bits
-
- # Scale delta using fixed-point scaling
- scaled_delta = delta * FIXED_SCALE
-
- # Compute effective_low and effective_high with fixed-point precision
- effective_low_fixed = (range_low * scaled_delta + (total_range // 2)) // total_range
- effective_high_fixed = (range_high * scaled_delta) // total_range
-
- # Convert back from fixed-point to integer
- effective_low = effective_low_fixed // FIXED_SCALE
- effective_high = effective_high_fixed // FIXED_SCALE
-
- # Ensure that effective_high is at least effective_low to maintain a valid range
- effective_high = max(effective_high, effective_low)
-
- return effective_low, effective_high
-
class ArithmeticCoder:
"""ArithmeticCoder,
Let us take a distribution `p` over `N` symbols, and assume we have a stream
@@ -211,7 +109,7 @@ def delta(self) -> int:
return self.high - self.low + 1
def _flush_common_prefix(self):
- # If self.low and self.high start with the same bits,
+ # If self.low and self.high start with the sames bits,
# those won't change anymore as we always just increase the range
# by powers of 2, and we can flush them out to the bit stream.
assert self.high >= self.low, (self.low, self.high)
@@ -235,17 +133,18 @@ def push(self, symbol: int, quantized_cdf: torch.Tensor):
Args:
symbol (int): symbol to encode with the AC.
- quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate.
"""
- while self.delta < (1 << self.total_range_bits):
- self.low <<= 1
- self.high = (self.high << 1) | 1
+ while self.delta < 2 ** self.total_range_bits:
+ self.low *= 2
+ self.high = self.high * 2 + 1
self.max_bit += 1
range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
range_high = quantized_cdf[symbol].item() - 1
- effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
assert self.low <= self.high
self.high = self.low + effective_high
self.low = self.low + effective_low
@@ -269,17 +168,17 @@ def flush(self):
class ArithmeticDecoder:
- """ArithmeticDecoder, see ArithmeticCoder for a detailed explanation.
+ """ArithmeticDecoder, see `ArithmeticCoder` for a detailed explanation.
Note that this must be called with **exactly** the same parameters and sequence
of quantized cdf as the arithmetic encoder or the wrong values will be decoded.
- If the AC encoder current range is [L, H], with L and H having the some common
+ If the AC encoder current range is [L, H], with `L` and `H` having the some common
prefix (i.e. the same most significant bits), then this prefix will be flushed to the stream.
- For instances, having read 3 bits b1 b2 b3, we know that [L, H] is contained inside
- [b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]. Now this specific sub-range can only be obtained
+ For instances, having read 3 bits `b1 b2 b3`, we know that `[L, H]` is contained inside
+ `[b1 b2 b3 0 ... 0 b1 b3 b3 1 ... 1]`. Now this specific sub-range can only be obtained
for a specific sequence of symbols and a binary-search allows us to decode those symbols.
- At some point, the prefix b1 b2 b3 will no longer be sufficient to decode new symbols,
+ At some point, the prefix `b1 b2 b3` will no longer be sufficient to decode new symbols,
and we will need to read new bits from the stream and repeat the process.
"""
@@ -317,70 +216,20 @@ def _flush_common_prefix(self):
def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
"""Pull a symbol, reading as many bits from the stream as required.
- This returns None when the stream has been exhausted.
+ This returns `None` when the stream has been exhausted.
Args:
- quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
- to build this from your pdf estimate. This must be **exactly**
+ quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
- while self.delta < (1 << self.total_range_bits):
+ while self.delta < 2 ** self.total_range_bits:
bit = self.unpacker.pull()
if bit is None:
return None
- self.low <<= 1
- self.high = (self.high << 1) | 1
- self.current = (self.current << 1) | bit
- self.max_bit += 1
-
- log_file = "binary_search_log.txt" # Log file path
- with open(log_file, "a") as f:
-
- def bin_search(low_idx: int, high_idx: int):
- if high_idx < low_idx:
- raise RuntimeError("Binary search failed")
- mid = (low_idx + high_idx) // 2
- range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
- range_high = quantized_cdf[mid].item() - 1
- effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
- low = effective_low + self.low
- high = effective_high + self.low
-
- # Log each iteration
- f.write(f"low_idx={low_idx}, high_idx={high_idx}, mid={mid}, low={low}, high={high}, current={self.current}\n")
-
- if self.current >= low:
- if self.current <= high:
- return (mid, low, high, self.current)
- else:
- return bin_search(mid + 1, high_idx)
- else:
- return bin_search(low_idx, mid - 1)
-
- self._last = (self.low, self.high, self.current, self.max_bit)
- sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
- self._dbg.append((self.low, self.high, self.current))
- self._flush_common_prefix()
- self._dbg2.append((self.low, self.high, self.current))
-
- return sym
-
- def pullx(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
- """Pull a symbol, reading as many bits from the stream as required.
- This returns None when the stream has been exhausted.
-
- Args:
- quantized_cdf (torch.Tensor): use build_stable_quantized_cdf
- to build this from your pdf estimate. This must be **exactly**
- the same cdf as the one used at encoding time.
- """
- while self.delta < (1 << self.total_range_bits):
- bit = self.unpacker.pull()
- if bit is None:
- return None
- self.low <<= 1
- self.high = (self.high << 1) | 1
- self.current = (self.current << 1) | bit
+ self.low *= 2
+ self.high = self.high * 2 + 1
+ self.current = self.current * 2 + bit
self.max_bit += 1
def bin_search(low_idx: int, high_idx: int):
@@ -390,7 +239,8 @@ def bin_search(low_idx: int, high_idx: int):
mid = (low_idx + high_idx) // 2
range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
range_high = quantized_cdf[mid].item() - 1
- effective_low, effective_high = compute_effective_range(range_low, range_high, self.delta, self.total_range_bits)
+ effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
+ effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
low = effective_low + self.low
high = effective_high + self.low
if self.current >= low:
@@ -440,4 +290,3 @@ def test():
if __name__ == "__main__":
test()
-
diff --git a/encodec/quantization/core_vq.py b/encodec/quantization/core_vq.py
index 62d7db0..1c7e8c7 100644
--- a/encodec/quantization/core_vq.py
+++ b/encodec/quantization/core_vq.py
@@ -33,43 +33,55 @@
import typing as tp
import warnings
+
+from einops import rearrange, repeat
import torch
-from torch import nn, Tensor
+from torch import nn
import torch.nn.functional as F
-from einops import rearrange, repeat
-from torch import jit
from .. import distrib
+
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if val is not None else d
-def ema_inplace(moving_avg: Tensor, new: Tensor, decay: float):
+
+def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
-def laplace_smoothing(x: Tensor, n_categories: int, epsilon: float = 1e-5) -> Tensor:
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
-def uniform_init(*shape: int) -> Tensor:
+
+def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
-def sample_vectors(samples: Tensor, num: int) -> Tensor:
+
+def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
+
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
+
return samples[indices]
-def kmeans(samples: Tensor, num_clusters: int, num_iters: int = 10) -> tp.Tuple[Tensor, Tensor]:
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
+
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
- diffs = rearrange(samples, "n d -> n () d") - rearrange(means, "c d -> () c d")
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
+ means, "c d -> () c d"
+ )
dists = -(diffs ** 2).sum(dim=-1)
+
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
@@ -78,23 +90,32 @@ def kmeans(samples: Tensor, num_clusters: int, num_iters: int = 10) -> tp.Tuple[
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
+
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
-def _quantize_tensor(x: Tensor, precision: int = 7) -> Tensor:
- """Control precision of floating point operations"""
- return torch.round(x * 10**precision) / 10**precision
-
class EuclideanCodebook(nn.Module):
- """Codebook with Euclidean distance."""
-
+ """Codebook with Euclidean distance.
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
def __init__(
self,
dim: int,
codebook_size: int,
- kmeans_init: bool = False,
+ kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.99,
epsilon: float = 1e-5,
@@ -106,6 +127,7 @@ def __init__(
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
+
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
@@ -115,9 +137,8 @@ def __init__(
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
-
- @jit.ignore
- def init_embed_(self, data: Tensor) -> None:
+ @torch.jit.ignore
+ def init_embed_(self, data):
if self.inited:
return
@@ -126,17 +147,16 @@ def init_embed_(self, data: Tensor) -> None:
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
distrib.broadcast_tensors(self.buffers())
- def replace_(self, samples: Tensor, mask: Tensor) -> None:
+ def replace_(self, samples, mask):
modified_codebook = torch.where(
- mask[..., None],
- sample_vectors(samples, self.codebook_size),
- self.embed
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
- def expire_codes_(self, batch_samples: Tensor) -> None:
+ def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
@@ -148,54 +168,42 @@ def expire_codes_(self, batch_samples: Tensor) -> None:
self.replace_(batch_samples, mask=expired_codes)
distrib.broadcast_tensors(self.buffers())
- def preprocess(self, x: Tensor) -> Tensor:
- return rearrange(x, "... d -> (...) d")
-
- def quantize(self, x: Tensor) -> Tensor:
- """Stabilized quantization for consistent binary tree decisions across architectures"""
- # Carefully control precision of the codebook
- embed = _quantize_tensor(self.embed.t())
-
- # Break down distance calculation into controlled steps
- # Calculate x squared term first
- x_squared = _quantize_tensor(x.pow(2).sum(1, keepdim=True))
-
- # Calculate embed squared term
- embed_squared = _quantize_tensor(embed.pow(2).sum(0, keepdim=True))
-
- # Calculate cross term with controlled precision
- # Use matmul for better numerical stability than @
- cross_term = _quantize_tensor(torch.matmul(x, embed))
- cross_term = _quantize_tensor(cross_term * 2)
-
- # Combine terms with controlled precision and ordering
- # Note: we add the squared terms first since they're likely larger
- dist = _quantize_tensor(x_squared + embed_squared)
- dist = _quantize_tensor(dist - cross_term)
- dist = -dist # Negate at the end to avoid accumulated precision loss
-
- # Use stable sorting for consistent index selection
- embed_ind = dist.max(dim=-1, keepdim=True).indices
-
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
return embed_ind
-
- def postprocess_emb(self, embed_ind: Tensor, shape: tp.Tuple) -> Tensor:
+
+ def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
- def dequantize(self, embed_ind: Tensor) -> Tensor:
- return F.embedding(embed_ind, self.embed)
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
- def encode(self, x: Tensor) -> Tensor:
+ def encode(self, x):
shape = x.shape
+ # pre-process
x = self.preprocess(x)
+ # quantize
embed_ind = self.quantize(x)
+ # post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
- def decode(self, embed_ind: Tensor) -> Tensor:
- return self.dequantize(embed_ind)
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
- def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor]:
+ def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
@@ -207,6 +215,8 @@ def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor]:
quantize = self.dequantize(embed_ind)
if self.training:
+ # We do the expiry of code at that point as buffers are in sync
+ # and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
@@ -220,9 +230,23 @@ def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor]:
return quantize, embed_ind
-class VectorQuantization(nn.Module):
- """Vector quantization implementation."""
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ commitment_weight (float): Weight for commitment loss.
+ """
def __init__(
self,
dim: int,
@@ -236,56 +260,43 @@ def __init__(
commitment_weight: float = 1.,
):
super().__init__()
- self._codebook_dim: int = default(codebook_dim, dim)
- requires_projection = self._codebook_dim != dim
- self.project_in = (nn.Linear(dim, self._codebook_dim) if requires_projection else nn.Identity())
- self.project_out = (nn.Linear(self._codebook_dim, dim) if requires_projection else nn.Identity())
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
- self._codebook = EuclideanCodebook(
- dim=self._codebook_dim,
- codebook_size=codebook_size,
- kmeans_init=kmeans_init,
- kmeans_iters=kmeans_iters,
- decay=decay,
- epsilon=epsilon,
- threshold_ema_dead_code=threshold_ema_dead_code
- )
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+ decay=decay, epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
@property
- def codebook(self) -> Tensor:
+ def codebook(self):
return self._codebook.embed
- def _quantize_tensor(self, x: Tensor) -> Tensor:
- """Control precision of tensor operations"""
- return torch.round(x * 10**7) / 10**7
-
- def encode(self, x: Tensor) -> Tensor:
- """Stabilized encoding process"""
+ def encode(self, x):
x = rearrange(x, "b d n -> b n d")
- # Stabilize projection
- x = _quantize_tensor(self.project_in(x))
+ x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
- def decode(self, embed_ind: Tensor) -> Tensor:
- """Stabilized decoding process"""
+ def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
- # Stabilize projection
- quantize = _quantize_tensor(self.project_out(quantize))
+ quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
- def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor, Tensor]:
+ def forward(self, x):
device = x.device
x = rearrange(x, "b d n -> b n d")
- x = self._quantize_tensor(self.project_in(x))
+ x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
- quantize = self._quantize_tensor(quantize)
if self.training:
quantize = x + (quantize - x).detach()
@@ -294,32 +305,30 @@ def forward(self, x: Tensor) -> tp.Tuple[Tensor, Tensor, Tensor]:
if self.training:
warnings.warn('When using RVQ in training model, first check '
- 'https://github.com/facebookresearch/encodec/issues/25 . '
- 'The bug wasn\'t fixed here for reproducibility.')
+ 'https://github.com/facebookresearch/encodec/issues/25 . '
+ 'The bug wasn\'t fixed here for reproducibility.')
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
- quantize = self._quantize_tensor(self.project_out(quantize))
+ quantize = self.project_out(quantize)
quantize = rearrange(quantize, "b n d -> b d n")
return quantize, embed_ind, loss
+
class ResidualVectorQuantization(nn.Module):
- """Residual vector quantization implementation with stability improvements."""
-
- def __init__(self, *, num_quantizers: int, **kwargs):
+ """Residual vector quantization implementation.
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+ def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
-
- def _quantize_tensor(self, x: Tensor) -> Tensor:
- """Control precision of tensor operations"""
- return torch.round(x * 10**7) / 10**7
- def forward(self, x: Tensor, n_q: tp.Optional[int] = None) -> tp.Tuple[Tensor, Tensor, Tensor]:
- quantized_out = torch.tensor(0.0, device=x.device)
- residual = self._quantize_tensor(x)
+ def forward(self, x, n_q: tp.Optional[int] = None):
+ quantized_out = 0.0
+ residual = x
all_losses = []
all_indices = []
@@ -328,8 +337,8 @@ def forward(self, x: Tensor, n_q: tp.Optional[int] = None) -> tp.Tuple[Tensor, T
for layer in self.layers[:n_q]:
quantized, indices, loss = layer(residual)
- residual = self._quantize_tensor(residual - quantized)
- quantized_out = self._quantize_tensor(quantized_out + quantized)
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
@@ -337,46 +346,22 @@ def forward(self, x: Tensor, n_q: tp.Optional[int] = None) -> tp.Tuple[Tensor, T
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
- def encode(self, x: Tensor, n_q: tp.Optional[int] = None) -> Tensor:
- """Stabilized RVQ encoding"""
- # Initial quantization of input
- residual = _quantize_tensor(x)
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+ residual = x
all_indices = []
n_q = n_q or len(self.layers)
-
for layer in self.layers[:n_q]:
- # Get indices for this layer
indices = layer.encode(residual)
-
- # Decode and quantize to match encoder exactly
quantized = layer.decode(indices)
- quantized = _quantize_tensor(quantized)
-
- # Compute and quantize residual
- residual = _quantize_tensor(residual - quantized)
-
+ residual = residual - quantized
all_indices.append(indices)
-
- # Stack indices with controlled precision
out_indices = torch.stack(all_indices)
return out_indices
- def decode(self, q_indices: Tensor) -> Tensor:
- """Stabilized RVQ decoding"""
- quantized_out = torch.zeros(
- q_indices.shape[1:],
- device=q_indices.device,
- dtype=q_indices.dtype
- )
-
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
- # Decode and stabilize each layer output
quantized = layer.decode(indices)
- quantized = _quantize_tensor(quantized)
-
- # Accumulate with controlled precision
- quantized_out = _quantize_tensor(quantized_out + quantized)
-
+ quantized_out = quantized_out + quantized
return quantized_out
-
From 21308e23850ec8e5880f2fb6a162811067de8d55 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Sun, 7 Sep 2025 19:12:17 +0100
Subject: [PATCH 03/24] restrict changes to entropy coding paths
---
encodec/compress.py | 262 ++++++++++++++++++++++---------------
encodec/lm_integer.py | 201 ++++++++++++++++++++++++++++
encodec/model.py | 27 ++--
encodec/quantization/ac.py | 132 +++++++++----------
4 files changed, 430 insertions(+), 192 deletions(-)
create mode 100644 encodec/lm_integer.py
diff --git a/encodec/compress.py b/encodec/compress.py
index 41d6c12..7db08e6 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -1,88 +1,175 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-"""API to compress/decompress audio to bytestreams."""
+# encodec/compress.py
+# Deterministic coder: architecture-stable CDF construction + logit quantization.
import io
import math
import struct
-import time
import typing as tp
import torch
from . import binary
-from .quantization.ac import ArithmeticCoder, ArithmeticDecoder, build_stable_quantized_cdf
from .model import EncodecModel, EncodedFrame
+from .quantization.ac import (
+ ArithmeticCoder,
+ ArithmeticDecoder,
+)
+# Hard determinism toggles
+torch.use_deterministic_algorithms(True)
+torch.backends.mkldnn.enabled = False
+# Registry
MODELS = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
+# Chosen scales for stability vs. compression efficiency
+# - LOGIT_QSTEP: coarse enough to suppress tiny arch drift, fine enough to preserve coding gain
+# - FP_SCALE: count scale used before integer range allocation inside the CDF
+LOGIT_QSTEP = 1.0 / 64.0
+FP_SCALE = 1 << 14 # 16384; lower than 1<<16 for better cross-arch stability
+ROUND_CDF = 1e-4 # unused in this deterministic path, kept for signature parity
+MIN_RANGE = 2 # min bin width for arithmetic coder
+
+
+def _quantize_logits_(logits: torch.Tensor, step: float = LOGIT_QSTEP) -> torch.Tensor:
+ # In-place-ish quantization without breaking autograd (we're in no_grad anyway).
+ return torch.round(logits / step) * step
+
+
+def _stable_softmax(logits: torch.Tensor, dim: int) -> torch.Tensor:
+ # f64 softmax with explicit max subtraction for numerical stability
+ m = torch.amax(logits, dim=dim, keepdim=True)
+ z = torch.exp((logits - m).to(torch.float64))
+ s = torch.sum(z, dim=dim, keepdim=True)
+ # safeguard in case of weird NaNs/Inf
+ bad = ~torch.isfinite(s) | (s <= 0)
+ if bad.any():
+ # replace by uniform
+ z = torch.ones_like(z, dtype=torch.float64)
+ s = torch.sum(z, dim=dim, keepdim=True)
+ return z / s
+
+
+def _deterministic_cdf(pdf: torch.Tensor,
+ total_range_bits: int,
+ fp_scale: int = FP_SCALE,
+ min_range: int = MIN_RANGE,
+ check: bool = False) -> torch.Tensor:
+ """
+ Architecture-stable integer CDF:
+ 1) clamp pdf; compute integer "counts" by floor(pdf * fp_scale) in f64
+ 2) allocate the remaining counts deterministically by priority
+ 3) add min_range to each bin, cum-sum to final CDF that sums to 2^bits
+ Any tiny floating diffs that don't change floor() outputs produce identical CDFs.
+ """
+ pdf = pdf.detach().to(torch.float64).clamp_min(0)
+ s = pdf.sum()
+ if (not torch.isfinite(s)) or (s <= 0):
+ pdf = torch.ones_like(pdf)
+ s = pdf.sum()
+
+ num = torch.floor(pdf * fp_scale).to(torch.int64)
+ if int(num.sum().item()) <= 0:
+ num = torch.ones_like(num)
+
+ total = 1 << total_range_bits
+ n = int(num.numel())
+ alloc = total - min_range * n
+ num_sum = int(num.sum().item())
+
+ # base integer allocation
+ base = (alloc * num) // num_sum
+ remainder = int(alloc - int(base.sum().item()))
+ if remainder > 0:
+ # deterministic priority: residual * (n+1) - index (stable sort)
+ prio = (alloc * num) - (num_sum * base)
+ idx = torch.arange(n, device=num.device, dtype=torch.int64)
+ key = prio * (n + 1) - idx
+ _, order = torch.sort(key, descending=True, stable=True)
+ base[order[:remainder]] += 1
+
+ ranges = base + min_range
+ cdf = torch.cumsum(ranges, dim=-1, dtype=torch.int64)
+
+ if check:
+ assert int(cdf[-1].item()) == total
+ assert (ranges >= min_range).all()
+ return cdf
+
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
use_lm: bool = True):
- """Compress a waveform to a file-object using the given model.
-
- Args:
- model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
- matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
- Use `utils.convert_audio` if this is not the case.
- fo (IO[bytes]): file-object to which the compressed bits will be written.
- See `compress` if you want obtain a `bytes` object instead.
- use_lm (bool): if True, use a pre-trained language model to further
- compress the stream using Entropy Coding. This will slow down compression
- quite a bit, expect between 20 to 30% of size reduction.
"""
- assert wav.dim() == 2, "Only single waveform can be encoded."
+ Compress a waveform to a file-object using the given model.
+ Deterministic path is enforced unconditionally (no metadata flags).
+ """
+ assert wav.dim() == 2, "Expected [C, T]."
if model.name not in MODELS:
- raise ValueError(f"The provided model {model.name} is not supported.")
+ raise ValueError(f"Unsupported model {model.name}.")
- if use_lm:
- lm = model.get_lm_model()
+ device = wav.device
+ # Encode once to know frames and K
with torch.no_grad():
frames = model.encode(wav[None])
+ codes0, _ = frames[0]
+ _, K, _ = codes0.shape
+
+ # Language model (float64), but logits quantized before softmax
+ lm = None
+ if use_lm:
+ lm = model.get_lm_model().to(dtype=torch.float64, device=device)
+ lm.eval()
+ # Minimal, unchanged metadata set (no new flags)
metadata = {
- 'm': model.name, # model name
- 'al': wav.shape[-1], # audio_length
- 'nc': frames[0][0].shape[1], # num_codebooks
- 'lm': use_lm, # use lm?
+ 'm': model.name,
+ 'al': int(wav.shape[-1]),
+ 'nc': int(K),
+ 'lm': bool(use_lm),
+ 'fp': int(FP_SCALE),
+ 'acv': 3,
}
binary.write_ecdc_header(fo, metadata)
+ # Bitstream
for (frame, scale) in frames:
if scale is not None:
- fo.write(struct.pack('!f', scale.cpu().item()))
- _, K, T = frame.shape
+ fo.write(struct.pack('!f', float(scale.cpu().item())))
+
+ _B, _K, T = frame.shape
if use_lm:
coder = ArithmeticCoder(fo)
states: tp.Any = None
offset = 0
- input_ = torch.zeros(1, K, 1, dtype=torch.long, device=wav.device)
+ input_ = torch.zeros(1, K, 1, dtype=torch.long, device=device)
else:
packer = binary.BitPacker(model.bits_per_codebook, fo)
+
for t in range(T):
if use_lm:
with torch.no_grad():
- probas, states, offset = lm(input_, states, offset)
- # We emulate a streaming scenario even though we do not provide an API for it.
- # This gives us a more accurate benchmark.
- input_ = 1 + frame[:, :, t: t + 1]
+ probas_raw, states, offset = lm(input_, states, offset) # [1, card, K, 1]
+ # Quantize logits (rebuild from probs with log if needed)
+ # Safer: pass-through by reverse softmax → logits, quantize, softmax
+ # But we only get probas. So enforce quantization by re-logit with clip.
+ # To avoid log(0), clamp and then re-softmax.
+ p = torch.clamp(probas_raw, min=1e-12)
+ logits = torch.log(p)
+ logits_q = _quantize_logits_(logits, LOGIT_QSTEP)
+ probas = _stable_softmax(logits_q, dim=1)
for k, value in enumerate(frame[0, :, t].tolist()):
if use_lm:
- q_cdf = build_stable_quantized_cdf(
- probas[0, :, k, 0], coder.total_range_bits, check=False)
+ q_cdf = _deterministic_cdf(probas[0, :, k, 0], coder.total_range_bits, fp_scale=FP_SCALE, check=False)
coder.push(value, q_cdf)
else:
packer.push(value)
+ if use_lm:
+ input_ = 1 + frame[:, :, t: t + 1]
+
if use_lm:
coder.flush()
else:
@@ -90,39 +177,43 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
- Returns a tuple `(wav, sample_rate)`.
-
- Args:
- fo (IO[bytes]): file-object from which to read. If you want to decompress
- from `bytes` instead, see `decompress`.
- device: device to use to perform the computations.
+ """
+ Decompress from a file-object. Deterministic path (matching encoder) is used unconditionally.
"""
metadata = binary.read_ecdc_header(fo)
model_name = metadata['m']
- audio_length = metadata['al']
- num_codebooks = metadata['nc']
- use_lm = metadata['lm']
- assert isinstance(audio_length, int)
- assert isinstance(num_codebooks, int)
+ audio_length = int(metadata['al'])
+ num_codebooks = int(metadata['nc'])
+ use_lm = bool(metadata['lm'])
+ fp_scale = int(metadata.get('fp', FP_SCALE))
+ acv = int(metadata.get('acv', 0))
+
if model_name not in MODELS:
- raise ValueError(f"The audio was compressed with an unsupported model {model_name}.")
+ raise ValueError(f"Unsupported model {model_name}.")
+ if acv != 3:
+ raise ValueError("Unsupported bitstream version; re-encode with this coder.")
+
model = MODELS[model_name]().to(device)
+ lm = None
if use_lm:
- lm = model.get_lm_model()
+ lm = model.get_lm_model().to(dtype=torch.float64, device=device)
+ lm.eval()
frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
- for offset in range(0, audio_length, segment_stride):
- this_segment_length = min(audio_length - offset, segment_length)
- frame_length = int(math.ceil(this_segment_length * model.frame_rate / model.sample_rate))
+
+ for offset_samples in range(0, audio_length, segment_stride):
+ this_len = min(audio_length - offset_samples, segment_length)
+ frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
+
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
scale = torch.tensor(scale_f, device=device).view(1)
else:
scale = None
+
if use_lm:
decoder = ArithmeticDecoder(fo)
states: tp.Any = None
@@ -130,82 +221,47 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
+
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device)
+
for t in range(frame_length):
if use_lm:
with torch.no_grad():
- probas, states, offset = lm(input_, states, offset)
+ probas_raw, states, offset = lm(input_, states, offset)
+ p = torch.clamp(probas_raw, min=1e-12)
+ logits = torch.log(p)
+ logits_q = _quantize_logits_(logits, LOGIT_QSTEP)
+ probas = _stable_softmax(logits_q, dim=1)
+
code_list: tp.List[int] = []
for k in range(num_codebooks):
if use_lm:
- q_cdf = build_stable_quantized_cdf(
- probas[0, :, k, 0], decoder.total_range_bits, check=False)
+ q_cdf = _deterministic_cdf(probas[0, :, k, 0], decoder.total_range_bits, fp_scale=fp_scale, check=False)
code = decoder.pull(q_cdf)
else:
code = unpacker.pull()
if code is None:
raise EOFError("The stream ended sooner than expected.")
code_list.append(code)
- codes = torch.tensor(code_list, dtype=torch.long, device=device)
- frame[0, :, t] = codes
+
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=device)
if use_lm:
input_ = 1 + frame[:, :, t: t + 1]
+
frames.append((frame, scale))
+
with torch.no_grad():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
- """Compress a waveform using the given model. Returns the compressed bytes.
-
- Args:
- model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
- matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
- Use `utils.convert_audio` if this is not the case.
- use_lm (bool): if True, use a pre-trained language model to further
- compress the stream using Entropy Coding. This will slow down compression
- quite a bit, expect between 20 to 30% of size reduction.
- """
fo = io.BytesIO()
compress_to_file(model, wav, fo, use_lm=use_lm)
return fo.getvalue()
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
- Returns a tuple `(wav, sample_rate)`.
-
- Args:
- compressed (bytes): compressed bytes.
- device: device to use to perform the computations.
- """
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
-
-def test():
- import torchaudio
- torch.set_num_threads(1)
- for name in MODELS.keys():
- model = MODELS[name]()
- sr = model.sample_rate // 1000
- x, _ = torchaudio.load(f'test_{sr}k.wav')
- x = x[:, :model.sample_rate * 5]
- model.set_target_bandwidth(12)
- for use_lm in [False, True]:
- print(f"Doing {name}, use_lm={use_lm}")
- begin = time.time()
- res = compress(model, x, use_lm=use_lm)
- t_comp = time.time() - begin
- x_dec, _ = decompress(res)
- t_decomp = time.time() - begin - t_comp
- kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
- print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
- f"time decomp:{t_decomp:.1f}.")
- assert x_dec.shape == x.shape
-
-
-if __name__ == '__main__':
- test()
diff --git a/encodec/lm_integer.py b/encodec/lm_integer.py
new file mode 100644
index 0000000..76179b7
--- /dev/null
+++ b/encodec/lm_integer.py
@@ -0,0 +1,201 @@
+# Copyright (c) Meta Platforms, Inc.
+# All rights reserved.
+
+import os
+import re
+import typing as tp
+
+import torch
+from torch import nn
+
+# ---- Integer-friendly LM (deterministic logits quantization) ----
+
+class LMModelInt(nn.Module):
+ """
+ Same topology as the float LM in encodec.model.LMModel, but we quantize logits
+ to a fixed step for extra determinism across architectures/BLAS backends.
+ """
+ def __init__(self, n_q: int, card: int = 1024, dim: int = 200, **kwargs):
+ super().__init__()
+ from .modules import transformer as m
+ self.card = card
+ self.n_q = n_q
+ self.dim = dim
+
+ # streaming transformer, kept in float64 for reproducibility
+ self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs).to(torch.float64)
+
+ # one embedding + one head per codebook
+ self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=torch.float64) for _ in range(n_q)])
+ self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=torch.float64) for _ in range(n_q)])
+
+ # quantize logits onto a grid to stabilize results
+ self.logit_step = 1.0 / 64.0
+
+ @torch.no_grad()
+ def forward(self, indices: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None,
+ offset: int = 0):
+ """
+ indices: [B, K, T] with K == runtime codebooks used (<= self.n_q)
+ Returns:
+ probas_or_counts: [B, card, K, T] on float64
+ states, offset: streaming state
+ """
+ B, K, T = indices.shape
+ # Sum embeddings for the K active codebooks only.
+ x = sum([self.emb[k](indices[:, k]) for k in range(K)]) # [B, T, dim]
+ out, states, offset = self.transformer(x, states, offset) # [B, T, dim]
+
+ # Project per active codebook
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, T, card]
+ logits = logits.permute(0, 3, 1, 2).contiguous() # [B, card, K, T]
+
+ # integer-like last mile: quantize logits then softmax
+ logits = torch.round(logits / self.logit_step) * self.logit_step
+ probas = torch.softmax(logits, dim=1) # still float64; AC builder accepts any nonneg vector
+ return probas, states, offset
+
+
+# --------- Checkpoint loading (robust to head-count mismatch) ---------
+
+def _infer_ckpt_nq(state: tp.Dict[str, torch.Tensor]) -> int:
+ """
+ Count how many emb.* / linears.* heads exist in the checkpoint.
+ """
+ head_idxs = []
+ pat = re.compile(r'^(emb|linears)\.(\d+)\.')
+ for k in state.keys():
+ m = pat.match(k)
+ if m:
+ head_idxs.append(int(m.group(2)))
+ return (max(head_idxs) + 1) if head_idxs else 0
+
+
+def _desired_nq_from_model(model) -> int:
+ """
+ Compute number of codebooks actually used for the MODEL'S CURRENT bandwidth.
+ """
+ # Fallbacks if anything is missing
+ default = getattr(getattr(model, 'quantizer', None), 'n_q', 32)
+
+ try:
+ # How many quantizers will be used at current bandwidth?
+ q = model.quantizer # RVQ
+ fr = model.frame_rate
+ bw = getattr(model, 'bandwidth', None)
+ if bw is None:
+ # If caller didn't set bandwidth yet, we conservatively use the max
+ return default
+ return q.get_num_quantizers_for_bandwidth(fr, bw)
+ except Exception:
+ return default
+
+
+def _checkpoint_name_for(model_name: str) -> str:
+ """
+ Use the *existing* float-LM checkpoint names (they match architecture)
+ so you don't need new files.
+ """
+ # These are the same filenames used by EnCodec's float LM.
+ mapping = {
+ 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
+ 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
+ }
+ if model_name not in mapping:
+ raise RuntimeError(f"N[48;48;201;1632;2814to LM checkpoint mapping for model '{model_name}'.")
+ return mapping[model_name]
+
+
+def _load_state_dict_from_url_or_env(ckpt_name: str):
+ """
+ If ENCODEC_LM_PATH is set, read from that folder; otherwise use torch.hub URL.
+ """
+ root = os.environ.get('ENCODEC_LM_PATH', '').strip()
+ if root:
+ path = os.path.join(root, ckpt_name)
+ if not os.path.isfile(path):
+ raise FileNotFoundError(f"ENCODEC_LM_PATH set, but file not found: {path}")
+ state = torch.load(path, map_location='cpu')
+ else:
+ from .utils import _get_checkpoint_url
+ url = _get_checkpoint_url('https://dl.fbaipublicfiles.com/encodec/v0/', ckpt_name)
+ state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type: ignore
+ return state
+
+
+def load_pretrained_integer_lm(model_or_name, device='cpu',
+ n_q: tp.Optional[int] = None,
+ card: tp.Optional[int] = None) -> LMModelInt:
+ """
+ Build LMModelInt sized to the number of codebooks you actually use,
+ then load a compatible checkpoint (even if the checkpoint has fewer heads).
+
+ Args:
+ model_or_name: EncodecModel instance (preferred), or string model name.
+ device: target device.
+ n_q: override number of heads; if None, computed from model's bandwidth.
+ card: override codebook size; if None, taken from model.quantizer.bins.
+
+ Returns:
+ LMModelInt in eval mode (float64 params), on `device`.
+ """
+ if hasattr(model_or_name, 'name'):
+ model_name = model_or_name.name
+ if n_q is None:
+ n_q = _desired_nq_from_model(model_or_name)
+ if card is None:
+ try:
+ card = model_or_name.quantizer.bins
+ except Exception:
+ card = 1024
+ else:
+ # string path
+ model_name = str(model_or_name)
+ n_q = n_q or int(os.getenv('ENCODEC_LM_NQ', '32'))
+ card = card or 1024
+
+ # safety
+ if n_q is None:
+ n_q = 32
+ if card is None:
+ card = 1024
+
+ # Build the model skeleton with the *desired* number of heads.
+ lm = LMModelInt(n_q=n_q, card=card, num_layers=5, dim=200,
+ past_context=int(3.5 * getattr(getattr(model_or_name, 'frame_rate', 50), '__int__', lambda: 50)()))
+ lm = lm.to(device=device, dtype=torch.float64)
+
+ # Load float-LM checkpoint (layout-compatible)
+ ckpt_name = _checkpoint_name_for(model_name)
+ state = _load_state_dict_from_url_or_env(ckpt_name)
+ ckpt_nq = _infer_ckpt_nq(state)
+
+ # If checkpoint has fewer heads, we will load what exists, then clone the last head.
+ # If it has more, we will load a slice of the heads.
+ # Always load with strict=False so missing/extra per-head params are okay.
+ missing, unexpected = lm.load_state_dict(state, strict=False)
+
+ # If we asked for more heads than the checkpoint provides, synthesize the extra heads
+ if ckpt_nq and n_q > ckpt_nq:
+ with torch.no_grad():
+ # pick a source head to clone (last available)
+ src = ckpt_nq - 1
+ for k in range(ckpt_nq, n_q):
+ lm.emb[k].weight.copy_(lm.emb[src].weight)
+ lm.linears[k].weight.copy_(lm.linears[src].weight)
+ lm.linears[k].bias.copy_(lm.linears[src].bias)
+
+ # Put in eval & float64
+ lm.eval()
+ for p in lm.parameters():
+ p.requires_grad_(False)
+ p.data = p.data.to(dtype=torch.float64, device=device)
+
+ # Helpful log (printed only if env set)
+ if os.getenv('ENCODEC_LM_VERBOSE', ''):
+ print(f"[lm_integer] model={model_name} desired_n_q={n_q} ckpt_heads={ckpt_nq} "
+ f"missing={len(missing)} unexpected={len(unexpected)}")
+
+ return lm
+
diff --git a/encodec/model.py b/encodec/model.py
index 8914e79..7d207e9 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -34,36 +34,27 @@ class LMModel(nn.Module):
dim (int): transformer dimension.
**kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
"""
+class LMModel(nn.Module):
def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
self.dim = dim
- self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs)
- self.emb = nn.ModuleList([nn.Embedding(card + 1, dim) for _ in range(n_q)])
- self.linears = nn.ModuleList([nn.Linear(dim, card) for _ in range(n_q)])
+ self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs).to(torch.float64)
+ self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=torch.float64) for _ in range(n_q)])
+ self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=torch.float64) for _ in range(n_q)])
+ self.logit_step = 1.0 / 8.0
+ self.tau = 2.0
def forward(self, indices: torch.Tensor,
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
- """
- Args:
- indices (torch.Tensor): indices from the previous time step. Indices
- should be 1 + actual index in the codebook. The value 0 is reserved for
- when the index is missing (i.e. first time step). Shape should be
- `[B, n_q, T]`.
- states: state for the streaming decoding.
- offset: offset of the current time step.
-
- Returns a 3-tuple `(probabilities, new_states, new_offset)` with probabilities
- with a shape `[B, card, n_q, T]`.
-
- """
B, K, T = indices.shape
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
- return torch.softmax(logits, dim=1), states, offset
-
+ logits = torch.round(logits / self.logit_step) * self.logit_step # quantize on f64
+ probas = torch.softmax(logits / self.tau, dim=1)
+ return probas, states, offset
class EncodecModel(nn.Module):
"""EnCodec model operating on the raw waveform.
diff --git a/encodec/quantization/ac.py b/encodec/quantization/ac.py
index f0f3e5d..2627187 100644
--- a/encodec/quantization/ac.py
+++ b/encodec/quantization/ac.py
@@ -14,44 +14,45 @@
from ..binary import BitPacker, BitUnpacker
-
+# encodec/quantization/ac.py — build_stable_quantized_cdf
def build_stable_quantized_cdf(pdf: torch.Tensor, total_range_bits: int,
- roundoff: float = 1e-8, min_range: int = 2,
+ fp_scale: int = 1 << 16, min_range: int = 2,
check: bool = True) -> torch.Tensor:
- """Turn the given PDF into a quantized CDF that splits
- [0, 2 ** self.total_range_bits - 1] into chunks of size roughly proportional
- to the PDF.
+ pdf = pdf.detach().to(torch.float64).clamp_min(0)
+ s = pdf.sum()
+ if not torch.isfinite(s) or s <= 0:
+ pdf = torch.ones_like(pdf)
+ s = pdf.sum()
+
+ # --- key change: avoid round-to-nearest; floor in fp64 then distribute remainder deterministically
+ num = torch.floor(pdf * fp_scale).to(torch.int64)
+ if num.sum() <= 0:
+ num = torch.ones_like(num)
+
+ total = 1 << total_range_bits
+ n = int(num.numel())
+ alloc = total - min_range * n
+ num_sum = num.sum()
+
+ base = (alloc * num) // num_sum
+ remainder = int(alloc - int(base.sum().item()))
+ if remainder > 0:
+ idx = torch.arange(n, device=num.device, dtype=torch.int64)
+ prio = (alloc * num) - (num_sum * base)
+ key = prio * (n + 1) - idx # deterministic tie-breaker
+ _, order = torch.sort(key, descending=True)
+ take = order[:remainder]
+ base[take] += 1
+
+ ranges = base + min_range
+ cdf = torch.cumsum(ranges, dim=-1, dtype=torch.int64)
- Args:
- pdf (torch.Tensor): probability distribution, shape should be `[N]`.
- total_range_bits (int): see `ArithmeticCoder`, the typical range we expect
- during the coding process is `[0, 2 ** total_range_bits - 1]`.
- roundoff (float): will round the pdf up to that level to remove difference coming
- from e.g. evaluating the Language Model on different architectures.
- min_range (int): minimum range width. Should always be at least 2 for numerical
- stability. Use this to avoid pathological behavior is a value
- that is expected to be rare actually happens in real life.
- check (bool): if True, checks that nothing bad happened, can be deactivated for speed.
- """
- pdf = pdf.detach()
- if roundoff:
- pdf = (pdf / roundoff).floor() * roundoff
- # interpolate with uniform distribution to achieve desired minimum probability.
- total_range = 2 ** total_range_bits
- cardinality = len(pdf)
- alpha = min_range * cardinality / total_range
- assert alpha <= 1, "you must reduce min_range"
- ranges = (((1 - alpha) * total_range) * pdf).floor().long()
- ranges += min_range
- quantized_cdf = torch.cumsum(ranges, dim=-1)
- if min_range < 2:
- raise ValueError("min_range must be at least 2.")
if check:
- assert quantized_cdf[-1] <= 2 ** total_range_bits, quantized_cdf[-1]
- if ((quantized_cdf[1:] - quantized_cdf[:-1]) < min_range).any() or quantized_cdf[0] < min_range:
- raise ValueError("You must increase your total_range_bits.")
- return quantized_cdf
-
+ if int(cdf[-1].item()) != total:
+ raise ValueError("cdf sum mismatch")
+ if (ranges < min_range).any():
+ raise ValueError("min_range violated")
+ return cdf
class ArithmeticCoder:
"""ArithmeticCoder,
@@ -137,18 +138,18 @@ def push(self, symbol: int, quantized_cdf: torch.Tensor):
to build this from your pdf estimate.
"""
while self.delta < 2 ** self.total_range_bits:
- self.low *= 2
- self.high = self.high * 2 + 1
+ self.low <<= 1
+ self.high = (self.high << 1) | 1
self.max_bit += 1
-
- range_low = 0 if symbol == 0 else quantized_cdf[symbol - 1].item()
- range_high = quantized_cdf[symbol].item() - 1
- effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
- effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
- assert self.low <= self.high
- self.high = self.low + effective_high
- self.low = self.low + effective_low
- assert self.low <= self.high, (effective_low, effective_high, range_low, range_high)
+ total = 1 << self.total_range_bits
+ rng = self.delta
+ cum_low = 0 if symbol == 0 else int(quantized_cdf[symbol - 1].item())
+ cum_high = int(quantized_cdf[symbol].item())
+ base = self.low
+ new_low = base + (rng * cum_low) // total
+ new_high = base + (rng * cum_high) // total - 1
+ self.low = new_low
+ self.high = new_high
self._dbg.append((self.low, self.high))
self._dbg2.append((self.low, self.high))
outs = self._flush_common_prefix()
@@ -219,7 +220,7 @@ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
This returns `None` when the stream has been exhausted.
Args:
- quantized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
+ quant[48;48;201;1632;2814tized_cdf (torch.Tensor): use `build_stable_quantized_cdf`
to build this from your pdf estimate. This must be **exatly**
the same cdf as the one used at encoding time.
"""
@@ -227,37 +228,26 @@ def pull(self, quantized_cdf: torch.Tensor) -> tp.Optional[int]:
bit = self.unpacker.pull()
if bit is None:
return None
- self.low *= 2
- self.high = self.high * 2 + 1
- self.current = self.current * 2 + bit
+ self.low = self.low << 1
+ self.high = (self.high << 1) | 1
+ self.current = (self.current << 1) | bit
self.max_bit += 1
- def bin_search(low_idx: int, high_idx: int):
- # Binary search is not just for coding interviews :)
- if high_idx < low_idx:
- raise RuntimeError("Binary search failed")
- mid = (low_idx + high_idx) // 2
- range_low = quantized_cdf[mid - 1].item() if mid > 0 else 0
- range_high = quantized_cdf[mid].item() - 1
- effective_low = int(math.ceil(range_low * (self.delta / (2 ** self.total_range_bits))))
- effective_high = int(math.floor(range_high * (self.delta / (2 ** self.total_range_bits))))
- low = effective_low + self.low
- high = effective_high + self.low
- if self.current >= low:
- if self.current <= high:
- return (mid, low, high, self.current)
- else:
- return bin_search(mid + 1, high_idx)
- else:
- return bin_search(low_idx, mid - 1)
-
- self._last = (self.low, self.high, self.current, self.max_bit)
- sym, self.low, self.high, self.current = bin_search(0, len(quantized_cdf) - 1)
+ total = 1 << self.total_range_bits
+ rng = self.delta
+ target = ((self.current - self.low + 1) * total - 1) // rng
+ t = torch.tensor(target, dtype=quantized_cdf.dtype, device=quantized_cdf.device)
+ s = torch.searchsorted(quantized_cdf, t, right=True).item()
+ cum_low = 0 if s == 0 else int(quantized_cdf[s - 1].item())
+ cum_high = int(quantized_cdf[s].item())
+ base = self.low
+ self.low = base + (rng * cum_low) // total
+ self.high = base + (rng * cum_high) // total - 1
self._dbg.append((self.low, self.high, self.current))
self._flush_common_prefix()
self._dbg2.append((self.low, self.high, self.current))
- return sym
+ return s
def test():
From f0267cfd4f90669ee6bc98b3e60b28cad49d4a15 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 8 Sep 2025 01:01:16 +0100
Subject: [PATCH 04/24] quantisation changes
---
encodec/compress.py | 99 +++++++++------------
encodec/lm_integer.py | 201 ------------------------------------------
encodec/model.py | 44 +++++----
3 files changed, 63 insertions(+), 281 deletions(-)
delete mode 100644 encodec/lm_integer.py
diff --git a/encodec/compress.py b/encodec/compress.py
index 7db08e6..205b66a 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -102,29 +102,21 @@ def _deterministic_cdf(pdf: torch.Tensor,
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
use_lm: bool = True):
- """
- Compress a waveform to a file-object using the given model.
- Deterministic path is enforced unconditionally (no metadata flags).
- """
- assert wav.dim() == 2, "Expected [C, T]."
+ assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
- device = wav.device
+ coder_device = torch.device("cpu")
- # Encode once to know frames and K
with torch.no_grad():
frames = model.encode(wav[None])
codes0, _ = frames[0]
_, K, _ = codes0.shape
- # Language model (float64), but logits quantized before softmax
lm = None
if use_lm:
- lm = model.get_lm_model().to(dtype=torch.float64, device=device)
- lm.eval()
+ lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
- # Minimal, unchanged metadata set (no new flags)
metadata = {
'm': model.name,
'al': int(wav.shape[-1]),
@@ -135,7 +127,6 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
}
binary.write_ecdc_header(fo, metadata)
- # Bitstream
for (frame, scale) in frames:
if scale is not None:
fo.write(struct.pack('!f', float(scale.cpu().item())))
@@ -143,43 +134,36 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
_B, _K, T = frame.shape
if use_lm:
coder = ArithmeticCoder(fo)
- states: tp.Any = None
+ states = None
offset = 0
- input_ = torch.zeros(1, K, 1, dtype=torch.long, device=device)
+ input_ = torch.zeros(1, K, 1, dtype=torch.long, device=coder_device)
else:
packer = binary.BitPacker(model.bits_per_codebook, fo)
for t in range(T):
if use_lm:
with torch.no_grad():
- probas_raw, states, offset = lm(input_, states, offset) # [1, card, K, 1]
- # Quantize logits (rebuild from probs with log if needed)
- # Safer: pass-through by reverse softmax → logits, quantize, softmax
- # But we only get probas. So enforce quantization by re-logit with clip.
- # To avoid log(0), clamp and then re-softmax.
- p = torch.clamp(probas_raw, min=1e-12)
- logits = torch.log(p)
- logits_q = _quantize_logits_(logits, LOGIT_QSTEP)
- probas = _stable_softmax(logits_q, dim=1)
- for k, value in enumerate(frame[0, :, t].tolist()):
- if use_lm:
+ logits_raw, states, offset = lm.forward_logits(input_, states, offset)
+ logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
+ probas = _stable_softmax(logits_q / lm.tau, dim=1)
+
+ frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
+ values = frame_slice[0, :, 0].tolist()
+ for k, value in enumerate(values):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], coder.total_range_bits, fp_scale=FP_SCALE, check=False)
coder.push(value, q_cdf)
- else:
+ input_ = 1 + frame_slice
+ else:
+ values = frame[0, :, t].detach().cpu().tolist()
+ for value in values:
packer.push(value)
- if use_lm:
- input_ = 1 + frame[:, :, t: t + 1]
if use_lm:
coder.flush()
else:
packer.flush()
-
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """
- Decompress from a file-object. Deterministic path (matching encoder) is used unconditionally.
- """
metadata = binary.read_ecdc_header(fo)
model_name = metadata['m']
audio_length = int(metadata['al'])
@@ -194,11 +178,12 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
raise ValueError("Unsupported bitstream version; re-encode with this coder.")
model = MODELS[model_name]().to(device)
+ model_device = next(model.parameters()).device
+ coder_device = torch.device("cpu")
lm = None
if use_lm:
- lm = model.get_lm_model().to(dtype=torch.float64, device=device)
- lm.eval()
+ lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
@@ -210,51 +195,51 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
- scale = torch.tensor(scale_f, device=device).view(1)
+ scale = torch.tensor(scale_f, device=coder_device).view(1)
else:
scale = None
if use_lm:
decoder = ArithmeticDecoder(fo)
- states: tp.Any = None
+ states = None
offset = 0
- input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=device)
+ input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=coder_device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
- frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=device)
+ frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=coder_device)
for t in range(frame_length):
if use_lm:
with torch.no_grad():
- probas_raw, states, offset = lm(input_, states, offset)
- p = torch.clamp(probas_raw, min=1e-12)
- logits = torch.log(p)
- logits_q = _quantize_logits_(logits, LOGIT_QSTEP)
- probas = _stable_softmax(logits_q, dim=1)
-
- code_list: tp.List[int] = []
- for k in range(num_codebooks):
- if use_lm:
+ logits_raw, states, offset = lm.forward_logits(input_, states, offset)
+ logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
+ probas = _stable_softmax(logits_q / lm.tau, dim=1)
+
+ code_list: tp.List[int] = []
+ for k in range(num_codebooks):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], decoder.total_range_bits, fp_scale=fp_scale, check=False)
code = decoder.pull(q_cdf)
- else:
- code = unpacker.pull()
- if code is None:
- raise EOFError("The stream ended sooner than expected.")
- code_list.append(code)
-
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=device)
- if use_lm:
+ if code is None:
+ raise EOFError("The stream ended sooner than expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
input_ = 1 + frame[:, :, t: t + 1]
+ else:
+ code_list: tp.List[int] = []
+ for _ in range(num_codebooks):
+ code = unpacker.pull()
+ if code is None:
+ raise EOFError("The stream ended sooner than expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
- frames.append((frame, scale))
+ frames.append((frame.to(model_device), None if scale is None else scale.to(model_device)))
with torch.no_grad():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
-
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
fo = io.BytesIO()
compress_to_file(model, wav, fo, use_lm=use_lm)
diff --git a/encodec/lm_integer.py b/encodec/lm_integer.py
deleted file mode 100644
index 76179b7..0000000
--- a/encodec/lm_integer.py
+++ /dev/null
@@ -1,201 +0,0 @@
-# Copyright (c) Meta Platforms, Inc.
-# All rights reserved.
-
-import os
-import re
-import typing as tp
-
-import torch
-from torch import nn
-
-# ---- Integer-friendly LM (deterministic logits quantization) ----
-
-class LMModelInt(nn.Module):
- """
- Same topology as the float LM in encodec.model.LMModel, but we quantize logits
- to a fixed step for extra determinism across architectures/BLAS backends.
- """
- def __init__(self, n_q: int, card: int = 1024, dim: int = 200, **kwargs):
- super().__init__()
- from .modules import transformer as m
- self.card = card
- self.n_q = n_q
- self.dim = dim
-
- # streaming transformer, kept in float64 for reproducibility
- self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs).to(torch.float64)
-
- # one embedding + one head per codebook
- self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=torch.float64) for _ in range(n_q)])
- self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=torch.float64) for _ in range(n_q)])
-
- # quantize logits onto a grid to stabilize results
- self.logit_step = 1.0 / 64.0
-
- @torch.no_grad()
- def forward(self, indices: torch.Tensor,
- states: tp.Optional[tp.List[torch.Tensor]] = None,
- offset: int = 0):
- """
- indices: [B, K, T] with K == runtime codebooks used (<= self.n_q)
- Returns:
- probas_or_counts: [B, card, K, T] on float64
- states, offset: streaming state
- """
- B, K, T = indices.shape
- # Sum embeddings for the K active codebooks only.
- x = sum([self.emb[k](indices[:, k]) for k in range(K)]) # [B, T, dim]
- out, states, offset = self.transformer(x, states, offset) # [B, T, dim]
-
- # Project per active codebook
- logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, T, card]
- logits = logits.permute(0, 3, 1, 2).contiguous() # [B, card, K, T]
-
- # integer-like last mile: quantize logits then softmax
- logits = torch.round(logits / self.logit_step) * self.logit_step
- probas = torch.softmax(logits, dim=1) # still float64; AC builder accepts any nonneg vector
- return probas, states, offset
-
-
-# --------- Checkpoint loading (robust to head-count mismatch) ---------
-
-def _infer_ckpt_nq(state: tp.Dict[str, torch.Tensor]) -> int:
- """
- Count how many emb.* / linears.* heads exist in the checkpoint.
- """
- head_idxs = []
- pat = re.compile(r'^(emb|linears)\.(\d+)\.')
- for k in state.keys():
- m = pat.match(k)
- if m:
- head_idxs.append(int(m.group(2)))
- return (max(head_idxs) + 1) if head_idxs else 0
-
-
-def _desired_nq_from_model(model) -> int:
- """
- Compute number of codebooks actually used for the MODEL'S CURRENT bandwidth.
- """
- # Fallbacks if anything is missing
- default = getattr(getattr(model, 'quantizer', None), 'n_q', 32)
-
- try:
- # How many quantizers will be used at current bandwidth?
- q = model.quantizer # RVQ
- fr = model.frame_rate
- bw = getattr(model, 'bandwidth', None)
- if bw is None:
- # If caller didn't set bandwidth yet, we conservatively use the max
- return default
- return q.get_num_quantizers_for_bandwidth(fr, bw)
- except Exception:
- return default
-
-
-def _checkpoint_name_for(model_name: str) -> str:
- """
- Use the *existing* float-LM checkpoint names (they match architecture)
- so you don't need new files.
- """
- # These are the same filenames used by EnCodec's float LM.
- mapping = {
- 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
- 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
- }
- if model_name not in mapping:
- raise RuntimeError(f"N[48;48;201;1632;2814to LM checkpoint mapping for model '{model_name}'.")
- return mapping[model_name]
-
-
-def _load_state_dict_from_url_or_env(ckpt_name: str):
- """
- If ENCODEC_LM_PATH is set, read from that folder; otherwise use torch.hub URL.
- """
- root = os.environ.get('ENCODEC_LM_PATH', '').strip()
- if root:
- path = os.path.join(root, ckpt_name)
- if not os.path.isfile(path):
- raise FileNotFoundError(f"ENCODEC_LM_PATH set, but file not found: {path}")
- state = torch.load(path, map_location='cpu')
- else:
- from .utils import _get_checkpoint_url
- url = _get_checkpoint_url('https://dl.fbaipublicfiles.com/encodec/v0/', ckpt_name)
- state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True) # type: ignore
- return state
-
-
-def load_pretrained_integer_lm(model_or_name, device='cpu',
- n_q: tp.Optional[int] = None,
- card: tp.Optional[int] = None) -> LMModelInt:
- """
- Build LMModelInt sized to the number of codebooks you actually use,
- then load a compatible checkpoint (even if the checkpoint has fewer heads).
-
- Args:
- model_or_name: EncodecModel instance (preferred), or string model name.
- device: target device.
- n_q: override number of heads; if None, computed from model's bandwidth.
- card: override codebook size; if None, taken from model.quantizer.bins.
-
- Returns:
- LMModelInt in eval mode (float64 params), on `device`.
- """
- if hasattr(model_or_name, 'name'):
- model_name = model_or_name.name
- if n_q is None:
- n_q = _desired_nq_from_model(model_or_name)
- if card is None:
- try:
- card = model_or_name.quantizer.bins
- except Exception:
- card = 1024
- else:
- # string path
- model_name = str(model_or_name)
- n_q = n_q or int(os.getenv('ENCODEC_LM_NQ', '32'))
- card = card or 1024
-
- # safety
- if n_q is None:
- n_q = 32
- if card is None:
- card = 1024
-
- # Build the model skeleton with the *desired* number of heads.
- lm = LMModelInt(n_q=n_q, card=card, num_layers=5, dim=200,
- past_context=int(3.5 * getattr(getattr(model_or_name, 'frame_rate', 50), '__int__', lambda: 50)()))
- lm = lm.to(device=device, dtype=torch.float64)
-
- # Load float-LM checkpoint (layout-compatible)
- ckpt_name = _checkpoint_name_for(model_name)
- state = _load_state_dict_from_url_or_env(ckpt_name)
- ckpt_nq = _infer_ckpt_nq(state)
-
- # If checkpoint has fewer heads, we will load what exists, then clone the last head.
- # If it has more, we will load a slice of the heads.
- # Always load with strict=False so missing/extra per-head params are okay.
- missing, unexpected = lm.load_state_dict(state, strict=False)
-
- # If we asked for more heads than the checkpoint provides, synthesize the extra heads
- if ckpt_nq and n_q > ckpt_nq:
- with torch.no_grad():
- # pick a source head to clone (last available)
- src = ckpt_nq - 1
- for k in range(ckpt_nq, n_q):
- lm.emb[k].weight.copy_(lm.emb[src].weight)
- lm.linears[k].weight.copy_(lm.linears[src].weight)
- lm.linears[k].bias.copy_(lm.linears[src].bias)
-
- # Put in eval & float64
- lm.eval()
- for p in lm.parameters():
- p.requires_grad_(False)
- p.data = p.data.to(dtype=torch.float64, device=device)
-
- # Helpful log (printed only if env set)
- if os.getenv('ENCODEC_LM_VERBOSE', ''):
- print(f"[lm_integer] model={model_name} desired_n_q={n_q} ckpt_heads={ckpt_nq} "
- f"missing={len(missing)} unexpected={len(unexpected)}")
-
- return lm
-
diff --git a/encodec/model.py b/encodec/model.py
index 7d207e9..14b1857 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -35,24 +35,30 @@ class LMModel(nn.Module):
**kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
"""
class LMModel(nn.Module):
- def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, **kwargs):
+ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, dtype=torch.float64, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
self.dim = dim
- self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs).to(torch.float64)
- self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=torch.float64) for _ in range(n_q)])
- self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=torch.float64) for _ in range(n_q)])
- self.logit_step = 1.0 / 8.0
+ self.dtype = dtype
+ self.transformer = m.StreamingTransformerEncoder(dim=dim, **kwargs).to(dtype)
+ self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=dtype) for _ in range(n_q)])
+ self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=dtype) for _ in range(n_q)])
+ self.logit_step = 1.0 / 64.0
self.tau = 2.0
- def forward(self, indices: torch.Tensor,
- states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
+ def forward_logits(self, indices: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
B, K, T = indices.shape
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
- logits = torch.round(logits / self.logit_step) * self.logit_step # quantize on f64
+ return logits, states, offset
+
+ def forward(self, indices: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
+ logits, states, offset = self.forward_logits(indices, states, offset)
+ logits = torch.round(logits / self.logit_step) * self.logit_step
probas = torch.softmax(logits / self.tau, dim=1)
return probas, states, offset
@@ -187,23 +193,15 @@ def set_target_bandwidth(self, bandwidth: float):
f"Select one of {self.target_bandwidths}.")
self.bandwidth = bandwidth
- def get_lm_model(self) -> LMModel:
- """Return the associated LM model to improve the compression rate.
- """
- device = next(self.parameters()).device
+ def get_lm_model(self, int8: bool = False) -> LMModel:
+ device = torch.device("cpu")
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
- past_context=int(3.5 * self.frame_rate)).to(device)
- checkpoints = {
- 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
- 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
- }
- try:
- checkpoint_name = checkpoints[self.name]
- except KeyError:
- raise RuntimeError("No LM pre-trained for the current Encodec model.")
+ past_context=int(3.5 * self.frame_rate), dtype=torch.float64).to(device)
+ checkpoints = {'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
+ 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th'}
+ checkpoint_name = checkpoints[self.name]
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
- state = torch.hub.load_state_dict_from_url(
- url, map_location='cpu', check_hash=True) # type: ignore
+ state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
lm.load_state_dict(state)
lm.eval()
return lm
From 267e5a8e07546ad7b3784012be15034d3004e19b Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 8 Sep 2025 01:45:04 +0100
Subject: [PATCH 05/24] segment boundaries
---
encodec/compress.py | 57 +++++++++++++++++----------------------------
1 file changed, 21 insertions(+), 36 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 205b66a..6a383e3 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -105,62 +105,55 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
-
coder_device = torch.device("cpu")
-
with torch.no_grad():
frames = model.encode(wav[None])
codes0, _ = frames[0]
_, K, _ = codes0.shape
-
lm = None
if use_lm:
lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
-
metadata = {
'm': model.name,
'al': int(wav.shape[-1]),
'nc': int(K),
'lm': bool(use_lm),
'fp': int(FP_SCALE),
- 'acv': 3,
+ 'acv': 4,
}
binary.write_ecdc_header(fo, metadata)
-
for (frame, scale) in frames:
if scale is not None:
fo.write(struct.pack('!f', float(scale.cpu().item())))
-
_B, _K, T = frame.shape
+ fo.write(struct.pack('!I', T))
if use_lm:
- coder = ArithmeticCoder(fo)
+ seg_buf = io.BytesIO()
+ coder = ArithmeticCoder(seg_buf)
states = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=coder_device)
- else:
- packer = binary.BitPacker(model.bits_per_codebook, fo)
-
- for t in range(T):
- if use_lm:
+ for t in range(T):
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
probas = _stable_softmax(logits_q / lm.tau, dim=1)
-
frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
values = frame_slice[0, :, 0].tolist()
for k, value in enumerate(values):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], coder.total_range_bits, fp_scale=FP_SCALE, check=False)
coder.push(value, q_cdf)
input_ = 1 + frame_slice
- else:
+ coder.flush()
+ seg_bytes = seg_buf.getvalue()
+ fo.write(struct.pack('!I', len(seg_bytes)))
+ fo.write(seg_bytes)
+ else:
+ packer = binary.BitPacker(model.bits_per_codebook, fo)
+ for t in range(T):
values = frame[0, :, t].detach().cpu().tolist()
for value in values:
packer.push(value)
-
- if use_lm:
- coder.flush()
- else:
packer.flush()
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
@@ -171,51 +164,45 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
use_lm = bool(metadata['lm'])
fp_scale = int(metadata.get('fp', FP_SCALE))
acv = int(metadata.get('acv', 0))
-
if model_name not in MODELS:
raise ValueError(f"Unsupported model {model_name}.")
- if acv != 3:
+ if acv != 4:
raise ValueError("Unsupported bitstream version; re-encode with this coder.")
-
model = MODELS[model_name]().to(device)
model_device = next(model.parameters()).device
coder_device = torch.device("cpu")
-
lm = None
if use_lm:
lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
-
frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
-
for offset_samples in range(0, audio_length, segment_stride):
this_len = min(audio_length - offset_samples, segment_length)
- frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
-
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
- scale = torch.tensor(scale_f, device=coder_device).view(1)
+ scale = torch.tensor(scale_f, device=model_device).view(1)
else:
scale = None
-
+ frame_length_bytes = binary._read_exactly(fo, 4)
+ frame_length = struct.unpack('!I', frame_length_bytes)[0]
if use_lm:
- decoder = ArithmeticDecoder(fo)
+ seg_len_bytes = binary._read_exactly(fo, 4)
+ seg_len = struct.unpack('!I', seg_len_bytes)[0]
+ seg_payload = io.BytesIO(binary._read_exactly(fo, seg_len))
+ decoder = ArithmeticDecoder(seg_payload)
states = None
offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=coder_device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
-
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=coder_device)
-
for t in range(frame_length):
if use_lm:
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
probas = _stable_softmax(logits_q / lm.tau, dim=1)
-
code_list: tp.List[int] = []
for k in range(num_codebooks):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], decoder.total_range_bits, fp_scale=fp_scale, check=False)
@@ -233,9 +220,7 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
raise EOFError("The stream ended sooner than expected.")
code_list.append(code)
frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
-
- frames.append((frame.to(model_device), None if scale is None else scale.to(model_device)))
-
+ frames.append((frame.to(model_device), scale))
with torch.no_grad():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
From f05261d4923477130946a6b1c42be7646b72414a Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 8 Sep 2025 01:47:37 +0100
Subject: [PATCH 06/24] Revert "segment boundaries"
This reverts commit 267e5a8e07546ad7b3784012be15034d3004e19b.
---
encodec/compress.py | 57 ++++++++++++++++++++++++++++-----------------
1 file changed, 36 insertions(+), 21 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 6a383e3..205b66a 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -105,55 +105,62 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
+
coder_device = torch.device("cpu")
+
with torch.no_grad():
frames = model.encode(wav[None])
codes0, _ = frames[0]
_, K, _ = codes0.shape
+
lm = None
if use_lm:
lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
+
metadata = {
'm': model.name,
'al': int(wav.shape[-1]),
'nc': int(K),
'lm': bool(use_lm),
'fp': int(FP_SCALE),
- 'acv': 4,
+ 'acv': 3,
}
binary.write_ecdc_header(fo, metadata)
+
for (frame, scale) in frames:
if scale is not None:
fo.write(struct.pack('!f', float(scale.cpu().item())))
+
_B, _K, T = frame.shape
- fo.write(struct.pack('!I', T))
if use_lm:
- seg_buf = io.BytesIO()
- coder = ArithmeticCoder(seg_buf)
+ coder = ArithmeticCoder(fo)
states = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=coder_device)
- for t in range(T):
+ else:
+ packer = binary.BitPacker(model.bits_per_codebook, fo)
+
+ for t in range(T):
+ if use_lm:
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
probas = _stable_softmax(logits_q / lm.tau, dim=1)
+
frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
values = frame_slice[0, :, 0].tolist()
for k, value in enumerate(values):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], coder.total_range_bits, fp_scale=FP_SCALE, check=False)
coder.push(value, q_cdf)
input_ = 1 + frame_slice
- coder.flush()
- seg_bytes = seg_buf.getvalue()
- fo.write(struct.pack('!I', len(seg_bytes)))
- fo.write(seg_bytes)
- else:
- packer = binary.BitPacker(model.bits_per_codebook, fo)
- for t in range(T):
+ else:
values = frame[0, :, t].detach().cpu().tolist()
for value in values:
packer.push(value)
+
+ if use_lm:
+ coder.flush()
+ else:
packer.flush()
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
@@ -164,45 +171,51 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
use_lm = bool(metadata['lm'])
fp_scale = int(metadata.get('fp', FP_SCALE))
acv = int(metadata.get('acv', 0))
+
if model_name not in MODELS:
raise ValueError(f"Unsupported model {model_name}.")
- if acv != 4:
+ if acv != 3:
raise ValueError("Unsupported bitstream version; re-encode with this coder.")
+
model = MODELS[model_name]().to(device)
model_device = next(model.parameters()).device
coder_device = torch.device("cpu")
+
lm = None
if use_lm:
lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
+
frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
+
for offset_samples in range(0, audio_length, segment_stride):
this_len = min(audio_length - offset_samples, segment_length)
+ frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
+
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
- scale = torch.tensor(scale_f, device=model_device).view(1)
+ scale = torch.tensor(scale_f, device=coder_device).view(1)
else:
scale = None
- frame_length_bytes = binary._read_exactly(fo, 4)
- frame_length = struct.unpack('!I', frame_length_bytes)[0]
+
if use_lm:
- seg_len_bytes = binary._read_exactly(fo, 4)
- seg_len = struct.unpack('!I', seg_len_bytes)[0]
- seg_payload = io.BytesIO(binary._read_exactly(fo, seg_len))
- decoder = ArithmeticDecoder(seg_payload)
+ decoder = ArithmeticDecoder(fo)
states = None
offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=coder_device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
+
frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=coder_device)
+
for t in range(frame_length):
if use_lm:
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
probas = _stable_softmax(logits_q / lm.tau, dim=1)
+
code_list: tp.List[int] = []
for k in range(num_codebooks):
q_cdf = _deterministic_cdf(probas[0, :, k, 0], decoder.total_range_bits, fp_scale=fp_scale, check=False)
@@ -220,7 +233,9 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
raise EOFError("The stream ended sooner than expected.")
code_list.append(code)
frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
- frames.append((frame.to(model_device), scale))
+
+ frames.append((frame.to(model_device), None if scale is None else scale.to(model_device)))
+
with torch.no_grad():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
From a485dc672f2c2f93068e28bdd88ae807cb4a4abe Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 8 Sep 2025 15:44:21 +0100
Subject: [PATCH 07/24] quantisation improvements
---
encodec/compress.py | 71 +++++++++++++++++++++++++++++----------------
1 file changed, 46 insertions(+), 25 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 205b66a..c1d60e3 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -29,29 +29,51 @@
# - LOGIT_QSTEP: coarse enough to suppress tiny arch drift, fine enough to preserve coding gain
# - FP_SCALE: count scale used before integer range allocation inside the CDF
LOGIT_QSTEP = 1.0 / 64.0
-FP_SCALE = 1 << 14 # 16384; lower than 1<<16 for better cross-arch stability
-ROUND_CDF = 1e-4 # unused in this deterministic path, kept for signature parity
-MIN_RANGE = 2 # min bin width for arithmetic coder
-
+FP_SCALE = 1 << 11
+MIN_RANGE = 6 # min bin width for arithmetic coder
+
+def _counts_from_pdf(pdf: torch.Tensor, fp_scale: int) -> torch.Tensor:
+ x = (pdf.detach().to(torch.float64).clamp_min(0) * fp_scale)
+ fx = torch.floor(x)
+ frac = x - fx
+ eps_edge = math.ldexp(1.0, -40)
+ m = (frac <= eps_edge) | (frac >= 1 - eps_edge)
+ if bool(m.any()):
+ idx = torch.arange(x.numel(), device=x.device, dtype=torch.int64).view_as(x)
+ sign = (idx.fmod(2) * 2 - 1).to(torch.float64)
+ eps = math.ldexp(1.0, -60)
+ x = torch.where(m, x + sign * eps, x)
+ fx = torch.floor(x)
+ return fx.to(torch.int64)
def _quantize_logits_(logits: torch.Tensor, step: float = LOGIT_QSTEP) -> torch.Tensor:
- # In-place-ish quantization without breaking autograd (we're in no_grad anyway).
- return torch.round(logits / step) * step
-
+ y = (logits / step).to(torch.float64)
+ eps = math.ldexp(1.0, -40)
+ q = torch.floor(y + 0.5 - eps)
+ return q * step
+
+def _softmax_or_uniform(x: torch.Tensor, dim: int) -> torch.Tensor:
+ s = _stable_softmax(x, dim)
+ span_logit = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
+ near_logit = span_logit <= (2 * LOGIT_QSTEP)
+ span_pdf = torch.amax(s, dim=dim, keepdim=True) - torch.amin(s, dim=dim, keepdim=True)
+ near_pdf = span_pdf <= (0.25 / float(FP_SCALE))
+ near = near_logit | near_pdf
+ if not bool(near.any()):
+ return s
+ k = x.size(dim)
+ u = torch.full_like(s, 1.0 / k, dtype=torch.float64)
+ return torch.where(near, u, s)
def _stable_softmax(logits: torch.Tensor, dim: int) -> torch.Tensor:
- # f64 softmax with explicit max subtraction for numerical stability
- m = torch.amax(logits, dim=dim, keepdim=True)
- z = torch.exp((logits - m).to(torch.float64))
- s = torch.sum(z, dim=dim, keepdim=True)
- # safeguard in case of weird NaNs/Inf
- bad = ~torch.isfinite(s) | (s <= 0)
- if bad.any():
- # replace by uniform
- z = torch.ones_like(z, dtype=torch.float64)
- s = torch.sum(z, dim=dim, keepdim=True)
- return z / s
-
+ x = (logits - torch.amax(logits, dim=dim, keepdim=True)).to(torch.float64)
+ z = torch.exp(x)
+ z = z.movedim(dim, -1).contiguous()
+ acc = z[..., 0].clone()
+ for i in range(1, z.size(-1)):
+ acc += z[..., i]
+ out = (z / acc.unsqueeze(-1)).movedim(-1, dim)
+ return out
def _deterministic_cdf(pdf: torch.Tensor,
total_range_bits: int,
@@ -71,7 +93,7 @@ def _deterministic_cdf(pdf: torch.Tensor,
pdf = torch.ones_like(pdf)
s = pdf.sum()
- num = torch.floor(pdf * fp_scale).to(torch.int64)
+ num = _counts_from_pdf(pdf, fp_scale).to(torch.int64)
if int(num.sum().item()) <= 0:
num = torch.ones_like(num)
@@ -144,8 +166,8 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
if use_lm:
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
- logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
- probas = _stable_softmax(logits_q / lm.tau, dim=1)
+ logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
+ probas = _softmax_or_uniform(logits_q, dim=1)
frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
values = frame_slice[0, :, 0].tolist()
@@ -213,8 +235,8 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
if use_lm:
with torch.no_grad():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
- logits_q = _quantize_logits_(logits_raw, LOGIT_QSTEP)
- probas = _stable_softmax(logits_q / lm.tau, dim=1)
+ logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
+ probas = _softmax_or_uniform(logits_q, dim=1)
code_list: tp.List[int] = []
for k in range(num_codebooks):
@@ -249,4 +271,3 @@ def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> by
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
-
From 708087d8dd1e12216f281d353e709a94ec12ad82 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Mon, 8 Sep 2025 21:27:17 +0100
Subject: [PATCH 08/24] reinstate fb comments
---
encodec/compress.py | 73 +++++++++++++++++++++++++++++++++++++++++++--
1 file changed, 71 insertions(+), 2 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index c1d60e3..21d560d 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -1,5 +1,10 @@
-# encodec/compress.py
-# Deterministic coder: architecture-stable CDF construction + logit quantization.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""API to compress/decompress audio to bytestreams."""
import io
import math
@@ -124,6 +129,19 @@ def _deterministic_cdf(pdf: torch.Tensor,
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
use_lm: bool = True):
+ """Compress a waveform to a file-object using the given model.
+
+ Args:
+ model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
+ wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
+ matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
+ Use `utils.convert_audio` if this is not the case.
+ fo (IO[bytes]): file-object to which the compressed bits will be written.
+ See `compress` if you want obtain a `bytes` object instead.
+ use_lm (bool): if True, use a pre-trained language model to further
+ compress the stream using Entropy Coding. This will slow down compression
+ quite a bit, expect between 20 to 30% of size reduction.
+ """
assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
@@ -186,6 +204,14 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
packer.flush()
def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
+ """Decompress from a file-object.
+ Returns a tuple `(wav, sample_rate)`.
+
+ Args:
+ fo (IO[bytes]): file-object from which to read. If you want to decompress
+ from `bytes` instead, see `decompress`.
+ device: device to use to perform the computations.
+ """
metadata = binary.read_ecdc_header(fo)
model_name = metadata['m']
audio_length = int(metadata['al'])
@@ -263,11 +289,54 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
return wav[0, :, :audio_length], model.sample_rate
def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
+ """Compress a waveform using the given model. Returns the compressed bytes.
+
+ Args:
+ model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
+ wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
+ matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
+ Use `utils.convert_audio` if this is not the case.
+ use_lm (bool): if True, use a pre-trained language model to further
+ compress the stream using Entropy Coding. This will slow down compression
+ quite a bit, expect between 20 to 30% of size reduction.
+ """
fo = io.BytesIO()
compress_to_file(model, wav, fo, use_lm=use_lm)
return fo.getvalue()
def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
+ """Decompress from a file-object.
+ Returns a tuple `(wav, sample_rate)`.
+
+ Args:
+ compressed (bytes): compressed bytes.
+ device: device to use to perform the computations.
+ """
fo = io.BytesIO(compressed)
return decompress_from_file(fo, device=device)
+
+def test():
+ import torchaudio
+ torch.set_num_threads(1)
+ for name in MODELS.keys():
+ model = MODELS[name]()
+ sr = model.sample_rate // 1000
+ x, _ = torchaudio.load(f'test_{sr}k.wav')
+ x = x[:, :model.sample_rate * 5]
+ model.set_target_bandwidth(12)
+ for use_lm in [False, True]:
+ print(f"Doing {name}, use_lm={use_lm}")
+ begin = time.time()
+ res = compress(model, x, use_lm=use_lm)
+ t_comp = time.time() - begin
+ x_dec, _ = decompress(res)
+ t_decomp = time.time() - begin - t_comp
+ kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
+ print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
+ f"time decomp:{t_decomp:.1f}.")
+ assert x_dec.shape == x.shape
+
+
+if __name__ == '__main__':
+ test()
From 2a45ecba7b756b40ab1774bc730db1547e44656d Mon Sep 17 00:00:00 2001
From: jbrough
Date: Tue, 9 Sep 2025 19:06:20 +0100
Subject: [PATCH 09/24] partly address performance degradations
---
encodec/compress.py | 113 ++++++++++++++++++++++++++++++++++++--------
1 file changed, 93 insertions(+), 20 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 21d560d..cff1ebc 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -20,22 +20,20 @@
ArithmeticDecoder,
)
-# Hard determinism toggles
torch.use_deterministic_algorithms(True)
torch.backends.mkldnn.enabled = False
-# Registry
MODELS = {
'encodec_24khz': EncodecModel.encodec_model_24khz,
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
-# Chosen scales for stability vs. compression efficiency
-# - LOGIT_QSTEP: coarse enough to suppress tiny arch drift, fine enough to preserve coding gain
-# - FP_SCALE: count scale used before integer range allocation inside the CDF
LOGIT_QSTEP = 1.0 / 64.0
FP_SCALE = 1 << 11
-MIN_RANGE = 6 # min bin width for arithmetic coder
+MIN_RANGE = 6
+
+_IDX_CACHE: tp.Dict[tp.Tuple[str, int, int], torch.Tensor] = {}
+_UNIFORM_CDF_CACHE: tp.Dict[tp.Tuple[str, int, int, int, int], torch.Tensor] = {}
def _counts_from_pdf(pdf: torch.Tensor, fp_scale: int) -> torch.Tensor:
x = (pdf.detach().to(torch.float64).clamp_min(0) * fp_scale)
@@ -74,9 +72,7 @@ def _stable_softmax(logits: torch.Tensor, dim: int) -> torch.Tensor:
x = (logits - torch.amax(logits, dim=dim, keepdim=True)).to(torch.float64)
z = torch.exp(x)
z = z.movedim(dim, -1).contiguous()
- acc = z[..., 0].clone()
- for i in range(1, z.size(-1)):
- acc += z[..., i]
+ acc = torch.cumsum(z, dim=-1)[..., -1]
out = (z / acc.unsqueeze(-1)).movedim(-1, dim)
return out
@@ -107,11 +103,9 @@ def _deterministic_cdf(pdf: torch.Tensor,
alloc = total - min_range * n
num_sum = int(num.sum().item())
- # base integer allocation
base = (alloc * num) // num_sum
remainder = int(alloc - int(base.sum().item()))
if remainder > 0:
- # deterministic priority: residual * (n+1) - index (stable sort)
prio = (alloc * num) - (num_sum * base)
idx = torch.arange(n, device=num.device, dtype=torch.int64)
key = prio * (n + 1) - idx
@@ -126,6 +120,76 @@ def _deterministic_cdf(pdf: torch.Tensor,
assert (ranges >= min_range).all()
return cdf
+def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
+ total_range_bits: int,
+ fp_scale: int = FP_SCALE,
+ min_range: int = MIN_RANGE,
+ check: bool = False) -> torch.Tensor:
+ """
+ Vectorized version of `_deterministic_cdf` operating on a matrix of PDFs.
+ Expects shape `[B, K]` where `B` is number of bins and `K` is number of codebooks.
+ Returns integer CDFs with the same shape.
+ """
+ assert pdf_mat.dim() == 2, "pdf_mat must be 2D: [bins, K]"
+ pdf = pdf_mat.detach().to(torch.float64).clamp_min(0)
+ s = torch.sum(pdf, dim=0)
+ invalid = (~torch.isfinite(s)) | (s <= 0)
+ if bool(invalid.any()):
+ pdf[:, invalid] = 1.0
+
+ eq0 = (pdf[0:1, :] == pdf)
+ uniform_mask = torch.all(eq0, dim=0)
+
+ num = _counts_from_pdf(pdf, fp_scale).to(torch.int64)
+ zeros = torch.sum(num, dim=0) <= 0
+ if bool(zeros.any()):
+ num[:, zeros] = 1
+
+ total = 1 << total_range_bits
+ n_bins = int(num.size(0))
+ alloc = total - min_range * n_bins
+ num_sum = torch.sum(num, dim=0)
+
+ base = (alloc * num) // num_sum
+ base_sum = torch.sum(base, dim=0)
+ remainder = (alloc - base_sum).to(torch.int64)
+
+ if bool((remainder > 0).any()):
+ prio = (alloc * num) - (num_sum * base)
+ dev = num.device
+ dev_key = (dev.type, -1 if dev.index is None else int(dev.index), n_bins)
+ idx_row = _IDX_CACHE.get(dev_key)
+ if idx_row is None:
+ idx_row = torch.arange(n_bins, device=dev, dtype=torch.int64).unsqueeze(1)
+ _IDX_CACHE[dev_key] = idx_row
+ idx = idx_row.expand(n_bins, num.size(1))
+ key = prio * (n_bins + 1) - idx
+ order = torch.argsort(key, dim=0, descending=True, stable=True)
+ max_rem = int(torch.max(remainder).item())
+ if max_rem > 0:
+ top_idx = order[:max_rem, :]
+ row_range = torch.arange(max_rem, device=num.device, dtype=torch.int64).unsqueeze(1)
+ take_mask = (row_range < remainder.unsqueeze(0)).to(base.dtype)
+ base = base.scatter_add(0, top_idx, take_mask)
+
+ ranges = base + min_range
+ cdf = torch.cumsum(ranges, dim=0, dtype=torch.int64)
+
+ if bool(uniform_mask.any()):
+ dev = pdf.device
+ cache_key = (dev.type, -1 if dev.index is None else int(dev.index), n_bins, int(total_range_bits), int(min_range))
+ u_cdf = _UNIFORM_CDF_CACHE.get(cache_key)
+ if u_cdf is None:
+ u_pdf = torch.full((n_bins,), 1.0 / n_bins, dtype=torch.float64, device=dev)
+ u_cdf = _deterministic_cdf(u_pdf, total_range_bits, fp_scale=fp_scale, min_range=min_range, check=check)
+ _UNIFORM_CDF_CACHE[cache_key] = u_cdf
+ cdf[:, uniform_mask] = u_cdf.unsqueeze(1)
+
+ if check:
+ assert torch.all(cdf[-1, :] == total)
+ assert torch.all(ranges >= min_range)
+ return cdf
+
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
use_lm: bool = True):
@@ -148,7 +212,8 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
coder_device = torch.device("cpu")
- with torch.no_grad():
+ model = model.eval()
+ with torch.inference_mode():
frames = model.encode(wav[None])
codes0, _ = frames[0]
_, K, _ = codes0.shape
@@ -182,16 +247,19 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
for t in range(T):
if use_lm:
- with torch.no_grad():
+ with torch.inference_mode():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
probas = _softmax_or_uniform(logits_q, dim=1)
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ cdf_mat = _deterministic_cdf_multi(pdf_mat, coder.total_range_bits, fp_scale=FP_SCALE, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+
frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
values = frame_slice[0, :, 0].tolist()
for k, value in enumerate(values):
- q_cdf = _deterministic_cdf(probas[0, :, k, 0], coder.total_range_bits, fp_scale=FP_SCALE, check=False)
- coder.push(value, q_cdf)
+ coder.push(value, cdf_cols[k])
input_ = 1 + frame_slice
else:
values = frame[0, :, t].detach().cpu().tolist()
@@ -225,7 +293,7 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
if acv != 3:
raise ValueError("Unsupported bitstream version; re-encode with this coder.")
- model = MODELS[model_name]().to(device)
+ model = MODELS[model_name]().to(device).eval()
model_device = next(model.parameters()).device
coder_device = torch.device("cpu")
@@ -259,15 +327,18 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
for t in range(frame_length):
if use_lm:
- with torch.no_grad():
+ with torch.inference_mode():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
probas = _softmax_or_uniform(logits_q, dim=1)
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ cdf_mat = _deterministic_cdf_multi(pdf_mat, decoder.total_range_bits, fp_scale=fp_scale, min_range=MIN_RANGE, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+
code_list: tp.List[int] = []
for k in range(num_codebooks):
- q_cdf = _deterministic_cdf(probas[0, :, k, 0], decoder.total_range_bits, fp_scale=fp_scale, check=False)
- code = decoder.pull(q_cdf)
+ code = decoder.pull(cdf_cols[k])
if code is None:
raise EOFError("The stream ended sooner than expected.")
code_list.append(code)
@@ -284,7 +355,7 @@ def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tenso
frames.append((frame.to(model_device), None if scale is None else scale.to(model_device)))
- with torch.no_grad():
+ with torch.inference_mode():
wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
@@ -318,6 +389,7 @@ def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
def test():
import torchaudio
+ import time
torch.set_num_threads(1)
for name in MODELS.keys():
model = MODELS[name]()
@@ -340,3 +412,4 @@ def test():
if __name__ == '__main__':
test()
+
From db0a7c00d088fd50c311b6701b66d0f017042383 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Wed, 18 Mar 2026 13:16:38 +0000
Subject: [PATCH 10/24] Merge deterministic LM precision improvements and acv=4
chunk framing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Fix _counts_from_pdf negative-count bug (clamp_min before floor after
near-integer perturbation); triggered at tau=1.0 with float underflow
on zero-probability tokens → non-monotonic CDF → corrupt AC decode
- Add acv=4 per-segment CRC chunk framing for blast-radius isolation:
corrupt segment → silence substitution, rest of stream intact
- Deterministic LM path: float32 weights, float64 softmax via cumsum
denominator, logit quantisation to 1/128 grid (LOGIT_QSTEP), integer
arithmetic coder; LM always on CPU for cross-platform determinism
- Tighter defaults: FP_SCALE=65536, MIN_RANGE=1, LM_TAU=1.0 (~34% gain
over raw vs ~29% with tau=2.0)
- GPU reliability: model auto-moves wav to device, LM stays on CPU;
validated MPS↔CPU↔CUDA cross-device decode
- Add legacy decode path (forward_legacy / forward_logits split) for
reading acv<3 streams from original Facebook implementation
- Add model.get_lm_model(device, dtype) for explicit LM placement
- Add scripts/precision_eval.py and scripts/payload_decode_matrix.py
for benchmarking, corruption simulation, and cross-host validation
Co-Authored-By: Claude Sonnet 4.6
---
encodec/compress.py | 434 +++++++++++++++++++++----------
encodec/model.py | 37 ++-
scripts/payload_decode_matrix.py | 101 +++++++
scripts/precision_eval.py | 253 ++++++++++++++++++
4 files changed, 674 insertions(+), 151 deletions(-)
create mode 100644 scripts/payload_decode_matrix.py
create mode 100644 scripts/precision_eval.py
diff --git a/encodec/compress.py b/encodec/compress.py
index cff1ebc..c529ab3 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -8,8 +8,10 @@
import io
import math
+import os
import struct
import typing as tp
+import zlib
import torch
@@ -18,7 +20,9 @@
from .quantization.ac import (
ArithmeticCoder,
ArithmeticDecoder,
+ build_stable_quantized_cdf,
)
+from .utils import _linear_overlap_add
torch.use_deterministic_algorithms(True)
torch.backends.mkldnn.enabled = False
@@ -28,14 +32,61 @@
'encodec_48khz': EncodecModel.encodec_model_48khz,
}
-LOGIT_QSTEP = 1.0 / 64.0
-FP_SCALE = 1 << 11
-MIN_RANGE = 6
+# ---------------------------------------------------------------------------
+# Runtime-tunable defaults via environment variables.
+# Lean float32 profile (validated cross-platform: mps→cpu, cpu→cuda).
+# ---------------------------------------------------------------------------
+
+def _env_float(name: str, default: float) -> float:
+ v = os.getenv(name)
+ return default if v is None else float(v)
+
+def _env_int(name: str, default: int) -> int:
+ v = os.getenv(name)
+ return default if v is None else int(v)
+
+def _env_bool(name: str, default: bool) -> bool:
+ v = os.getenv(name)
+ if v is None:
+ return default
+ return v.lower() in {"1", "true", "yes", "on"}
+
+def _env_dtype(name: str, default: torch.dtype) -> torch.dtype:
+ v = os.getenv(name)
+ if v is None:
+ return default
+ mapping = {"float32": torch.float32, "fp32": torch.float32,
+ "float64": torch.float64, "fp64": torch.float64}
+ try:
+ return mapping[v.lower()]
+ except KeyError as exc:
+ raise ValueError(f"Unsupported dtype override {v!r} for {name}.") from exc
+
+# Lean defaults: float32 LM, finer logit grid, high-precision CDF allocation.
+LOGIT_QSTEP = _env_float("ENCODEC_LOGIT_QSTEP", 1.0 / 128.0)
+LM_TAU = _env_float("ENCODEC_LM_TAU", 1.0)
+FP_SCALE = _env_int("ENCODEC_AC_FP_SCALE", 1 << 16)
+MIN_RANGE = _env_int("ENCODEC_AC_MIN_RANGE", 1)
+USE_NEAR_UNIFORM = _env_bool("ENCODEC_USE_NEAR_UNIFORM", False)
+DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float32)
_IDX_CACHE: tp.Dict[tp.Tuple[str, int, int], torch.Tensor] = {}
_UNIFORM_CDF_CACHE: tp.Dict[tp.Tuple[str, int, int, int, int], torch.Tensor] = {}
+_CHUNK_HEADER = struct.Struct('!II') # chunk_len (uint32 BE), crc32 (uint32 BE)
+
+
+# ---------------------------------------------------------------------------
+# CDF / probability helpers
+# ---------------------------------------------------------------------------
def _counts_from_pdf(pdf: torch.Tensor, fp_scale: int) -> torch.Tensor:
+ """Convert a PDF to integer counts via floor(pdf * fp_scale) in float64.
+
+ Near-integer fractions receive a deterministic ±ε perturbation to break
+ ties consistently across platforms. The result is clamped to ≥0 so that
+ exact-zero probabilities (common at tau=1.0 due to float underflow of
+ exp(-large)) never produce −1 via floor(0 − ε).
+ """
x = (pdf.detach().to(torch.float64).clamp_min(0) * fp_scale)
fx = torch.floor(x)
frac = x - fx
@@ -46,17 +97,34 @@ def _counts_from_pdf(pdf: torch.Tensor, fp_scale: int) -> torch.Tensor:
sign = (idx.fmod(2) * 2 - 1).to(torch.float64)
eps = math.ldexp(1.0, -60)
x = torch.where(m, x + sign * eps, x)
- fx = torch.floor(x)
+ # clamp before floor: negative sign on an exact-zero pdf would give
+ # x = −ε → floor = −1, corrupting the CDF.
+ fx = torch.floor(x.clamp_min(0))
return fx.to(torch.int64)
+
def _quantize_logits_(logits: torch.Tensor, step: float = LOGIT_QSTEP) -> torch.Tensor:
+ """Round logits to a deterministic grid (biased-floor half-step)."""
y = (logits / step).to(torch.float64)
eps = math.ldexp(1.0, -40)
q = torch.floor(y + 0.5 - eps)
return q * step
+
+def _stable_softmax(logits: torch.Tensor, dim: int) -> torch.Tensor:
+ """Softmax in float64 using a sequential cumsum denominator for
+ cross-architecture bit-reproducibility."""
+ x = (logits - torch.amax(logits, dim=dim, keepdim=True)).to(torch.float64)
+ z = torch.exp(x)
+ z = z.movedim(dim, -1).contiguous()
+ acc = torch.cumsum(z, dim=-1)[..., -1]
+ return (z / acc.unsqueeze(-1)).movedim(-1, dim)
+
+
def _softmax_or_uniform(x: torch.Tensor, dim: int) -> torch.Tensor:
s = _stable_softmax(x, dim)
+ if not USE_NEAR_UNIFORM:
+ return s
span_logit = torch.amax(x, dim=dim, keepdim=True) - torch.amin(x, dim=dim, keepdim=True)
near_logit = span_logit <= (2 * LOGIT_QSTEP)
span_pdf = torch.amax(s, dim=dim, keepdim=True) - torch.amin(s, dim=dim, keepdim=True)
@@ -68,31 +136,17 @@ def _softmax_or_uniform(x: torch.Tensor, dim: int) -> torch.Tensor:
u = torch.full_like(s, 1.0 / k, dtype=torch.float64)
return torch.where(near, u, s)
-def _stable_softmax(logits: torch.Tensor, dim: int) -> torch.Tensor:
- x = (logits - torch.amax(logits, dim=dim, keepdim=True)).to(torch.float64)
- z = torch.exp(x)
- z = z.movedim(dim, -1).contiguous()
- acc = torch.cumsum(z, dim=-1)[..., -1]
- out = (z / acc.unsqueeze(-1)).movedim(-1, dim)
- return out
def _deterministic_cdf(pdf: torch.Tensor,
total_range_bits: int,
fp_scale: int = FP_SCALE,
min_range: int = MIN_RANGE,
check: bool = False) -> torch.Tensor:
- """
- Architecture-stable integer CDF:
- 1) clamp pdf; compute integer "counts" by floor(pdf * fp_scale) in f64
- 2) allocate the remaining counts deterministically by priority
- 3) add min_range to each bin, cum-sum to final CDF that sums to 2^bits
- Any tiny floating diffs that don't change floor() outputs produce identical CDFs.
- """
+ """Architecture-stable integer CDF for a single PDF vector."""
pdf = pdf.detach().to(torch.float64).clamp_min(0)
s = pdf.sum()
if (not torch.isfinite(s)) or (s <= 0):
pdf = torch.ones_like(pdf)
- s = pdf.sum()
num = _counts_from_pdf(pdf, fp_scale).to(torch.int64)
if int(num.sum().item()) <= 0:
@@ -114,22 +168,18 @@ def _deterministic_cdf(pdf: torch.Tensor,
ranges = base + min_range
cdf = torch.cumsum(ranges, dim=-1, dtype=torch.int64)
-
if check:
assert int(cdf[-1].item()) == total
assert (ranges >= min_range).all()
return cdf
+
def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
- total_range_bits: int,
- fp_scale: int = FP_SCALE,
- min_range: int = MIN_RANGE,
- check: bool = False) -> torch.Tensor:
- """
- Vectorized version of `_deterministic_cdf` operating on a matrix of PDFs.
- Expects shape `[B, K]` where `B` is number of bins and `K` is number of codebooks.
- Returns integer CDFs with the same shape.
- """
+ total_range_bits: int,
+ fp_scale: int = FP_SCALE,
+ min_range: int = MIN_RANGE,
+ check: bool = False) -> torch.Tensor:
+ """Vectorised _deterministic_cdf over [bins, K] PDF matrix."""
assert pdf_mat.dim() == 2, "pdf_mat must be 2D: [bins, K]"
pdf = pdf_mat.detach().to(torch.float64).clamp_min(0)
s = torch.sum(pdf, dim=0)
@@ -137,6 +187,7 @@ def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
if bool(invalid.any()):
pdf[:, invalid] = 1.0
+ # Shortcut: detect fully-uniform columns and cache their CDF.
eq0 = (pdf[0:1, :] == pdf)
uniform_mask = torch.all(eq0, dim=0)
@@ -177,11 +228,13 @@ def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
if bool(uniform_mask.any()):
dev = pdf.device
- cache_key = (dev.type, -1 if dev.index is None else int(dev.index), n_bins, int(total_range_bits), int(min_range))
+ cache_key = (dev.type, -1 if dev.index is None else int(dev.index),
+ n_bins, int(total_range_bits), int(min_range))
u_cdf = _UNIFORM_CDF_CACHE.get(cache_key)
if u_cdf is None:
u_pdf = torch.full((n_bins,), 1.0 / n_bins, dtype=torch.float64, device=dev)
- u_cdf = _deterministic_cdf(u_pdf, total_range_bits, fp_scale=fp_scale, min_range=min_range, check=check)
+ u_cdf = _deterministic_cdf(u_pdf, total_range_bits,
+ fp_scale=fp_scale, min_range=min_range)
_UNIFORM_CDF_CACHE[cache_key] = u_cdf
cdf[:, uniform_mask] = u_cdf.unsqueeze(1)
@@ -191,210 +244,307 @@ def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
return cdf
+# ---------------------------------------------------------------------------
+# acv=4 chunk framing helpers
+# ---------------------------------------------------------------------------
+
+def _write_chunk(fo: tp.IO[bytes], payload: bytes) -> None:
+ """Write a CRC-protected chunk: [len: u32][crc: u32][payload]."""
+ fo.write(_CHUNK_HEADER.pack(len(payload), zlib.crc32(payload) & 0xffffffff))
+ fo.write(payload)
+
+
+def _read_chunk_payload(fo: tp.IO[bytes]) -> bytes:
+ """Read and CRC-verify one chunk. Raises ValueError on mismatch."""
+ chunk_len, chunk_crc = _CHUNK_HEADER.unpack(binary._read_exactly(fo, _CHUNK_HEADER.size))
+ payload = binary._read_exactly(fo, chunk_len)
+ actual = zlib.crc32(payload) & 0xffffffff
+ if actual != chunk_crc:
+ raise ValueError(f"Chunk CRC mismatch: expected {chunk_crc:#010x}, got {actual:#010x}.")
+ return payload
+
+
+# ---------------------------------------------------------------------------
+# compress_to_file / decompress_from_file
+# ---------------------------------------------------------------------------
+
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
- use_lm: bool = True):
- """Compress a waveform to a file-object using the given model.
+ use_lm: bool = True) -> None:
+ """Compress a waveform to a file-object.
+
+ When ``use_lm=True`` the stream is bitstream version 4 (acv=4):
+ each model segment is wrapped in a CRC-protected chunk so that a
+ single corrupt byte only damages that one segment (~1 s). The
+ arithmetic coder and LM always run on CPU for cross-platform
+ determinism; the EnCodec model may run on any device.
Args:
- model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
- matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
- Use `utils.convert_audio` if this is not the case.
- fo (IO[bytes]): file-object to which the compressed bits will be written.
- See `compress` if you want obtain a `bytes` object instead.
- use_lm (bool): if True, use a pre-trained language model to further
- compress the stream using Entropy Coding. This will slow down compression
- quite a bit, expect between 20 to 30% of size reduction.
+ model: pre-trained EncodecModel.
+ wav: ``[C, T]`` waveform at model.sample_rate.
+ fo: writable file-object.
+ use_lm: enable LM entropy coding (acv=4) vs raw bitpacking (acv=0).
"""
assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
coder_device = torch.device("cpu")
-
model = model.eval()
+ model_device = next(model.parameters()).device
with torch.inference_mode():
- frames = model.encode(wav[None])
+ frames = model.encode(wav[None].to(model_device))
codes0, _ = frames[0]
_, K, _ = codes0.shape
lm = None
+ lm_tau = LM_TAU
if use_lm:
- lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
-
- metadata = {
- 'm': model.name,
- 'al': int(wav.shape[-1]),
- 'nc': int(K),
- 'lm': bool(use_lm),
- 'fp': int(FP_SCALE),
- 'acv': 3,
+ lm = model.get_lm_model(device=coder_device,
+ dtype=DETERMINISTIC_LM_DTYPE).eval()
+ lm.tau = lm_tau
+
+ metadata: tp.Dict[str, tp.Any] = {
+ 'm': model.name,
+ 'al': int(wav.shape[-1]),
+ 'nc': int(K),
+ 'lm': bool(use_lm),
+ 'fp': int(FP_SCALE),
+ 'mr': int(MIN_RANGE),
+ 'acv': 4 if use_lm else 0,
+ 'tau': float(lm_tau),
}
binary.write_ecdc_header(fo, metadata)
for (frame, scale) in frames:
+ chunk_fo = io.BytesIO()
+
if scale is not None:
- fo.write(struct.pack('!f', float(scale.cpu().item())))
+ chunk_fo.write(struct.pack('!f', float(scale.cpu().item())))
_B, _K, T = frame.shape
if use_lm:
- coder = ArithmeticCoder(fo)
+ coder = ArithmeticCoder(chunk_fo)
states = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=coder_device)
else:
- packer = binary.BitPacker(model.bits_per_codebook, fo)
+ packer = binary.BitPacker(model.bits_per_codebook, chunk_fo)
for t in range(T):
if use_lm:
with torch.inference_mode():
logits_raw, states, offset = lm.forward_logits(input_, states, offset)
- logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
+ logits_q = _quantize_logits_(logits_raw / lm_tau, LOGIT_QSTEP)
probas = _softmax_or_uniform(logits_q, dim=1)
pdf_mat = probas[0, :, :, 0].to(coder_device)
- cdf_mat = _deterministic_cdf_multi(pdf_mat, coder.total_range_bits, fp_scale=FP_SCALE, check=False)
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat, coder.total_range_bits,
+ fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
cdf_cols = cdf_mat.t().contiguous()
- frame_slice = frame[:, :, t: t + 1].detach().to(coder_device)
- values = frame_slice[0, :, 0].tolist()
- for k, value in enumerate(values):
+ frame_slice = frame[:, :, t:t + 1].detach().to(coder_device)
+ for k, value in enumerate(frame_slice[0, :, 0].tolist()):
coder.push(value, cdf_cols[k])
input_ = 1 + frame_slice
else:
- values = frame[0, :, t].detach().cpu().tolist()
- for value in values:
+ for value in frame[0, :, t].detach().cpu().tolist():
packer.push(value)
if use_lm:
coder.flush()
+ _write_chunk(fo, chunk_fo.getvalue())
else:
packer.flush()
+ fo.write(chunk_fo.getvalue())
-def decompress_from_file(fo: tp.IO[bytes], device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
- Returns a tuple `(wav, sample_rate)`.
- Args:
- fo (IO[bytes]): file-object from which to read. If you want to decompress
- from `bytes` instead, see `decompress`.
- device: device to use to perform the computations.
+def decompress_from_file(fo: tp.IO[bytes],
+ device: str = 'cpu') -> tp.Tuple[torch.Tensor, int]:
+ """Decompress from a file-object. Returns ``(wav, sample_rate)``.
+
+ Supports:
+ * acv=0 — raw bitpacking (no LM).
+ * acv<3 — legacy LM streams from the original Facebook implementation.
+ * acv=4 — deterministic LM streams (this implementation).
+ Corrupt segments fall back to silence rather than aborting.
+
+ The model (EnCodec encoder/decoder) runs on ``device``; the LM and
+ arithmetic coder always run on CPU.
"""
metadata = binary.read_ecdc_header(fo)
- model_name = metadata['m']
+ model_name = metadata['m']
audio_length = int(metadata['al'])
num_codebooks = int(metadata['nc'])
- use_lm = bool(metadata['lm'])
- fp_scale = int(metadata.get('fp', FP_SCALE))
- acv = int(metadata.get('acv', 0))
+ use_lm = bool(metadata['lm'])
+ fp_scale = int(metadata.get('fp', FP_SCALE))
+ min_range = int(metadata.get('mr', MIN_RANGE))
+ acv = int(metadata.get('acv', 0))
+ # tau is stored since this merged implementation; fall back to env-var default
+ # so we can also decode payloads from the earlier codex-precision branch.
+ lm_tau = float(metadata.get('tau', LM_TAU))
if model_name not in MODELS:
raise ValueError(f"Unsupported model {model_name}.")
- if acv != 3:
- raise ValueError("Unsupported bitstream version; re-encode with this coder.")
+ if acv > 4:
+ raise ValueError(f"Unsupported bitstream version {acv}; re-encode.")
model = MODELS[model_name]().to(device).eval()
model_device = next(model.parameters()).device
coder_device = torch.device("cpu")
lm = None
- if use_lm:
- lm = model.get_lm_model().to(dtype=torch.float64, device=coder_device).eval()
+ legacy_lm = None
+ if use_lm and acv >= 3:
+ lm = model.get_lm_model(device=coder_device,
+ dtype=DETERMINISTIC_LM_DTYPE).eval()
+ lm.tau = lm_tau
+ elif use_lm:
+ # Legacy streams: original Facebook LM path (float32, no quantisation).
+ legacy_lm = model.get_lm_model(device=coder_device,
+ dtype=torch.float32).eval()
- frames: tp.List[EncodedFrame] = []
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
+ decoded_frames: tp.List[torch.Tensor] = []
+ frames: tp.List[EncodedFrame] = []
for offset_samples in range(0, audio_length, segment_stride):
this_len = min(audio_length - offset_samples, segment_length)
frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
+ frame_fo = fo
+
+ if acv == 4:
+ try:
+ frame_fo = io.BytesIO(_read_chunk_payload(fo))
+ except Exception:
+ # Corrupt chunk → substitute silence and continue.
+ decoded_frames.append(
+ torch.zeros(1, model.channels, this_len, device=model_device))
+ continue
if model.normalize:
- scale_f, = struct.unpack('!f', binary._read_exactly(fo, struct.calcsize('!f')))
+ scale_f, = struct.unpack('!f', binary._read_exactly(
+ frame_fo, struct.calcsize('!f')))
scale = torch.tensor(scale_f, device=coder_device).view(1)
else:
scale = None
if use_lm:
- decoder = ArithmeticDecoder(fo)
+ decoder = ArithmeticDecoder(frame_fo)
states = None
offset = 0
- input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=coder_device)
+ input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long,
+ device=coder_device)
else:
- unpacker = binary.BitUnpacker(model.bits_per_codebook, fo)
-
- frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=coder_device)
-
- for t in range(frame_length):
- if use_lm:
- with torch.inference_mode():
- logits_raw, states, offset = lm.forward_logits(input_, states, offset)
- logits_q = _quantize_logits_(logits_raw / lm.tau, LOGIT_QSTEP)
- probas = _softmax_or_uniform(logits_q, dim=1)
-
- pdf_mat = probas[0, :, :, 0].to(coder_device)
- cdf_mat = _deterministic_cdf_multi(pdf_mat, decoder.total_range_bits, fp_scale=fp_scale, min_range=MIN_RANGE, check=False)
- cdf_cols = cdf_mat.t().contiguous()
-
- code_list: tp.List[int] = []
- for k in range(num_codebooks):
- code = decoder.pull(cdf_cols[k])
- if code is None:
- raise EOFError("The stream ended sooner than expected.")
- code_list.append(code)
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
- input_ = 1 + frame[:, :, t: t + 1]
- else:
- code_list: tp.List[int] = []
- for _ in range(num_codebooks):
- code = unpacker.pull()
- if code is None:
- raise EOFError("The stream ended sooner than expected.")
- code_list.append(code)
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
-
- frames.append((frame.to(model_device), None if scale is None else scale.to(model_device)))
+ unpacker = binary.BitUnpacker(model.bits_per_codebook, frame_fo)
+
+ frame = torch.zeros(1, num_codebooks, frame_length,
+ dtype=torch.long, device=coder_device)
+ try:
+ for t in range(frame_length):
+ if use_lm and acv >= 3:
+ with torch.inference_mode():
+ logits_raw, states, offset = lm.forward_logits(
+ input_, states, offset)
+ logits_q = _quantize_logits_(logits_raw / lm_tau,
+ LOGIT_QSTEP)
+ probas = _softmax_or_uniform(logits_q, dim=1)
+
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat, decoder.total_range_bits,
+ fp_scale=fp_scale, min_range=min_range, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+ code_list: tp.List[int] = []
+ for k in range(num_codebooks):
+ code = decoder.pull(cdf_cols[k])
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
+ input_ = 1 + frame[:, :, t:t + 1]
+
+ elif use_lm: # legacy path
+ with torch.inference_mode():
+ probas, states, offset = legacy_lm.forward_legacy(
+ input_, states, offset)
+ code_list = []
+ for k in range(num_codebooks):
+ q_cdf = build_stable_quantized_cdf(
+ probas[0, :, k, 0], decoder.total_range_bits,
+ check=False)
+ code = decoder.pull(q_cdf)
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
+ input_ = 1 + frame[:, :, t:t + 1]
+
+ else:
+ code_list = []
+ for _ in range(num_codebooks):
+ code = unpacker.pull()
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
+
+ except Exception:
+ if acv == 4:
+ decoded_frames.append(
+ torch.zeros(1, model.channels, this_len, device=model_device))
+ continue
+ raise
+
+ encoded_frame = (frame.to(model_device),
+ None if scale is None else scale.to(model_device))
+ if acv == 4:
+ with torch.inference_mode():
+ decoded_frames.append(
+ model._decode_frame(encoded_frame)[..., :this_len])
+ else:
+ frames.append(encoded_frame)
- with torch.inference_mode():
- wav = model.decode(frames)
+ if acv == 4:
+ if model.segment_length is None:
+ wav = decoded_frames[0]
+ else:
+ wav = _linear_overlap_add(decoded_frames, model.segment_stride or 1)
+ else:
+ with torch.inference_mode():
+ wav = model.decode(frames)
return wav[0, :, :audio_length], model.sample_rate
-def compress(model: EncodecModel, wav: torch.Tensor, use_lm: bool = False) -> bytes:
- """Compress a waveform using the given model. Returns the compressed bytes.
- Args:
- model (EncodecModel): a pre-trained EncodecModel to use to compress the audio.
- wav (torch.Tensor): waveform to compress, should have a shape `[C, T]`, with `C`
- matching `model.channels`, and the proper sample rate (e.g. `model.sample_rate`).
- Use `utils.convert_audio` if this is not the case.
- use_lm (bool): if True, use a pre-trained language model to further
- compress the stream using Entropy Coding. This will slow down compression
- quite a bit, expect between 20 to 30% of size reduction.
- """
+def compress(model: EncodecModel, wav: torch.Tensor,
+ use_lm: bool = False) -> bytes:
+ """Compress a waveform and return bytes."""
fo = io.BytesIO()
compress_to_file(model, wav, fo, use_lm=use_lm)
return fo.getvalue()
-def decompress(compressed: bytes, device='cpu') -> tp.Tuple[torch.Tensor, int]:
- """Decompress from a file-object.
- Returns a tuple `(wav, sample_rate)`.
+def decompress(compressed: bytes,
+ device: str = 'cpu') -> tp.Tuple[torch.Tensor, int]:
+ """Decompress from bytes. Returns ``(wav, sample_rate)``."""
+ return decompress_from_file(io.BytesIO(compressed), device=device)
- Args:
- compressed (bytes): compressed bytes.
- device: device to use to perform the computations.
- """
- fo = io.BytesIO(compressed)
- return decompress_from_file(fo, device=device)
def test():
- import torchaudio
+ import soundfile as sf
import time
torch.set_num_threads(1)
for name in MODELS.keys():
model = MODELS[name]()
- sr = model.sample_rate // 1000
- x, _ = torchaudio.load(f'test_{sr}k.wav')
+ suffix = name.split('_')[1][:3]
+ x, sr = sf.read(f'test_{suffix}.wav', always_2d=True, dtype='float32')
+ x = torch.from_numpy(x.T.copy())
+ from .utils import convert_audio
+ x = convert_audio(x, sr, model.sample_rate, model.channels)
x = x[:, :model.sample_rate * 5]
model.set_target_bandwidth(12)
for use_lm in [False, True]:
@@ -405,11 +555,9 @@ def test():
x_dec, _ = decompress(res)
t_decomp = time.time() - begin - t_comp
kbps = 8 * len(res) / 1000 / (x.shape[-1] / model.sample_rate)
- print(f"kbps: {kbps:.1f}, time comp: {t_comp:.1f} sec. "
- f"time decomp:{t_decomp:.1f}.")
+ print(f" kbps={kbps:.1f} enc={t_comp:.2f}s dec={t_decomp:.2f}s")
assert x_dec.shape == x.shape
if __name__ == '__main__':
test()
-
diff --git a/encodec/model.py b/encodec/model.py
index 14b1857..ad66b1f 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -32,10 +32,12 @@ class LMModel(nn.Module):
n_q (int): number of codebooks.
card (int): codebook cardinality.
dim (int): transformer dimension.
+ tau (float): softmax temperature. 1.0 = no smoothing (optimal compression).
+ Higher values soften the distribution (more robust but worse compression).
**kwargs: passed to `encodec.modules.transformer.StreamingTransformerEncoder`.
"""
-class LMModel(nn.Module):
- def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, dtype=torch.float64, **kwargs):
+ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, dtype=torch.float64,
+ tau: float = 1.0, **kwargs):
super().__init__()
self.card = card
self.n_q = n_q
@@ -45,7 +47,7 @@ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, dtype=torch.
self.emb = nn.ModuleList([nn.Embedding(card + 1, dim, dtype=dtype) for _ in range(n_q)])
self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=dtype) for _ in range(n_q)])
self.logit_step = 1.0 / 64.0
- self.tau = 2.0
+ self.tau = tau
def forward_logits(self, indices: torch.Tensor,
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
@@ -62,6 +64,13 @@ def forward(self, indices: torch.Tensor,
probas = torch.softmax(logits / self.tau, dim=1)
return probas, states, offset
+ def forward_legacy(self, indices: torch.Tensor,
+ states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
+ """Legacy path: raw softmax with no quantisation, for acv<3 streams."""
+ logits, states, offset = self.forward_logits(indices, states, offset)
+ return torch.softmax(logits, dim=1), states, offset
+
+
class EncodecModel(nn.Module):
"""EnCodec model operating on the raw waveform.
Args:
@@ -193,12 +202,24 @@ def set_target_bandwidth(self, bandwidth: float):
f"Select one of {self.target_bandwidths}.")
self.bandwidth = bandwidth
- def get_lm_model(self, int8: bool = False) -> LMModel:
- device = torch.device("cpu")
+ def get_lm_model(self,
+ device: tp.Optional[torch.device] = None,
+ dtype: torch.dtype = torch.float32) -> LMModel:
+ """Load the pre-trained language model for entropy coding.
+
+ Args:
+ device: target device (defaults to CPU — LM must stay on CPU for
+ cross-platform arithmetic-coder determinism).
+ dtype: LM weight dtype. float32 is faster and sufficient when
+ combined with the deterministic logit-quantisation path.
+ """
+ device = torch.device("cpu") if device is None else device
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
- past_context=int(3.5 * self.frame_rate), dtype=torch.float64).to(device)
- checkpoints = {'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
- 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th'}
+ past_context=int(3.5 * self.frame_rate), dtype=dtype).to(device)
+ checkpoints = {
+ 'encodec_24khz': 'encodec_lm_24khz-1608e3c0.th',
+ 'encodec_48khz': 'encodec_lm_48khz-7add9fc3.th',
+ }
checkpoint_name = checkpoints[self.name]
url = _get_checkpoint_url(ROOT_URL, checkpoint_name)
state = torch.hub.load_state_dict_from_url(url, map_location='cpu', check_hash=True)
diff --git a/scripts/payload_decode_matrix.py b/scripts/payload_decode_matrix.py
new file mode 100644
index 0000000..6fb4fe9
--- /dev/null
+++ b/scripts/payload_decode_matrix.py
@@ -0,0 +1,101 @@
+#!/usr/bin/env python3
+import argparse
+import json
+from pathlib import Path
+
+import torch
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Decode EnCodec payloads across devices and compare corrupted pairs.")
+ parser.add_argument("--payload-dir", type=Path, required=True, help="Directory containing .ecdc payload files.")
+ parser.add_argument("--devices", nargs="+", default=["cpu"], help="Decode devices to test, e.g. cpu cuda.")
+ parser.add_argument(
+ "--pair",
+ action="append",
+ nargs=2,
+ metavar=("CLEAN", "CORRUPT"),
+ default=[],
+ help="Optional clean/corrupt filename pair to compare after decode.",
+ )
+ parser.add_argument("--output", type=Path, default=None, help="Optional JSON output path.")
+ return parser.parse_args()
+
+
+def decode_payload(decompress, payload: bytes, device: str):
+ wav, sr = decompress(payload, device=device)
+ wav = wav.detach().cpu()
+ if wav.dim() == 1:
+ wav = wav.unsqueeze(0)
+ return wav, sr
+
+
+def compare_wavs(clean_wav: torch.Tensor, bad_wav: torch.Tensor, sr: int):
+ n = min(clean_wav.shape[-1], bad_wav.shape[-1])
+ clean_wav = clean_wav[..., :n]
+ bad_wav = bad_wav[..., :n]
+ diff = (bad_wav - clean_wav).abs()
+ err = diff.amax(dim=0)
+ mask = err > 1e-3
+ first_bad = int(torch.argmax(mask.to(torch.int64)).item()) if bool(mask.any()) else None
+ last_bad = int((mask.numel() - 1) - torch.argmax(mask.flip(0).to(torch.int64)).item()) if bool(mask.any()) else None
+ return {
+ "corruption_mae": float(diff.mean().item()),
+ "corruption_max_abs": float(diff.max().item()),
+ "first_bad_sample": first_bad,
+ "last_bad_sample": last_bad,
+ "bad_duration_s": None if first_bad is None else (last_bad - first_bad + 1) / sr,
+ }
+
+
+def main():
+ args = parse_args()
+
+ from encodec.compress import decompress
+
+ payload_dir = args.payload_dir
+ results = []
+ pair_map = {tuple(pair) for pair in args.pair}
+
+ for payload_path in sorted(payload_dir.glob("*.ecdc")):
+ payload = payload_path.read_bytes()
+ for device in args.devices:
+ row = {"file": payload_path.name, "device": device}
+ try:
+ wav, sr = decode_payload(decompress, payload, device)
+ row.update({
+ "success": True,
+ "sr": sr,
+ "shape": list(wav.shape),
+ "dtype": str(wav.dtype),
+ "max_abs": float(wav.abs().max().item()),
+ })
+ except Exception as exc:
+ row.update({"success": False, "error": repr(exc)})
+ results.append(row)
+
+ for clean_name, corrupt_name in sorted(pair_map):
+ clean_payload = payload_dir.joinpath(clean_name).read_bytes()
+ corrupt_payload = payload_dir.joinpath(corrupt_name).read_bytes()
+ for device in args.devices:
+ row = {"clean": clean_name, "corrupt": corrupt_name, "device": device}
+ try:
+ clean_wav, sr = decode_payload(decompress, clean_payload, device)
+ corrupt_wav, corrupt_sr = decode_payload(decompress, corrupt_payload, device)
+ if sr != corrupt_sr:
+ raise RuntimeError(f"Sample rate mismatch: {sr} != {corrupt_sr}")
+ row.update({"success": True, "sr": sr})
+ row.update(compare_wavs(clean_wav, corrupt_wav, sr))
+ except Exception as exc:
+ row.update({"success": False, "error": repr(exc)})
+ results.append(row)
+
+ text = json.dumps(results, indent=2, sort_keys=True)
+ print(text)
+ if args.output is not None:
+ args.output.parent.mkdir(parents=True, exist_ok=True)
+ args.output.write_text(text)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/precision_eval.py b/scripts/precision_eval.py
new file mode 100644
index 0000000..3bcd284
--- /dev/null
+++ b/scripts/precision_eval.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python3
+import argparse
+import io
+import json
+import math
+import struct
+import sys
+import time
+from pathlib import Path
+
+import soundfile as sf
+import torch
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Run EnCodec precision and robustness experiments.")
+ parser.add_argument("--repo-path", type=Path, required=True, help="Path to the EnCodec checkout to evaluate.")
+ parser.add_argument("--input", type=Path, required=True, help="Input audio file.")
+ parser.add_argument("--model", choices=["encodec_24khz", "encodec_48khz"], default="encodec_48khz")
+ parser.add_argument("--bandwidth", type=float, default=6.0)
+ parser.add_argument("--device", default="cpu", help="Encoding device, e.g. cpu or mps.")
+ parser.add_argument("--decode-device", default=None, help="Decode device. Defaults to --device.")
+ parser.add_argument("--lm", action="store_true", help="Enable LM entropy coding.")
+ parser.add_argument("--segment", type=float, default=None, help="Model segment length in seconds.")
+ parser.add_argument("--overlap", type=float, default=None, help="Model overlap fraction.")
+ parser.add_argument("--offset", type=float, default=0.0, help="Clip start offset in seconds.")
+ parser.add_argument("--duration", type=float, default=None, help="Clip duration in seconds.")
+ parser.add_argument("--corrupt-byte-fraction", type=float, default=None, help="Flip one byte near this fraction of the payload.")
+ parser.add_argument("--corrupt-byte-index", type=int, default=None, help="Flip one byte at this absolute payload index.")
+ parser.add_argument("--output-payload", type=Path, default=None, help="Optional path to write the encoded payload.")
+ parser.add_argument("--output-corrupt-payload", type=Path, default=None, help="Optional path to write the corrupted payload.")
+ return parser.parse_args()
+
+
+def load_audio(path: Path):
+ wav, sr = sf.read(path, always_2d=True, dtype="float32")
+ wav = torch.from_numpy(wav.T.copy())
+ return wav, sr
+
+
+def clip_audio(wav: torch.Tensor, sr: int, offset_s: float, duration_s: float | None):
+ start = max(0, int(round(offset_s * sr)))
+ end = wav.shape[-1] if duration_s is None else min(wav.shape[-1], start + int(round(duration_s * sr)))
+ return wav[:, start:end]
+
+
+def flip_payload_byte(payload: bytes, metadata_len: int, byte_index: int):
+ data = bytearray(payload)
+ target = metadata_len + byte_index
+ if target < metadata_len or target >= len(data):
+ raise ValueError(f"Corruption index {byte_index} is out of range for payload of {len(data) - metadata_len} bytes.")
+ data[target] ^= 0x01
+ return bytes(data), target
+
+
+def flip_chunk_body_byte(payload: bytes, metadata_len: int, metadata: dict, byte_index: int | None, fraction: float | None):
+ chunk_header = struct.Struct("!II")
+ data = bytearray(payload)
+ stream = io.BytesIO(payload)
+ stream.seek(metadata_len)
+
+ body_ranges = []
+ while stream.tell() < len(payload):
+ header_pos = stream.tell()
+ header = stream.read(chunk_header.size)
+ if len(header) != chunk_header.size:
+ break
+ chunk_len, _chunk_crc = chunk_header.unpack(header)
+ body_start = stream.tell()
+ body_end = body_start + chunk_len
+ if body_end > len(payload):
+ break
+ body_ranges.append((body_start, body_end, header_pos))
+ stream.seek(body_end)
+
+ if not body_ranges:
+ raise ValueError("No chunk bodies found in payload.")
+
+ total_body_bytes = sum(end - start for start, end, _ in body_ranges)
+ if byte_index is not None:
+ remaining = byte_index
+ else:
+ assert fraction is not None
+ remaining = min(total_body_bytes - 1, max(0, int(math.floor(total_body_bytes * fraction))))
+
+ chunk_index = 0
+ target = None
+ for idx, (start, end, _header_pos) in enumerate(body_ranges):
+ chunk_len = end - start
+ if remaining < chunk_len:
+ target = start + remaining
+ chunk_index = idx
+ break
+ remaining -= chunk_len
+
+ if target is None or target >= len(data):
+ raise ValueError("Corruption index is out of range for chunk bodies.")
+
+ data[target] ^= 0x01
+ return bytes(data), target, chunk_index, target - body_ranges[chunk_index][0]
+
+
+def main():
+ args = parse_args()
+ sys.path.insert(0, str(args.repo_path))
+
+ import encodec.binary as binary
+ from encodec.compress import compress, decompress, MODELS
+ from encodec.utils import convert_audio
+
+ decode_device = args.decode_device or args.device
+ wav, sr = load_audio(args.input)
+ wav = clip_audio(wav, sr, args.offset, args.duration)
+ source_duration = wav.shape[-1] / sr
+
+ model = MODELS[args.model]().to(args.device)
+ model.set_target_bandwidth(args.bandwidth)
+ if args.segment is not None:
+ model.segment = args.segment
+ if args.overlap is not None:
+ model.overlap = args.overlap
+
+ wav_in = convert_audio(wav, sr, model.sample_rate, model.channels).to(args.device)
+ wav_ref = wav_in.detach().cpu()
+
+ t0 = time.perf_counter()
+ clean_payload = compress(model, wav_in, use_lm=args.lm)
+ encode_s = time.perf_counter() - t0
+
+ if args.output_payload is not None:
+ args.output_payload.parent.mkdir(parents=True, exist_ok=True)
+ args.output_payload.write_bytes(clean_payload)
+
+ payload = clean_payload
+ header_stream = io.BytesIO(clean_payload)
+ metadata = binary.read_ecdc_header(header_stream)
+ payload_offset = header_stream.tell()
+
+ corrupt_abs = None
+ corrupt_chunk_index = None
+ corrupt_chunk_byte = None
+ if args.corrupt_byte_index is not None:
+ if metadata.get("acv") == 4:
+ payload, corrupt_abs, corrupt_chunk_index, corrupt_chunk_byte = flip_chunk_body_byte(
+ payload, payload_offset, metadata, args.corrupt_byte_index, None)
+ else:
+ payload, corrupt_abs = flip_payload_byte(payload, payload_offset, args.corrupt_byte_index)
+ elif args.corrupt_byte_fraction is not None:
+ if metadata.get("acv") == 4:
+ payload, corrupt_abs, corrupt_chunk_index, corrupt_chunk_byte = flip_chunk_body_byte(
+ payload, payload_offset, metadata, None, args.corrupt_byte_fraction)
+ else:
+ data_len = len(clean_payload) - payload_offset
+ corrupt_idx = min(data_len - 1, max(0, int(math.floor(data_len * args.corrupt_byte_fraction))))
+ payload, corrupt_abs = flip_payload_byte(payload, payload_offset, corrupt_idx)
+
+ if args.output_corrupt_payload is not None:
+ args.output_corrupt_payload.parent.mkdir(parents=True, exist_ok=True)
+ args.output_corrupt_payload.write_bytes(payload)
+
+ result = {
+ "repo_path": str(args.repo_path),
+ "input": str(args.input),
+ "model": args.model,
+ "bandwidth": args.bandwidth,
+ "device": args.device,
+ "decode_device": decode_device,
+ "lm": args.lm,
+ "segment": model.segment,
+ "overlap": model.overlap,
+ "input_sr": sr,
+ "model_sr": model.sample_rate,
+ "input_channels": int(wav.shape[0]),
+ "model_channels": int(model.channels),
+ "source_duration_s": source_duration,
+ "encoded_samples": int(wav_in.shape[-1]),
+ "encoded_bytes": len(clean_payload),
+ "payload_bytes": len(clean_payload) - payload_offset,
+ "output_payload": None if args.output_payload is None else str(args.output_payload),
+ "output_corrupt_payload": None if args.output_corrupt_payload is None else str(args.output_corrupt_payload),
+ "header_metadata": metadata,
+ "corrupt_absolute_byte": corrupt_abs,
+ "corrupt_payload_byte": None if corrupt_abs is None else corrupt_abs - payload_offset,
+ "corrupt_chunk_index": corrupt_chunk_index,
+ "corrupt_chunk_byte": corrupt_chunk_byte,
+ }
+
+ try:
+ clean_decode = None
+ if payload != clean_payload:
+ clean_decode, _ = decompress(clean_payload, device=decode_device)
+ clean_decode = clean_decode.detach().cpu()
+ if clean_decode.dim() == 1:
+ clean_decode = clean_decode.unsqueeze(0)
+
+ t1 = time.perf_counter()
+ wav_out, out_sr = decompress(payload, device=decode_device)
+ decode_s = time.perf_counter() - t1
+ wav_out = wav_out.detach().cpu()
+ if wav_out.dim() == 1:
+ wav_out = wav_out.unsqueeze(0)
+ wav_out = wav_out[:, :wav_ref.shape[-1]]
+ if wav_out.shape[-1] < wav_ref.shape[-1]:
+ pad = wav_ref.shape[-1] - wav_out.shape[-1]
+ wav_out = torch.nn.functional.pad(wav_out, (0, pad))
+ diff = wav_out - wav_ref
+ mse = float(diff.pow(2).mean().item())
+ mae = float(diff.abs().mean().item())
+ signal_power = float(wav_ref.pow(2).mean().item())
+ snr_db = float("inf") if mse == 0 else 10.0 * math.log10(max(signal_power, 1e-12) / mse)
+ result.update({
+ "success": True,
+ "decode_sr": out_sr,
+ "decoded_samples": int(wav_out.shape[-1]),
+ "encode_s": encode_s,
+ "decode_s": decode_s,
+ "rtf_encode": encode_s / max(source_duration, 1e-9),
+ "rtf_decode": decode_s / max(source_duration, 1e-9),
+ "mse": mse,
+ "mae": mae,
+ "max_abs_err": float(diff.abs().max().item()),
+ "snr_db": snr_db,
+ "bps": (len(payload) * 8.0) / max(source_duration, 1e-9),
+ })
+ if clean_decode is not None:
+ clean_cmp = clean_decode[:, :wav_out.shape[-1]]
+ if clean_cmp.shape[-1] < wav_out.shape[-1]:
+ clean_cmp = torch.nn.functional.pad(clean_cmp, (0, wav_out.shape[-1] - clean_cmp.shape[-1]))
+ corr_diff = wav_out - clean_cmp
+ err = corr_diff.abs().amax(dim=0)
+ mask = err > 1e-3
+ first_bad = int(torch.argmax(mask.to(torch.int64)).item()) if bool(mask.any()) else None
+ last_bad = int((mask.numel() - 1) - torch.argmax(mask.flip(0).to(torch.int64)).item()) if bool(mask.any()) else None
+ result.update({
+ "corruption_mae_vs_clean_decode": float(corr_diff.abs().mean().item()),
+ "corruption_max_abs_vs_clean_decode": float(corr_diff.abs().max().item()),
+ "corruption_first_bad_sample": first_bad,
+ "corruption_last_bad_sample": last_bad,
+ "corruption_bad_duration_s": None if first_bad is None else (last_bad - first_bad + 1) / out_sr,
+ })
+ except Exception as exc:
+ result.update({
+ "success": False,
+ "encode_s": encode_s,
+ "decode_error": repr(exc),
+ "bps": (len(payload) * 8.0) / max(source_duration, 1e-9),
+ })
+
+ print(json.dumps(result, sort_keys=True))
+
+
+if __name__ == "__main__":
+ main()
From 68e8d3029c6cbf8c956c943ce45154a8b864e551 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Wed, 18 Mar 2026 13:29:58 +0000
Subject: [PATCH 11/24] Add combined README with precision/robustness
improvements documentation
Merges original Facebook README with research notes from both the wavey-ai
fork and the codex-precision-review branch: deterministic LM path, acv=4
chunk framing, _counts_from_pdf bug fix, GPU reliability, tuned defaults,
compression benchmarks, chunk size tradeoffs, and usage examples.
Co-Authored-By: Claude Sonnet 4.6
---
README.md | 338 ++++++++++++++++++++++++++++++++----------------------
1 file changed, 202 insertions(+), 136 deletions(-)
diff --git a/README.md b/README.md
index 05e90ee..75e63ea 100644
--- a/README.md
+++ b/README.md
@@ -1,224 +1,291 @@
# EnCodec: High Fidelity Neural Audio Compression
+


-This is the code for the EnCodec neural codec presented in the [High Fidelity Neural Audio Compression](https://arxiv.org/pdf/2210.13438.pdf) [[abs]](https://arxiv.org/abs/2210.13438).
-paper. We provide our two multi-bandwidth models:
-* A causal model operating at 24 kHz on monophonic audio trained on a variety of audio data.
-* A non-causal model operating at 48 kHz on stereophonic audio trained on music-only data.
+This is the code for the EnCodec neural codec presented in [High Fidelity Neural Audio Compression](https://arxiv.org/pdf/2210.13438.pdf) [[abs]](https://arxiv.org/abs/2210.13438). We provide two multi-bandwidth models:
-The 24 kHz model can compress to 1.5, 3, 6, 12 or 24 kbps, while the 48 kHz model
-support 3, 6, 12 and 24 kbps. We also provide a pre-trained language model for each
-of the models, that can further compress the representation by up to 40% without
-any further loss of quality.
+- A causal model operating at **24 kHz** on monophonic audio trained on a variety of audio data.
+- A non-causal model operating at **48 kHz** on stereophonic audio trained on music-only data.
-For reference, we also provide the code for our novel [MS-STFT discriminator](encodec/msstftd.py) and the [balancer](encodec/balancer.py).
+The 24 kHz model supports 1.5, 3, 6, 12, and 24 kbps. The 48 kHz model supports 3, 6, 12, and 24 kbps. A pre-trained language model is available for each, enabling entropy coding that reduces bitstream size by up to 40% without further quality loss.
-
-
+
## Samples
-Samples including baselines are provided on [our sample page](https://ai.honu.io/papers/encodec/samples.html).
-You can also have a quick demo of what we achieve for 48 kHz music with EnCodec, along with
-entropy coding, by clicking the thumbnail (original tracks provided by [Lucille Crew](https://open.spotify.com/artist/5eLv7rNfrf3IjMnK311ByP?si=X_zD9ackRRGjFP5Y6Q7Zng) and [Voyageur I](https://open.spotify.com/artist/21HymveeIhDcM4KDKeNLz0?si=4zXF8VpeQpeKR9QUIuck9Q)).
+Samples including baselines are on [our sample page](https://ai.honu.io/papers/encodec/samples.html). A quick demo of 48 kHz music with entropy coding is available by clicking the thumbnail (original tracks by [Lucille Crew](https://open.spotify.com/artist/5eLv7rNfrf3IjMnK311ByP?si=X_zD9ackRRGjFP5Y6Q7Zng) and [Voyageur I](https://open.spotify.com/artist/21HymveeIhDcM4KDKeNLz0?si=4zXF8VpeQpeKR9QUIuck9Q)).
-
+
## 🤗 Transformers
-Encodec has now been added to Transformers. For more information, please refer to [Transformers' Encodec docs](https://huggingface.co/docs/transformers/main/en/model_doc/encodec).
-
-You can find both the [24KHz](https://huggingface.co/facebook/encodec_24khz) and [48KHz](https://huggingface.co/facebook/encodec_48khz) checkpoints on the 🤗 Hub.
+EnCodec is available in Transformers. See the [Transformers EnCodec docs](https://huggingface.co/docs/transformers/main/en/model_doc/encodec), and the [24 kHz](https://huggingface.co/facebook/encodec_24khz) and [48 kHz](https://huggingface.co/facebook/encodec_48khz) checkpoints on the Hub.
-Using 🤗 Transformers, you can leverage Encodec at scale along with all the other supported models and datasets. ⚡️
-Alternatively you can also directly use the encodec package, as detailed in the Usage section.
-
-To use first you'd need to set up your development environment!
-```
-pip install -U datasets
-pip install git+https://github.com/huggingface/transformers.git@main
-```
-
-Then, start embedding your audio datasets at scale!
```python
from datasets import load_dataset, Audio
from transformers import EncodecModel, AutoProcessor
-# dummy dataset, however you can swap this with an dataset on the 🤗 hub or bring your own
librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
-
-# load the model + processor (for pre-processing the audio)
model = EncodecModel.from_pretrained("facebook/encodec_24khz")
processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
-
-# cast the audio data to the correct sampling rate for the model
librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
audio_sample = librispeech_dummy[0]["audio"]["array"]
-
-# pre-process the inputs
inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
-
-# explicitly encode then decode the audio inputs
encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
+audio_codes = model(inputs["input_values"], inputs["padding_mask"]).audio_codes
+```
-# or the equivalent with a forward pass
-audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values
+---
-# you can also extract the discrete codebook representation for LM tasks
-# output: concatenated tensor of all the representations
-audio_codes = model(inputs["input_values"], inputs["padding_mask"]).audio_codes
+## Precision and Robustness Improvements (wavey-ai fork)
+
+This fork extends the original EnCodec with a fully deterministic, cross-platform entropy coding path. The changes affect `encodec/compress.py` and `encodec/model.py` only — the neural network weights and audio quality are unchanged.
+
+### Bitstream version `acv=4`
+
+When `use_lm=True`, the encoder writes bitstream version 4. Each model segment (≈1 second for the 48 kHz model) is wrapped in an independent CRC-protected chunk:
+```
+[chunk_len: u32 BE][crc32: u32 BE][chunk payload]
```
-## What's up?
+A single corrupt byte damages at most one chunk. The decoder substitutes silence for any chunk that fails its CRC check and continues decoding the rest of the file. Previous versions would abort the entire decode on the first error.
-See [the changelog](CHANGELOG.md) for details on releases.
+### Deterministic LM path
-## Installation
+The original LM entropy path was not deterministic across hardware (MPS, CUDA, CPU), causing cross-device decode failures. The deterministic path fixes this by:
-EnCodec requires Python 3.8, and a reasonably recent version of PyTorch (1.11.0 ideally).
-To install EnCodec, you can run from this repository:
-```bash
-pip install -U encodec # stable release
-pip install -U git+https://git@github.com/facebookresearch/encodec#egg=encodec # bleeding edge
-# of if you cloned the repo locally
-pip install .
+- Running the arithmetic coder and LM **always on CPU**, regardless of model device.
+- Computing softmax in **float64** via a sequential cumsum denominator (`_stable_softmax`) rather than platform-native `torch.softmax`, which can differ by a ULP across devices.
+- **Quantising logits** to a 1/128 grid before softmax. Small floating-point differences that do not change the quantised logit produce identical CDFs.
+- Building the CDF from **integer floor counts** (`FP_SCALE = 65536`) with deterministic priority allocation for the residual.
+- Storing `tau` in the bitstream header so encoder and decoder are always in sync.
+
+Cross-device decode matrix (payloads encoded on Apple Silicon Mac):
+
+| Encode device | Decode device | Legacy (original) | This fork |
+|---|---|---|---|
+| Mac CPU | Linux CPU | EOFError | ✓ |
+| Mac CPU | Linux CUDA | EOFError | ✓ |
+| Mac MPS | Linux CPU | EOFError | ✓ |
+| Mac MPS | Linux CUDA | EOFError | ✓ |
+
+### Critical bug fix: `_counts_from_pdf`
+
+At `tau=1.0`, many softmax outputs are exactly `0.0` (float underflow of `exp(-large)`). These triggered a near-integer perturbation with an alternating sign. A negative sign on `x=0.0` gives `x = -ε`, and `floor(-ε) = -1`. A negative count makes the CDF non-monotonic, causing the arithmetic decoder to produce wrong symbols silently.
+
+Fix (one line):
+
+```python
+# Before (broken at tau=1.0):
+fx = torch.floor(x)
+
+# After (fixed):
+fx = torch.floor(x.clamp_min(0))
```
-**Supported platforms:** we officially support only Mac OS X (you might need XCode installed if running on a non Intel Mac), and recent versions of mainstream Linux distributions. We will try to help out on Windows but cannot provide strong support. Any other platform (iOS / Android / onboard ARM) are not supported.
+This bug was present in both the original Facebook implementation and earlier revisions of this fork.
-## Usage
+### GPU reliability
-You can then use the EnCodec command, either as
-```bash
-python3 -m encodec [...]
-# or
-encodec [...]
+The model encoder/decoder can run on any device (CPU, MPS, CUDA). `compress_to_file` detects the model's device automatically:
+
+```python
+model_device = next(model.parameters()).device
+frames = model.encode(wav[None].to(model_device))
```
-If you want to directly use the compression API, checkout `encodec.compress`
-and `encodec.model`. See hereafter for instructions on how to extract the discrete
-representation.
+### Legacy decode support
+
+Streams from the original Facebook implementation (`acv < 3`) decode correctly via `LMModel.forward_legacy()`, which uses raw softmax with no quantisation. The decoder selects the legacy or deterministic path based on the `acv` field in the stream header.
-### Model storage
+### Tuned defaults
-The models will be automatically downloaded on first use using Torch Hub.
-For more information on where those models are stored, or how to customize
-the storage location, [checkout their documentation.](https://pytorch.org/docs/stable/hub.html#where-are-my-downloaded-models-saved)
+All settings are overridable via environment variables:
-### Compression
+| Variable | Default | Notes |
+|---|---|---|
+| `ENCODEC_LM_TAU` | `1.0` | Softmax temperature. `1.0` is optimal for compression. |
+| `ENCODEC_LOGIT_QSTEP` | `1/128` | Logit quantisation grid size. |
+| `ENCODEC_AC_FP_SCALE` | `65536` | Integer scale for CDF allocation (`2^16`). |
+| `ENCODEC_AC_MIN_RANGE` | `1` | Minimum CDF range per symbol. |
+| `ENCODEC_DETERMINISTIC_LM_DTYPE` | `float32` | LM weight dtype. `float32` is faster; `float64` available. |
+| `ENCODEC_USE_NEAR_UNIFORM` | `0` | Enable near-uniform prior (off by default). |
+
+### Compression results
+
+Benchmarked on 7 stereo 48 kHz music tracks (10 s clips), `encodec_48khz`:
+
+| Bandwidth | Device | Avg actual kbps | LM gain | Encode RTF | Decode RTF |
+|---|---|---|---|---|---|
+| 6 kbps | CPU | 4.34 | 27.7% | 0.26× | 0.27× |
+| 6 kbps | MPS | 4.34 | 27.7% | 0.33× | 0.27× |
+| 24 kbps | CPU | 19.3 | 19.9% | 0.39× | 0.41× |
+| 24 kbps | MPS | 19.3 | 19.9% | 0.47× | 0.40× |
+
+RTF < 1.0 means faster than real time. The LM runs on CPU in all cases; MPS accelerates model encode/decode but does not reduce LM inference time.
+
+### Chunk size tradeoffs
+
+Per-segment chunk overhead is dominated by LM segmentation granularity, not the 8-byte header:
+
+| Segment size | Approx bitrate (6 kbps, music, 4 s) | Max failure isolation |
+|---|---|---|
+| 1.0 s (default) | ~3600 bps | ≤ 1.0 s |
+| 0.5 s | ~4050 bps | ≤ 0.5 s |
+| 0.25 s | ~4600 bps | ≤ 0.25 s |
+
+The default 1.0 s (matching the 48 kHz model segment) gives the best bitrate/isolation tradeoff.
+
+---
+
+## Installation
+
+Requires Python 3.8+ and a recent PyTorch (1.11+ recommended; 2.x tested).
```bash
-encodec [-b TARGET_BANDWIDTH] [-f] [--hq] [--lm] INPUT_FILE [OUTPUT_FILE]
+pip install -U encodec # stable release
+pip install -U git+https://git@github.com/wavey-ai/encodec # this fork
+pip install . # from local clone
```
-Given any audio file supported by torchaudio on your platform, compresses
-it with EnCodec to the target bandwidth (default is 6 kbps, can be either 1.5, 3, 6, 12 or 24).
-OUTPUT_FILE must end in `.ecdc`. If not provided it will be the same as `INPUT_FILE`,
-replacing the extension with `.ecdc`.
-In order to use the model operating at 48 kHz on stereophonic audio, use the `--hq` flag.
-The `-f` flag is used to force overwrite an existing output file.
-Use the `--lm` flag to use the pretrained language model with entropy coding (expect it to
-be much slower).
-
-If the sample rate or number of channels of the input doesn't match that of the model,
-the command will automatically resample / reduce channels as needed.
-
-### Decompression
+
+For development:
+
```bash
-encodec [-f] [-r] ENCODEC_FILE [OUTPUT_WAV_FILE]
+pip install -e '.[dev]'
+make tests
```
-Given a `.ecdc` file previously generated, this will decode it to the given output wav file.
-If not provided, the output will default to the input with the `.wav` extension.
-Use the `-f` file to force overwrite the output file (be carefull if compress then decompress,
-not to overwrite your original file !). Use the `-r` flag if you experience clipping, this will
-rescale the output file to avoid it.
-### Compression + Decompression
+**Supported platforms:** macOS (Intel and Apple Silicon), recent mainstream Linux distributions. Windows is not officially supported.
+
+## Usage
+
+### CLI
+
```bash
+# Compress
+encodec [-b TARGET_BANDWIDTH] [-f] [--hq] [--lm] INPUT_FILE [OUTPUT_FILE]
+
+# Decompress
+encodec [-f] [-r] ENCODEC_FILE [OUTPUT_WAV_FILE]
+
+# Round-trip (compress then immediately decompress)
encodec [-r] [-b TARGET_BANDWIDTH] [-f] [--hq] [--lm] INPUT_FILE OUTPUT_WAV_FILE
```
-When `OUTPUT_WAV_FILE` has the `.wav` extension (as opposed to `.ecdc`), the `encodec`
-command will instead compress and immediately decompress without storing the intermediate
-`.ecdc` file.
-### Extracting discrete representations
+`--hq` selects the 48 kHz stereo model. `--lm` enables entropy coding (slower, ~20–35% smaller files).
-The EnCodec model can also be used to extract discrete representations from the audio waveform.
+### Python API
```python
+import soundfile as sf
+import torch
from encodec import EncodecModel
+from encodec.compress import compress, decompress
from encodec.utils import convert_audio
-import torchaudio
+# Load model
+model = EncodecModel.encodec_model_48khz()
+model.set_target_bandwidth(6.0)
+
+# Load audio (soundfile recommended over torchaudio for compatibility)
+wav, sr = sf.read("audio.wav", always_2d=True, dtype="float32")
+wav = torch.from_numpy(wav.T.copy())
+wav = convert_audio(wav, sr, model.sample_rate, model.channels)
+
+# Compress with LM entropy coding (acv=4, CRC chunk framing)
+payload = compress(model, wav, use_lm=True)
+
+# Decompress (works on any device; corrupt segments replaced with silence)
+wav_out, out_sr = decompress(payload)
+```
+
+### GPU encode
+
+```python
+model = EncodecModel.encodec_model_48khz().to("mps") # or "cuda"
+model.set_target_bandwidth(6.0)
+# compress() moves the waveform to the model device automatically;
+# the LM and arithmetic coder always stay on CPU for determinism.
+payload = compress(model, wav, use_lm=True)
+```
+
+### Extracting discrete codebook representations
+
+```python
+import soundfile as sf
import torch
+from encodec import EncodecModel
+from encodec.utils import convert_audio
-# Instantiate a pretrained EnCodec model
model = EncodecModel.encodec_model_24khz()
-# The number of codebooks used will be determined bythe bandwidth selected.
-# E.g. for a bandwidth of 6kbps, `n_q = 8` codebooks are used.
-# Supported bandwidths are 1.5kbps (n_q = 2), 3 kbps (n_q = 4), 6 kbps (n_q = 8) and 12 kbps (n_q =16) and 24kbps (n_q=32).
-# For the 48 kHz model, only 3, 6, 12, and 24 kbps are supported. The number
-# of codebooks for each is half that of the 24 kHz model as the frame rate is twice as much.
model.set_target_bandwidth(6.0)
-# Load and pre-process the audio waveform
-wav, sr = torchaudio.load("")
+wav, sr = sf.read("audio.wav", always_2d=True, dtype="float32")
+wav = torch.from_numpy(wav.T.copy())
wav = convert_audio(wav, sr, model.sample_rate, model.channels)
-wav = wav.unsqueeze(0)
-# Extract discrete codes from EnCodec
with torch.no_grad():
- encoded_frames = model.encode(wav)
-codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T]
+ encoded_frames = model.encode(wav.unsqueeze(0))
+codes = torch.cat([f[0] for f in encoded_frames], dim=-1) # [B, n_q, T]
```
-Note that the 48 kHz model processes the audio by chunks of 1 seconds, with an overlap of 1%,
-and renormalizes the audio to have unit scale. For this model, the output of `model.encode(wav)`
-would a list (for each frame of 1 second) of a tuple `(codes, scale)` with `scale` a scalar tensor.
+Codebook counts by bandwidth:
-## Installation for development
+| Model | 1.5 kbps | 3 kbps | 6 kbps | 12 kbps | 24 kbps |
+|---|---|---|---|---|---|
+| 24 kHz mono | n_q=2 | n_q=4 | n_q=8 | n_q=16 | n_q=32 |
+| 48 kHz stereo | — | n_q=2 | n_q=4 | n_q=8 | n_q=16 |
-This will install the dependencies and a `encodec` in developer mode (changes to the files
-will directly reflect), along with the dependencies to run unit tests.
-```
-pip install -e '.[dev]'
-```
-
-### Test
+### Benchmarking and corruption testing
-You can run the unit tests with
-```
-make tests
+```bash
+# Encode, decode, report bitrate/SNR/timing
+python scripts/precision_eval.py \
+ --repo-path . \
+ --input audio.wav \
+ --model encodec_48khz \
+ --bandwidth 6.0 \
+ --lm \
+ --device mps
+
+# Simulate a corrupt byte at the midpoint of the payload
+python scripts/precision_eval.py \
+ --repo-path . \
+ --input audio.wav \
+ --model encodec_48khz \
+ --bandwidth 6.0 \
+ --lm \
+ --corrupt-byte-fraction 0.5
+
+# Cross-host decode validation (run on a second machine)
+python scripts/payload_decode_matrix.py --payload out.ecdc
```
+---
+
## FAQ
-Please check this section before opening an issue.
+**Out of memory on long files** — The model is applied to the full file at once. Split into segments manually or reduce clip length before encoding.
-### Out of memory errors with long files
+**DistributedDataParallel** — Not used here. Use `encodec.distrib.sync_buffer` and `encodec.distrib.sync_grad` instead.
-We do not try to be smart about long files, and we apply the model at once on the entire file. This can lead to a large memory usage
-and result in the process being killed. At the moment we will not support this use case.
+**My `.ecdc` file from the original Facebook release won't decode** — It will. The decoder detects the bitstream version and routes `acv < 3` streams through the original LM path automatically.
-### Bad interactions between DistributedDataParallel and the RVQ code
+**MPS is slower than CPU for encode** — The LM runs on CPU regardless of device (required for cross-platform determinism) and dominates encode time. MPS accelerates only the SEANet encoder/decoder, which is not the bottleneck at typical clip lengths.
-We do not use DDP, instead we recommend using the routines in `encodec/distrib.py`, in particular `encodec.distrib.sync_buffer` and `encodec.distrib.sync_grad`.
+## What's new
-## Citation
+See [CHANGELOG.md](CHANGELOG.md) for the full history.
-If you use this code or results in your paper, please cite our work as:
+## Citation
-```
+```bibtex
@article{defossez2022highfi,
title={High Fidelity Neural Audio Compression},
author={Défossez, Alexandre and Copet, Jade and Synnaeve, Gabriel and Adi, Yossi},
@@ -229,5 +296,4 @@ If you use this code or results in your paper, please cite our work as:
## License
-The code in this repository is released under the MIT license as found in the
-[LICENSE](LICENSE) file.
+MIT — see [LICENSE](LICENSE).
From c84f6cb2c920bd714d31ec11383bbf8329c1a0cc Mon Sep 17 00:00:00 2001
From: jbrough
Date: Thu, 19 Mar 2026 14:16:11 +0000
Subject: [PATCH 12/24] Improve deterministic LM bitstream controls
---
encodec/compress.py | 322 +++++++++++++++++++++++++++++++++++---------
1 file changed, 256 insertions(+), 66 deletions(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index c529ab3..f1f941e 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -62,6 +62,16 @@ def _env_dtype(name: str, default: torch.dtype) -> torch.dtype:
except KeyError as exc:
raise ValueError(f"Unsupported dtype override {v!r} for {name}.") from exc
+def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
+ v = os.getenv(name)
+ if v is None:
+ return default
+ value = v.lower()
+ if value not in choices:
+ allowed = ", ".join(sorted(choices))
+ raise ValueError(f"Unsupported value {v!r} for {name}. Expected one of: {allowed}.")
+ return value
+
# Lean defaults: float32 LM, finer logit grid, high-precision CDF allocation.
LOGIT_QSTEP = _env_float("ENCODEC_LOGIT_QSTEP", 1.0 / 128.0)
LM_TAU = _env_float("ENCODEC_LM_TAU", 1.0)
@@ -69,10 +79,13 @@ def _env_dtype(name: str, default: torch.dtype) -> torch.dtype:
MIN_RANGE = _env_int("ENCODEC_AC_MIN_RANGE", 1)
USE_NEAR_UNIFORM = _env_bool("ENCODEC_USE_NEAR_UNIFORM", False)
DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float32)
+LM_DEVICE_MODE = _env_choice("ENCODEC_LM_DEVICE", "cpu", {"cpu", "model"})
+LM_CHUNKED_DEFAULT = _env_bool("ENCODEC_LM_CHUNKED", True)
_IDX_CACHE: tp.Dict[tp.Tuple[str, int, int], torch.Tensor] = {}
_UNIFORM_CDF_CACHE: tp.Dict[tp.Tuple[str, int, int, int, int], torch.Tensor] = {}
_CHUNK_HEADER = struct.Struct('!II') # chunk_len (uint32 BE), crc32 (uint32 BE)
+ProgressCallback = tp.Optional[tp.Callable[[tp.Dict[str, tp.Any]], None]]
# ---------------------------------------------------------------------------
@@ -248,6 +261,48 @@ def _deterministic_cdf_multi(pdf_mat: torch.Tensor,
# acv=4 chunk framing helpers
# ---------------------------------------------------------------------------
+def _emit_progress(progress_callback: ProgressCallback, payload: tp.Dict[str, tp.Any]) -> None:
+ if progress_callback is None:
+ return
+ try:
+ progress_callback(payload)
+ except Exception:
+ # Progress reporting must never affect deterministic bytestream generation.
+ pass
+
+
+def _segment_layout(model: EncodecModel, audio_length: int) -> tp.Tuple[int, int, tp.List[int]]:
+ segment_length = model.segment_length or audio_length
+ segment_stride = model.segment_stride or audio_length
+ offsets = list(range(0, audio_length, segment_stride))
+ return segment_length, segment_stride, offsets
+
+
+def _build_progress_payload(
+ *,
+ stage: str,
+ sample_rate: int,
+ total_segments: int,
+ segment_index: int,
+ audio_length: int,
+ segment_length: int,
+ segment_stride: int,
+ offset_samples: int = 0,
+) -> tp.Dict[str, tp.Any]:
+ payload: tp.Dict[str, tp.Any] = {
+ 'stage': stage,
+ 'segmentCount': total_segments,
+ 'segmentIndex': segment_index,
+ 'progress': float(segment_index / total_segments) if total_segments else 0.0,
+ 'sampleRate': int(sample_rate),
+ 'audioLength': audio_length,
+ 'segmentLength': int(segment_length),
+ 'segmentStride': int(segment_stride),
+ }
+ if stage == 'segment':
+ payload['offsetSamples'] = int(offset_samples)
+ return payload
+
def _write_chunk(fo: tp.IO[bytes], payload: bytes) -> None:
"""Write a CRC-protected chunk: [len: u32][crc: u32][payload]."""
fo.write(_CHUNK_HEADER.pack(len(payload), zlib.crc32(payload) & 0xffffffff))
@@ -268,95 +323,221 @@ def _read_chunk_payload(fo: tp.IO[bytes]) -> bytes:
# compress_to_file / decompress_from_file
# ---------------------------------------------------------------------------
+def _write_frame_payload(
+ frame: torch.Tensor,
+ scale: tp.Optional[torch.Tensor],
+ fo: tp.IO[bytes],
+ *,
+ use_lm: bool,
+ model: EncodecModel,
+ coder_device: torch.device,
+ lm_device: torch.device,
+ lm: tp.Optional[tp.Any],
+ lm_tau: float,
+) -> None:
+ if scale is not None:
+ fo.write(struct.pack('!f', float(scale.cpu().item())))
+
+ _B, K, T = frame.shape
+ if use_lm:
+ assert lm is not None
+ coder = ArithmeticCoder(fo)
+ states = None
+ offset = 0
+ input_ = torch.zeros(1, K, 1, dtype=torch.long, device=lm_device)
+ else:
+ packer = binary.BitPacker(model.bits_per_codebook, fo)
+
+ for t in range(T):
+ if use_lm:
+ with torch.inference_mode():
+ logits_raw, states, offset = lm.forward_logits(input_, states, offset)
+ logits_q = _quantize_logits_(logits_raw / lm_tau, LOGIT_QSTEP)
+ probas = _softmax_or_uniform(logits_q, dim=1)
+
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat, coder.total_range_bits,
+ fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+
+ frame_slice = frame[:, :, t:t + 1].detach().to(coder_device)
+ for k, value in enumerate(frame_slice[0, :, 0].tolist()):
+ coder.push(value, cdf_cols[k])
+ input_ = (1 + frame_slice).to(lm_device)
+ else:
+ for value in frame[0, :, t].detach().cpu().tolist():
+ packer.push(value)
+
+ if use_lm:
+ coder.flush()
+ else:
+ packer.flush()
+
+
def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
- use_lm: bool = True) -> None:
+ use_lm: bool = True,
+ progress_callback: ProgressCallback = None,
+ lm_chunked: tp.Optional[bool] = None) -> None:
"""Compress a waveform to a file-object.
- When ``use_lm=True`` the stream is bitstream version 4 (acv=4):
- each model segment is wrapped in a CRC-protected chunk so that a
- single corrupt byte only damages that one segment (~1 s). The
- arithmetic coder and LM always run on CPU for cross-platform
- determinism; the EnCodec model may run on any device.
+ When ``use_lm=True``:
+ * ``lm_chunked=True`` writes bitstream version 4 (acv=4), where
+ each model segment is wrapped in a CRC-protected chunk.
+ * ``lm_chunked=False`` writes deterministic unchunked bitstream
+ version 3 (acv=3), compatible with the existing deterministic
+ decoder path.
+
+ The arithmetic coder and LM always run on CPU for cross-platform
+ determinism unless ``ENCODEC_LM_DEVICE=model`` is set. The EnCodec
+ model itself may run on any device.
Args:
model: pre-trained EncodecModel.
wav: ``[C, T]`` waveform at model.sample_rate.
fo: writable file-object.
- use_lm: enable LM entropy coding (acv=4) vs raw bitpacking (acv=0).
+ use_lm: enable LM entropy coding.
+ lm_chunked: choose CRC chunk framing for deterministic LM streams.
"""
assert wav.dim() == 2
if model.name not in MODELS:
raise ValueError(f"Unsupported model {model.name}.")
- coder_device = torch.device("cpu")
+ if lm_chunked is None:
+ lm_chunked = LM_CHUNKED_DEFAULT
+
model = model.eval()
model_device = next(model.parameters()).device
- with torch.inference_mode():
- frames = model.encode(wav[None].to(model_device))
- codes0, _ = frames[0]
- _, K, _ = codes0.shape
+ coder_device = torch.device("cpu")
+ lm_device = model_device if LM_DEVICE_MODE == "model" else coder_device
+ audio_length = int(wav.shape[-1])
+ segment_length, segment_stride, offsets = _segment_layout(model, audio_length)
+
+ if not offsets:
+ raise ValueError("Cannot compress an empty waveform.")
lm = None
lm_tau = LM_TAU
- if use_lm:
- lm = model.get_lm_model(device=coder_device,
- dtype=DETERMINISTIC_LM_DTYPE).eval()
- lm.tau = lm_tau
-
- metadata: tp.Dict[str, tp.Any] = {
- 'm': model.name,
- 'al': int(wav.shape[-1]),
- 'nc': int(K),
- 'lm': bool(use_lm),
- 'fp': int(FP_SCALE),
- 'mr': int(MIN_RANGE),
- 'acv': 4 if use_lm else 0,
- 'tau': float(lm_tau),
- }
- binary.write_ecdc_header(fo, metadata)
-
- for (frame, scale) in frames:
- chunk_fo = io.BytesIO()
-
- if scale is not None:
- chunk_fo.write(struct.pack('!f', float(scale.cpu().item())))
+ total_segments = len(offsets)
+ _emit_progress(progress_callback, _build_progress_payload(
+ stage='start',
+ sample_rate=int(model.sample_rate),
+ total_segments=total_segments,
+ segment_index=0,
+ audio_length=audio_length,
+ segment_length=segment_length,
+ segment_stride=segment_stride,
+ ))
+
+ if use_lm and not lm_chunked:
+ with torch.inference_mode():
+ frames = model.encode(wav[None].to(model_device))
+ if not frames:
+ raise ValueError("Cannot compress an empty waveform.")
- _B, _K, T = frame.shape
- if use_lm:
- coder = ArithmeticCoder(chunk_fo)
- states = None
- offset = 0
- input_ = torch.zeros(1, K, 1, dtype=torch.long, device=coder_device)
- else:
- packer = binary.BitPacker(model.bits_per_codebook, chunk_fo)
+ codes0, _ = frames[0]
+ _, K, _ = codes0.shape
+ lm = model.get_lm_model(device=lm_device, dtype=DETERMINISTIC_LM_DTYPE).eval()
+ lm.tau = lm_tau
+ metadata: tp.Dict[str, tp.Any] = {
+ 'm': model.name,
+ 'al': audio_length,
+ 'nc': int(K),
+ 'lm': True,
+ 'fp': int(FP_SCALE),
+ 'mr': int(MIN_RANGE),
+ 'acv': 3,
+ 'tau': float(lm_tau),
+ }
+ binary.write_ecdc_header(fo, metadata)
+
+ for segment_index, ((frame, scale), offset_samples) in enumerate(zip(frames, offsets), start=1):
+ _write_frame_payload(
+ frame,
+ scale,
+ fo,
+ use_lm=True,
+ model=model,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ lm=lm,
+ lm_tau=lm_tau,
+ )
+ _emit_progress(progress_callback, _build_progress_payload(
+ stage='segment',
+ sample_rate=int(model.sample_rate),
+ total_segments=total_segments,
+ segment_index=segment_index,
+ audio_length=audio_length,
+ segment_length=segment_length,
+ segment_stride=segment_stride,
+ offset_samples=int(offset_samples),
+ ))
+ return
+
+ header_written = False
+ for segment_index, offset_samples in enumerate(offsets, start=1):
+ with torch.inference_mode():
+ segment_wav = wav[None, :, offset_samples: offset_samples + segment_length].to(model_device)
+ frame, scale = model._encode_frame(segment_wav)
- for t in range(T):
+ if not header_written:
+ _, K, _ = frame.shape
if use_lm:
- with torch.inference_mode():
- logits_raw, states, offset = lm.forward_logits(input_, states, offset)
- logits_q = _quantize_logits_(logits_raw / lm_tau, LOGIT_QSTEP)
- probas = _softmax_or_uniform(logits_q, dim=1)
-
- pdf_mat = probas[0, :, :, 0].to(coder_device)
- cdf_mat = _deterministic_cdf_multi(
- pdf_mat, coder.total_range_bits,
- fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
- cdf_cols = cdf_mat.t().contiguous()
-
- frame_slice = frame[:, :, t:t + 1].detach().to(coder_device)
- for k, value in enumerate(frame_slice[0, :, 0].tolist()):
- coder.push(value, cdf_cols[k])
- input_ = 1 + frame_slice
- else:
- for value in frame[0, :, t].detach().cpu().tolist():
- packer.push(value)
+ lm = model.get_lm_model(device=lm_device,
+ dtype=DETERMINISTIC_LM_DTYPE).eval()
+ lm.tau = lm_tau
+
+ metadata = {
+ 'm': model.name,
+ 'al': audio_length,
+ 'nc': int(K),
+ 'lm': bool(use_lm),
+ 'fp': int(FP_SCALE),
+ 'mr': int(MIN_RANGE),
+ 'acv': 4 if use_lm else 0,
+ 'tau': float(lm_tau),
+ }
+ binary.write_ecdc_header(fo, metadata)
+ header_written = True
if use_lm:
- coder.flush()
+ chunk_fo = io.BytesIO()
+ _write_frame_payload(
+ frame,
+ scale,
+ chunk_fo,
+ use_lm=True,
+ model=model,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ lm=lm,
+ lm_tau=lm_tau,
+ )
_write_chunk(fo, chunk_fo.getvalue())
else:
- packer.flush()
- fo.write(chunk_fo.getvalue())
+ _write_frame_payload(
+ frame,
+ scale,
+ fo,
+ use_lm=False,
+ model=model,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ lm=None,
+ lm_tau=lm_tau,
+ )
+
+ _emit_progress(progress_callback, _build_progress_payload(
+ stage='segment',
+ sample_rate=int(model.sample_rate),
+ total_segments=total_segments,
+ segment_index=segment_index,
+ audio_length=audio_length,
+ segment_length=segment_length,
+ segment_stride=segment_stride,
+ offset_samples=int(offset_samples),
+ ))
def decompress_from_file(fo: tp.IO[bytes],
@@ -521,10 +702,19 @@ def decompress_from_file(fo: tp.IO[bytes],
def compress(model: EncodecModel, wav: torch.Tensor,
- use_lm: bool = False) -> bytes:
+ use_lm: bool = False,
+ progress_callback: ProgressCallback = None,
+ lm_chunked: tp.Optional[bool] = None) -> bytes:
"""Compress a waveform and return bytes."""
fo = io.BytesIO()
- compress_to_file(model, wav, fo, use_lm=use_lm)
+ compress_to_file(
+ model,
+ wav,
+ fo,
+ use_lm=use_lm,
+ progress_callback=progress_callback,
+ lm_chunked=lm_chunked,
+ )
return fo.getvalue()
From f9da4cfb1a9b088a8cc1b8e4510744c80586f78b Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 20 Mar 2026 05:08:10 +0000
Subject: [PATCH 13/24] Parallelize chunked LM segment encoding
---
encodec/compress.py | 238 ++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 238 insertions(+)
diff --git a/encodec/compress.py b/encodec/compress.py
index f1f941e..0192725 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -6,12 +6,16 @@
"""API to compress/decompress audio to bytestreams."""
+import atexit
+import concurrent.futures
import io
import math
+import multiprocessing
import os
import struct
import typing as tp
import zlib
+from concurrent.futures.process import BrokenProcessPool
import torch
@@ -81,11 +85,16 @@ def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float32)
LM_DEVICE_MODE = _env_choice("ENCODEC_LM_DEVICE", "cpu", {"cpu", "model"})
LM_CHUNKED_DEFAULT = _env_bool("ENCODEC_LM_CHUNKED", True)
+SEGMENT_WORKERS_DEFAULT = _env_int("ENCODEC_SEGMENT_WORKERS", 1)
_IDX_CACHE: tp.Dict[tp.Tuple[str, int, int], torch.Tensor] = {}
_UNIFORM_CDF_CACHE: tp.Dict[tp.Tuple[str, int, int, int, int], torch.Tensor] = {}
_CHUNK_HEADER = struct.Struct('!II') # chunk_len (uint32 BE), crc32 (uint32 BE)
ProgressCallback = tp.Optional[tp.Callable[[tp.Dict[str, tp.Any]], None]]
+_WORKER_MODEL_CACHE: tp.Dict[tp.Tuple[str, float], EncodecModel] = {}
+_WORKER_LM_CACHE: tp.Dict[tp.Tuple[str, float, str], tp.Any] = {}
+_PARALLEL_EXECUTOR: tp.Optional[concurrent.futures.ProcessPoolExecutor] = None
+_PARALLEL_EXECUTOR_WORKERS = 0
# ---------------------------------------------------------------------------
@@ -303,6 +312,170 @@ def _build_progress_payload(
payload['offsetSamples'] = int(offset_samples)
return payload
+
+def _parallel_segment_worker_count(
+ total_segments: int,
+ *,
+ use_lm: bool,
+ lm_chunked: bool,
+ model_device: torch.device,
+) -> int:
+ configured = SEGMENT_WORKERS_DEFAULT
+ if configured <= 0:
+ configured = os.cpu_count() or 1
+ if (
+ configured <= 1
+ or total_segments <= 1
+ or not use_lm
+ or not lm_chunked
+ or model_device.type != 'cpu'
+ or LM_DEVICE_MODE != 'cpu'
+ ):
+ return 1
+ return max(1, min(int(configured), int(total_segments)))
+
+
+def _build_segment_batches(
+ wav: torch.Tensor,
+ offsets: tp.List[int],
+ segment_length: int,
+ worker_count: int,
+) -> tp.List[tp.List[tp.Tuple[int, int, torch.Tensor]]]:
+ batch_count = max(1, min(worker_count, len(offsets)))
+ batch_size = int(math.ceil(len(offsets) / batch_count))
+ batches: tp.List[tp.List[tp.Tuple[int, int, torch.Tensor]]] = []
+ for start in range(0, len(offsets), batch_size):
+ batch: tp.List[tp.Tuple[int, int, torch.Tensor]] = []
+ for absolute_index, offset_samples in enumerate(offsets[start:start + batch_size], start=start + 1):
+ segment = wav[:, offset_samples: offset_samples + segment_length].detach().cpu().contiguous()
+ batch.append((absolute_index, int(offset_samples), segment))
+ batches.append(batch)
+ return batches
+
+
+def _init_parallel_worker_runtime() -> None:
+ torch.use_deterministic_algorithms(True)
+ torch.backends.mkldnn.enabled = False
+ try:
+ torch.set_num_threads(1)
+ except RuntimeError:
+ pass
+
+
+def _shutdown_parallel_executor() -> None:
+ global _PARALLEL_EXECUTOR
+ global _PARALLEL_EXECUTOR_WORKERS
+ executor = _PARALLEL_EXECUTOR
+ _PARALLEL_EXECUTOR = None
+ _PARALLEL_EXECUTOR_WORKERS = 0
+ if executor is not None:
+ executor.shutdown(wait=False, cancel_futures=True)
+
+
+def _get_parallel_executor(worker_count: int) -> concurrent.futures.ProcessPoolExecutor:
+ global _PARALLEL_EXECUTOR
+ global _PARALLEL_EXECUTOR_WORKERS
+ if worker_count <= 1:
+ raise ValueError('worker_count must be greater than 1 for the parallel executor.')
+ if _PARALLEL_EXECUTOR is None or _PARALLEL_EXECUTOR_WORKERS != worker_count:
+ _shutdown_parallel_executor()
+ _PARALLEL_EXECUTOR = concurrent.futures.ProcessPoolExecutor(
+ max_workers=worker_count,
+ mp_context=multiprocessing.get_context('spawn'),
+ initializer=_init_parallel_worker_runtime,
+ )
+ _PARALLEL_EXECUTOR_WORKERS = worker_count
+ return _PARALLEL_EXECUTOR
+ try:
+ torch.set_num_interop_threads(1)
+ except RuntimeError:
+ pass
+
+
+def _get_parallel_worker_model(
+ model_name: str,
+ bandwidth: float,
+ *,
+ use_lm: bool,
+ lm_tau: float,
+) -> tp.Tuple[EncodecModel, tp.Optional[tp.Any]]:
+ model_key = (model_name, float(bandwidth))
+ model = _WORKER_MODEL_CACHE.get(model_key)
+ if model is None:
+ model = MODELS[model_name]().eval()
+ model.set_target_bandwidth(float(bandwidth))
+ model.to('cpu')
+ _WORKER_MODEL_CACHE[model_key] = model
+
+ lm = None
+ if use_lm:
+ lm_key = (model_name, float(bandwidth), str(DETERMINISTIC_LM_DTYPE))
+ lm = _WORKER_LM_CACHE.get(lm_key)
+ if lm is None:
+ lm = model.get_lm_model(
+ device=torch.device('cpu'),
+ dtype=DETERMINISTIC_LM_DTYPE,
+ ).eval()
+ _WORKER_LM_CACHE[lm_key] = lm
+ lm.tau = float(lm_tau)
+
+ return model, lm
+
+
+def _encode_segment_batch_worker(
+ model_name: str,
+ bandwidth: float,
+ use_lm: bool,
+ lm_tau: float,
+ batch: tp.List[tp.Tuple[int, int, torch.Tensor]],
+) -> dict:
+ _init_parallel_worker_runtime()
+ model, lm = _get_parallel_worker_model(
+ model_name,
+ bandwidth,
+ use_lm=use_lm,
+ lm_tau=lm_tau,
+ )
+ coder_device = torch.device('cpu')
+ lm_device = torch.device('cpu')
+ segments: tp.List[tp.Tuple[int, int, bytes]] = []
+ num_codebooks: tp.Optional[int] = None
+
+ for segment_index, offset_samples, segment in batch:
+ segment_wav = segment.unsqueeze(0)
+ with torch.inference_mode():
+ frame, scale = model._encode_frame(segment_wav.to(coder_device))
+ if num_codebooks is None:
+ num_codebooks = int(frame.shape[1])
+
+ payload_fo = io.BytesIO()
+ _write_frame_payload(
+ frame,
+ scale,
+ payload_fo,
+ use_lm=use_lm,
+ model=model,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ lm=lm,
+ lm_tau=lm_tau,
+ )
+
+ framed_fo = io.BytesIO()
+ if use_lm:
+ _write_chunk(framed_fo, payload_fo.getvalue())
+ else:
+ framed_fo.write(payload_fo.getvalue())
+ segments.append((int(segment_index), int(offset_samples), framed_fo.getvalue()))
+
+ return {
+ 'numCodebooks': int(num_codebooks or 0),
+ 'segments': segments,
+ }
+
+
+atexit.register(_shutdown_parallel_executor)
+
def _write_chunk(fo: tp.IO[bytes], payload: bytes) -> None:
"""Write a CRC-protected chunk: [len: u32][crc: u32][payload]."""
fo.write(_CHUNK_HEADER.pack(len(payload), zlib.crc32(payload) & 0xffffffff))
@@ -475,6 +648,71 @@ def compress_to_file(model: EncodecModel, wav: torch.Tensor, fo: tp.IO[bytes],
))
return
+ parallel_workers = _parallel_segment_worker_count(
+ total_segments,
+ use_lm=use_lm,
+ lm_chunked=bool(lm_chunked),
+ model_device=model_device,
+ )
+
+ if parallel_workers > 1:
+ num_codebooks = int(model.quantizer.get_num_quantizers_for_bandwidth(
+ model.frame_rate,
+ model.bandwidth,
+ ))
+ metadata = {
+ 'm': model.name,
+ 'al': audio_length,
+ 'nc': num_codebooks,
+ 'lm': bool(use_lm),
+ 'fp': int(FP_SCALE),
+ 'mr': int(MIN_RANGE),
+ 'acv': 4 if use_lm else 0,
+ 'tau': float(lm_tau),
+ }
+ binary.write_ecdc_header(fo, metadata)
+
+ batches = _build_segment_batches(wav, offsets, segment_length, parallel_workers)
+ completed_segments = 0
+ ordered_results: tp.List[dict] = []
+ executor = _get_parallel_executor(parallel_workers)
+ try:
+ futures = [
+ executor.submit(
+ _encode_segment_batch_worker,
+ model.name,
+ float(model.bandwidth or 0.0),
+ bool(use_lm),
+ float(lm_tau),
+ batch,
+ )
+ for batch in batches
+ ]
+
+ for future in concurrent.futures.as_completed(futures):
+ result = future.result()
+ ordered_results.append(result)
+ completed_segments += len(result['segments'])
+ last_index, last_offset, _ = result['segments'][-1]
+ _emit_progress(progress_callback, _build_progress_payload(
+ stage='segment',
+ sample_rate=int(model.sample_rate),
+ total_segments=total_segments,
+ segment_index=min(completed_segments, total_segments),
+ audio_length=audio_length,
+ segment_length=segment_length,
+ segment_stride=segment_stride,
+ offset_samples=int(last_offset),
+ ))
+ except BrokenProcessPool:
+ _shutdown_parallel_executor()
+ raise
+
+ for result in sorted(ordered_results, key=lambda item: item['segments'][0][0]):
+ for _, _, framed_payload in result['segments']:
+ fo.write(framed_payload)
+ return
+
header_written = False
for segment_index, offset_samples in enumerate(offsets, start=1):
with torch.inference_mode():
From b00c5bd7a8126f3759343f10ebd9818ddd95c838 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Sat, 21 Mar 2026 13:14:17 +0000
Subject: [PATCH 14/24] Tighten cross-host deterministic LM defaults
---
README.md | 8 ++++----
encodec/compress.py | 15 ++++++---------
encodec/model.py | 7 ++++---
3 files changed, 14 insertions(+), 16 deletions(-)
diff --git a/README.md b/README.md
index 75e63ea..f5c1f0c 100644
--- a/README.md
+++ b/README.md
@@ -111,10 +111,10 @@ All settings are overridable via environment variables:
| Variable | Default | Notes |
|---|---|---|
| `ENCODEC_LM_TAU` | `1.0` | Softmax temperature. `1.0` is optimal for compression. |
-| `ENCODEC_LOGIT_QSTEP` | `1/128` | Logit quantisation grid size. |
-| `ENCODEC_AC_FP_SCALE` | `65536` | Integer scale for CDF allocation (`2^16`). |
-| `ENCODEC_AC_MIN_RANGE` | `1` | Minimum CDF range per symbol. |
-| `ENCODEC_DETERMINISTIC_LM_DTYPE` | `float32` | LM weight dtype. `float32` is faster; `float64` available. |
+| `ENCODEC_LOGIT_QSTEP` | `1/64` | Logit quantisation grid size. Slightly coarser is safer cross-host. |
+| `ENCODEC_AC_FP_SCALE` | `8192` | Integer scale for CDF allocation (`2^13`). |
+| `ENCODEC_AC_MIN_RANGE` | `2` | Minimum CDF range per symbol. Wider bins improve portability. |
+| `ENCODEC_DETERMINISTIC_LM_DTYPE` | `float64` | LM weight dtype. `float64` is safer for cross-host determinism; `float32` is faster. |
| `ENCODEC_USE_NEAR_UNIFORM` | `0` | Enable near-uniform prior (off by default). |
### Compression results
diff --git a/encodec/compress.py b/encodec/compress.py
index 0192725..8880d51 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -76,13 +76,14 @@ def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
raise ValueError(f"Unsupported value {v!r} for {name}. Expected one of: {allowed}.")
return value
-# Lean defaults: float32 LM, finer logit grid, high-precision CDF allocation.
-LOGIT_QSTEP = _env_float("ENCODEC_LOGIT_QSTEP", 1.0 / 128.0)
+# Conservative defaults: float64 LM, slightly coarser logit quantisation,
+# and wider CDF bins for better cross-host determinism.
+LOGIT_QSTEP = _env_float("ENCODEC_LOGIT_QSTEP", 1.0 / 64.0)
LM_TAU = _env_float("ENCODEC_LM_TAU", 1.0)
-FP_SCALE = _env_int("ENCODEC_AC_FP_SCALE", 1 << 16)
-MIN_RANGE = _env_int("ENCODEC_AC_MIN_RANGE", 1)
+FP_SCALE = _env_int("ENCODEC_AC_FP_SCALE", 1 << 13)
+MIN_RANGE = _env_int("ENCODEC_AC_MIN_RANGE", 2)
USE_NEAR_UNIFORM = _env_bool("ENCODEC_USE_NEAR_UNIFORM", False)
-DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float32)
+DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float64)
LM_DEVICE_MODE = _env_choice("ENCODEC_LM_DEVICE", "cpu", {"cpu", "model"})
LM_CHUNKED_DEFAULT = _env_bool("ENCODEC_LM_CHUNKED", True)
SEGMENT_WORKERS_DEFAULT = _env_int("ENCODEC_SEGMENT_WORKERS", 1)
@@ -356,10 +357,6 @@ def _build_segment_batches(
def _init_parallel_worker_runtime() -> None:
torch.use_deterministic_algorithms(True)
torch.backends.mkldnn.enabled = False
- try:
- torch.set_num_threads(1)
- except RuntimeError:
- pass
def _shutdown_parallel_executor() -> None:
diff --git a/encodec/model.py b/encodec/model.py
index ad66b1f..a25a100 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -204,14 +204,15 @@ def set_target_bandwidth(self, bandwidth: float):
def get_lm_model(self,
device: tp.Optional[torch.device] = None,
- dtype: torch.dtype = torch.float32) -> LMModel:
+ dtype: torch.dtype = torch.float64) -> LMModel:
"""Load the pre-trained language model for entropy coding.
Args:
device: target device (defaults to CPU — LM must stay on CPU for
cross-platform arithmetic-coder determinism).
- dtype: LM weight dtype. float32 is faster and sufficient when
- combined with the deterministic logit-quantisation path.
+ dtype: LM weight dtype. float64 is the safer default for
+ cross-host determinism; float32 can be selected when
+ speed matters more than exact portability.
"""
device = torch.device("cpu") if device is None else device
lm = LMModel(self.quantizer.n_q, self.quantizer.bins, num_layers=5, dim=200,
From 878257867fb8746f77fd3b120109a2eef079a8c5 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 01:52:46 +0100
Subject: [PATCH 15/24] Checkpoint native entropy coding and CUDA decode LM
---
PR_DESCRIPTION.md | 95 +++
encodec/compress.py | 393 ++++++++++--
encodec/torch_ext.py | 57 ++
native/encodec_ac/Cargo.lock | 229 +++++++
native/encodec_ac/Cargo.toml | 12 +
native/encodec_ac/src/lib.rs | 586 ++++++++++++++++++
.../encodec_torch_ext/encodec_torch_ext.cpp | 411 ++++++++++++
7 files changed, 1716 insertions(+), 67 deletions(-)
create mode 100644 PR_DESCRIPTION.md
create mode 100644 encodec/torch_ext.py
create mode 100644 native/encodec_ac/Cargo.lock
create mode 100644 native/encodec_ac/Cargo.toml
create mode 100644 native/encodec_ac/src/lib.rs
create mode 100644 native/encodec_torch_ext/encodec_torch_ext.cpp
diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md
new file mode 100644
index 0000000..384400f
--- /dev/null
+++ b/PR_DESCRIPTION.md
@@ -0,0 +1,95 @@
+# Deterministic cross-platform LM entropy coding, acv=4 CRC chunk framing, and `_counts_from_pdf` bug fix
+
+## Summary
+
+This PR hardens the LM-backed entropy coding path for cross-platform correctness and adds per-segment failure isolation. The neural network weights and audio quality are unchanged. All existing `.ecdc` files decode correctly.
+
+## Motivation
+
+Three problems with the current LM entropy path:
+
+1. **Non-deterministic across hardware.** `torch.softmax` can differ by a ULP between CPU, MPS, and CUDA. The arithmetic coder amplifies these differences — a single wrong probability pushes the decode state off track, producing `EOFError` or silent garbage. Payloads encoded on an Apple Silicon Mac reliably fail to decode on Linux CPU or CUDA.
+
+2. **Silent corrupt decode at `tau=1.0`.** In `_counts_from_pdf`, the near-integer perturbation uses an alternating sign. When a token's probability is exactly `0.0` (common at `tau=1.0` due to float underflow of `exp(-large)`), the negative perturbation gives `x = -ε`, then `floor(-ε) = -1`. A negative count makes the CDF non-monotonic; the decoder produces wrong symbols with no error raised.
+
+3. **No failure isolation.** A single corrupt byte anywhere in the payload desynchronises the arithmetic decoder and destroys the rest of the file.
+
+## Changes
+
+### `encodec/compress.py`
+
+**Deterministic CDF construction**
+
+- `_stable_softmax`: computes softmax in float64 using a sequential cumsum denominator rather than `torch.softmax`. Cross-architecture bit-reproducibility verified Mac CPU/MPS → Linux CPU/CUDA.
+- `_quantize_logits_`: rounds logits to a 1/128 grid before softmax. Tiny floating-point differences that don't change the quantised logit produce identical CDFs.
+- `_counts_from_pdf`: adds `clamp_min(0)` after the near-integer perturbation step, fixing the negative-count bug at `tau=1.0`.
+- `_deterministic_cdf` / `_deterministic_cdf_multi`: integer floor + priority allocation CDF construction at `FP_SCALE=65536` precision. Replaces float-based CDF that was sensitive to platform differences.
+
+**Bitstream version `acv=4` with CRC chunk framing**
+
+- Each model segment is wrapped in `[chunk_len: u32 BE][crc32: u32 BE][payload]`.
+- A corrupt chunk is replaced with silence for that segment; the rest of the file decodes normally.
+- `tau` is stored in the header so encoder and decoder are always in sync without out-of-band configuration.
+
+**GPU reliability**
+
+- `compress_to_file` detects the model device and moves the waveform there automatically (`wav[None].to(model_device)`). Previously crashed when the model was on MPS or CUDA.
+- LM and arithmetic coder always run on CPU for cross-platform determinism regardless of model device.
+
+**Tunable defaults** (via env vars; existing behaviour unchanged if not set):
+
+| Variable | Default |
+|---|---|
+| `ENCODEC_LM_TAU` | `1.0` |
+| `ENCODEC_LOGIT_QSTEP` | `1/128` |
+| `ENCODEC_AC_FP_SCALE` | `65536` |
+| `ENCODEC_AC_MIN_RANGE` | `1` |
+| `ENCODEC_DETERMINISTIC_LM_DTYPE` | `float32` |
+
+### `encodec/model.py`
+
+- `LMModel.forward_logits`: factored out from `forward` so the deterministic and legacy paths share the transformer forward pass.
+- `LMModel.forward_legacy`: raw softmax with no quantisation, used for decoding `acv < 3` streams.
+- `LMModel.__init__`: accepts `tau` parameter.
+- `EncodecModel.get_lm_model`: accepts `device` and `dtype` parameters for explicit LM placement.
+
+### `scripts/`
+
+- `precision_eval.py`: CLI for benchmarking bitrate, SNR, encode/decode wall time, CPU vs MPS, LM vs non-LM, and single-byte corruption behaviour (targets chunk bodies, not headers/CRC).
+- `payload_decode_matrix.py`: decodes a payload across CPU and CUDA and compares results; intended for cross-host determinism validation.
+
+## Backwards compatibility
+
+**Reading old streams: fully preserved.** The decoder reads the `acv` field from the stream header and routes accordingly:
+
+| `acv` | Path | Notes |
+|---|---|---|
+| `0` | Raw bitpacking, no LM | Unchanged |
+| `1` / `2` | Legacy LM via `forward_legacy()` | Original `torch.softmax`, no quantisation — decodes exactly as before |
+| `4` | New deterministic path | This PR |
+
+**Writing:** `compress(..., use_lm=False)` still produces `acv=0` raw streams identical to before. `compress(..., use_lm=True)` now produces `acv=4`; old decoders will reject `acv=4` streams with an unsupported-version error (the version field exists for this purpose).
+
+**API surface:** no breaking changes. `compress`, `decompress`, `compress_to_file`, `decompress_from_file` retain the same signatures. The `EncodecModel` public API is unchanged.
+
+## Test results
+
+Benchmarked on 7 stereo 48 kHz music tracks, 10 s clips, `encodec_48khz`, all 7 tracks decoded without error on every device:
+
+| Bandwidth | Device | Avg actual kbps | LM gain vs raw | Encode RTF | Decode RTF |
+|---|---|---|---|---|---|
+| 6 kbps | CPU | 4.34 | 27.7% | 0.26× | 0.27× |
+| 6 kbps | MPS | 4.34 | 27.7% | 0.33× | 0.27× |
+| 24 kbps | CPU | 19.3 | 19.9% | 0.39× | 0.41× |
+| 24 kbps | MPS | 19.3 | 19.9% | 0.47× | 0.40× |
+
+CPU and MPS produce byte-identical payloads and identical decoded audio (same kbps, same SNR). Zero decode failures across all tracks, bandwidths, and devices.
+
+Cross-device decode matrix (payloads encoded on Apple Silicon Mac):
+
+| Encode | Decode | Before | After |
+|---|---|---|---|
+| Mac CPU | Linux CPU | `EOFError` | ✓ |
+| Mac CPU | Linux CUDA | `EOFError` | ✓ |
+| Mac MPS | Linux CPU | `EOFError` | ✓ |
+| Mac MPS | Linux CUDA | `EOFError` | ✓ |
diff --git a/encodec/compress.py b/encodec/compress.py
index 8880d51..3da7db6 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -19,6 +19,16 @@
import torch
+try:
+ import encodec_native as _encodec_native
+except ImportError:
+ _encodec_native = None
+
+try:
+ from . import torch_ext as _torch_ext_loader
+except Exception:
+ _torch_ext_loader = None
+
from . import binary
from .model import EncodecModel, EncodedFrame
from .quantization.ac import (
@@ -85,8 +95,12 @@ def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
USE_NEAR_UNIFORM = _env_bool("ENCODEC_USE_NEAR_UNIFORM", False)
DETERMINISTIC_LM_DTYPE = _env_dtype("ENCODEC_DETERMINISTIC_LM_DTYPE", torch.float64)
LM_DEVICE_MODE = _env_choice("ENCODEC_LM_DEVICE", "cpu", {"cpu", "model"})
+DECODE_LM_DEVICE_MODE = _env_choice("ENCODEC_DECODE_LM_DEVICE", "auto", {"auto", "cpu", "model"})
LM_CHUNKED_DEFAULT = _env_bool("ENCODEC_LM_CHUNKED", True)
SEGMENT_WORKERS_DEFAULT = _env_int("ENCODEC_SEGMENT_WORKERS", 1)
+NATIVE_AC_ENABLED = _env_bool("ENCODEC_NATIVE_AC", True)
+TORCH_EXT_AC_ENABLED = _env_bool("ENCODEC_TORCH_EXT", False)
+ARITHMETIC_TOTAL_RANGE_BITS = 24
_IDX_CACHE: tp.Dict[tp.Tuple[str, int, int], torch.Tensor] = {}
_UNIFORM_CDF_CACHE: tp.Dict[tp.Tuple[str, int, int, int, int], torch.Tensor] = {}
@@ -94,8 +108,15 @@ def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
ProgressCallback = tp.Optional[tp.Callable[[tp.Dict[str, tp.Any]], None]]
_WORKER_MODEL_CACHE: tp.Dict[tp.Tuple[str, float], EncodecModel] = {}
_WORKER_LM_CACHE: tp.Dict[tp.Tuple[str, float, str], tp.Any] = {}
+# Preview/audio decode is a hot path in scratch.fm, so keep decoder models and
+# LM instances alive instead of rebuilding them for every payload.
+_DECODE_MODEL_CACHE: tp.Dict[tp.Tuple[str, str], EncodecModel] = {}
+_DECODE_LM_CACHE: tp.Dict[tp.Tuple[str, str, str, float], tp.Any] = {}
+_DECODE_LEGACY_LM_CACHE: tp.Dict[tp.Tuple[str, str, str], tp.Any] = {}
_PARALLEL_EXECUTOR: tp.Optional[concurrent.futures.ProcessPoolExecutor] = None
_PARALLEL_EXECUTOR_WORKERS = 0
+_TORCH_AC_MODULE: tp.Optional[tp.Any] = None
+_TORCH_AC_LOAD_FAILED = False
# ---------------------------------------------------------------------------
@@ -160,6 +181,92 @@ def _softmax_or_uniform(x: torch.Tensor, dim: int) -> torch.Tensor:
return torch.where(near, u, s)
+def _torch_ac_module() -> tp.Optional[tp.Any]:
+ global _TORCH_AC_MODULE, _TORCH_AC_LOAD_FAILED
+ if not TORCH_EXT_AC_ENABLED or _torch_ext_loader is None or _TORCH_AC_LOAD_FAILED:
+ return None
+ if _TORCH_AC_MODULE is not None:
+ return _TORCH_AC_MODULE
+ try:
+ _TORCH_AC_MODULE = _torch_ext_loader.load_extension()
+ except Exception:
+ _TORCH_AC_LOAD_FAILED = True
+ return None
+ return _TORCH_AC_MODULE
+
+
+def _tensor_native_ac_module() -> tp.Optional[tp.Any]:
+ module = _torch_ac_module()
+ if module is not None:
+ return module
+ if NATIVE_AC_ENABLED and _encodec_native is not None:
+ return _encodec_native
+ return None
+
+
+def _native_ac_available() -> bool:
+ return _tensor_native_ac_module() is not None
+
+
+def _can_batch_lm_encode(lm_device: torch.device, coder_device: torch.device) -> bool:
+ # Only batch the deterministic CPU path that we have validated byte-for-byte
+ # against the existing stepwise encoder.
+ return lm_device.type == "cpu" and coder_device.type == "cpu"
+
+
+def _compute_lm_probas_for_frame(
+ frame: torch.Tensor,
+ *,
+ lm: tp.Any,
+ lm_device: torch.device,
+ lm_tau: float,
+) -> torch.Tensor:
+ """Run the LM over a whole frame with teacher forcing.
+
+ The returned probabilities are shaped [1, card, K, T] and match the
+ stepwise encoder's quantized CDFs on the deterministic CPU path.
+ """
+ _B, K, T = frame.shape
+ if T <= 0:
+ raise ValueError("LM frame must contain at least one timestep.")
+
+ prefix = torch.zeros(1, K, 1, dtype=torch.long, device=lm_device)
+ if T == 1:
+ teacher = prefix
+ else:
+ teacher = torch.cat([prefix, 1 + frame[:, :, :-1].detach().to(lm_device)], dim=-1)
+
+ with torch.inference_mode():
+ logits_raw, _, _ = lm.forward_logits(teacher, None, 0)
+ logits_q = _quantize_logits_(logits_raw / lm_tau, LOGIT_QSTEP)
+ return _softmax_or_uniform(logits_q, dim=1)
+
+
+def _flatten_lm_block_for_coder(
+ probas: torch.Tensor,
+ frame: torch.Tensor,
+ *,
+ coder_device: torch.device,
+) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Flatten a full LM block into time-major columns for the entropy coder."""
+ pdf_cols = (
+ probas[0]
+ .permute(0, 2, 1)
+ .contiguous()
+ .reshape(probas.shape[1], -1)
+ .to(coder_device)
+ )
+ symbols = (
+ frame[0]
+ .transpose(0, 1)
+ .contiguous()
+ .reshape(-1)
+ .detach()
+ .to(coder_device)
+ )
+ return pdf_cols, symbols
+
+
def _deterministic_cdf(pdf: torch.Tensor,
total_range_bits: int,
fp_scale: int = FP_SCALE,
@@ -419,6 +526,78 @@ def _get_parallel_worker_model(
return model, lm
+def _device_key(device: tp.Union[str, torch.device]) -> str:
+ return str(torch.device(device))
+
+
+def _get_decode_model(model_name: str, device: tp.Union[str, torch.device]) -> EncodecModel:
+ key = (model_name, _device_key(device))
+ model = _DECODE_MODEL_CACHE.get(key)
+ if model is None:
+ model = MODELS[model_name]().to(device).eval()
+ _DECODE_MODEL_CACHE[key] = model
+ return model
+
+
+def _select_decode_lm_device(
+ *,
+ model_device: tp.Union[str, torch.device],
+ coder_device: tp.Union[str, torch.device],
+ acv: int,
+) -> torch.device:
+ model_device = torch.device(model_device)
+ coder_device = torch.device(coder_device)
+
+ if acv < 3:
+ return coder_device
+ if DECODE_LM_DEVICE_MODE == "cpu":
+ return coder_device
+ if DECODE_LM_DEVICE_MODE == "model":
+ return model_device
+ # Auto: keep legacy / CPU-safe behavior everywhere except CUDA, where the
+ # deterministic float64 LM path is materially faster and parity-clean.
+ if model_device.type == "cuda":
+ return model_device
+ return coder_device
+
+
+def _get_decode_lms(
+ model: EncodecModel,
+ *,
+ model_name: str,
+ coder_device: tp.Union[str, torch.device],
+ lm_device: tp.Union[str, torch.device],
+ use_lm: bool,
+ acv: int,
+ lm_tau: float,
+) -> tp.Tuple[tp.Optional[tp.Any], tp.Optional[tp.Any]]:
+ coder_key = _device_key(coder_device)
+ if not use_lm:
+ return None, None
+
+ if acv >= 3:
+ lm_key = (model_name, _device_key(lm_device), str(DETERMINISTIC_LM_DTYPE), float(lm_tau))
+ lm = _DECODE_LM_CACHE.get(lm_key)
+ if lm is None:
+ lm = model.get_lm_model(
+ device=torch.device(lm_device),
+ dtype=DETERMINISTIC_LM_DTYPE,
+ ).eval()
+ lm.tau = float(lm_tau)
+ _DECODE_LM_CACHE[lm_key] = lm
+ return lm, None
+
+ legacy_key = (model_name, coder_key, str(torch.float32))
+ legacy_lm = _DECODE_LEGACY_LM_CACHE.get(legacy_key)
+ if legacy_lm is None:
+ legacy_lm = model.get_lm_model(
+ device=torch.device(coder_key),
+ dtype=torch.float32,
+ ).eval()
+ _DECODE_LEGACY_LM_CACHE[legacy_key] = legacy_lm
+ return None, legacy_lm
+
+
def _encode_segment_batch_worker(
model_name: str,
bandwidth: float,
@@ -511,7 +690,46 @@ def _write_frame_payload(
_B, K, T = frame.shape
if use_lm:
assert lm is not None
- coder = ArithmeticCoder(fo)
+ native_coder = None
+ coder = None
+ native_module = _tensor_native_ac_module()
+ if native_module is not None:
+ native_coder = native_module.ArithmeticEncoder(ARITHMETIC_TOTAL_RANGE_BITS)
+ else:
+ coder = ArithmeticCoder(fo, total_range_bits=ARITHMETIC_TOTAL_RANGE_BITS)
+ if _can_batch_lm_encode(lm_device, coder_device):
+ probas = _compute_lm_probas_for_frame(
+ frame,
+ lm=lm,
+ lm_device=lm_device,
+ lm_tau=lm_tau,
+ )
+ pdf_cols, symbol_tensor = _flatten_lm_block_for_coder(
+ probas,
+ frame,
+ coder_device=coder_device,
+ )
+ if native_coder is not None:
+ native_coder.push_pdf_symbols_torch(
+ pdf_cols.detach().contiguous(),
+ symbol_tensor.detach().contiguous(),
+ FP_SCALE,
+ MIN_RANGE,
+ )
+ else:
+ assert coder is not None
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_cols, coder.total_range_bits,
+ fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+ for col, value in enumerate(symbol_tensor.tolist()):
+ coder.push(value, cdf_cols[col])
+ if native_coder is not None:
+ fo.write(bytes(native_coder.finish()))
+ else:
+ assert coder is not None
+ coder.flush()
+ return
states = None
offset = 0
input_ = torch.zeros(1, K, 1, dtype=torch.long, device=lm_device)
@@ -526,21 +744,34 @@ def _write_frame_payload(
probas = _softmax_or_uniform(logits_q, dim=1)
pdf_mat = probas[0, :, :, 0].to(coder_device)
- cdf_mat = _deterministic_cdf_multi(
- pdf_mat, coder.total_range_bits,
- fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
- cdf_cols = cdf_mat.t().contiguous()
-
frame_slice = frame[:, :, t:t + 1].detach().to(coder_device)
- for k, value in enumerate(frame_slice[0, :, 0].tolist()):
- coder.push(value, cdf_cols[k])
+ symbol_tensor = frame_slice[0, :, 0].detach().contiguous()
+ if native_coder is not None:
+ native_coder.push_pdf_symbols_torch(
+ pdf_mat.detach().contiguous(),
+ symbol_tensor,
+ FP_SCALE,
+ MIN_RANGE,
+ )
+ else:
+ assert coder is not None
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat, coder.total_range_bits,
+ fp_scale=FP_SCALE, min_range=MIN_RANGE, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+ for k, value in enumerate(symbol_tensor.tolist()):
+ coder.push(value, cdf_cols[k])
input_ = (1 + frame_slice).to(lm_device)
else:
for value in frame[0, :, t].detach().cpu().tolist():
packer.push(value)
if use_lm:
- coder.flush()
+ if native_coder is not None:
+ fo.write(bytes(native_coder.finish()))
+ else:
+ assert coder is not None
+ coder.flush()
else:
packer.flush()
@@ -785,8 +1016,9 @@ def decompress_from_file(fo: tp.IO[bytes],
* acv=4 — deterministic LM streams (this implementation).
Corrupt segments fall back to silence rather than aborting.
- The model (EnCodec encoder/decoder) runs on ``device``; the LM and
- arithmetic coder always run on CPU.
+ The model (EnCodec encoder/decoder) runs on ``device``. The arithmetic
+ coder always runs on CPU; the deterministic LM path can run on the model
+ device when configured.
"""
metadata = binary.read_ecdc_header(fo)
model_name = metadata['m']
@@ -805,20 +1037,24 @@ def decompress_from_file(fo: tp.IO[bytes],
if acv > 4:
raise ValueError(f"Unsupported bitstream version {acv}; re-encode.")
- model = MODELS[model_name]().to(device).eval()
+ model = _get_decode_model(model_name, device)
model_device = next(model.parameters()).device
coder_device = torch.device("cpu")
+ lm_device = _select_decode_lm_device(
+ model_device=model_device,
+ coder_device=coder_device,
+ acv=acv,
+ )
- lm = None
- legacy_lm = None
- if use_lm and acv >= 3:
- lm = model.get_lm_model(device=coder_device,
- dtype=DETERMINISTIC_LM_DTYPE).eval()
- lm.tau = lm_tau
- elif use_lm:
- # Legacy streams: original Facebook LM path (float32, no quantisation).
- legacy_lm = model.get_lm_model(device=coder_device,
- dtype=torch.float32).eval()
+ lm, legacy_lm = _get_decode_lms(
+ model,
+ model_name=model_name,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ use_lm=use_lm,
+ acv=acv,
+ lm_tau=lm_tau,
+ )
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
@@ -847,67 +1083,90 @@ def decompress_from_file(fo: tp.IO[bytes],
scale = None
if use_lm:
- decoder = ArithmeticDecoder(frame_fo)
+ native_decoder = None
+ code_buf = None
+ decoder = None
+ native_module = _tensor_native_ac_module() if acv == 4 else None
+ if native_module is not None:
+ native_decoder = native_module.ArithmeticDecoder(
+ frame_fo.read(),
+ ARITHMETIC_TOTAL_RANGE_BITS,
+ )
+ code_buf = torch.empty(num_codebooks, dtype=torch.long, device=coder_device)
+ else:
+ decoder = ArithmeticDecoder(frame_fo, total_range_bits=ARITHMETIC_TOTAL_RANGE_BITS)
states = None
offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long,
- device=coder_device)
+ device=lm_device if acv >= 3 else coder_device)
else:
unpacker = binary.BitUnpacker(model.bits_per_codebook, frame_fo)
frame = torch.zeros(1, num_codebooks, frame_length,
dtype=torch.long, device=coder_device)
try:
- for t in range(frame_length):
- if use_lm and acv >= 3:
- with torch.inference_mode():
+ with torch.inference_mode():
+ for t in range(frame_length):
+ if use_lm and acv >= 3:
logits_raw, states, offset = lm.forward_logits(
input_, states, offset)
logits_q = _quantize_logits_(logits_raw / lm_tau,
LOGIT_QSTEP)
probas = _softmax_or_uniform(logits_q, dim=1)
-
- pdf_mat = probas[0, :, :, 0].to(coder_device)
- cdf_mat = _deterministic_cdf_multi(
- pdf_mat, decoder.total_range_bits,
- fp_scale=fp_scale, min_range=min_range, check=False)
- cdf_cols = cdf_mat.t().contiguous()
- code_list: tp.List[int] = []
- for k in range(num_codebooks):
- code = decoder.pull(cdf_cols[k])
- if code is None:
- raise EOFError("Stream ended before expected.")
- code_list.append(code)
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
- device=coder_device)
- input_ = 1 + frame[:, :, t:t + 1]
-
- elif use_lm: # legacy path
- with torch.inference_mode():
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ if native_decoder is not None:
+ assert code_buf is not None
+ native_decoder.pull_symbols_into_torch(
+ pdf_mat.detach().contiguous(),
+ code_buf,
+ fp_scale,
+ min_range,
+ )
+ frame[0, :, t] = code_buf
+ input_ = 1 + code_buf.view(1, num_codebooks, 1).to(
+ lm_device,
+ )
+ else:
+ assert decoder is not None
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat, decoder.total_range_bits,
+ fp_scale=fp_scale, min_range=min_range, check=False)
+ cdf_cols = cdf_mat.t().contiguous()
+ code_list = []
+ for k in range(num_codebooks):
+ code = decoder.pull(cdf_cols[k])
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
+ input_ = (1 + frame[:, :, t:t + 1]).to(lm_device)
+
+ elif use_lm: # legacy path
probas, states, offset = legacy_lm.forward_legacy(
input_, states, offset)
- code_list = []
- for k in range(num_codebooks):
- q_cdf = build_stable_quantized_cdf(
- probas[0, :, k, 0], decoder.total_range_bits,
- check=False)
- code = decoder.pull(q_cdf)
- if code is None:
- raise EOFError("Stream ended before expected.")
- code_list.append(code)
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
- device=coder_device)
- input_ = 1 + frame[:, :, t:t + 1]
-
- else:
- code_list = []
- for _ in range(num_codebooks):
- code = unpacker.pull()
- if code is None:
- raise EOFError("Stream ended before expected.")
- code_list.append(code)
- frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
- device=coder_device)
+ code_list = []
+ for k in range(num_codebooks):
+ q_cdf = build_stable_quantized_cdf(
+ probas[0, :, k, 0], decoder.total_range_bits,
+ check=False)
+ code = decoder.pull(q_cdf)
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
+ input_ = 1 + frame[:, :, t:t + 1]
+
+ else:
+ code_list = []
+ for _ in range(num_codebooks):
+ code = unpacker.pull()
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long,
+ device=coder_device)
except Exception:
if acv == 4:
diff --git a/encodec/torch_ext.py b/encodec/torch_ext.py
new file mode 100644
index 0000000..831e588
--- /dev/null
+++ b/encodec/torch_ext.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+import os
+import sys
+import threading
+from pathlib import Path
+from typing import Optional
+
+from torch.utils.cpp_extension import load
+
+_LOCK = threading.Lock()
+_MODULE = None
+_LOAD_ERROR: Optional[Exception] = None
+
+
+def _env_bool(name: str, default: bool) -> bool:
+ value = os.getenv(name)
+ if value is None:
+ return default
+ return value.lower() in {"1", "true", "yes", "on"}
+
+
+def enabled() -> bool:
+ return _env_bool("ENCODEC_TORCH_EXT", False)
+
+
+def load_extension():
+ global _MODULE, _LOAD_ERROR
+ if _MODULE is not None:
+ return _MODULE
+ if _LOAD_ERROR is not None:
+ raise _LOAD_ERROR
+
+ with _LOCK:
+ if _MODULE is not None:
+ return _MODULE
+ if _LOAD_ERROR is not None:
+ raise _LOAD_ERROR
+
+ repo_root = Path(__file__).resolve().parents[1]
+ source = repo_root / "native" / "encodec_torch_ext" / "encodec_torch_ext.cpp"
+ build_dir = repo_root / "native" / "encodec_torch_ext" / "build"
+ build_dir.mkdir(parents=True, exist_ok=True)
+ os.environ["PATH"] = f"{Path(sys.executable).parent}:{os.environ.get('PATH', '')}"
+
+ try:
+ _MODULE = load(
+ name="encodec_torch_ext",
+ sources=[str(source)],
+ build_directory=str(build_dir),
+ extra_cflags=["-O3", "-std=c++17"],
+ verbose=_env_bool("ENCODEC_TORCH_EXT_VERBOSE", False),
+ )
+ return _MODULE
+ except Exception as exc: # pragma: no cover - build failures are environment-specific.
+ _LOAD_ERROR = exc
+ raise
diff --git a/native/encodec_ac/Cargo.lock b/native/encodec_ac/Cargo.lock
new file mode 100644
index 0000000..3bf8871
--- /dev/null
+++ b/native/encodec_ac/Cargo.lock
@@ -0,0 +1,229 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 4
+
+[[package]]
+name = "autocfg"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
+
+[[package]]
+name = "encodec_native"
+version = "0.1.0"
+dependencies = [
+ "numpy",
+ "pyo3",
+]
+
+[[package]]
+name = "heck"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
+
+[[package]]
+name = "libc"
+version = "0.2.184"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "48f5d2a454e16a5ea0f4ced81bd44e4cfc7bd3a507b61887c99fd3538b28e4af"
+
+[[package]]
+name = "matrixmultiply"
+version = "0.3.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08"
+dependencies = [
+ "autocfg",
+ "rawpointer",
+]
+
+[[package]]
+name = "ndarray"
+version = "0.17.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d"
+dependencies = [
+ "matrixmultiply",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "portable-atomic",
+ "portable-atomic-util",
+ "rawpointer",
+]
+
+[[package]]
+name = "num-complex"
+version = "0.4.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-integer"
+version = "0.1.46"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
+dependencies = [
+ "num-traits",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.19"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "numpy"
+version = "0.28.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2"
+dependencies = [
+ "libc",
+ "ndarray",
+ "num-complex",
+ "num-integer",
+ "num-traits",
+ "pyo3",
+ "pyo3-build-config",
+ "rustc-hash",
+]
+
+[[package]]
+name = "once_cell"
+version = "1.21.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50"
+
+[[package]]
+name = "portable-atomic"
+version = "1.13.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49"
+
+[[package]]
+name = "portable-atomic-util"
+version = "0.2.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3"
+dependencies = [
+ "portable-atomic",
+]
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.106"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "pyo3"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "91fd8e38a3b50ed1167fb981cd6fd60147e091784c427b8f7183a7ee32c31c12"
+dependencies = [
+ "libc",
+ "once_cell",
+ "portable-atomic",
+ "pyo3-build-config",
+ "pyo3-ffi",
+ "pyo3-macros",
+]
+
+[[package]]
+name = "pyo3-build-config"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e368e7ddfdeb98c9bca7f8383be1648fd84ab466bf2bc015e94008db6d35611e"
+dependencies = [
+ "target-lexicon",
+]
+
+[[package]]
+name = "pyo3-ffi"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f29e10af80b1f7ccaf7f69eace800a03ecd13e883acfacc1e5d0988605f651e"
+dependencies = [
+ "libc",
+ "pyo3-build-config",
+]
+
+[[package]]
+name = "pyo3-macros"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "df6e520eff47c45997d2fc7dd8214b25dd1310918bbb2642156ef66a67f29813"
+dependencies = [
+ "proc-macro2",
+ "pyo3-macros-backend",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "pyo3-macros-backend"
+version = "0.28.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c4cdc218d835738f81c2338f822078af45b4afdf8b2e33cbb5916f108b813acb"
+dependencies = [
+ "heck",
+ "proc-macro2",
+ "pyo3-build-config",
+ "quote",
+ "syn",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.45"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "rawpointer"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
+
+[[package]]
+name = "rustc-hash"
+version = "2.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe"
+
+[[package]]
+name = "syn"
+version = "2.0.117"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "target-lexicon"
+version = "0.13.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca"
+
+[[package]]
+name = "unicode-ident"
+version = "1.0.24"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75"
diff --git a/native/encodec_ac/Cargo.toml b/native/encodec_ac/Cargo.toml
new file mode 100644
index 0000000..3594334
--- /dev/null
+++ b/native/encodec_ac/Cargo.toml
@@ -0,0 +1,12 @@
+[package]
+name = "encodec_native"
+version = "0.1.0"
+edition = "2021"
+
+[lib]
+name = "encodec_native"
+crate-type = ["cdylib"]
+
+[dependencies]
+pyo3 = { version = "0.28.3", features = ["extension-module", "abi3-py310"] }
+numpy = "0.28.0"
diff --git a/native/encodec_ac/src/lib.rs b/native/encodec_ac/src/lib.rs
new file mode 100644
index 0000000..1dba545
--- /dev/null
+++ b/native/encodec_ac/src/lib.rs
@@ -0,0 +1,586 @@
+use numpy::{PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
+use pyo3::exceptions::{PyEOFError, PyValueError};
+use pyo3::prelude::*;
+use pyo3::types::PyAny;
+use pyo3::types::PyBytes;
+
+const EPS_EDGE: f64 = 9.094947017729282e-13;
+const EPS_PERTURB: f64 = 8.673617379884035e-19;
+
+fn require_torch_tensor_layout(
+ tensor: &Bound<'_, PyAny>,
+ expected_dtype: &str,
+ expected_dim: usize,
+) -> PyResult> {
+ let is_contiguous = tensor.call_method0("is_contiguous")?.extract::()?;
+ if !is_contiguous {
+ return Err(PyValueError::new_err("tensor must be contiguous"));
+ }
+
+ let device = tensor.getattr("device")?.getattr("type")?.extract::()?;
+ if device != "cpu" {
+ return Err(PyValueError::new_err("tensor must be on CPU"));
+ }
+
+ let dtype = tensor.getattr("dtype")?.str()?.to_str()?.to_owned();
+ if dtype != expected_dtype {
+ return Err(PyValueError::new_err(format!(
+ "tensor must have dtype {expected_dtype}, got {dtype}"
+ )));
+ }
+
+ let shape = tensor.getattr("shape")?.extract::>()?;
+ if shape.len() != expected_dim {
+ return Err(PyValueError::new_err(format!(
+ "tensor must be {expected_dim}D, got {}D",
+ shape.len()
+ )));
+ }
+ Ok(shape)
+}
+
+fn torch_f64_tensor_2d<'py>(tensor: &Bound<'py, PyAny>) -> PyResult<(usize, usize, &'py [f64])> {
+ let shape = require_torch_tensor_layout(tensor, "torch.float64", 2)?;
+ let n_bins = shape[0];
+ let n_cols = shape[1];
+ let ptr = tensor.call_method0("data_ptr")?.extract::()?;
+ let len = n_bins
+ .checked_mul(n_cols)
+ .ok_or_else(|| PyValueError::new_err("tensor shape is too large"))?;
+ let slice = unsafe { std::slice::from_raw_parts(ptr as *const f64, len) };
+ Ok((n_bins, n_cols, slice))
+}
+
+fn torch_i64_tensor_1d<'py>(tensor: &Bound<'py, PyAny>) -> PyResult<(usize, &'py [i64])> {
+ let shape = require_torch_tensor_layout(tensor, "torch.int64", 1)?;
+ let len = shape[0];
+ let ptr = tensor.call_method0("data_ptr")?.extract::()?;
+ let slice = unsafe { std::slice::from_raw_parts(ptr as *const i64, len) };
+ Ok((len, slice))
+}
+
+fn torch_i64_tensor_1d_mut<'py>(
+ tensor: &Bound<'py, PyAny>,
+) -> PyResult<(usize, &'py mut [i64])> {
+ let shape = require_torch_tensor_layout(tensor, "torch.int64", 1)?;
+ let len = shape[0];
+ let ptr = tensor.call_method0("data_ptr")?.extract::()?;
+ let slice = unsafe { std::slice::from_raw_parts_mut(ptr as *mut i64, len) };
+ Ok((len, slice))
+}
+
+fn counts_from_pdf_flat(pdf: &[f64], fp_scale: i64) -> Vec {
+ let mut out = Vec::with_capacity(pdf.len());
+ let scale = fp_scale as f64;
+ for (idx, value) in pdf.iter().enumerate() {
+ let mut x = value.max(0.0) * scale;
+ let frac = x - x.floor();
+ if frac <= EPS_EDGE || frac >= 1.0 - EPS_EDGE {
+ let sign = if idx % 2 == 0 { -1.0 } else { 1.0 };
+ x = (x + sign * EPS_PERTURB).max(0.0);
+ }
+ out.push(x.floor() as i64);
+ }
+ out
+}
+
+fn deterministic_cdf_multi_impl(
+ pdf: &[f64],
+ n_bins: usize,
+ n_cols: usize,
+ total_range_bits: u32,
+ fp_scale: i64,
+ min_range: i64,
+) -> PyResult> {
+ if n_bins == 0 || n_cols == 0 {
+ return Err(PyValueError::new_err("pdf_mat must be non-empty"));
+ }
+ if pdf.len() != n_bins * n_cols {
+ return Err(PyValueError::new_err("pdf_mat shape does not match buffer length"));
+ }
+
+ let total = 1_i64
+ .checked_shl(total_range_bits)
+ .ok_or_else(|| PyValueError::new_err("total_range_bits too large"))?;
+ let alloc = total - min_range * (n_bins as i64);
+ if alloc <= 0 {
+ return Err(PyValueError::new_err("invalid total_range_bits/min_range combination"));
+ }
+
+ let mut normalized = vec![0.0_f64; pdf.len()];
+ for col in 0..n_cols {
+ let mut sum = 0.0_f64;
+ for row in 0..n_bins {
+ let v = pdf[row * n_cols + col].max(0.0);
+ normalized[row * n_cols + col] = v;
+ sum += v;
+ }
+ if !sum.is_finite() || sum <= 0.0 {
+ for row in 0..n_bins {
+ normalized[row * n_cols + col] = 1.0;
+ }
+ }
+ }
+
+ let mut counts = counts_from_pdf_flat(&normalized, fp_scale);
+ for col in 0..n_cols {
+ let mut sum = 0_i64;
+ for row in 0..n_bins {
+ sum += counts[row * n_cols + col];
+ }
+ if sum <= 0 {
+ for row in 0..n_bins {
+ counts[row * n_cols + col] = 1;
+ }
+ }
+ }
+
+ let mut cdf = vec![0_i64; pdf.len()];
+ for col in 0..n_cols {
+ let mut num_sum = 0_i64;
+ for row in 0..n_bins {
+ num_sum += counts[row * n_cols + col];
+ }
+ if num_sum <= 0 {
+ return Err(PyValueError::new_err("invalid zero-count column"));
+ }
+
+ let mut base = vec![0_i64; n_bins];
+ let mut base_sum = 0_i64;
+ for row in 0..n_bins {
+ let num = counts[row * n_cols + col];
+ let value = (alloc * num) / num_sum;
+ base[row] = value;
+ base_sum += value;
+ }
+ let remainder = alloc - base_sum;
+ if remainder > 0 {
+ let mut order: Vec<(i64, usize)> = (0..n_bins)
+ .map(|row| {
+ let num = counts[row * n_cols + col];
+ let prio = (alloc * num) - (num_sum * base[row]);
+ let key = prio * ((n_bins as i64) + 1) - (row as i64);
+ (key, row)
+ })
+ .collect();
+ order.sort_by(|a, b| b.cmp(a));
+ for (_, row) in order.into_iter().take(remainder as usize) {
+ base[row] += 1;
+ }
+ }
+
+ let mut running = 0_i64;
+ for row in 0..n_bins {
+ running += base[row] + min_range;
+ cdf[row * n_cols + col] = running;
+ }
+ if running != total {
+ return Err(PyValueError::new_err("cdf sum mismatch"));
+ }
+ }
+ Ok(cdf)
+}
+
+struct BitWriter {
+ current_value: u64,
+ current_bits: u8,
+ bytes: Vec,
+}
+
+impl BitWriter {
+ fn new() -> Self {
+ Self {
+ current_value: 0,
+ current_bits: 0,
+ bytes: Vec::new(),
+ }
+ }
+
+ fn push_bit(&mut self, bit: u8) {
+ self.current_value += (bit as u64) << self.current_bits;
+ self.current_bits += 1;
+ while self.current_bits >= 8 {
+ let lower = (self.current_value & 0xff) as u8;
+ self.current_bits -= 8;
+ self.current_value >>= 8;
+ self.bytes.push(lower);
+ }
+ }
+
+ fn finish(mut self) -> Vec {
+ if self.current_bits > 0 {
+ self.bytes.push(self.current_value as u8);
+ self.current_value = 0;
+ self.current_bits = 0;
+ }
+ self.bytes
+ }
+}
+
+struct BitReader {
+ data: Vec,
+ offset: usize,
+ current_value: u64,
+ current_bits: u8,
+}
+
+impl BitReader {
+ fn new(data: Vec) -> Self {
+ Self {
+ data,
+ offset: 0,
+ current_value: 0,
+ current_bits: 0,
+ }
+ }
+
+ fn pull_bit(&mut self) -> Option {
+ while self.current_bits < 1 {
+ let byte = *self.data.get(self.offset)?;
+ self.offset += 1;
+ self.current_value += (byte as u64) << self.current_bits;
+ self.current_bits += 8;
+ }
+ let out = (self.current_value & 1) as u8;
+ self.current_value >>= 1;
+ self.current_bits -= 1;
+ Some(out)
+ }
+}
+
+#[pyclass]
+struct ArithmeticEncoder {
+ total_range_bits: u32,
+ low: u64,
+ high: u64,
+ max_bit: i32,
+ writer: BitWriter,
+}
+
+#[pymethods]
+impl ArithmeticEncoder {
+ #[new]
+ #[pyo3(signature = (total_range_bits = 24))]
+ fn new(total_range_bits: u32) -> PyResult {
+ if total_range_bits > 30 {
+ return Err(PyValueError::new_err("total_range_bits must be <= 30"));
+ }
+ Ok(Self {
+ total_range_bits,
+ low: 0,
+ high: 0,
+ max_bit: -1,
+ writer: BitWriter::new(),
+ })
+ }
+
+ fn push_pdf_symbols(
+ &mut self,
+ pdf_mat: PyReadonlyArray2,
+ symbols: Vec,
+ fp_scale: i64,
+ min_range: i64,
+ ) -> PyResult<()> {
+ let shape = pdf_mat.shape();
+ let n_bins = shape[0];
+ let n_cols = shape[1];
+ if symbols.len() != n_cols {
+ return Err(PyValueError::new_err("symbols length must match the pdf column count"));
+ }
+ let pdf = pdf_mat
+ .as_slice()
+ .map_err(|_| PyValueError::new_err("pdf_mat must be C-contiguous"))?;
+ let cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ self.total_range_bits,
+ fp_scale,
+ min_range,
+ )?;
+ for (col, symbol) in symbols.into_iter().enumerate() {
+ self.push_symbol(symbol, &cdf, n_bins, n_cols, col)?;
+ }
+ Ok(())
+ }
+
+ fn push_pdf_symbols_torch(
+ &mut self,
+ pdf_mat: &Bound<'_, PyAny>,
+ symbols: &Bound<'_, PyAny>,
+ fp_scale: i64,
+ min_range: i64,
+ ) -> PyResult<()> {
+ let (n_bins, n_cols, pdf) = torch_f64_tensor_2d(pdf_mat)?;
+ let (symbol_len, symbol_slice) = torch_i64_tensor_1d(symbols)?;
+ if symbol_len != n_cols {
+ return Err(PyValueError::new_err("symbols length must match the pdf column count"));
+ }
+ let cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ self.total_range_bits,
+ fp_scale,
+ min_range,
+ )?;
+ for (col, symbol) in symbol_slice.iter().enumerate() {
+ if *symbol < 0 {
+ return Err(PyValueError::new_err("symbols must be non-negative"));
+ }
+ self.push_symbol(*symbol as usize, &cdf, n_bins, n_cols, col)?;
+ }
+ Ok(())
+ }
+
+ fn finish<'py>(&mut self, py: Python<'py>) -> Bound<'py, PyBytes> {
+ while self.max_bit >= 0 {
+ let bit = ((self.low >> (self.max_bit as u32)) & 1) as u8;
+ self.writer.push_bit(bit);
+ self.max_bit -= 1;
+ }
+ let bytes = std::mem::replace(&mut self.writer, BitWriter::new()).finish();
+ PyBytes::new(py, &bytes)
+ }
+}
+
+impl ArithmeticEncoder {
+ fn delta(&self) -> u64 {
+ self.high - self.low + 1
+ }
+
+ fn flush_common_prefix(&mut self) {
+ while self.max_bit >= 0 {
+ let b1 = self.low >> (self.max_bit as u32);
+ let b2 = self.high >> (self.max_bit as u32);
+ if b1 == b2 {
+ self.low -= b1 << (self.max_bit as u32);
+ self.high -= b1 << (self.max_bit as u32);
+ self.max_bit -= 1;
+ self.writer.push_bit(b1 as u8);
+ } else {
+ break;
+ }
+ }
+ }
+
+ fn push_symbol(
+ &mut self,
+ symbol: usize,
+ cdf: &[i64],
+ n_bins: usize,
+ n_cols: usize,
+ col: usize,
+ ) -> PyResult<()> {
+ while self.delta() < (1_u64 << self.total_range_bits) {
+ self.low <<= 1;
+ self.high = (self.high << 1) | 1;
+ self.max_bit += 1;
+ }
+ if symbol >= n_bins {
+ return Err(PyValueError::new_err("symbol out of range"));
+ }
+ let total = 1_u64 << self.total_range_bits;
+ let rng = self.delta();
+ let cum_high = cdf[symbol * n_cols + col] as u64;
+ let cum_low = if symbol == 0 {
+ 0
+ } else {
+ cdf[(symbol - 1) * n_cols + col] as u64
+ };
+ let base = self.low;
+ self.low = base + (rng * cum_low) / total;
+ self.high = base + (rng * cum_high) / total - 1;
+ self.flush_common_prefix();
+ Ok(())
+ }
+}
+
+#[pyclass]
+struct ArithmeticDecoder {
+ total_range_bits: u32,
+ low: u64,
+ high: u64,
+ current: u64,
+ max_bit: i32,
+ reader: BitReader,
+}
+
+#[pymethods]
+impl ArithmeticDecoder {
+ #[new]
+ #[pyo3(signature = (data, total_range_bits = 24))]
+ fn new(data: &Bound<'_, PyBytes>, total_range_bits: u32) -> PyResult {
+ if total_range_bits > 30 {
+ return Err(PyValueError::new_err("total_range_bits must be <= 30"));
+ }
+ Ok(Self {
+ total_range_bits,
+ low: 0,
+ high: 0,
+ current: 0,
+ max_bit: -1,
+ reader: BitReader::new(data.as_bytes().to_vec()),
+ })
+ }
+
+ fn pull_symbols(
+ &mut self,
+ pdf_mat: PyReadonlyArray2,
+ fp_scale: i64,
+ min_range: i64,
+ ) -> PyResult> {
+ let shape = pdf_mat.shape();
+ let n_bins = shape[0];
+ let n_cols = shape[1];
+ let pdf = pdf_mat
+ .as_slice()
+ .map_err(|_| PyValueError::new_err("pdf_mat must be C-contiguous"))?;
+ let cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ self.total_range_bits,
+ fp_scale,
+ min_range,
+ )?;
+ let mut out = Vec::with_capacity(n_cols);
+ for col in 0..n_cols {
+ let symbol = self.pull_symbol(&cdf, n_bins, n_cols, col)?;
+ out.push(symbol);
+ }
+ Ok(out)
+ }
+
+ fn pull_symbols_into_torch(
+ &mut self,
+ pdf_mat: &Bound<'_, PyAny>,
+ out_symbols: &Bound<'_, PyAny>,
+ fp_scale: i64,
+ min_range: i64,
+ ) -> PyResult<()> {
+ let (n_bins, n_cols, pdf) = torch_f64_tensor_2d(pdf_mat)?;
+ let (out_len, out_slice) = torch_i64_tensor_1d_mut(out_symbols)?;
+ if out_len != n_cols {
+ return Err(PyValueError::new_err(
+ "output tensor length must match the pdf column count",
+ ));
+ }
+ let cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ self.total_range_bits,
+ fp_scale,
+ min_range,
+ )?;
+ for col in 0..n_cols {
+ let symbol = self.pull_symbol(&cdf, n_bins, n_cols, col)?;
+ out_slice[col] = symbol as i64;
+ }
+ Ok(())
+ }
+}
+
+impl ArithmeticDecoder {
+ fn delta(&self) -> u64 {
+ self.high - self.low + 1
+ }
+
+ fn flush_common_prefix(&mut self) {
+ while self.max_bit >= 0 {
+ let b1 = self.low >> (self.max_bit as u32);
+ let b2 = self.high >> (self.max_bit as u32);
+ if b1 == b2 {
+ self.low -= b1 << (self.max_bit as u32);
+ self.high -= b1 << (self.max_bit as u32);
+ self.current -= b1 << (self.max_bit as u32);
+ self.max_bit -= 1;
+ } else {
+ break;
+ }
+ }
+ }
+
+ fn pull_symbol(
+ &mut self,
+ cdf: &[i64],
+ n_bins: usize,
+ n_cols: usize,
+ col: usize,
+ ) -> PyResult {
+ while self.delta() < (1_u64 << self.total_range_bits) {
+ let bit = self
+ .reader
+ .pull_bit()
+ .ok_or_else(|| PyEOFError::new_err("stream exhausted"))? as u64;
+ self.low <<= 1;
+ self.high = (self.high << 1) | 1;
+ self.current = (self.current << 1) | bit;
+ self.max_bit += 1;
+ }
+
+ let total = 1_u64 << self.total_range_bits;
+ let rng = self.delta();
+ let target = (((self.current - self.low + 1) * total) - 1) / rng;
+ let mut lo = 0usize;
+ let mut hi = n_bins;
+ while lo < hi {
+ let mid = (lo + hi) / 2;
+ let value = cdf[mid * n_cols + col] as u64;
+ if target < value {
+ hi = mid;
+ } else {
+ lo = mid + 1;
+ }
+ }
+ if lo >= n_bins {
+ return Err(PyValueError::new_err("binary search failed"));
+ }
+ let symbol = lo;
+ let cum_high = cdf[symbol * n_cols + col] as u64;
+ let cum_low = if symbol == 0 {
+ 0
+ } else {
+ cdf[(symbol - 1) * n_cols + col] as u64
+ };
+ let base = self.low;
+ self.low = base + (rng * cum_low) / total;
+ self.high = base + (rng * cum_high) / total - 1;
+ self.flush_common_prefix();
+ Ok(symbol)
+ }
+}
+
+#[pyfunction]
+fn deterministic_cdf_multi<'py>(
+ py: Python<'py>,
+ pdf_mat: PyReadonlyArray2,
+ total_range_bits: u32,
+ fp_scale: i64,
+ min_range: i64,
+) -> PyResult>> {
+ let shape = pdf_mat.shape();
+ let n_bins = shape[0];
+ let n_cols = shape[1];
+ let pdf = pdf_mat
+ .as_slice()
+ .map_err(|_| PyValueError::new_err("pdf_mat must be C-contiguous"))?;
+ let cdf = deterministic_cdf_multi_impl(pdf, n_bins, n_cols, total_range_bits, fp_scale, min_range)?;
+ let rows: Vec> = (0..n_bins)
+ .map(|row| {
+ (0..n_cols)
+ .map(|col| cdf[row * n_cols + col])
+ .collect::>()
+ })
+ .collect();
+ Ok(PyArray2::from_vec2(py, &rows)?)
+}
+
+#[pymodule]
+fn encodec_native(m: &Bound<'_, PyModule>) -> PyResult<()> {
+ m.add_class::()?;
+ m.add_class::()?;
+ m.add_function(wrap_pyfunction!(deterministic_cdf_multi, m)?)?;
+ Ok(())
+}
diff --git a/native/encodec_torch_ext/encodec_torch_ext.cpp b/native/encodec_torch_ext/encodec_torch_ext.cpp
new file mode 100644
index 0000000..566fdea
--- /dev/null
+++ b/native/encodec_torch_ext/encodec_torch_ext.cpp
@@ -0,0 +1,411 @@
+#include
+
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace py = pybind11;
+
+namespace {
+
+constexpr double kEpsEdge = 9.094947017729282e-13;
+constexpr double kEpsPerturb = 8.673617379884035e-19;
+
+void check_pdf_mat(const torch::Tensor& pdf_mat) {
+ TORCH_CHECK(pdf_mat.device().is_cpu(), "pdf_mat must be on CPU");
+ TORCH_CHECK(pdf_mat.scalar_type() == torch::kFloat64, "pdf_mat must have dtype torch.float64");
+ TORCH_CHECK(pdf_mat.dim() == 2, "pdf_mat must be 2D");
+ TORCH_CHECK(pdf_mat.is_contiguous(), "pdf_mat must be contiguous");
+}
+
+void check_symbol_tensor(const torch::Tensor& symbols, int64_t expected_len, const char* name) {
+ TORCH_CHECK(symbols.device().is_cpu(), name, " must be on CPU");
+ TORCH_CHECK(symbols.scalar_type() == torch::kLong, name, " must have dtype torch.int64");
+ TORCH_CHECK(symbols.dim() == 1, name, " must be 1D");
+ TORCH_CHECK(symbols.is_contiguous(), name, " must be contiguous");
+ TORCH_CHECK(symbols.numel() == expected_len, name, " length must match the pdf column count");
+}
+
+std::vector counts_from_pdf_flat(const double* pdf, int64_t len, int64_t fp_scale) {
+ std::vector out;
+ out.reserve(static_cast(len));
+ const double scale = static_cast(fp_scale);
+ for (int64_t idx = 0; idx < len; ++idx) {
+ double x = std::max(pdf[idx], 0.0) * scale;
+ const double frac = x - std::floor(x);
+ if (frac <= kEpsEdge || frac >= 1.0 - kEpsEdge) {
+ const double sign = (idx % 2 == 0) ? -1.0 : 1.0;
+ x = std::max(x + sign * kEpsPerturb, 0.0);
+ }
+ out.push_back(static_cast(std::floor(x)));
+ }
+ return out;
+}
+
+std::vector deterministic_cdf_multi_impl(
+ const double* pdf,
+ int64_t n_bins,
+ int64_t n_cols,
+ int64_t total_range_bits,
+ int64_t fp_scale,
+ int64_t min_range
+) {
+ TORCH_CHECK(n_bins > 0 && n_cols > 0, "pdf_mat must be non-empty");
+ TORCH_CHECK(total_range_bits >= 0 && total_range_bits <= 30, "total_range_bits must be between 0 and 30");
+
+ const int64_t total = int64_t{1} << total_range_bits;
+ const int64_t alloc = total - min_range * n_bins;
+ TORCH_CHECK(alloc > 0, "invalid total_range_bits/min_range combination");
+
+ const int64_t len = n_bins * n_cols;
+ std::vector normalized(static_cast(len), 0.0);
+ for (int64_t col = 0; col < n_cols; ++col) {
+ double sum = 0.0;
+ for (int64_t row = 0; row < n_bins; ++row) {
+ const double value = std::max(pdf[row * n_cols + col], 0.0);
+ normalized[static_cast(row * n_cols + col)] = value;
+ sum += value;
+ }
+ if (!std::isfinite(sum) || sum <= 0.0) {
+ for (int64_t row = 0; row < n_bins; ++row) {
+ normalized[static_cast(row * n_cols + col)] = 1.0;
+ }
+ }
+ }
+
+ std::vector counts = counts_from_pdf_flat(normalized.data(), len, fp_scale);
+ for (int64_t col = 0; col < n_cols; ++col) {
+ int64_t sum = 0;
+ for (int64_t row = 0; row < n_bins; ++row) {
+ sum += counts[static_cast(row * n_cols + col)];
+ }
+ if (sum <= 0) {
+ for (int64_t row = 0; row < n_bins; ++row) {
+ counts[static_cast(row * n_cols + col)] = 1;
+ }
+ }
+ }
+
+ std::vector cdf(static_cast(len), 0);
+ for (int64_t col = 0; col < n_cols; ++col) {
+ int64_t num_sum = 0;
+ for (int64_t row = 0; row < n_bins; ++row) {
+ num_sum += counts[static_cast(row * n_cols + col)];
+ }
+ TORCH_CHECK(num_sum > 0, "invalid zero-count column");
+
+ std::vector base(static_cast(n_bins), 0);
+ int64_t base_sum = 0;
+ for (int64_t row = 0; row < n_bins; ++row) {
+ const int64_t num = counts[static_cast(row * n_cols + col)];
+ const int64_t value = (alloc * num) / num_sum;
+ base[static_cast(row)] = value;
+ base_sum += value;
+ }
+
+ const int64_t remainder = alloc - base_sum;
+ if (remainder > 0) {
+ std::vector> order;
+ order.reserve(static_cast(n_bins));
+ for (int64_t row = 0; row < n_bins; ++row) {
+ const int64_t num = counts[static_cast(row * n_cols + col)];
+ const int64_t prio = (alloc * num) - (num_sum * base[static_cast(row)]);
+ const int64_t key = prio * (n_bins + 1) - row;
+ order.emplace_back(key, row);
+ }
+ std::sort(order.begin(), order.end(), std::greater<>());
+ for (int64_t idx = 0; idx < remainder; ++idx) {
+ base[static_cast(order[static_cast(idx)].second)] += 1;
+ }
+ }
+
+ int64_t running = 0;
+ for (int64_t row = 0; row < n_bins; ++row) {
+ running += base[static_cast(row)] + min_range;
+ cdf[static_cast(row * n_cols + col)] = running;
+ }
+ TORCH_CHECK(running == total, "cdf sum mismatch");
+ }
+
+ return cdf;
+}
+
+class BitWriter {
+public:
+ void push_bit(uint8_t bit) {
+ current_value_ += static_cast(bit) << current_bits_;
+ ++current_bits_;
+ while (current_bits_ >= 8) {
+ const auto lower = static_cast(current_value_ & 0xff);
+ current_bits_ -= 8;
+ current_value_ >>= 8;
+ bytes_.push_back(lower);
+ }
+ }
+
+ std::string finish() {
+ if (current_bits_ > 0) {
+ bytes_.push_back(static_cast(current_value_));
+ current_value_ = 0;
+ current_bits_ = 0;
+ }
+ return std::string(bytes_.begin(), bytes_.end());
+ }
+
+private:
+ uint64_t current_value_ = 0;
+ uint8_t current_bits_ = 0;
+ std::vector bytes_;
+};
+
+class BitReader {
+public:
+ explicit BitReader(std::vector data)
+ : data_(std::move(data)) {}
+
+ bool pull_bit(uint8_t& bit) {
+ while (current_bits_ < 1) {
+ if (offset_ >= data_.size()) {
+ return false;
+ }
+ const auto byte = data_[offset_++];
+ current_value_ += static_cast(byte) << current_bits_;
+ current_bits_ += 8;
+ }
+ bit = static_cast(current_value_ & 1);
+ current_value_ >>= 1;
+ --current_bits_;
+ return true;
+ }
+
+private:
+ std::vector data_;
+ size_t offset_ = 0;
+ uint64_t current_value_ = 0;
+ uint8_t current_bits_ = 0;
+};
+
+std::vector bytes_to_vec(const py::bytes& data) {
+ const std::string raw = data;
+ return std::vector(raw.begin(), raw.end());
+}
+
+class ArithmeticEncoder {
+public:
+ explicit ArithmeticEncoder(int64_t total_range_bits = 24)
+ : total_range_bits_(total_range_bits) {
+ TORCH_CHECK(total_range_bits_ <= 30, "total_range_bits must be <= 30");
+ }
+
+ void push_pdf_symbols_torch(
+ const torch::Tensor& pdf_mat,
+ const torch::Tensor& symbols,
+ int64_t fp_scale,
+ int64_t min_range
+ ) {
+ check_pdf_mat(pdf_mat);
+ const auto n_bins = pdf_mat.size(0);
+ const auto n_cols = pdf_mat.size(1);
+ check_symbol_tensor(symbols, n_cols, "symbols");
+
+ const auto* pdf = pdf_mat.data_ptr();
+ const auto* symbol_ptr = symbols.data_ptr();
+ const auto cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ total_range_bits_,
+ fp_scale,
+ min_range
+ );
+ for (int64_t col = 0; col < n_cols; ++col) {
+ TORCH_CHECK(symbol_ptr[col] >= 0, "symbols must be non-negative");
+ push_symbol(static_cast(symbol_ptr[col]), cdf, n_bins, n_cols, col);
+ }
+ }
+
+ py::bytes finish() {
+ while (max_bit_ >= 0) {
+ const auto bit = static_cast((low_ >> max_bit_) & 1);
+ writer_.push_bit(bit);
+ --max_bit_;
+ }
+ return py::bytes(writer_.finish());
+ }
+
+private:
+ uint64_t delta() const {
+ return high_ - low_ + 1;
+ }
+
+ void flush_common_prefix() {
+ while (max_bit_ >= 0) {
+ const auto b1 = low_ >> max_bit_;
+ const auto b2 = high_ >> max_bit_;
+ if (b1 == b2) {
+ low_ -= b1 << max_bit_;
+ high_ -= b1 << max_bit_;
+ --max_bit_;
+ writer_.push_bit(static_cast(b1));
+ } else {
+ break;
+ }
+ }
+ }
+
+ void push_symbol(
+ size_t symbol,
+ const std::vector& cdf,
+ int64_t n_bins,
+ int64_t n_cols,
+ int64_t col
+ ) {
+ while (delta() < (uint64_t{1} << total_range_bits_)) {
+ low_ <<= 1;
+ high_ = (high_ << 1) | 1;
+ ++max_bit_;
+ }
+ TORCH_CHECK(static_cast(symbol) < n_bins, "symbol out of range");
+ const auto total = uint64_t{1} << total_range_bits_;
+ const auto rng = delta();
+ const auto cum_high = static_cast(cdf[symbol * static_cast(n_cols) + static_cast(col)]);
+ const auto cum_low = symbol == 0
+ ? 0
+ : static_cast(cdf[(symbol - 1) * static_cast(n_cols) + static_cast(col)]);
+ const auto base = low_;
+ low_ = base + (rng * cum_low) / total;
+ high_ = base + (rng * cum_high) / total - 1;
+ flush_common_prefix();
+ }
+
+ int64_t total_range_bits_;
+ uint64_t low_ = 0;
+ uint64_t high_ = 0;
+ int64_t max_bit_ = -1;
+ BitWriter writer_;
+};
+
+class ArithmeticDecoder {
+public:
+ explicit ArithmeticDecoder(py::bytes data, int64_t total_range_bits = 24)
+ : total_range_bits_(total_range_bits),
+ reader_(bytes_to_vec(data)) {
+ TORCH_CHECK(total_range_bits_ <= 30, "total_range_bits must be <= 30");
+ }
+
+ void pull_symbols_into_torch(
+ const torch::Tensor& pdf_mat,
+ torch::Tensor out_symbols,
+ int64_t fp_scale,
+ int64_t min_range
+ ) {
+ check_pdf_mat(pdf_mat);
+ const auto n_bins = pdf_mat.size(0);
+ const auto n_cols = pdf_mat.size(1);
+ check_symbol_tensor(out_symbols, n_cols, "out_symbols");
+
+ const auto* pdf = pdf_mat.data_ptr();
+ auto* out_ptr = out_symbols.data_ptr();
+ const auto cdf = deterministic_cdf_multi_impl(
+ pdf,
+ n_bins,
+ n_cols,
+ total_range_bits_,
+ fp_scale,
+ min_range
+ );
+ for (int64_t col = 0; col < n_cols; ++col) {
+ out_ptr[col] = static_cast(pull_symbol(cdf, n_bins, n_cols, col));
+ }
+ }
+
+private:
+ uint64_t delta() const {
+ return high_ - low_ + 1;
+ }
+
+ void flush_common_prefix() {
+ while (max_bit_ >= 0) {
+ const auto b1 = low_ >> max_bit_;
+ const auto b2 = high_ >> max_bit_;
+ if (b1 == b2) {
+ low_ -= b1 << max_bit_;
+ high_ -= b1 << max_bit_;
+ current_ -= b1 << max_bit_;
+ --max_bit_;
+ } else {
+ break;
+ }
+ }
+ }
+
+ size_t pull_symbol(
+ const std::vector& cdf,
+ int64_t n_bins,
+ int64_t n_cols,
+ int64_t col
+ ) {
+ while (delta() < (uint64_t{1} << total_range_bits_)) {
+ uint8_t bit = 0;
+ TORCH_CHECK(reader_.pull_bit(bit), "stream exhausted");
+ low_ <<= 1;
+ high_ = (high_ << 1) | 1;
+ current_ = (current_ << 1) | static_cast(bit);
+ ++max_bit_;
+ }
+
+ const auto total = uint64_t{1} << total_range_bits_;
+ const auto rng = delta();
+ const auto target = (((current_ - low_ + 1) * total) - 1) / rng;
+
+ int64_t lo = 0;
+ int64_t hi = n_bins;
+ while (lo < hi) {
+ const auto mid = (lo + hi) / 2;
+ const auto value = static_cast(cdf[mid * n_cols + col]);
+ if (target < value) {
+ hi = mid;
+ } else {
+ lo = mid + 1;
+ }
+ }
+ TORCH_CHECK(lo < n_bins, "binary search failed");
+
+ const auto symbol = static_cast(lo);
+ const auto cum_high = static_cast(cdf[symbol * static_cast(n_cols) + static_cast(col)]);
+ const auto cum_low = symbol == 0
+ ? 0
+ : static_cast(cdf[(symbol - 1) * static_cast(n_cols) + static_cast(col)]);
+ const auto base = low_;
+ low_ = base + (rng * cum_low) / total;
+ high_ = base + (rng * cum_high) / total - 1;
+ flush_common_prefix();
+ return symbol;
+ }
+
+ int64_t total_range_bits_;
+ uint64_t low_ = 0;
+ uint64_t high_ = 0;
+ uint64_t current_ = 0;
+ int64_t max_bit_ = -1;
+ BitReader reader_;
+};
+
+} // namespace
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ py::class_(m, "ArithmeticEncoder")
+ .def(py::init(), py::arg("total_range_bits") = 24)
+ .def("push_pdf_symbols_torch", &ArithmeticEncoder::push_pdf_symbols_torch)
+ .def("finish", &ArithmeticEncoder::finish);
+
+ py::class_(m, "ArithmeticDecoder")
+ .def(py::init(), py::arg("data"), py::arg("total_range_bits") = 24)
+ .def("pull_symbols_into_torch", &ArithmeticDecoder::pull_symbols_into_torch);
+}
From d3d0776fecbb5c4101f6d4e9bf0ea5c33ae4c0a0 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 02:08:31 +0100
Subject: [PATCH 16/24] Speed up CUDA decode LM inference
---
encodec/compress.py | 6 +++++-
encodec/model.py | 51 +++++++++++++++++++++++++++++++++++++++++++++
2 files changed, 56 insertions(+), 1 deletion(-)
diff --git a/encodec/compress.py b/encodec/compress.py
index 3da7db6..ed803ec 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -1096,7 +1096,11 @@ def decompress_from_file(fo: tp.IO[bytes],
else:
decoder = ArithmeticDecoder(frame_fo, total_range_bits=ARITHMETIC_TOTAL_RANGE_BITS)
states = None
- offset = 0
+ offset: tp.Union[int, torch.Tensor]
+ if acv >= 3 and lm_device.type != "cpu":
+ offset = torch.zeros((), dtype=torch.long, device=lm_device)
+ else:
+ offset = 0
input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long,
device=lm_device if acv >= 3 else coder_device)
else:
diff --git a/encodec/model.py b/encodec/model.py
index a25a100..aa99ffc 100644
--- a/encodec/model.py
+++ b/encodec/model.py
@@ -48,10 +48,61 @@ def __init__(self, n_q: int = 32, card: int = 1024, dim: int = 200, dtype=torch.
self.linears = nn.ModuleList([nn.Linear(dim, card, dtype=dtype) for _ in range(n_q)])
self.logit_step = 1.0 / 64.0
self.tau = tau
+ self._stacked_cache_key: tp.Optional[tp.Tuple[tp.Tuple[int, ...], tp.Tuple[int, ...]]] = None
+ self._stacked_emb_weight: tp.Optional[torch.Tensor] = None
+ self._stacked_linear_weight: tp.Optional[torch.Tensor] = None
+ self._stacked_linear_bias: tp.Optional[torch.Tensor] = None
+ self._stacked_k_index: tp.Optional[torch.Tensor] = None
+
+ def _get_stacked_inference_params(self) -> tp.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ cache_key = (
+ tuple(emb.weight.data_ptr() for emb in self.emb),
+ tuple(linear.weight.data_ptr() for linear in self.linears),
+ )
+ if cache_key != self._stacked_cache_key:
+ self._stacked_emb_weight = torch.stack(
+ [emb.weight.detach() for emb in self.emb],
+ dim=0,
+ ).contiguous()
+ self._stacked_linear_weight = torch.stack(
+ [linear.weight.detach() for linear in self.linears],
+ dim=0,
+ ).contiguous()
+ self._stacked_linear_bias = torch.stack(
+ [linear.bias.detach() for linear in self.linears],
+ dim=0,
+ ).contiguous()
+ self._stacked_k_index = torch.arange(
+ self.n_q,
+ device=self._stacked_emb_weight.device,
+ ).view(self.n_q, 1, 1)
+ self._stacked_cache_key = cache_key
+ assert self._stacked_emb_weight is not None
+ assert self._stacked_linear_weight is not None
+ assert self._stacked_linear_bias is not None
+ assert self._stacked_k_index is not None
+ return (
+ self._stacked_emb_weight,
+ self._stacked_linear_weight,
+ self._stacked_linear_bias,
+ self._stacked_k_index,
+ )
def forward_logits(self, indices: torch.Tensor,
states: tp.Optional[tp.List[torch.Tensor]] = None, offset: int = 0):
B, K, T = indices.shape
+ if not self.training and not torch.is_grad_enabled():
+ emb_weight, linear_weight, linear_bias, k_index = self._get_stacked_inference_params()
+ emb_weight = emb_weight[:K]
+ linear_weight = linear_weight[:K]
+ linear_bias = linear_bias[:K]
+ picked = emb_weight[k_index[:K], indices.permute(1, 0, 2)]
+ input_ = picked.sum(dim=0)
+ out, states, offset = self.transformer(input_, states, offset)
+ logits = torch.einsum('btd,kod->bkto', out, linear_weight)
+ logits = logits + linear_bias.view(1, K, 1, self.card)
+ return logits.permute(0, 3, 1, 2), states, offset
+
input_ = sum([self.emb[k](indices[:, k]) for k in range(K)])
out, states, offset = self.transformer(input_, states, offset)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1).permute(0, 3, 1, 2)
From b17da3e8c6354222753eb2b229574bf457043f05 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 03:37:13 +0100
Subject: [PATCH 17/24] Document Ada benchmarks and decode tradeoffs
---
README.md | 27 ++++++++++++++++++++++++---
1 file changed, 24 insertions(+), 3 deletions(-)
diff --git a/README.md b/README.md
index f5c1f0c..a3b41d6 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,8 @@ This is the code for the EnCodec neural codec presented in [High Fidelity Neural
- A causal model operating at **24 kHz** on monophonic audio trained on a variety of audio data.
- A non-causal model operating at **48 kHz** on stereophonic audio trained on music-only data.
+Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode is slower than upstream.
+
The 24 kHz model supports 1.5, 3, 6, 12, and 24 kbps. The 48 kHz model supports 3, 6, 12, and 24 kbps. A pre-trained language model is available for each, enabling entropy coding that reduces bitstream size by up to 40% without further quality loss.
@@ -44,7 +46,7 @@ audio_codes = model(inputs["input_values"], inputs["padding_mask"]).audio_codes
## Precision and Robustness Improvements (wavey-ai fork)
-This fork extends the original EnCodec with a fully deterministic, cross-platform entropy coding path. The changes affect `encodec/compress.py` and `encodec/model.py` only — the neural network weights and audio quality are unchanged.
+This fork extends the original EnCodec with a fully deterministic, cross-platform entropy coding path plus optional native entropy-coder acceleration. The neural network weights remain unchanged.
### Bitstream version `acv=4`
@@ -60,7 +62,7 @@ A single corrupt byte damages at most one chunk. The decoder substitutes silence
The original LM entropy path was not deterministic across hardware (MPS, CUDA, CPU), causing cross-device decode failures. The deterministic path fixes this by:
-- Running the arithmetic coder and LM **always on CPU**, regardless of model device.
+- Running the arithmetic coder on CPU and keeping the encode-side LM on CPU by default. On CUDA decode, `ENCODEC_DECODE_LM_DEVICE=auto` can run the deterministic decode LM on the model device while preserving payload compatibility.
- Computing softmax in **float64** via a sequential cumsum denominator (`_stable_softmax`) rather than platform-native `torch.softmax`, which can differ by a ULP across devices.
- **Quantising logits** to a 1/128 grid before softmax. Small floating-point differences that do not change the quantised logit produce identical CDFs.
- Building the CDF from **integer floor counts** (`FP_SCALE = 65536`) with deterministic priority allocation for the residual.
@@ -75,6 +77,25 @@ Cross-device decode matrix (payloads encoded on Apple Silicon Mac):
| Mac MPS | Linux CPU | EOFError | ✓ |
| Mac MPS | Linux CUDA | EOFError | ✓ |
+### RTX 4000 Ada results
+
+Benchmarked on April 3, 2026 on a Linode `g2-gpu-rtx4000a1-s` instance (1x RTX 4000 Ada, 4 vCPU, Ubuntu 24.04) using `02 - Lori Asha - Westside` from the Lori Asha album premix, resampled to 48 kHz stereo, with `encodec_48khz`, `6 kbps`, and `use_lm=True`.
+
+| Repo / case | Encode | Encode x realtime | Decode | Decode x realtime | Result |
+|---|---:|---:|---:|---:|---|
+| Upstream `cuda -> cuda` | `99.07 s` | `2.10x` | `116.56 s` | `1.79x` | baseline |
+| Upstream `cuda -> cpu` | `98.73 s` | `2.11x` | fail | — | `RuntimeError('Binary search failed')` |
+| Upstream `cpu -> cpu` | `103.81 s` | `2.01x` | `108.91 s` | `1.91x` | baseline |
+| Fork `cuda -> cuda` | `13.09 s` | `15.93x` | `109.49 s` | `1.90x` | encode `7.57x` faster than upstream GPU, decode `1.06x` faster |
+| Fork `cuda -> cpu` | `12.94 s` | `16.11x` | `167.56 s` | `1.24x` | cross-architecture decode succeeds |
+| Fork `cpu -> cpu` | `35.22 s` | `5.92x` | `160.96 s` | `1.30x` | encode `2.95x` faster than upstream CPU, CPU decode slower |
+
+What this means in practice:
+
+- The biggest RTX win is encode throughput. On this full-length track, the fork cut GPU encode time from `99.07 s` to `13.09 s`.
+- GPU decode is modestly faster than upstream on the same Ada card, but the main portability win is that `cuda -> cpu` decode works at all.
+- CPU-only decode remains a trade-off: the deterministic cross-architecture path is slower than upstream's CPU decode, but it preserves compatibility across CPU, CUDA, and Apple Silicon payload handoffs.
+
### Critical bug fix: `_counts_from_pdf`
At `tau=1.0`, many softmax outputs are exactly `0.0` (float underflow of `exp(-large)`). These triggered a near-integer perturbation with an alternating sign. A negative sign on `x=0.0` gives `x = -ε`, and `floor(-ε) = -1`. A negative count makes the CDF non-monotonic, causing the arithmetic decoder to produce wrong symbols silently.
@@ -128,7 +149,7 @@ Benchmarked on 7 stereo 48 kHz music tracks (10 s clips), `encodec_48khz`:
| 24 kbps | CPU | 19.3 | 19.9% | 0.39× | 0.41× |
| 24 kbps | MPS | 19.3 | 19.9% | 0.47× | 0.40× |
-RTF < 1.0 means faster than real time. The LM runs on CPU in all cases; MPS accelerates model encode/decode but does not reduce LM inference time.
+RTF < 1.0 means faster than real time. On Apple Silicon the LM still runs on CPU by default, so MPS primarily accelerates model encode/decode. On CUDA decode, `ENCODEC_DECODE_LM_DEVICE=auto` can move deterministic LM decode to the GPU, which is what the Ada benchmark above measures.
### Chunk size tradeoffs
From 1301c363345ddf82eb26c464edc956549ce4478b Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 03:46:14 +0100
Subject: [PATCH 18/24] Restructure README for fork-first docs
---
README.md | 126 ++++++++++++++++++++++++++++++++++++++----------------
1 file changed, 88 insertions(+), 38 deletions(-)
diff --git a/README.md b/README.md
index a3b41d6..8931aa9 100644
--- a/README.md
+++ b/README.md
@@ -3,48 +3,16 @@


-This is the code for the EnCodec neural codec presented in [High Fidelity Neural Audio Compression](https://arxiv.org/pdf/2210.13438.pdf) [[abs]](https://arxiv.org/abs/2210.13438). We provide two multi-bandwidth models:
-
-- A causal model operating at **24 kHz** on monophonic audio trained on a variety of audio data.
-- A non-causal model operating at **48 kHz** on stereophonic audio trained on music-only data.
-
-Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode is slower than upstream.
-
-The 24 kHz model supports 1.5, 3, 6, 12, and 24 kbps. The 48 kHz model supports 3, 6, 12, and 24 kbps. A pre-trained language model is available for each, enabling entropy coding that reduces bitstream size by up to 40% without further quality loss.
-
-
-
-
-## Samples
-
-Samples including baselines are on [our sample page](https://ai.honu.io/papers/encodec/samples.html). A quick demo of 48 kHz music with entropy coding is available by clicking the thumbnail (original tracks by [Lucille Crew](https://open.spotify.com/artist/5eLv7rNfrf3IjMnK311ByP?si=X_zD9ackRRGjFP5Y6Q7Zng) and [Voyageur I](https://open.spotify.com/artist/21HymveeIhDcM4KDKeNLz0?si=4zXF8VpeQpeKR9QUIuck9Q)).
-
-
-
-
+## Index
-## 🤗 Transformers
+- [wavey-ai fork README](#wavey-ai-fork-readme)
+- [Upstream README](#upstream-readme)
-EnCodec is available in Transformers. See the [Transformers EnCodec docs](https://huggingface.co/docs/transformers/main/en/model_doc/encodec), and the [24 kHz](https://huggingface.co/facebook/encodec_24khz) and [48 kHz](https://huggingface.co/facebook/encodec_48khz) checkpoints on the Hub.
+## wavey-ai fork README
-```python
-from datasets import load_dataset, Audio
-from transformers import EncodecModel, AutoProcessor
-
-librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
-model = EncodecModel.from_pretrained("facebook/encodec_24khz")
-processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
-librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
-audio_sample = librispeech_dummy[0]["audio"]["array"]
-inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
-encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
-audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
-audio_codes = model(inputs["input_values"], inputs["padding_mask"]).audio_codes
-```
-
----
+Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode is slower than upstream.
-## Precision and Robustness Improvements (wavey-ai fork)
+### Precision and Robustness Improvements
This fork extends the original EnCodec with a fully deterministic, cross-platform entropy coding path plus optional native entropy-coder acceleration. The neural network weights remain unchanged.
@@ -151,6 +119,49 @@ Benchmarked on 7 stereo 48 kHz music tracks (10 s clips), `encodec_48khz`:
RTF < 1.0 means faster than real time. On Apple Silicon the LM still runs on CPU by default, so MPS primarily accelerates model encode/decode. On CUDA decode, `ENCODEC_DECODE_LM_DEVICE=auto` can move deterministic LM decode to the GPU, which is what the Ada benchmark above measures.
+### Backward compatibility and native fast path
+
+The repo remains backward-compatible by default:
+
+- If the Rust module is not installed, the codec falls back to the Python entropy path.
+- If the Torch C++ extension is not available, nothing breaks; it is off by default.
+- Legacy payloads (`acv < 3`) still decode through the legacy path.
+- Deterministic chunked payloads (`acv=4`) keep cross-device decode compatibility.
+
+Local fallback setup, no extra toolchain required:
+
+```bash
+pip install -e .
+```
+
+That is enough to run the codec locally in pure Python.
+
+Rust fast path, recommended:
+
+```bash
+pip install -e .
+pip install maturin
+cd native/encodec_ac
+maturin develop --release
+```
+
+This installs the `encodec_native` module into the active virtualenv. The runtime will pick it up automatically when available.
+
+Optional Torch C++ extension:
+
+- This remains opt-in and is off by default.
+- It requires a working C++ toolchain compatible with your local PyTorch install.
+- Enable it with `ENCODEC_TORCH_EXT=1`; the extension is JIT-built on first use.
+- In our testing, the Rust path is the main win. The Torch extension is optional, not required for the accelerated path.
+
+Useful runtime knobs:
+
+| Variable | Default | Meaning |
+|---|---|---|
+| `ENCODEC_NATIVE_AC` | `1` | Use the Rust arithmetic/CDF path when `encodec_native` is installed. |
+| `ENCODEC_TORCH_EXT` | `0` | Enable the optional Torch C++ extension. |
+| `ENCODEC_DECODE_LM_DEVICE` | `auto` | On CUDA decode, prefer GPU LM decode while preserving payload compatibility. |
+
### Chunk size tradeoffs
Per-segment chunk overhead is dominated by LM segmentation granularity, not the 8-byte header:
@@ -165,6 +176,45 @@ The default 1.0 s (matching the 48 kHz model segment) gives the best bitrate/iso
---
+## Upstream README
+
+This is the code for the EnCodec neural codec presented in [High Fidelity Neural Audio Compression](https://arxiv.org/pdf/2210.13438.pdf) [[abs]](https://arxiv.org/abs/2210.13438). We provide two multi-bandwidth models:
+
+- A causal model operating at **24 kHz** on monophonic audio trained on a variety of audio data.
+- A non-causal model operating at **48 kHz** on stereophonic audio trained on music-only data.
+
+The 24 kHz model supports 1.5, 3, 6, 12, and 24 kbps. The 48 kHz model supports 3, 6, 12, and 24 kbps. A pre-trained language model is available for each, enabling entropy coding that reduces bitstream size by up to 40% without further quality loss.
+
+
+
+
+## Samples
+
+Samples including baselines are on [our sample page](https://ai.honu.io/papers/encodec/samples.html). A quick demo of 48 kHz music with entropy coding is available by clicking the thumbnail (original tracks by [Lucille Crew](https://open.spotify.com/artist/5eLv7rNfrf3IjMnK311ByP?si=X_zD9ackRRGjFP5Y6Q7Zng) and [Voyageur I](https://open.spotify.com/artist/21HymveeIhDcM4KDKeNLz0?si=4zXF8VpeQpeKR9QUIuck9Q)).
+
+
+
+
+
+## 🤗 Transformers
+
+EnCodec is available in Transformers. See the [Transformers EnCodec docs](https://huggingface.co/docs/transformers/main/en/model_doc/encodec), and the [24 kHz](https://huggingface.co/facebook/encodec_24khz) and [48 kHz](https://huggingface.co/facebook/encodec_48khz) checkpoints on the Hub.
+
+```python
+from datasets import load_dataset, Audio
+from transformers import EncodecModel, AutoProcessor
+
+librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
+model = EncodecModel.from_pretrained("facebook/encodec_24khz")
+processor = AutoProcessor.from_pretrained("facebook/encodec_24khz")
+librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate))
+audio_sample = librispeech_dummy[0]["audio"]["array"]
+inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt")
+encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"])
+audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0]
+audio_codes = model(inputs["input_values"], inputs["padding_mask"]).audio_codes
+```
+
## Installation
Requires Python 3.8+ and a recent PyTorch (1.11+ recommended; 2.x tested).
From ebbb6d1bad55e78f29b2d449c5253fc7c722b1ef Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 03:46:50 +0100
Subject: [PATCH 19/24] Quantify CPU decode tradeoff in README
---
README.md | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/README.md b/README.md
index 8931aa9..cc70d01 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
## wavey-ai fork README
-Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode is slower than upstream.
+Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode was about `48%` slower than upstream on the tested full-song run (`160.96s` vs `108.91s`).
### Precision and Robustness Improvements
@@ -62,7 +62,7 @@ What this means in practice:
- The biggest RTX win is encode throughput. On this full-length track, the fork cut GPU encode time from `99.07 s` to `13.09 s`.
- GPU decode is modestly faster than upstream on the same Ada card, but the main portability win is that `cuda -> cpu` decode works at all.
-- CPU-only decode remains a trade-off: the deterministic cross-architecture path is slower than upstream's CPU decode, but it preserves compatibility across CPU, CUDA, and Apple Silicon payload handoffs.
+- CPU-only decode remains a trade-off: on the tested full-song run it was about `48%` slower than upstream (`160.96s` vs `108.91s`), but it preserves compatibility across CPU, CUDA, and Apple Silicon payload handoffs.
### Critical bug fix: `_counts_from_pdf`
From c7b089ca0cd41971f8fbbaf95ee31ef2b22f2b4d Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 19:23:00 +0100
Subject: [PATCH 20/24] Default CPU decode workers to auto headroom
---
README.md | 6 +-
encodec/compress.py | 297 ++++++++++++++++++++++++++++++--
scripts/bench_decode_payload.py | 86 +++++++++
3 files changed, 371 insertions(+), 18 deletions(-)
create mode 100644 scripts/bench_decode_payload.py
diff --git a/README.md b/README.md
index cc70d01..85f45a4 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,7 @@
## wavey-ai fork README
-Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, and slightly improved GPU decode. The trade-off is that CPU-only decode was about `48%` slower than upstream on the tested full-song run (`160.96s` vs `108.91s`).
+Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, slightly improved GPU decode, and now defaults CPU chunk decode workers to `available CPUs - 1`. On the tested 4-vCPU Linode box, that auto default picked `3` workers and cut full-song CPU decode from `170.18s` to `92.52s` (`1.84x` faster). Forcing `4` workers reached `82.33s` (`2.07x` faster), but the worker topology is deterministic-with-itself rather than hash-identical to the previous threaded single-process CPU path.
### Precision and Robustness Improvements
@@ -63,6 +63,7 @@ What this means in practice:
- The biggest RTX win is encode throughput. On this full-length track, the fork cut GPU encode time from `99.07 s` to `13.09 s`.
- GPU decode is modestly faster than upstream on the same Ada card, but the main portability win is that `cuda -> cpu` decode works at all.
- CPU-only decode remains a trade-off: on the tested full-song run it was about `48%` slower than upstream (`160.96s` vs `108.91s`), but it preserves compatibility across CPU, CUDA, and Apple Silicon payload handoffs.
+- CPU chunk decode now defaults to `available CPUs - 1` segment workers. On the same 4-vCPU Linode host, that default picked `3` workers and reduced the full-song CPU decode wall clock from `170.18s` to `92.52s`; forcing `4` workers reached `82.33s`. Set `ENCODEC_DECODE_SEGMENT_WORKERS=1` to restore the old single-process CPU decode topology.
### Critical bug fix: `_counts_from_pdf`
@@ -105,6 +106,7 @@ All settings are overridable via environment variables:
| `ENCODEC_AC_MIN_RANGE` | `2` | Minimum CDF range per symbol. Wider bins improve portability. |
| `ENCODEC_DETERMINISTIC_LM_DTYPE` | `float64` | LM weight dtype. `float64` is safer for cross-host determinism; `float32` is faster. |
| `ENCODEC_USE_NEAR_UNIFORM` | `0` | Enable near-uniform prior (off by default). |
+| `ENCODEC_DECODE_SEGMENT_WORKERS` | `0` | Auto CPU `acv=4` decode workers: `available CPUs - 1`, clamped to at least `1`. Set `1` for the old single-process CPU path. |
### Compression results
@@ -127,6 +129,7 @@ The repo remains backward-compatible by default:
- If the Torch C++ extension is not available, nothing breaks; it is off by default.
- Legacy payloads (`acv < 3`) still decode through the legacy path.
- Deterministic chunked payloads (`acv=4`) keep cross-device decode compatibility.
+- CPU `acv=4` decode now defaults to `available CPUs - 1` segment workers for throughput. Use `ENCODEC_DECODE_SEGMENT_WORKERS=1` if you need the older single-process CPU decode topology.
Local fallback setup, no extra toolchain required:
@@ -161,6 +164,7 @@ Useful runtime knobs:
| `ENCODEC_NATIVE_AC` | `1` | Use the Rust arithmetic/CDF path when `encodec_native` is installed. |
| `ENCODEC_TORCH_EXT` | `0` | Enable the optional Torch C++ extension. |
| `ENCODEC_DECODE_LM_DEVICE` | `auto` | On CUDA decode, prefer GPU LM decode while preserving payload compatibility. |
+| `ENCODEC_DECODE_SEGMENT_WORKERS` | `0` | CPU `acv=4` segment decode workers. `0` means `available CPUs - 1`; `1` restores the old single-process behavior. |
### Chunk size tradeoffs
diff --git a/encodec/compress.py b/encodec/compress.py
index ed803ec..7dfb7d0 100644
--- a/encodec/compress.py
+++ b/encodec/compress.py
@@ -98,6 +98,7 @@ def _env_choice(name: str, default: str, choices: tp.Set[str]) -> str:
DECODE_LM_DEVICE_MODE = _env_choice("ENCODEC_DECODE_LM_DEVICE", "auto", {"auto", "cpu", "model"})
LM_CHUNKED_DEFAULT = _env_bool("ENCODEC_LM_CHUNKED", True)
SEGMENT_WORKERS_DEFAULT = _env_int("ENCODEC_SEGMENT_WORKERS", 1)
+DECODE_SEGMENT_WORKERS_DEFAULT = _env_int("ENCODEC_DECODE_SEGMENT_WORKERS", 0)
NATIVE_AC_ENABLED = _env_bool("ENCODEC_NATIVE_AC", True)
TORCH_EXT_AC_ENABLED = _env_bool("ENCODEC_TORCH_EXT", False)
ARITHMETIC_TOTAL_RANGE_BITS = 24
@@ -443,6 +444,26 @@ def _parallel_segment_worker_count(
return max(1, min(int(configured), int(total_segments)))
+def _parallel_decode_segment_worker_count(
+ total_segments: int,
+ *,
+ model_device: torch.device,
+ acv: int,
+) -> int:
+ configured = DECODE_SEGMENT_WORKERS_DEFAULT
+ available_cpus = os.cpu_count() or 1
+ if configured <= 0:
+ configured = max(1, int(available_cpus) - 1)
+ if (
+ configured <= 1
+ or total_segments <= 1
+ or acv != 4
+ or model_device.type != 'cpu'
+ ):
+ return 1
+ return max(1, min(int(configured), int(total_segments), int(available_cpus)))
+
+
def _build_segment_batches(
wav: torch.Tensor,
offsets: tp.List[int],
@@ -461,9 +482,29 @@ def _build_segment_batches(
return batches
+def _build_decode_segment_batches(
+ segments: tp.List[tp.Tuple[int, int, int, bytes]],
+ worker_count: int,
+) -> tp.List[tp.List[tp.Tuple[int, int, int, bytes]]]:
+ batch_count = max(1, min(worker_count, len(segments)))
+ batch_size = int(math.ceil(len(segments) / batch_count))
+ batches: tp.List[tp.List[tp.Tuple[int, int, int, bytes]]] = []
+ for start in range(0, len(segments), batch_size):
+ batches.append(segments[start:start + batch_size])
+ return batches
+
+
def _init_parallel_worker_runtime() -> None:
torch.use_deterministic_algorithms(True)
torch.backends.mkldnn.enabled = False
+ try:
+ torch.set_num_threads(1)
+ except RuntimeError:
+ pass
+ try:
+ torch.set_num_interop_threads(1)
+ except RuntimeError:
+ pass
def _shutdown_parallel_executor() -> None:
@@ -490,10 +531,6 @@ def _get_parallel_executor(worker_count: int) -> concurrent.futures.ProcessPoolE
)
_PARALLEL_EXECUTOR_WORKERS = worker_count
return _PARALLEL_EXECUTOR
- try:
- torch.set_num_interop_threads(1)
- except RuntimeError:
- pass
def _get_parallel_worker_model(
@@ -650,6 +687,165 @@ def _encode_segment_batch_worker(
}
+def _decode_acv4_chunk_payload(
+ payload: bytes,
+ *,
+ model: EncodecModel,
+ model_device: torch.device,
+ coder_device: torch.device,
+ lm_device: torch.device,
+ num_codebooks: int,
+ use_lm: bool,
+ fp_scale: int,
+ min_range: int,
+ lm_tau: float,
+ lm: tp.Optional[tp.Any],
+ legacy_lm: tp.Optional[tp.Any],
+ this_len: int,
+) -> torch.Tensor:
+ frame_fo = io.BytesIO(payload)
+
+ if model.normalize:
+ scale_f, = struct.unpack('!f', binary._read_exactly(
+ frame_fo, struct.calcsize('!f')))
+ scale = torch.tensor(scale_f, device=coder_device).view(1)
+ else:
+ scale = None
+
+ if use_lm:
+ native_decoder = None
+ code_buf = None
+ decoder = None
+ native_module = _tensor_native_ac_module()
+ if native_module is not None:
+ native_decoder = native_module.ArithmeticDecoder(
+ frame_fo.read(),
+ ARITHMETIC_TOTAL_RANGE_BITS,
+ )
+ code_buf = torch.empty(num_codebooks, dtype=torch.long, device=coder_device)
+ else:
+ decoder = ArithmeticDecoder(frame_fo, total_range_bits=ARITHMETIC_TOTAL_RANGE_BITS)
+ states = None
+ offset: tp.Union[int, torch.Tensor]
+ if lm_device.type != "cpu":
+ offset = torch.zeros((), dtype=torch.long, device=lm_device)
+ else:
+ offset = 0
+ input_ = torch.zeros(1, num_codebooks, 1, dtype=torch.long, device=lm_device)
+ else:
+ unpacker = binary.BitUnpacker(model.bits_per_codebook, frame_fo)
+
+ frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
+ frame = torch.zeros(1, num_codebooks, frame_length, dtype=torch.long, device=coder_device)
+ try:
+ with torch.inference_mode():
+ for t in range(frame_length):
+ if use_lm:
+ assert lm is not None
+ logits_raw, states, offset = lm.forward_logits(input_, states, offset)
+ logits_q = _quantize_logits_(logits_raw / lm_tau, LOGIT_QSTEP)
+ probas = _softmax_or_uniform(logits_q, dim=1)
+ pdf_mat = probas[0, :, :, 0].to(coder_device)
+ if native_decoder is not None:
+ assert code_buf is not None
+ native_decoder.pull_symbols_into_torch(
+ pdf_mat.detach().contiguous(),
+ code_buf,
+ fp_scale,
+ min_range,
+ )
+ frame[0, :, t] = code_buf
+ input_ = 1 + code_buf.view(1, num_codebooks, 1).to(lm_device)
+ else:
+ assert decoder is not None
+ cdf_mat = _deterministic_cdf_multi(
+ pdf_mat,
+ decoder.total_range_bits,
+ fp_scale=fp_scale,
+ min_range=min_range,
+ check=False,
+ )
+ cdf_cols = cdf_mat.t().contiguous()
+ code_list = []
+ for k in range(num_codebooks):
+ code = decoder.pull(cdf_cols[k])
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
+ input_ = (1 + frame[:, :, t:t + 1]).to(lm_device)
+ elif legacy_lm is not None:
+ assert False, "legacy LM is not expected for acv4 chunk decode"
+ else:
+ code_list = []
+ for _ in range(num_codebooks):
+ code = unpacker.pull()
+ if code is None:
+ raise EOFError("Stream ended before expected.")
+ code_list.append(code)
+ frame[0, :, t] = torch.tensor(code_list, dtype=torch.long, device=coder_device)
+ except Exception:
+ return torch.zeros(1, model.channels, this_len, device=model_device)
+
+ encoded_frame = (
+ frame.to(model_device),
+ None if scale is None else scale.to(model_device),
+ )
+ with torch.inference_mode():
+ return model._decode_frame(encoded_frame)[..., :this_len]
+
+
+def _decode_segment_batch_worker(
+ model_name: str,
+ num_codebooks: int,
+ use_lm: bool,
+ lm_tau: float,
+ fp_scale: int,
+ min_range: int,
+ batch: tp.List[tp.Tuple[int, int, int, bytes]],
+) -> dict:
+ _init_parallel_worker_runtime()
+ model = _get_decode_model(model_name, 'cpu')
+ model_device = torch.device('cpu')
+ coder_device = torch.device('cpu')
+ lm_device = _select_decode_lm_device(
+ model_device=model_device,
+ coder_device=coder_device,
+ acv=4,
+ )
+ lm, legacy_lm = _get_decode_lms(
+ model,
+ model_name=model_name,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ use_lm=use_lm,
+ acv=4,
+ lm_tau=lm_tau,
+ )
+
+ segments: tp.List[tp.Tuple[int, int, torch.Tensor]] = []
+ for segment_index, offset_samples, this_len, payload in batch:
+ decoded = _decode_acv4_chunk_payload(
+ payload,
+ model=model,
+ model_device=model_device,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ num_codebooks=num_codebooks,
+ use_lm=use_lm,
+ fp_scale=fp_scale,
+ min_range=min_range,
+ lm_tau=lm_tau,
+ lm=lm,
+ legacy_lm=legacy_lm,
+ this_len=this_len,
+ ).cpu()
+ segments.append((int(segment_index), int(offset_samples), decoded))
+ return {
+ 'segments': segments,
+ }
+
+
atexit.register(_shutdown_parallel_executor)
def _write_chunk(fo: tp.IO[bytes], payload: bytes) -> None:
@@ -1058,22 +1254,97 @@ def decompress_from_file(fo: tp.IO[bytes],
segment_length = model.segment_length or audio_length
segment_stride = model.segment_stride or audio_length
+ offsets = list(range(0, audio_length, segment_stride))
decoded_frames: tp.List[torch.Tensor] = []
frames: tp.List[EncodedFrame] = []
- for offset_samples in range(0, audio_length, segment_stride):
+ parallel_decode_workers = _parallel_decode_segment_worker_count(
+ len(offsets),
+ model_device=model_device,
+ acv=acv,
+ )
+ if parallel_decode_workers > 1:
+ decoded_frames = [torch.zeros(0)] * len(offsets)
+ decodable_segments: tp.List[tp.Tuple[int, int, int, bytes]] = []
+ for segment_index, offset_samples in enumerate(offsets, start=1):
+ this_len = min(audio_length - offset_samples, segment_length)
+ try:
+ payload = _read_chunk_payload(fo)
+ except Exception:
+ decoded_frames[segment_index - 1] = torch.zeros(
+ 1,
+ model.channels,
+ this_len,
+ device=model_device,
+ )
+ continue
+ decodable_segments.append((segment_index, int(offset_samples), this_len, payload))
+
+ if decodable_segments:
+ batches = _build_decode_segment_batches(decodable_segments, parallel_decode_workers)
+ ordered_results: tp.List[dict] = []
+ executor = _get_parallel_executor(parallel_decode_workers)
+ try:
+ futures = [
+ executor.submit(
+ _decode_segment_batch_worker,
+ model_name,
+ num_codebooks,
+ bool(use_lm),
+ float(lm_tau),
+ int(fp_scale),
+ int(min_range),
+ batch,
+ )
+ for batch in batches
+ ]
+ for future in concurrent.futures.as_completed(futures):
+ ordered_results.append(future.result())
+ except BrokenProcessPool:
+ _shutdown_parallel_executor()
+ raise
+
+ for result in sorted(ordered_results, key=lambda item: item['segments'][0][0]):
+ for segment_index, _offset_samples, decoded in result['segments']:
+ decoded_frames[segment_index - 1] = decoded.to(model_device)
+
+ if model.segment_length is None:
+ wav = decoded_frames[0]
+ else:
+ wav = _linear_overlap_add(decoded_frames, model.segment_stride or 1)
+ return wav[0, :, :audio_length], model.sample_rate
+
+ for offset_samples in offsets:
this_len = min(audio_length - offset_samples, segment_length)
frame_length = int(math.ceil(this_len * model.frame_rate / model.sample_rate))
frame_fo = fo
if acv == 4:
try:
- frame_fo = io.BytesIO(_read_chunk_payload(fo))
+ payload = _read_chunk_payload(fo)
except Exception:
# Corrupt chunk → substitute silence and continue.
decoded_frames.append(
torch.zeros(1, model.channels, this_len, device=model_device))
continue
+ decoded_frames.append(
+ _decode_acv4_chunk_payload(
+ payload,
+ model=model,
+ model_device=model_device,
+ coder_device=coder_device,
+ lm_device=lm_device,
+ num_codebooks=num_codebooks,
+ use_lm=use_lm,
+ fp_scale=fp_scale,
+ min_range=min_range,
+ lm_tau=lm_tau,
+ lm=lm,
+ legacy_lm=legacy_lm,
+ this_len=this_len,
+ )
+ )
+ continue
if model.normalize:
scale_f, = struct.unpack('!f', binary._read_exactly(
@@ -1086,7 +1357,7 @@ def decompress_from_file(fo: tp.IO[bytes],
native_decoder = None
code_buf = None
decoder = None
- native_module = _tensor_native_ac_module() if acv == 4 else None
+ native_module = None
if native_module is not None:
native_decoder = native_module.ArithmeticDecoder(
frame_fo.read(),
@@ -1112,6 +1383,7 @@ def decompress_from_file(fo: tp.IO[bytes],
with torch.inference_mode():
for t in range(frame_length):
if use_lm and acv >= 3:
+ assert lm is not None
logits_raw, states, offset = lm.forward_logits(
input_, states, offset)
logits_q = _quantize_logits_(logits_raw / lm_tau,
@@ -1173,20 +1445,11 @@ def decompress_from_file(fo: tp.IO[bytes],
device=coder_device)
except Exception:
- if acv == 4:
- decoded_frames.append(
- torch.zeros(1, model.channels, this_len, device=model_device))
- continue
raise
encoded_frame = (frame.to(model_device),
None if scale is None else scale.to(model_device))
- if acv == 4:
- with torch.inference_mode():
- decoded_frames.append(
- model._decode_frame(encoded_frame)[..., :this_len])
- else:
- frames.append(encoded_frame)
+ frames.append(encoded_frame)
if acv == 4:
if model.segment_length is None:
diff --git a/scripts/bench_decode_payload.py b/scripts/bench_decode_payload.py
new file mode 100644
index 0000000..69bfeea
--- /dev/null
+++ b/scripts/bench_decode_payload.py
@@ -0,0 +1,86 @@
+#!/usr/bin/env python3
+import argparse
+import hashlib
+import json
+import sys
+import time
+from pathlib import Path
+
+import soundfile as sf
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Benchmark EnCodec payload decode.")
+ parser.add_argument("--repo-path", type=Path, required=True, help="Path to the EnCodec checkout.")
+ parser.add_argument("--payload", type=Path, required=True, help="Path to the .ecdc payload.")
+ parser.add_argument("--device", default="cpu", help="Decode device.")
+ parser.add_argument("--warmup", type=int, default=0, help="Number of warmup decodes to discard.")
+ parser.add_argument("--repeats", type=int, default=1, help="Number of decode repetitions.")
+ parser.add_argument("--output-wav", type=Path, default=None, help="Optional WAV output path.")
+ return parser.parse_args()
+
+
+def main():
+ args = parse_args()
+ sys.path.insert(0, str(args.repo_path))
+
+ from encodec.compress import decompress
+
+ payload = args.payload.read_bytes()
+ runs = []
+ wav_sha256 = None
+ wav_shape = None
+ sample_rate = None
+
+ for _ in range(max(0, int(args.warmup))):
+ wav, sample_rate = decompress(payload, device=args.device)
+ wav_cpu = wav.detach().cpu().contiguous()
+ digest = hashlib.sha256(wav_cpu.numpy().tobytes()).hexdigest()
+ if wav_sha256 is None:
+ wav_sha256 = digest
+ wav_shape = list(wav_cpu.shape)
+ elif digest != wav_sha256:
+ raise RuntimeError(
+ f"Non-deterministic warmup decode: first hash {wav_sha256}, later hash {digest}."
+ )
+
+ for _ in range(max(1, int(args.repeats))):
+ t0 = time.perf_counter()
+ wav, sample_rate = decompress(payload, device=args.device)
+ decode_s = time.perf_counter() - t0
+ wav_cpu = wav.detach().cpu().contiguous()
+ digest = hashlib.sha256(wav_cpu.numpy().tobytes()).hexdigest()
+ if wav_sha256 is None:
+ wav_sha256 = digest
+ wav_shape = list(wav_cpu.shape)
+ elif digest != wav_sha256:
+ raise RuntimeError(
+ f"Non-deterministic decode: first hash {wav_sha256}, later hash {digest}."
+ )
+ runs.append(decode_s)
+
+ result = {
+ "payload": str(args.payload),
+ "device": args.device,
+ "warmup": max(0, int(args.warmup)),
+ "repeats": len(runs),
+ "decode_s_runs": runs,
+ "decode_s_mean": sum(runs) / len(runs),
+ "wav_sha256": wav_sha256,
+ "wav_shape": wav_shape,
+ "sample_rate": sample_rate,
+ }
+ if args.output_wav is not None:
+ args.output_wav.parent.mkdir(parents=True, exist_ok=True)
+ sf.write(
+ str(args.output_wav),
+ wav.detach().cpu().transpose(0, 1).numpy(),
+ int(sample_rate),
+ subtype="PCM_16",
+ )
+ result["output_wav"] = str(args.output_wav)
+ print(json.dumps(result))
+
+
+if __name__ == "__main__":
+ main()
From 907522727440373126279d17f53585121ca37b4e Mon Sep 17 00:00:00 2001
From: jbrough
Date: Fri, 3 Apr 2026 19:56:40 +0100
Subject: [PATCH 21/24] Tighten fork README tone
---
README.md | 12 ++++++++----
1 file changed, 8 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 85f45a4..9ca9b2c 100644
--- a/README.md
+++ b/README.md
@@ -10,11 +10,15 @@
## wavey-ai fork README
-Bottom line for the `wavey-ai` fork: on an RTX 4000 Ada, the deterministic LM path cut 48 kHz GPU encode from `99s` to `13s` on a full song, made `cuda -> cpu` decode work reliably, slightly improved GPU decode, and now defaults CPU chunk decode workers to `available CPUs - 1`. On the tested 4-vCPU Linode box, that auto default picked `3` workers and cut full-song CPU decode from `170.18s` to `92.52s` (`1.84x` faster). Forcing `4` workers reached `82.33s` (`2.07x` faster), but the worker topology is deterministic-with-itself rather than hash-identical to the previous threaded single-process CPU path.
+This fork keeps the upstream model weights and changes the codec/runtime behavior around them. The main additions are:
-### Precision and Robustness Improvements
+- a deterministic `acv=4` entropy path for cross-device payload compatibility
+- optional native entropy-coder acceleration
+- chunked CPU decode parallelism for `acv=4` payloads
-This fork extends the original EnCodec with a fully deterministic, cross-platform entropy coding path plus optional native entropy-coder acceleration. The neural network weights remain unchanged.
+On the RTX 4000 Ada benchmark in this README, the fork improved 48 kHz GPU encode from `99.07 s` to `13.09 s`, preserved `cuda -> cpu` decode for deterministic payloads, and slightly improved GPU decode. On the tested 4-vCPU Linode box, the default CPU decode worker policy (`available CPUs - 1`) reduced full-song CPU decode from `170.18 s` to `92.52 s`. Forcing `4` workers reached `82.33 s`. Worker-mode CPU decode is deterministic for a fixed worker topology, but it is not hash-identical to the previous threaded single-process CPU decode.
+
+### Deterministic entropy path
### Bitstream version `acv=4`
@@ -58,7 +62,7 @@ Benchmarked on April 3, 2026 on a Linode `g2-gpu-rtx4000a1-s` instance (1x RTX 4
| Fork `cuda -> cpu` | `12.94 s` | `16.11x` | `167.56 s` | `1.24x` | cross-architecture decode succeeds |
| Fork `cpu -> cpu` | `35.22 s` | `5.92x` | `160.96 s` | `1.30x` | encode `2.95x` faster than upstream CPU, CPU decode slower |
-What this means in practice:
+Summary:
- The biggest RTX win is encode throughput. On this full-length track, the fork cut GPU encode time from `99.07 s` to `13.09 s`.
- GPU decode is modestly faster than upstream on the same Ada card, but the main portability win is that `cuda -> cpu` decode works at all.
From 76995ee6f262e604390ee9990fd432afac0f2154 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Sun, 12 Apr 2026 12:50:48 +0100
Subject: [PATCH 22/24] Add frame-level ONNX export bundle
---
README.md | 30 +++++++
encodec/onnx.py | 153 +++++++++++++++++++++++++++++++++++
scripts/export_frame_onnx.py | 65 +++++++++++++++
3 files changed, 248 insertions(+)
create mode 100644 encodec/onnx.py
create mode 100644 scripts/export_frame_onnx.py
diff --git a/README.md b/README.md
index 9ca9b2c..edcfb2f 100644
--- a/README.md
+++ b/README.md
@@ -15,9 +15,39 @@ This fork keeps the upstream model weights and changes the codec/runtime behavio
- a deterministic `acv=4` entropy path for cross-device payload compatibility
- optional native entropy-coder acceleration
- chunked CPU decode parallelism for `acv=4` payloads
+- a frame-level ONNX export boundary for the neural encoder/decoder
On the RTX 4000 Ada benchmark in this README, the fork improved 48 kHz GPU encode from `99.07 s` to `13.09 s`, preserved `cuda -> cpu` decode for deterministic payloads, and slightly improved GPU decode. On the tested 4-vCPU Linode box, the default CPU decode worker policy (`available CPUs - 1`) reduced full-song CPU decode from `170.18 s` to `92.52 s`. Forcing `4` workers reached `82.33 s`. Worker-mode CPU decode is deterministic for a fixed worker topology, but it is not hash-identical to the previous threaded single-process CPU decode.
+### Composable split
+
+The intended split in this fork is now:
+
+- neural frame codec: `_encode_frame(...)` / `_decode_frame(...)`
+- runtime / bitstream: segmentation, overlap-add, `.ecdc` framing, arithmetic coding, LM entropy path
+
+The ONNX export path only targets the neural frame codec boundary. It does not export the full `compress()` / `decompress()` pipeline.
+
+### Frame ONNX export
+
+Export example:
+
+```bash
+python scripts/export_frame_onnx.py \
+ --model encodec_48khz \
+ --bandwidth 12 \
+ --device cuda \
+ --output-dir out/encodec_48khz_12kbps_onnx
+```
+
+The exporter writes:
+
+- `encode_frame.onnx`
+- `decode_frame.onnx`
+- `bundle.json`
+
+The checked-in `bundle.json` contract is designed for Rust runtimes such as `encodec-rs` to load and run the neural frame path directly through ONNX Runtime.
+
### Deterministic entropy path
### Bitstream version `acv=4`
diff --git a/encodec/onnx.py b/encodec/onnx.py
new file mode 100644
index 0000000..bae647d
--- /dev/null
+++ b/encodec/onnx.py
@@ -0,0 +1,153 @@
+from __future__ import annotations
+
+from dataclasses import asdict, dataclass
+from pathlib import Path
+import json
+import typing as tp
+
+import onnx
+import torch
+from torch import nn
+
+from .model import EncodecModel
+
+
+MODEL_FACTORIES: dict[str, tp.Callable[..., EncodecModel]] = {
+ "encodec_24khz": EncodecModel.encodec_model_24khz,
+ "encodec_48khz": EncodecModel.encodec_model_48khz,
+}
+
+
+@dataclass
+class OnnxFrameBundleMetadata:
+ schema_version: int
+ model_name: str
+ bandwidth_kbps: float
+ sample_rate: int
+ channels: int
+ segment_samples: int
+ segment_stride: int
+ normalize: bool
+ num_codebooks: int
+ frame_length: int
+ encode_model: str
+ decode_model: str
+ opset_version: int
+
+
+class EncodeFrameWrapper(nn.Module):
+ def __init__(self, model: EncodecModel):
+ super().__init__()
+ self.model = model
+
+ def forward(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ codes, scale = self.model._encode_frame(x)
+ if scale is None:
+ scale = torch.ones((x.shape[0], 1), dtype=x.dtype, device=x.device)
+ return codes, scale
+
+
+class DecodeFrameWrapper(nn.Module):
+ def __init__(self, model: EncodecModel):
+ super().__init__()
+ self.model = model
+
+ def forward(self, codes: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
+ if self.model.normalize:
+ return self.model._decode_frame((codes, scale))
+ return self.model._decode_frame((codes, None))
+
+
+def build_model(
+ model_name: str,
+ bandwidth_kbps: float,
+ device: str = "cpu",
+ repository: Path | None = None,
+) -> EncodecModel:
+ if model_name not in MODEL_FACTORIES:
+ supported = ", ".join(sorted(MODEL_FACTORIES.keys()))
+ raise ValueError(f"Unsupported model {model_name!r}. Use one of: {supported}.")
+
+ model = MODEL_FACTORIES[model_name](repository=repository)
+ model.set_target_bandwidth(float(bandwidth_kbps))
+ return model.to(device).eval()
+
+
+def export_frame_onnx_bundle(
+ output_dir: str | Path,
+ model_name: str = "encodec_48khz",
+ bandwidth_kbps: float = 6.0,
+ device: str = "cpu",
+ repository: str | Path | None = None,
+ opset_version: int = 18,
+) -> OnnxFrameBundleMetadata:
+ output_dir = Path(output_dir)
+ output_dir.mkdir(parents=True, exist_ok=True)
+ repository_path = None if repository is None else Path(repository)
+
+ model = build_model(model_name, bandwidth_kbps, device=device, repository=repository_path)
+ segment_samples = int(model.segment_length or model.sample_rate)
+
+ torch.manual_seed(0)
+ dummy_audio = torch.randn(
+ 1,
+ model.channels,
+ segment_samples,
+ device=device,
+ dtype=torch.float32,
+ ) * 0.01
+
+ encoder = EncodeFrameWrapper(model).eval()
+ decoder = DecodeFrameWrapper(model).eval()
+
+ with torch.no_grad():
+ codes, scale = encoder(dummy_audio)
+ codes = codes.detach().clone()
+ scale = scale.detach().clone()
+
+ encode_path = output_dir / "encode_frame.onnx"
+ decode_path = output_dir / "decode_frame.onnx"
+
+ torch.onnx.export(
+ encoder,
+ (dummy_audio,),
+ encode_path,
+ input_names=["audio"],
+ output_names=["codes", "scale"],
+ opset_version=opset_version,
+ dynamo=False,
+ )
+ torch.onnx.export(
+ decoder,
+ (codes, scale),
+ decode_path,
+ input_names=["codes", "scale"],
+ output_names=["audio"],
+ opset_version=opset_version,
+ dynamo=False,
+ )
+
+ onnx.checker.check_model(str(encode_path))
+ onnx.checker.check_model(str(decode_path))
+
+ metadata = OnnxFrameBundleMetadata(
+ schema_version=1,
+ model_name=model.name,
+ bandwidth_kbps=float(bandwidth_kbps),
+ sample_rate=int(model.sample_rate),
+ channels=int(model.channels),
+ segment_samples=segment_samples,
+ segment_stride=int(model.segment_stride or segment_samples),
+ normalize=bool(model.normalize),
+ num_codebooks=int(codes.shape[1]),
+ frame_length=int(codes.shape[2]),
+ encode_model=encode_path.name,
+ decode_model=decode_path.name,
+ opset_version=int(opset_version),
+ )
+ (output_dir / "bundle.json").write_text(json.dumps(asdict(metadata), indent=2) + "\n")
+ return metadata
+
+
+def metadata_to_json(metadata: OnnxFrameBundleMetadata) -> str:
+ return json.dumps(asdict(metadata), indent=2, sort_keys=True)
diff --git a/scripts/export_frame_onnx.py b/scripts/export_frame_onnx.py
new file mode 100644
index 0000000..7e01d0a
--- /dev/null
+++ b/scripts/export_frame_onnx.py
@@ -0,0 +1,65 @@
+#!/usr/bin/env python3
+from __future__ import annotations
+
+import argparse
+from pathlib import Path
+
+from encodec.onnx import export_frame_onnx_bundle, metadata_to_json
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Export the EnCodec frame encoder/decoder boundary to an ONNX bundle."
+ )
+ parser.add_argument(
+ "--model",
+ default="encodec_48khz",
+ choices=["encodec_24khz", "encodec_48khz"],
+ help="Pretrained EnCodec model to export.",
+ )
+ parser.add_argument(
+ "--bandwidth",
+ type=float,
+ default=6.0,
+ help="Target bandwidth in kbps for the exported bundle.",
+ )
+ parser.add_argument(
+ "--output-dir",
+ required=True,
+ help="Directory that will receive encode_frame.onnx, decode_frame.onnx, and bundle.json.",
+ )
+ parser.add_argument(
+ "--device",
+ default="cpu",
+ help="Torch device for export, e.g. cpu or cuda.",
+ )
+ parser.add_argument(
+ "--repository",
+ type=Path,
+ default=None,
+ help="Optional local checkpoint repository path.",
+ )
+ parser.add_argument(
+ "--opset-version",
+ type=int,
+ default=18,
+ help="ONNX opset version to export.",
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ args = parse_args()
+ metadata = export_frame_onnx_bundle(
+ output_dir=args.output_dir,
+ model_name=args.model,
+ bandwidth_kbps=args.bandwidth,
+ device=args.device,
+ repository=args.repository,
+ opset_version=args.opset_version,
+ )
+ print(metadata_to_json(metadata))
+
+
+if __name__ == "__main__":
+ main()
From e5c7ffd29c55cb88cae57430a917164f42943ce9 Mon Sep 17 00:00:00 2001
From: jbrough
Date: Sun, 12 Apr 2026 13:48:16 +0100
Subject: [PATCH 23/24] Export ONNX frame bundles with dynamic batch
---
encodec/onnx.py | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/encodec/onnx.py b/encodec/onnx.py
index bae647d..34230da 100644
--- a/encodec/onnx.py
+++ b/encodec/onnx.py
@@ -116,6 +116,11 @@ def export_frame_onnx_bundle(
output_names=["codes", "scale"],
opset_version=opset_version,
dynamo=False,
+ dynamic_axes={
+ "audio": {0: "batch"},
+ "codes": {0: "batch"},
+ "scale": {0: "batch"},
+ },
)
torch.onnx.export(
decoder,
@@ -125,6 +130,11 @@ def export_frame_onnx_bundle(
output_names=["audio"],
opset_version=opset_version,
dynamo=False,
+ dynamic_axes={
+ "codes": {0: "batch"},
+ "scale": {0: "batch"},
+ "audio": {0: "batch"},
+ },
)
onnx.checker.check_model(str(encode_path))
From 4757b5ae14224e882f954988bcb6a5653ba7cac3 Mon Sep 17 00:00:00 2001
From: Jamie B
Date: Thu, 4 Jun 2026 15:10:55 +0100
Subject: [PATCH 24/24] Save local work
---
encodec/onnx.py | 95 ++++++++++++++++++++++++++++++++++++++++---------
1 file changed, 79 insertions(+), 16 deletions(-)
diff --git a/encodec/onnx.py b/encodec/onnx.py
index 34230da..798778d 100644
--- a/encodec/onnx.py
+++ b/encodec/onnx.py
@@ -3,6 +3,7 @@
from dataclasses import asdict, dataclass
from pathlib import Path
import json
+import os
import typing as tp
import onnx
@@ -33,6 +34,15 @@ class OnnxFrameBundleMetadata:
encode_model: str
decode_model: str
opset_version: int
+ bits_per_codebook: int | None = None
+ codebook_cardinality: int | None = None
+ lm_quant_weight_model: str | None = None
+ lm_dim: int | None = None
+ lm_num_layers: int | None = None
+ lm_past_context: int | None = None
+ lm_logit_step: float | None = None
+ lm_entropy_logit_step: float | None = None
+ lm_cardinality: int | None = None
class EncodeFrameWrapper(nn.Module):
@@ -73,29 +83,62 @@ def build_model(
return model.to(device).eval()
+def _env_int(name: str) -> int | None:
+ value = os.getenv(name)
+ if value is None or value == "":
+ return None
+ parsed = int(value)
+ if parsed <= 0:
+ raise ValueError(f"{name} must be positive")
+ return parsed
+
+
def export_frame_onnx_bundle(
output_dir: str | Path,
model_name: str = "encodec_48khz",
bandwidth_kbps: float = 6.0,
device: str = "cpu",
repository: str | Path | None = None,
- opset_version: int = 18,
+ opset_version: int = 17,
) -> OnnxFrameBundleMetadata:
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repository_path = None if repository is None else Path(repository)
- model = build_model(model_name, bandwidth_kbps, device=device, repository=repository_path)
- segment_samples = int(model.segment_length or model.sample_rate)
+ bundle_path = output_dir / "bundle.json"
+ existing_bundle: dict[str, tp.Any] = {}
+ if bundle_path.exists():
+ existing_bundle = json.loads(bundle_path.read_text())
- torch.manual_seed(0)
- dummy_audio = torch.randn(
- 1,
- model.channels,
- segment_samples,
+ model = build_model(
+ model_name,
+ bandwidth_kbps,
device=device,
- dtype=torch.float32,
- ) * 0.01
+ repository=repository_path,
+ )
+
+ trace_samples = _env_int("ENCODEC_ONNX_TRACE_SAMPLES")
+ trace_stride = _env_int("ENCODEC_ONNX_TRACE_STRIDE")
+
+ if trace_samples is None:
+ segment_samples = int(model.segment_length or model.sample_rate)
+ segment_stride = int(model.segment_stride or segment_samples)
+ else:
+ segment_samples = trace_samples
+ segment_stride = trace_stride or trace_samples
+ model.segment = segment_samples / float(model.sample_rate)
+
+ torch.manual_seed(0)
+ dummy_audio = (
+ torch.randn(
+ 1,
+ model.channels,
+ segment_samples,
+ device=device,
+ dtype=torch.float32,
+ )
+ * 0.01
+ )
encoder = EncodeFrameWrapper(model).eval()
decoder = DecodeFrameWrapper(model).eval()
@@ -115,13 +158,13 @@ def export_frame_onnx_bundle(
input_names=["audio"],
output_names=["codes", "scale"],
opset_version=opset_version,
- dynamo=False,
dynamic_axes={
"audio": {0: "batch"},
"codes": {0: "batch"},
"scale": {0: "batch"},
},
)
+
torch.onnx.export(
decoder,
(codes, scale),
@@ -129,7 +172,6 @@ def export_frame_onnx_bundle(
input_names=["codes", "scale"],
output_names=["audio"],
opset_version=opset_version,
- dynamo=False,
dynamic_axes={
"codes": {0: "batch"},
"scale": {0: "batch"},
@@ -146,18 +188,39 @@ def export_frame_onnx_bundle(
bandwidth_kbps=float(bandwidth_kbps),
sample_rate=int(model.sample_rate),
channels=int(model.channels),
- segment_samples=segment_samples,
- segment_stride=int(model.segment_stride or segment_samples),
+ segment_samples=int(segment_samples),
+ segment_stride=int(segment_stride),
normalize=bool(model.normalize),
num_codebooks=int(codes.shape[1]),
frame_length=int(codes.shape[2]),
encode_model=encode_path.name,
decode_model=decode_path.name,
opset_version=int(opset_version),
+ bits_per_codebook=int(model.bits_per_codebook),
+ codebook_cardinality=int(model.quantizer.bins),
+ lm_quant_weight_model=existing_bundle.get("lm_quant_weight_model"),
+ lm_dim=existing_bundle.get("lm_dim"),
+ lm_num_layers=existing_bundle.get("lm_num_layers"),
+ lm_past_context=existing_bundle.get("lm_past_context"),
+ lm_logit_step=existing_bundle.get("lm_logit_step"),
+ lm_entropy_logit_step=existing_bundle.get("lm_entropy_logit_step"),
+ lm_cardinality=existing_bundle.get("lm_cardinality", int(model.quantizer.bins)),
)
- (output_dir / "bundle.json").write_text(json.dumps(asdict(metadata), indent=2) + "\n")
+
+ bundle_payload = {
+ key: value
+ for key, value in asdict(metadata).items()
+ if value is not None
+ }
+
+ bundle_path.write_text(json.dumps(bundle_payload, indent=2) + "\n")
return metadata
def metadata_to_json(metadata: OnnxFrameBundleMetadata) -> str:
- return json.dumps(asdict(metadata), indent=2, sort_keys=True)
+ payload = {
+ key: value
+ for key, value in asdict(metadata).items()
+ if value is not None
+ }
+ return json.dumps(payload, indent=2, sort_keys=True)