For practical usage examples showing how to integrate these APIs with upstream encoders, see the Integration guide.
The high-level module for most use cases. Uses the streaming Triton kernel internally, which is the recommended approach for both training and inference:
from flash_semicrf import SemiMarkovCRFHead
class SemiMarkovCRFHead(nn.Module):
def __init__(
self,
num_classes: int, # Number of label classes (C)
max_duration: int, # Maximum segment duration (K)
hidden_dim: int = None, # Optional: projection from encoder dim
init_scale: float = 0.1, # Parameter initialization scale
duration_distribution: str = None, # "learned", "geometric", "poisson", etc.
edge_memory_threshold: float = 8e9, # Memory threshold for non-log/max semirings (8GB)
num_warps: int = 4, # Triton kernel parallelism (2-8 recommended)
):
"""
CRF head for Semi-Markov sequence labeling.
Compatible with DDP - gradients sync automatically via standard PyTorch.
Memory: O(KC) independent of sequence length T.
Transition Matrix Convention:
transition[i, j] = score for transitioning FROM label i TO label j.
Note:
For T > 100K, use float32 precision for numerical stability.
"""
def forward(
self,
hidden_states, # (batch, T, hidden_dim) or (batch, T, C)
lengths, # (batch,) sequence lengths
use_triton=True,
backend="auto", # "auto", "streaming", "exact", or "binary_tree_sharded"
) -> dict:
"""
Compute partition function.
Args:
backend: Backend selection mode:
- "auto": Streaming for log/max semirings, exact for others (default)
- "streaming": Force streaming backend (genome-scale)
- "exact": Force exact backend via semimarkov.py
- "binary_tree_sharded": Memory-efficient reference with checkpointing
Returns:
dict with 'partition' (batch,) and 'cum_scores' (batch, T+1, C)
"""
def compute_loss(
self,
hidden_states,
lengths,
labels, # (batch, T) per-position labels
use_triton=True,
backend="auto", # "auto", "streaming", "exact", or "binary_tree_sharded"
reduction="mean",
) -> Tensor:
"""Compute negative log-likelihood loss."""
def decode(
self,
hidden_states,
lengths,
use_triton=True,
backend="auto", # "auto", "streaming", "exact", or "binary_tree_sharded"
) -> Tensor:
"""Viterbi decoding - returns best score (batch,)."""
def decode_with_traceback(
self,
hidden_states,
lengths,
max_traceback_length=10000,
use_triton=True,
) -> ViterbiResult:
"""Viterbi with path reconstruction. Returns (scores, segments)."""
def parameter_penalty(self, p: float = 2.0) -> Tensor:
"""
Compute Lp penalty on CRF parameters for regularization.
Returns: ||transition||_p^p + ||duration_bias||_p^p
"""Example usage:
from flash_semicrf import SemiMarkovCRFHead
# Create CRF head
crf = SemiMarkovCRFHead(
num_classes=24,
max_duration=100,
hidden_dim=512,
duration_distribution="geometric", # or "learned", "poisson", etc.
)
# Encoder output
hidden = encoder(x) # (batch, T, 512)
lengths = torch.full((batch,), T)
# Forward pass
result = crf(hidden, lengths)
partition = result['partition'] # (batch,)
# Training with labels
labels = torch.randint(0, 24, (batch, T))
loss = crf.compute_loss(hidden, lengths, labels)
loss.backward()
# Viterbi decoding with traceback
result = crf.decode_with_traceback(hidden, lengths)
for seg in result.segments[0]:
print(f"[{seg.start}, {seg.end}] label={seg.label}")
# Regularization
reg_loss = loss + 0.01 * crf.parameter_penalty()Understanding how data flows through the CRF head helps with debugging and optimization.
hidden_states (batch, T, hidden_dim)
│
▼ [projection layer, if hidden_dim provided]
scores (batch, T, C)
│
▼ [zero-center for numerical stability]
scores_centered = scores - scores.mean(dim=1)
│
▼ [cumulative sum in float32]
cum_scores (batch, T+1, C)
│
▼ [backend selection: streaming for log/max, exact for others]
│
├── streaming (log/max semirings — always default)
│ │
│ ▼
│ semi_crf_streaming_forward()
│ │
│ ▼ [Triton kernel on CUDA, PyTorch reference on CPU]
│ partition (batch,)
│
└── exact (non-log/max semirings or explicitly requested)
│
▼
_build_edge_tensor() → edge (batch, T, K, C, C)
│
▼
SemiMarkov.logpartition()
│
▼
partition (batch,)
│
▼ [score gold segmentation via label changes]
gold_score = score_gold_vectorized(cum_scores, labels, ...)
│
▼
NLL = partition - gold_score
hidden_states (batch, T, hidden_dim)
│
▼ [same preprocessing as training]
cum_scores (batch, T+1, C)
│
▼ [streaming Viterbi with backpointer storage]
semi_crf_streaming_viterbi_with_backpointers()
│
├── max_scores (batch,)
├── bp_k (batch, T, C) ← best duration at each (position, label)
├── bp_c (batch, T, C) ← best source label at each (position, label)
└── final_labels (batch,)
│
▼ [O(T) traceback using backpointers]
_traceback_from_backpointers()
│
▼
ViterbiResult(scores, List[List[Segment]])
| Component | File |
|---|---|
| CRF head module | nn.py |
| Streaming kernels | streaming/autograd.py |
| Triton forward | streaming/triton_forward.py |
| Triton backward | streaming/triton_backward.py |
| Gold scoring | helpers.py |
Extended CRF head with uncertainty quantification:
from flash_semicrf import UncertaintySemiMarkovCRFHead
class UncertaintySemiMarkovCRFHead(SemiMarkovCRFHead):
"""
SemiMarkovCRFHead with uncertainty methods.
Additional methods for boundary confidence and active learning.
Inherits all parameters from SemiMarkovCRFHead including accum_dtype and num_warps.
"""
def compute_boundary_marginals(
self,
hidden_states,
lengths,
backend="auto", # "auto", "streaming", or "exact"
normalize=True,
) -> Tensor:
"""
P(boundary at position t) for each position.
Args:
backend: Backend selection mode:
- "auto": Streaming for log/max semirings, exact for others (default)
- "streaming": Force streaming forward-backward algorithm
- "exact": Force exact marginals via edge tensor
Returns: (batch, T) boundary probabilities
Note: use_streaming parameter is deprecated, use backend instead.
"""
def compute_position_marginals(
self,
hidden_states,
lengths,
) -> Tensor:
"""
P(label=c at position t) for each position.
Returns: (batch, T, C) label probabilities
"""
def compute_entropy_streaming(
self,
hidden_states,
lengths,
) -> Tensor:
"""
Approximate entropy from marginals (works for T > 10K).
Returns: (batch,) entropy estimates
"""
def compute_entropy_exact(
self,
hidden_states,
lengths,
) -> Tensor:
"""
Exact entropy via EntropySemiring (T < 10K only).
Computes H(P) = -sum_y P(y) log P(y) using the entropy semiring.
Requires building the full edge tensor, so only works for short sequences.
Returns: (batch,) exact entropy values
"""
def compute_loss_uncertainty_weighted(
self,
hidden_states,
lengths,
labels,
uncertainty_weight=1.0,
focus_mode="high_uncertainty", # or "boundary_regions"
use_triton=False,
reduction="mean",
) -> Tensor:
"""
Uncertainty-weighted loss for active learning.
L_weighted = (1 + lambda * uncertainty) * NLL
"""Example usage:
from flash_semicrf import UncertaintySemiMarkovCRFHead
model = UncertaintySemiMarkovCRFHead(num_classes=24, max_duration=100, hidden_dim=512)
# Boundary confidence for decision support (auto-selects streaming for large T)
boundary_probs = model.compute_boundary_marginals(hidden, lengths)
# Force streaming backend for genome-scale sequences
boundary_probs = model.compute_boundary_marginals(hidden, lengths, backend="streaming")
# Entropy computation
entropy_approx = model.compute_entropy_streaming(hidden, lengths) # T > 10K
entropy_exact = model.compute_entropy_exact(hidden, lengths) # T < 10K
# Uncertainty-weighted training for active learning
loss = model.compute_loss_uncertainty_weighted(hidden, lengths, labels)The streaming API is the recommended approach for training on GPU. It uses Triton kernels that compute edges on-the-fly, providing both memory efficiency and reliable gradient computation.
For very long sequences (T = 100K - 400K+) where edge tensor cannot fit in memory:
from flash_semicrf import semi_crf_streaming_forward
def semi_crf_streaming_forward(
cum_scores, # (batch, T+1, C) cumulative projected scores
transition, # (C, C) or (K, C, C) transition matrix
duration_bias, # (K, C) duration bias
lengths, # (batch,) sequence lengths
K, # max segment duration
semiring="log", # "log" (partition) or "max" (Viterbi)
use_triton=True, # Use Triton kernel if available
num_warps=4, # Triton kernel parallelism (2-8 recommended)
) -> Tensor:
"""
Memory-efficient Semi-CRF forward with on-the-fly edge computation.
Memory: O(T*C) for cum_scores vs O(T*K*C^2) for edge tensor
- T=400K, K=3K, C=24: 38 MB vs 2.76 TB
Args:
num_warps: Number of warps per block for Triton kernels.
Higher values increase parallelism but also register pressure.
Recommended range: 2-8. Default: 4
Returns:
partition: (batch,) log partition function or Viterbi score
"""Example usage:
from flash_semicrf import semi_crf_streaming_forward
# Pre-project features (outside kernel)
projected = hidden @ W_content # (batch, T, C)
projected = projected - projected.mean(dim=1, keepdim=True) # Zero-center!
# Build cumulative scores
cum_scores = torch.zeros(batch, T+1, C, dtype=torch.float32)
cum_scores[:, 1:] = torch.cumsum(projected.float(), dim=1)
# Streaming forward (edges computed on-the-fly)
partition = semi_crf_streaming_forward(
cum_scores, transition, duration_bias, lengths, max_duration
)Low-level components of the streaming API for advanced use cases.
Compute edge potentials for a single (position, duration) pair on-the-fly:
from flash_semicrf.streaming import compute_edge_block_streaming
def compute_edge_block_streaming(
cum_scores, # (batch, T+1, C)
transition, # (C, C)
duration_bias, # (K, C)
start: int, # Segment start position
k: int, # Segment duration
) -> Tensor:
"""
Compute edge block for segment [start, start+k) on-the-fly.
Returns: (batch, C_dest, C_src) edge potentials
"""PyTorch autograd functions for streaming Semi-CRF with custom backward passes:
from flash_semicrf.streaming import SemiCRFStreaming, SemiCRFStreamingTriton
# Pure PyTorch (CPU or GPU without Triton)
class SemiCRFStreaming(torch.autograd.Function):
"""
O(KC) memory streaming forward-backward.
Uses gradient checkpointing to recompute alpha values during backward,
trading compute for memory.
"""
# Triton-accelerated (GPU with Triton)
class SemiCRFStreamingTriton(torch.autograd.Function):
"""
O(KC) memory with hand-written Triton backward kernels.
Hand-written Triton forward and backward kernels.
"""When Triton is installed, these low-level kernel launchers are available:
from flash_semicrf.streaming import HAS_TRITON
if HAS_TRITON:
from flash_semicrf.streaming import (
launch_streaming_triton_kernel, # Forward pass
launch_streaming_triton_backward, # Backward pass
launch_streaming_triton_kernel_max_bp, # Viterbi with backpointers
semi_crf_streaming_viterbi_triton, # Full Viterbi decoding
)from flash_semicrf.streaming import NEG_INF
NEG_INF # Large negative value for log-space operations (-1e38)Flexible parameterization for segment duration priors:
from flash_semicrf import (
create_duration_distribution,
LearnedDuration,
GeometricDuration,
NegativeBinomialDuration,
PoissonDuration,
UniformDuration,
CallableDuration,
)
# Factory function
dur = create_duration_distribution(
"geometric", # or "learned", "negative_binomial", "poisson", "uniform"
max_duration=100,
num_classes=24,
init_logit=-1.0, # Distribution-specific kwargs
)
bias = dur() # Returns (K, C) tensorAvailable distributions:
| Distribution | Formula | Use case |
|---|---|---|
LearnedDuration |
Fully learned | Most flexible, default |
GeometricDuration |
P(k) ~ p(1-p)^(k-1) | Exponential decay, numerically stable |
NegativeBinomialDuration |
Generalizes geometric | Peaked distributions |
PoissonDuration |
P(k) ~ lambda^k/k! | Characteristic segment length |
UniformDuration |
P(k) = const | No duration preference |
CallableDuration |
User-defined | Full customization |
Numerical stability note:
NegativeBinomialDuration with very small r values (init_log_r < -10) can cause
numerical instability. A runtime warning is emitted when non-finite values are detected.
Use GeometricDuration as a stable alternative.
Low-level API with semiring abstraction for advanced use cases:
from flash_semicrf import SemiMarkov
from flash_semicrf.semirings import LogSemiring, MaxSemiring
class SemiMarkov(semiring):
def logpartition(
self,
log_potentials, # (batch, N-1, K, C, C) edge potentials
lengths=None, # (batch,) sequence lengths
force_grad=False,
use_linear_scan=None, # None=auto, True=O(N) scan, False=O(log N) tree
use_vectorized=False, # If True, O(TKC) memory but 2-3x faster
use_banded=False, # Prototype: banded matrix optimization
) -> Tuple[Tensor, List[Tensor], None]:
"""
Compute log partition function.
Algorithm selection (use_linear_scan):
- None (default): Auto-select based on KC > 200
- True: O(N) linear scan with O(KC) ring buffer memory
- False: O(log N) binary tree (WARNING: O((KC)³) memory per matmul)
Memory modes:
- use_vectorized=False (default): O(KC) streaming ring buffer
- use_vectorized=True: O(TKC) but 2-3x faster
Returns:
v: (ssize, batch,) log partition values
edges: list containing input potentials for gradient computation
charts: None (streaming scan does not store charts)
"""
def marginals(
self,
log_potentials,
lengths=None,
) -> Tensor:
"""
Compute edge marginals via backward pass.
Returns: (batch, N-1, K, C, C) marginal probabilities
"""
@staticmethod
def hsmm(init, trans_z, trans_l, emission) -> Tensor:
"""
Convert HSMM parameters to Semi-Markov edge potentials.
Args:
init: (C,) initial state distribution
trans_z: (C, C) state transition matrix
trans_l: (C, K) duration distribution per state
emission: (batch, N, K, C) emission scores
Returns:
edge: (batch, N, K, C, C) edge potentials
"""from flash_semicrf.semirings import (
LogSemiring, # Standard log-space (sum-product)
MaxSemiring, # Viterbi decoding (max-product)
StdSemiring, # Standard arithmetic
KMaxSemiring, # Top-k paths
EntropySemiring, # Entropy computation
KLDivergenceSemiring, # KL divergence D_KL(P || Q)
CrossEntropySemiring, # Cross-entropy H(P, Q)
)
from flash_semicrf.semirings.checkpoint import (
CheckpointSemiring, # Gradient checkpointing
CheckpointShardSemiring, # Sharded checkpointing
)For pre-computed edge tensors, use SemiMarkov.logpartition directly. This unlocks
all 7 semirings (not just log/max as in the streaming API).
Note: For most use cases, the streaming API (
SemiMarkovCRFHeadorsemi_crf_streaming_forward) is recommended — it computes edges on-the-fly with O(KC) memory and supports both training and inference.
from flash_semicrf import SemiMarkov
from flash_semicrf.semirings import LogSemiring, MaxSemiring, EntropySemiring
# Pre-computed edge tensor (batch, T-1, K, C, C) - must fit in memory
crf = SemiMarkov(LogSemiring)
log_Z, _ = crf.logpartition(edge, lengths=lengths)
# Viterbi score (max semiring)
crf_max = SemiMarkov(MaxSemiring)
best_score, _ = crf_max.logpartition(edge, lengths=lengths)
# Entropy (not available via streaming API)
crf_ent = SemiMarkov(EntropySemiring)
entropy, _ = crf_ent.logpartition(edge, lengths=lengths)from flash_semicrf import Segment, ViterbiResult
@dataclass
class Segment:
"""A single segment from Viterbi decoding."""
start: int # Start position (inclusive)
end: int # End position (inclusive)
label: int # Label class
score: float # Segment score contribution
class ViterbiResult(NamedTuple):
"""Result from decode_with_traceback."""
scores: Tensor # (batch,) best scores
segments: List[List[Segment]] # Per-batch segment listsThese classes provide memory-efficient sparse representations for Semi-Markov structures. They are primarily used for benchmarking and experimentation.
Lightweight banded matrix representation for CPU/PyTorch operations:
from flash_semicrf import BandedMatrix
@dataclass
class BandedMatrix:
"""
Banded matrix representation for memory-efficient sparse operations.
Stores only the non-zero diagonals of a banded matrix in a compact format.
Supports log-semiring and max-semiring matrix multiplication.
"""
data: Tensor # (batch, n, lu+ld+1)
lu: int # Number of upper diagonals
ld: int # Number of lower diagonals
fill: float = 0.0
@classmethod
def from_dense(cls, dense, lu, ld, fill=0.0) -> BandedMatrix:
"""Extract banded view from dense square matrix."""
def to_dense(self) -> Tensor:
"""Expand back to dense matrix."""
def transpose(self) -> BandedMatrix:
"""Transpose the banded matrix."""
def multiply_log(self, other) -> BandedMatrix:
"""Log-semiring matrix multiplication."""
def multiply_max(self, other) -> BandedMatrix:
"""Max-semiring matrix multiplication."""Example:
from flash_semicrf import BandedMatrix
# Create banded matrix from dense
dense = torch.randn(2, 10, 10)
banded = BandedMatrix.from_dense(dense, lu=2, ld=2)
print(banded.data.shape) # (2, 10, 5)
# Convert back to dense
reconstructed = banded.to_dense()Block-triangular sparse representation exploiting duration constraint k1 + k2 <= span:
from flash_semicrf import BlockTriangularMatrix, block_triang_matmul
@dataclass
class BlockTriangularMatrix:
"""
Block-triangular matrix over duration states.
Stores only blocks (k1, k2) satisfying the duration constraint,
reducing memory from O(K²C²) to O(K(K+1)/2 * C²).
"""
values: Tensor # (batch, num_blocks, C, C)
block_indices: Tensor # (num_blocks, 2) - (k1, k2) coordinates
K: int
C: int
@classmethod
def from_dense(cls, dense, K, C, span, duration_mask=None) -> BlockTriangularMatrix:
"""Compress dense to block-triangular representation."""
def to_dense(self, semiring=None) -> Tensor:
"""Expand back to dense matrix."""
def block_triang_matmul(left, right, semiring, span) -> BlockTriangularMatrix:
"""Sparse semiring matrix multiplication."""Example:
from flash_semicrf import BlockTriangularMatrix
dense = torch.randn(2, 12, 12) # K=4, C=3
bt = BlockTriangularMatrix.from_dense(dense, K=4, C=3, span=4)
print(bt.values.shape) # (2, 10, 3, 3) - only 10 blocks satisfy k1+k2 <= 4Utilities for analyzing and optimizing banded matrix structures:
from flash_semicrf import (
measure_effective_bandwidth,
snake_ordering,
rcm_ordering_from_adjacency,
apply_permutation,
)
def measure_effective_bandwidth(adj, fill_value=None) -> int:
"""
Compute maximum distance from diagonal of any non-fill entry.
Args:
adj: Adjacency matrix (n, n) or (batch, n, n), or BandedMatrix
fill_value: Value representing empty/non-edges (auto-detected if None)
Returns:
Maximum distance from diagonal
"""
def snake_ordering(K, C) -> Tensor:
"""
Generate snake ordering permutation for (K, C) state space.
Interleaves low and high duration states to reduce bandwidth:
[0, K-1, 1, K-2, 2, K-3, ...] for each label.
Returns: Permutation tensor of shape (K*C,)
"""
def rcm_ordering_from_adjacency(adj) -> tuple[Tensor, bool]:
"""
Compute Reverse Cuthill-McKee ordering (requires SciPy).
Minimizes bandwidth of sparse matrices.
Returns: (permutation, success) - success=False if SciPy unavailable
"""
def apply_permutation(potentials, perm) -> Tensor:
"""Apply permutation to both dimensions of a matrix."""Example:
from flash_semicrf import measure_effective_bandwidth, snake_ordering
adj = torch.eye(5)
print(measure_effective_bandwidth(adj)) # 0
adj[0, 4] = 1 # Add off-diagonal entry
print(measure_effective_bandwidth(adj)) # 4
# Generate snake ordering for K=10, C=3
perm = snake_ordering(10, 3)| Scenario | Edge tensor size | Streaming API size |
|---|---|---|
| T=1K, K=32, C=24 | 18 MB | 96 KB |
| T=10K, K=100, C=24 | 5.5 GB | 960 KB |
| T=400K, K=3K, C=24 | 2.76 TB | 38 MB |
Recommended usage:
| Use case | API | Why |
|---|---|---|
| Training (any T) | SemiMarkovCRFHead or semi_crf_streaming_forward |
Hand-written Triton forward and backward kernels |
| Inference (recommended) | SemiMarkovCRFHead or semi_crf_streaming_forward |
O(KC) memory, faster than pre-computed edges |
| Pre-computed edges / all semirings | SemiMarkov.logpartition |
All 7 semirings, edge tensor must fit in memory |
| Very long sequences (T > 10K) | Streaming API | Edge tensor cannot fit in memory |
- Backends and Triton kernel - Detailed kernel behavior and backend selection
- Integration guide - End-to-end training examples with BERT, Mamba, CNNs
- Streaming internals - Low-level algorithm details and numerical stability
- Uncertainty and focused learning - Boundary confidence and active learning
- Parameter guide: T, K, C - Understanding sequence length, duration, and state dimensions
- Semirings guide - Mathematical context for semiring operations