diff --git a/penzai/nn/geometric_attention.py b/penzai/nn/geometric_attention.py new file mode 100644 index 0000000..848d480 --- /dev/null +++ b/penzai/nn/geometric_attention.py @@ -0,0 +1,509 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Geometric Sparse Attention primitives for Penzai. + +This module provides AETHER (Adaptive Event-driven Threshold Hybrid Entangled +Rendering) geometric sparse attention, which replaces standard O(n²) attention +with a geometric block-sparse approximation. + +The key insight is that for a block of keys with centroid μ and radius r, +the maximum attention score for any key k in the block is bounded by: + + max(q · k for k in block) ≤ q · μ + ||q|| × r + +This Cauchy-Schwarz upper bound allows us to safely skip entire blocks when +the upper bound is below a threshold, achieving sub-linear complexity on +structured data. + +Example usage:: + + from penzai.nn.geometric_attention import GeometricSparseAttention + + # Create from config + attn = GeometricSparseAttention.from_config(block_size=64, adaptive=True) + output = attn((query, key, value)) + + # Or swap into existing model using Penzai selectors + model = pz.select(model).at_instances_of( + pz.nn.Attention + ).apply( + GeometricSparseAttention.from_attention + ) +""" + +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +from penzai.core import named_axes +from penzai.core import struct +from penzai.core import variables +from penzai.nn import layer as layer_base + + +def compute_block_geometry( + keys: jax.Array, + block_size: int = 64, +) -> tuple[jax.Array, jax.Array]: + """Compute geometric metadata (centroids and radii) for blocks of keys. + + Groups keys into fixed-size blocks and computes each block's centroid + (mean) and radius (max L2 distance from centroid). These are used to + compute upper bounds on attention scores. + + Args: + keys: Key tensor with shape (..., seq_len, dim). + block_size: Number of tokens per block. + + Returns: + A tuple of (centroids, radii): + - centroids: Block means, shape (..., n_blocks, dim) + - radii: Block radii (max L2 distance from centroid), + shape (..., n_blocks) + """ + *batch_dims, seq_len, dim = keys.shape + n_blocks = (seq_len + block_size - 1) // block_size + padded_len = n_blocks * block_size + + # Pad keys to be evenly divisible by block_size + if padded_len > seq_len: + pad_width = [(0, 0)] * len(batch_dims) + [(0, padded_len - seq_len), (0, 0)] + keys_padded = jnp.pad(keys, pad_width, mode='constant', constant_values=0) + else: + keys_padded = keys + + # Reshape to blocks: (..., n_blocks, block_size, dim) + new_shape = batch_dims + [n_blocks, block_size, dim] + blocks = keys_padded.reshape(new_shape) + + # Centroid: mean of each block + centroids = jnp.mean(blocks, axis=-2) # (..., n_blocks, dim) + + # Radius: max L2 distance from centroid to any key in block + diffs = blocks - centroids[..., None, :] # (..., n_blocks, block_size, dim) + distances = jnp.linalg.norm(diffs, axis=-1) # (..., n_blocks, block_size) + radii = jnp.max(distances, axis=-1) # (..., n_blocks) + + return centroids, radii + + +def geometric_upper_bound( + query: jax.Array, + centroid: jax.Array, + radius: jax.Array, +) -> jax.Array: + """Compute Cauchy-Schwarz upper bound on maximum attention scores. + + For any key k in a block with centroid μ and radius r, the dot product + q · k is bounded by: + + q · k ≤ q · μ + ||q|| × r + + This bound is tight and mathematically guaranteed. If this upper bound + is below the sparsity threshold, we can safely skip the entire block. + + Args: + query: Query tensor, shape (..., dim). + centroid: Block centroid, shape (..., dim). + radius: Block radius (scalar per block), shape (...). + + Returns: + Upper bound on max score for any key in the block, shape (...). + """ + # q · μ (dot product with centroid) + centroid_score = jnp.sum(query * centroid, axis=-1) + + # ||q|| (query norm) + query_norm = jnp.linalg.norm(query, axis=-1) + + # Upper bound: q · μ + ||q|| × r + return centroid_score + query_norm * radius + + +def _compute_block_mask( + query: jax.Array, + centroids: jax.Array, + radii: jax.Array, + threshold: float | jax.Array, + block_size: int, + seq_len: int, +) -> jax.Array: + """Compute which blocks to include based on geometric upper bound. + + Args: + query: Query tensor, shape (..., seq, dim). + centroids: Block centroids, shape (..., n_blocks, dim). + radii: Block radii, shape (..., n_blocks). + threshold: Sparsity threshold (blocks with upper bound < τ are skipped). + block_size: Number of tokens per block. + seq_len: Original sequence length. + + Returns: + Block mask, shape (..., seq, n_blocks), True where blocks should + be included in attention. + """ + # query: (..., seq, dim) -> (..., seq, 1, dim) for broadcasting + q_expanded = query[..., :, None, :] # (..., seq, 1, dim) + + # centroids: (..., n_blocks, dim) -> (..., 1, n_blocks, dim) + c_expanded = centroids[..., None, :, :] # (..., 1, n_blocks, dim) + + # radii: (..., n_blocks) -> (..., 1, n_blocks) + r_expanded = radii[..., None, :] # (..., 1, n_blocks) + + # Compute upper bounds: (..., seq, n_blocks) + centroid_scores = jnp.sum(q_expanded * c_expanded, axis=-1) + query_norms = jnp.linalg.norm(query, axis=-1, keepdims=True) # (..., seq, 1) + upper_bounds = centroid_scores + query_norms * r_expanded + + # Blocks with upper bound >= threshold are included + block_mask = upper_bounds >= threshold + + return block_mask + + +def _expand_block_mask( + block_mask: jax.Array, + block_size: int, + seq_len: int, +) -> jax.Array: + """Expand block-level mask to token-level mask. + + Args: + block_mask: Block mask, shape (..., seq, n_blocks). + block_size: Number of tokens per block. + seq_len: Original sequence length (for truncation). + + Returns: + Token-level mask, shape (..., seq, seq_len). + """ + *batch_dims, query_len, n_blocks = block_mask.shape + + # Expand each block to block_size tokens + # (..., seq, n_blocks) -> (..., seq, n_blocks, 1) -> (..., seq, n_blocks, block_size) + expanded = jnp.broadcast_to( + block_mask[..., None], + (*batch_dims, query_len, n_blocks, block_size) + ) + + # Flatten to (..., seq, n_blocks * block_size) + token_mask = expanded.reshape(*batch_dims, query_len, n_blocks * block_size) + + # Truncate to original sequence length + token_mask = token_mask[..., :seq_len] + + return token_mask + + +def _phi_rotation_update( + epsilon: jax.Array, + phi: jax.Array, + sparsity: jax.Array, + target_sparsity: float = 0.3, + learning_rate: float = 0.01, +) -> tuple[jax.Array, jax.Array]: + """Update adaptive threshold using phi-rotation dynamics. + + This implements AETHER's self-tuning mechanism, adjusting the sparsity + threshold based on observed sparsity levels. + + Args: + epsilon: Current threshold value. + phi: Current phase parameter. + sparsity: Observed sparsity (fraction of blocks pruned). + target_sparsity: Target sparsity level (default 0.3 = 70% blocks kept). + learning_rate: Step size for threshold adjustment. + + Returns: + Tuple of (new_epsilon, new_phi). + """ + # Error signal: positive if too dense, negative if too sparse + error = target_sparsity - sparsity + + # Rotation-based update (smooth, bounded dynamics) + new_phi = phi + learning_rate * error + new_phi = jnp.clip(new_phi, -jnp.pi, jnp.pi) + + # Threshold adjustment via sine of phase + delta = learning_rate * jnp.sin(new_phi) + new_epsilon = jnp.clip(epsilon + delta, 0.01, 0.9) + + return new_epsilon, new_phi + + +def geometric_sparse_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + block_size: int = 64, + threshold: float | jax.Array = 0.15, + causal_mask: jax.Array | None = None, + scale: float | None = None, +) -> tuple[jax.Array, jax.Array]: + """Compute geometric sparse attention. + + This function computes attention using geometric block scoring to + identify and skip irrelevant key/value blocks. + + Args: + query: Query tensor, shape (..., q_seq, dim). + key: Key tensor, shape (..., kv_seq, dim). + value: Value tensor, shape (..., kv_seq, v_dim). + block_size: Number of keys per block for geometry computation. + threshold: Sparsity threshold. Blocks with geometric upper bound + below this are masked out. Set to 0 for exact attention. + causal_mask: Optional causal or other attention mask, + shape (..., q_seq, kv_seq), True where attention is allowed. + scale: Optional scale factor for attention scores. + Defaults to 1/sqrt(dim). + + Returns: + Tuple of (output, sparsity): + - output: Attention output, shape (..., q_seq, v_dim) + - sparsity: Fraction of blocks that were pruned + """ + # Handle 2D inputs by adding batch dimension + original_ndim = query.ndim + if original_ndim == 2: + query = query[None, :, :] + key = key[None, :, :] + value = value[None, :, :] + + *batch_dims, q_seq, dim = query.shape + kv_seq = key.shape[-2] + v_dim = value.shape[-1] + + if scale is None: + scale = 1.0 / jnp.sqrt(dim) + + # 1. Compute block geometry + centroids, radii = compute_block_geometry(key, block_size) + + # 2. Compute geometric block mask + block_mask = _compute_block_mask( + query, centroids, radii, threshold, block_size, kv_seq + ) + + # 3. Expand to token-level mask + geo_mask = _expand_block_mask(block_mask, block_size, kv_seq) + + # 4. Compute sparsity (fraction pruned) + sparsity = 1.0 - jnp.mean(geo_mask.astype(jnp.float32)) + + # 5. Compute attention scores + scores = jnp.einsum('...qd,...kd->...qk', query, key) * scale + + # 6. Apply geometric mask + masked_scores = jnp.where(geo_mask, scores, -1e9) + + # 7. Apply causal mask if provided + if causal_mask is not None: + masked_scores = jnp.where(causal_mask, masked_scores, -1e9) + + # 8. Softmax and weighted sum + attn_weights = jax.nn.softmax(masked_scores, axis=-1) + output = jnp.einsum('...qk,...kv->...qv', attn_weights, value) + + # Remove batch dimension if input was 2D + if original_ndim == 2: + output = output[0] + + return output, sparsity + + +@struct.pytree_dataclass +class GeometricSparseAttention(layer_base.Layer): + """Geometric Sparse Attention layer using AETHER block scoring. + + This layer is a drop-in replacement for standard attention that uses + geometric upper bounds to prune irrelevant key/value blocks, achieving + O(n × sparse) complexity instead of O(n²). + + The mathematical guarantee is based on the Cauchy-Schwarz inequality: + for any query q and block with centroid μ and radius r, + + max(q · k for k in block) ≤ q · μ + ||q|| × r + + Blocks where this upper bound is below the threshold are safely skipped. + + Attributes: + block_size: Number of tokens per sparsity block. + threshold: Fixed sparsity threshold (used when adaptive=False). + adaptive: Whether to self-tune the threshold at runtime. + target_sparsity: Target sparsity level for adaptive mode. + epsilon: Mutable state for adaptive threshold value. + phi: Mutable state for phase parameter in adaptive updates. + + Example:: + + # Create with config + attn = GeometricSparseAttention.from_config( + block_size=64, + adaptive=True + ) + + # Forward pass: input is (query, key, value) tuple + output = attn((query, key, value)) + + # Convert from existing Attention layer + geo_attn = GeometricSparseAttention.from_attention(existing_attn) + """ + + block_size: int = dataclasses.field( + default=64, metadata={"pytree_node": False} + ) + threshold: float = dataclasses.field( + default=0.15, metadata={"pytree_node": False} + ) + adaptive: bool = dataclasses.field( + default=True, metadata={"pytree_node": False} + ) + target_sparsity: float = dataclasses.field( + default=0.3, metadata={"pytree_node": False} + ) + + # Mutable state for adaptive threshold + epsilon: variables.StateVariable[jax.Array] | None = None + phi: variables.StateVariable[jax.Array] | None = None + + @classmethod + def from_config( + cls, + block_size: int = 64, + threshold: float = 0.15, + adaptive: bool = True, + target_sparsity: float = 0.3, + ) -> GeometricSparseAttention: + """Factory method to create a GeometricSparseAttention layer. + + Args: + block_size: Number of tokens per sparsity block. + threshold: Initial sparsity threshold. + adaptive: Whether to enable adaptive threshold tuning. + target_sparsity: Target sparsity for adaptive mode. + + Returns: + Configured GeometricSparseAttention instance. + """ + if adaptive: + epsilon = variables.StateVariable( + value=jnp.array(threshold, dtype=jnp.float32), + label="aether_epsilon", + ) + phi = variables.StateVariable( + value=jnp.array(0.0, dtype=jnp.float32), + label="aether_phi", + ) + else: + epsilon = None + phi = None + + return cls( + block_size=block_size, + threshold=threshold, + adaptive=adaptive, + target_sparsity=target_sparsity, + epsilon=epsilon, + phi=phi, + ) + + @classmethod + def from_attention( + cls, + attn: Any, + block_size: int = 64, + adaptive: bool = True, + ) -> GeometricSparseAttention: + """Convert an existing Attention layer to GeometricSparseAttention. + + This is a convenience method for use with Penzai selectors:: + + model = pz.select(model).at_instances_of( + pz.nn.Attention + ).apply( + GeometricSparseAttention.from_attention + ) + + Args: + attn: The original Attention layer (ignored, just for API compat). + block_size: Number of tokens per sparsity block. + adaptive: Whether to enable adaptive threshold tuning. + + Returns: + A new GeometricSparseAttention instance. + """ + return cls.from_config(block_size=block_size, adaptive=adaptive) + + def __call__( + self, + argument: tuple[ + named_axes.NamedArray, + named_axes.NamedArray, + named_axes.NamedArray, + ], + **side_inputs: Any, + ) -> named_axes.NamedArray: + """Compute geometric sparse attention. + + Args: + argument: Tuple of (query, key, value) NamedArrays. + **side_inputs: Side inputs (e.g., attention mask). Currently + supports "attention_mask" for explicit masking. + + Returns: + Attention output as a NamedArray. + """ + query, key, value = argument + + # Get raw arrays for computation + # We use nmap to handle named axes properly + def _geo_attn_impl(q, k, v): + # Determine current threshold + if self.adaptive and self.epsilon is not None: + current_threshold = self.epsilon.value + else: + current_threshold = self.threshold + + # Compute geometric sparse attention + output, sparsity = geometric_sparse_attention( + q, k, v, + block_size=self.block_size, + threshold=current_threshold, + ) + + # Update adaptive threshold + if self.adaptive and self.epsilon is not None and self.phi is not None: + new_epsilon, new_phi = _phi_rotation_update( + self.epsilon.value, + self.phi.value, + sparsity, + self.target_sparsity, + ) + self.epsilon.value = new_epsilon + self.phi.value = new_phi + + return output + + # Apply using nmap for named axis compatibility + output = named_axes.nmap(_geo_attn_impl)(query, key, value) + + return output + + def treescope_color(self): + """Custom color for Treescope visualization.""" + return "oklch(0.75 0.15 280 / 1.0)" # Purple for geometric attention diff --git a/penzai/pz/nn.py b/penzai/pz/nn.py index fcd79fc..6d70857 100644 --- a/penzai/pz/nn.py +++ b/penzai/pz/nn.py @@ -23,6 +23,12 @@ Attention, KVCachingAttention, ) +from penzai.nn.geometric_attention import ( + GeometricSparseAttention, + compute_block_geometry, + geometric_upper_bound, + geometric_sparse_attention, +) from penzai.nn.basic_ops import ( CastToDType, Elementwise, diff --git a/tests/nn/geometric_attention_test.py b/tests/nn/geometric_attention_test.py new file mode 100644 index 0000000..5e5665e --- /dev/null +++ b/tests/nn/geometric_attention_test.py @@ -0,0 +1,299 @@ +# Copyright 2024 The Penzai Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for geometric sparse attention.""" + +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np +from penzai import pz +from penzai.nn.geometric_attention import ( + GeometricSparseAttention, + compute_block_geometry, + geometric_upper_bound, + geometric_sparse_attention, +) + + +def _standard_attention(query, key, value, scale=None): + """Reference implementation of standard dense attention.""" + dim = query.shape[-1] + if scale is None: + scale = 1.0 / jnp.sqrt(dim) + scores = jnp.einsum('...qd,...kd->...qk', query, key) * scale + weights = jax.nn.softmax(scores, axis=-1) + return jnp.einsum('...qk,...kv->...qv', weights, value) + + +class ComputeBlockGeometryTest(absltest.TestCase): + + def test_centroid_computation(self): + """Verify centroids are correctly computed as block means.""" + # Create simple keys where blocks have known means + keys = jnp.array([ + [1.0, 0.0], # Block 0 + [3.0, 0.0], # Block 0 + [5.0, 2.0], # Block 1 + [5.0, 4.0], # Block 1 + ]) + centroids, radii = compute_block_geometry(keys, block_size=2) + + expected_centroids = jnp.array([ + [2.0, 0.0], # Mean of [1,0] and [3,0] + [5.0, 3.0], # Mean of [5,2] and [5,4] + ]) + np.testing.assert_allclose(centroids, expected_centroids, rtol=1e-5) + + def test_radius_computation(self): + """Verify radii are max L2 distance from centroid.""" + keys = jnp.array([ + [0.0, 0.0], + [2.0, 0.0], # Distance 1.0 from centroid [1,0] + ]) + centroids, radii = compute_block_geometry(keys, block_size=2) + + # Centroid = [1, 0], both keys are distance 1.0 from it + expected_radius = 1.0 + np.testing.assert_allclose(radii[0], expected_radius, rtol=1e-5) + + def test_batch_dimensions(self): + """Verify block geometry works with batch dimensions.""" + batch_size = 3 + seq_len = 8 + dim = 4 + block_size = 2 + + keys = jax.random.normal(jax.random.key(0), (batch_size, seq_len, dim)) + centroids, radii = compute_block_geometry(keys, block_size=block_size) + + expected_n_blocks = seq_len // block_size + self.assertEqual(centroids.shape, (batch_size, expected_n_blocks, dim)) + self.assertEqual(radii.shape, (batch_size, expected_n_blocks)) + + +class GeometricUpperBoundTest(absltest.TestCase): + + def test_upper_bound_never_violated(self): + """Verify Cauchy-Schwarz bound is never violated.""" + key = jax.random.key(42) + for _ in range(20): + key, q_key, b_key = jax.random.split(key, 3) + + # Random query + query = jax.random.normal(q_key, (16,)) + + # Random block of keys + block = jax.random.normal(b_key, (64, 16)) + centroid = jnp.mean(block, axis=0) + diffs = block - centroid + radius = jnp.max(jnp.linalg.norm(diffs, axis=-1)) + + # Compute upper bound + upper_bound = geometric_upper_bound(query, centroid, radius) + + # Compute actual max score + actual_scores = jnp.dot(block, query) + actual_max = jnp.max(actual_scores) + + # Upper bound must never be less than actual max + self.assertGreaterEqual( + float(upper_bound) + 1e-5, # Small tolerance + float(actual_max), + "Cauchy-Schwarz upper bound was violated!" + ) + + +class GeometricSparseAttentionFunctionTest(absltest.TestCase): + + def test_exact_match_with_zero_threshold(self): + """With threshold=0, geometric attention matches dense attention.""" + key = jax.random.key(123) + q_key, k_key, v_key = jax.random.split(key, 3) + + query = jax.random.normal(q_key, (4, 32, 64)) + keys = jax.random.normal(k_key, (4, 128, 64)) + values = jax.random.normal(v_key, (4, 128, 64)) + + # Standard attention + expected = _standard_attention(query, keys, values) + + # Geometric attention with threshold=0 (no pruning) + actual, sparsity = geometric_sparse_attention( + query, keys, values, + block_size=16, + threshold=0.0, + ) + + np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-5) + self.assertEqual(sparsity, 0.0) # No blocks pruned + + def test_sparsity_on_clustered_data(self): + """Verify sparsity is achieved on clustered data.""" + # Create highly clustered keys - attention should be sparse + key = jax.random.key(456) + + # Queries in cluster 0 region + query = jnp.ones((1, 4, 16)) * 10.0 + + # Keys: cluster 0 (high values) and cluster 1 (low values) + cluster_0 = jnp.ones((1, 64, 16)) * 10.0 # Similar to queries + cluster_1 = jnp.ones((1, 64, 16)) * -10.0 # Very different + + keys = jnp.concatenate([cluster_0, cluster_1], axis=1) + values = jnp.ones_like(keys) + + _, sparsity = geometric_sparse_attention( + query, keys, values, + block_size=16, + threshold=0.5, + ) + + # Should prune some blocks (cluster 1 is far from queries) + self.assertGreater(float(sparsity), 0.1) + + +class GeometricSparseAttentionLayerTest(absltest.TestCase): + + def test_from_config_creates_state(self): + """Verify from_config initializes adaptive state correctly.""" + attn = GeometricSparseAttention.from_config( + block_size=32, + threshold=0.2, + adaptive=True, + ) + + self.assertEqual(attn.block_size, 32) + self.assertTrue(attn.adaptive) + self.assertIsNotNone(attn.epsilon) + self.assertIsNotNone(attn.phi) + np.testing.assert_allclose(attn.epsilon.value, 0.2, rtol=1e-5) + np.testing.assert_allclose(attn.phi.value, 0.0, rtol=1e-5) + + def test_from_config_no_state_when_not_adaptive(self): + """Verify non-adaptive mode has no state variables.""" + attn = GeometricSparseAttention.from_config( + block_size=64, + adaptive=False, + ) + + self.assertFalse(attn.adaptive) + self.assertIsNone(attn.epsilon) + self.assertIsNone(attn.phi) + + def test_layer_is_pytree(self): + """Verify the layer is a valid JAX pytree.""" + attn = GeometricSparseAttention.from_config(adaptive=True) + + # Flatten and unflatten should work + leaves, treedef = jax.tree_util.tree_flatten(attn) + reconstructed = jax.tree_util.tree_unflatten(treedef, leaves) + + self.assertEqual(reconstructed.block_size, attn.block_size) + + def test_jit_compatibility(self): + """Verify the layer works under jax.jit.""" + attn = GeometricSparseAttention.from_config( + block_size=16, + adaptive=False, # Non-adaptive for pure JIT + ) + + @jax.jit + def forward(q, k, v): + return attn((q, k, v)) + + key = jax.random.key(789) + q = jax.random.normal(key, (2, 16, 32)) + k = jax.random.normal(key, (2, 64, 32)) + v = jax.random.normal(key, (2, 64, 32)) + + # Should compile and run without error + output = forward(q, k, v) + # Handle NamedArrayView output + if hasattr(output, 'data_array'): + out_shape = output.data_array.shape + else: + out_shape = output.shape + self.assertEqual(out_shape, (2, 16, 32)) + + def test_vmap_compatibility(self): + """Verify the layer vectorizes correctly.""" + attn = GeometricSparseAttention.from_config( + block_size=8, + adaptive=False, + ) + + def forward(qkv): + q, k, v = qkv + return attn((q, k, v)) + + key = jax.random.key(101) + batch_size = 4 + q = jax.random.normal(key, (batch_size, 8, 16)) + k = jax.random.normal(key, (batch_size, 32, 16)) + v = jax.random.normal(key, (batch_size, 32, 16)) + + # vmap over the batch dimension + batched_forward = jax.vmap(forward) + outputs = batched_forward((q, k, v)) + + # Handle NamedArrayView output + if hasattr(outputs, 'data_array'): + out_shape = outputs.data_array.shape + else: + out_shape = outputs.shape + self.assertEqual(out_shape, (batch_size, 8, 16)) + + def test_treescope_color(self): + """Verify custom treescope color is defined.""" + attn = GeometricSparseAttention.from_config() + color = attn.treescope_color() + self.assertIsInstance(color, str) + self.assertIn("oklch", color) + + +class AdaptiveThresholdTest(absltest.TestCase): + + def test_adaptive_updates_state(self): + """Verify adaptive mode updates epsilon and phi.""" + attn = GeometricSparseAttention.from_config( + block_size=16, + threshold=0.5, + adaptive=True, + ) + + initial_epsilon = float(attn.epsilon.value) + initial_phi = float(attn.phi.value) + + # Run forward pass + key = jax.random.key(999) + q = jax.random.normal(key, (8, 32)) + k = jax.random.normal(key, (64, 32)) + v = jax.random.normal(key, (64, 32)) + + _ = attn((q, k, v)) + + # State should have changed + new_epsilon = float(attn.epsilon.value) + new_phi = float(attn.phi.value) + + # At least one should change (unless sparsity exactly matches target) + self.assertTrue( + new_epsilon != initial_epsilon or new_phi != initial_phi, + "Adaptive state should update after forward pass" + ) + + +if __name__ == "__main__": + absltest.main()