Skip to content

Commit 9bf6e09

Browse files
committed
Update
[ghstack-poisoned]
1 parent 11d8d07 commit 9bf6e09

1 file changed

Lines changed: 15 additions & 0 deletions

File tree

examples/models/eagle3/eager_reference.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,23 @@ def speculative_decode(draft, target, prompt_ids, num_gen, chain_len):
115115
accepted = [0] * chain_len
116116
accept_lengths = []
117117

118+
# This reference recomputes the whole sequence each round through the
119+
# stateful gemma target, whose sliding layers assert positions fit one ring
120+
# (2*sliding_window). It is a short-prompt correctness reference, not a
121+
# long-context path, so fail early with a clear message instead of letting
122+
# the RingKVCache assertion fire mid-run.
123+
max_ctx = 2 * target.model.config.sliding_window
124+
118125
while len(emitted) < num_gen:
119126
L = len(seq)
127+
if L + chain_len > max_ctx:
128+
raise RuntimeError(
129+
f"eager reference is limited to 2*sliding_window={max_ctx} "
130+
f"positions (seq={L} + chain={chain_len} exceeds it); it "
131+
f"recomputes through the stateful RingKVCache and does not "
132+
f"support long context. Use a shorter prompt or smaller "
133+
f"--num-gen."
134+
)
120135
_, taps = target.forward(seq)
121136
proposals = draft_chain(draft, seq, taps, chain_len)
122137

0 commit comments

Comments
 (0)