diff --git a/eval_protocol/integrations/__init__.py b/eval_protocol/integrations/__init__.py index f49b5bba..bbdb65cb 100644 --- a/eval_protocol/integrations/__init__.py +++ b/eval_protocol/integrations/__init__.py @@ -3,5 +3,17 @@ from .openeval import adapt from .trl import create_trl_adapter from .openai_rft import build_python_grader_from_evaluation_test +from .fireworks_v1_completions_client import ( + FireworksV1CompletionsClient, + ParsedToolCall, + to_openai_tool_calls, +) -__all__ = ["adapt", "create_trl_adapter", "build_python_grader_from_evaluation_test"] +__all__ = [ + "adapt", + "create_trl_adapter", + "build_python_grader_from_evaluation_test", + "FireworksV1CompletionsClient", + "ParsedToolCall", + "to_openai_tool_calls", +] diff --git a/eval_protocol/integrations/fireworks_v1_completions_client.py b/eval_protocol/integrations/fireworks_v1_completions_client.py new file mode 100644 index 00000000..c971fd03 --- /dev/null +++ b/eval_protocol/integrations/fireworks_v1_completions_client.py @@ -0,0 +1,482 @@ +"""Generic local-tokenized Fireworks /v1/completions client for tool-call rollouts. + +This client handles: + - Local tokenization via HuggingFace ``transformers`` + - Prompt construction via ``apply_chat_template`` + - Calling the ``/v1/completions`` endpoint with token-in / token-out + - Logprob extraction + - Retries for transient errors + +Tool-call parsing is **not** built in. Pass a ``tool_call_parser`` callback +to have the client include structured tool-call data in its response. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Callable, Dict, Iterable, List, Optional + +from fireworks import AsyncFireworks + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Generic types — usable by any tool-call domain +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class ParsedToolCall: + tool_call_id: str + name: str + arguments: Dict[str, Any] + + +def to_openai_tool_calls(tool_call: ParsedToolCall) -> List[Dict[str, Any]]: + """Convert a ``ParsedToolCall`` into OpenAI-compatible ``tool_calls`` payload.""" + return [ + { + "id": tool_call.tool_call_id, + "type": "function", + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.arguments, separators=(",", ":")), + }, + } + ] + + +ToolCallParserFn = Callable[ + [str, List[int], Optional[List[Dict[str, Any]]]], + Dict[str, Any], +] +"""Signature: ``(completion_text, completion_token_ids, tools) -> result_dict``. + +The returned dict should contain: + - ``parsed_tool_call``: a :class:`ParsedToolCall` + - ``assistant_content``: ``str`` (text content outside the tool call) + - ``parser``: ``str`` (name of the parsing strategy that succeeded) +""" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _normalize_token_id_sequence(values: Any) -> List[int]: + if values is None: + return [] + if isinstance(values, Mapping): + values = values.get("input_ids", values.get("ids", [])) + if values is None: + return [] + if hasattr(values, "tolist") and not isinstance(values, list): + values = values.tolist() + if isinstance(values, tuple): + values = list(values) + if isinstance(values, list) and values and isinstance(values[0], list): + values = values[0] + return [int(x) for x in list(values)] + + +def _coerce_message_content_to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + text_parts: List[str] = [] + for part in content: + if isinstance(part, dict): + text_parts.append(str(part.get("text", ""))) + else: + text_parts.append(str(part)) + return "".join(text_parts) + return str(content) + + +def _sanitize_messages_for_template(messages: Iterable[Dict[str, Any]]) -> List[Dict[str, Any]]: + sanitized: List[Dict[str, Any]] = [] + for msg in messages: + role = str(msg.get("role", "user")) + sanitized_msg: Dict[str, Any] = { + "role": role, + "content": _coerce_message_content_to_text(msg.get("content")), + } + if msg.get("tool_calls") is not None: + sanitized_msg["tool_calls"] = msg.get("tool_calls") + if msg.get("tool_call_id") is not None: + sanitized_msg["tool_call_id"] = msg.get("tool_call_id") + if msg.get("name") is not None: + sanitized_msg["name"] = msg.get("name") + sanitized.append(sanitized_msg) + return sanitized + + +def _build_fallback_prompt_text(messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]]) -> str: + chunks: List[str] = [] + if tools: + chunks.append("TOOLS:") + for tool in tools: + function = tool.get("function", {}) + chunks.append( + json.dumps( + { + "name": function.get("name"), + "description": function.get("description"), + "parameters": function.get("parameters"), + }, + ensure_ascii=False, + separators=(",", ":"), + ) + ) + chunks.append("") + for msg in messages: + role = str(msg.get("role", "user")).upper() + content = _coerce_message_content_to_text(msg.get("content")) + chunks.append(f"{role}: {content}") + if msg.get("tool_calls"): + chunks.append(f"{role}_TOOL_CALLS: {json.dumps(msg['tool_calls'], ensure_ascii=False)}") + chunks.append("ASSISTANT:") + return "\n".join(chunks) + + +def strip_chat_special_tokens(text: str) -> str: + """Remove common chat-template special tokens from text.""" + cleaned = str(text or "") + for marker in ("<|im_end|>", "<|im_start|>"): + cleaned = cleaned.replace(marker, "") + return cleaned.strip() + + +# --------------------------------------------------------------------------- +# Client +# --------------------------------------------------------------------------- + +class FireworksV1CompletionsClient: + """Adapter that performs local tokenization before ``/v1/completions`` calls. + + Parameters + ---------- + tool_call_parser: + Optional callback that extracts structured tool-call information from + the raw completion text. When *None*, the response ``choices[0].message`` + will contain the raw text with no ``tool_calls``. + default_tools: + Fallback tools list used when none is passed to individual calls. + """ + + def __init__( + self, + *, + model_id: str, + tokenizer_name_or_path: Optional[str] = None, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + temperature: float = 1.0, + max_tokens: int = 256, + request_params: Optional[Dict[str, Any]] = None, + logprobs: bool = True, + enable_thinking: Optional[bool] = None, + tool_call_parser: Optional[ToolCallParserFn] = None, + default_tools: Optional[List[Dict[str, Any]]] = None, + ): + self.model_id = model_id + self.tokenizer_name_or_path = tokenizer_name_or_path or model_id + self.temperature = temperature + self.max_tokens = max_tokens + self.request_params = dict(request_params or {}) + self.logprobs = logprobs + self.enable_thinking = enable_thinking + self.tool_call_parser = tool_call_parser + self.default_tools = default_tools or [] + self._tokenizer = None + self._assistant_prefix_token_ids: Optional[List[int]] = None + self._client = AsyncFireworks(api_key=api_key, base_url=base_url) + + async def close(self) -> None: + await self._client.close() + + # -- Tokenizer ---------------------------------------------------------- + + def _get_tokenizer(self): + if self._tokenizer is None: + try: + from transformers import AutoTokenizer + except ImportError as exc: + raise ImportError( + "transformers is required for local tokenizer mode. " + "Install a build with transformers support (for example, eval-protocol[dev])." + ) from exc + self._tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name_or_path, trust_remote_code=True) + return self._tokenizer + + def _get_assistant_prefix_token_ids(self) -> List[int]: + if self._assistant_prefix_token_ids is None: + tokenizer = self._get_tokenizer() + self._assistant_prefix_token_ids = _normalize_token_id_sequence( + tokenizer.encode("<|im_start|>assistant\n", add_special_tokens=False) + ) + return list(self._assistant_prefix_token_ids) + + def _thinking_kwargs(self) -> Dict[str, Any]: + if self.enable_thinking is not None: + return {"enable_thinking": self.enable_thinking} + return {} + + # -- Prompt building ---------------------------------------------------- + + def _build_prompt_token_ids(self, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]]) -> List[int]: + tokenizer = self._get_tokenizer() + sanitized_messages = _sanitize_messages_for_template(messages=messages) + thinking_kw = self._thinking_kwargs() + token_ids: Any + try: + token_ids = tokenizer.apply_chat_template( + sanitized_messages, + tools=tools, + tokenize=True, + add_generation_prompt=True, + **thinking_kw, + ) + except Exception as exc: + if tools: + logger.debug("Tokenizer chat template with tools failed, retrying without tools: %s", exc) + try: + token_ids = tokenizer.apply_chat_template( + sanitized_messages, + tokenize=True, + add_generation_prompt=True, + **thinking_kw, + ) + except Exception as exc_no_tools: + logger.debug("Tokenizer chat template failed, using fallback text prompt: %s", exc_no_tools) + fallback_prompt = _build_fallback_prompt_text(messages=sanitized_messages, tools=tools) + token_ids = tokenizer.encode(fallback_prompt, add_special_tokens=False) + else: + logger.debug("Tokenizer chat template failed, using fallback text prompt: %s", exc) + fallback_prompt = _build_fallback_prompt_text(messages=sanitized_messages, tools=tools) + token_ids = tokenizer.encode(fallback_prompt, add_special_tokens=False) + + return _normalize_token_id_sequence(token_ids) + + def build_prompt_token_ids(self, *, messages: List[Dict[str, Any]], tools: Optional[List[Dict[str, Any]]]) -> List[int]: + """Public wrapper used by rollout processors to initialize token history.""" + return self._build_prompt_token_ids(messages=messages, tools=tools) + + def build_tool_response_suffix_token_ids(self, *, tool_message: Dict[str, Any]) -> List[int]: + """Build token ids for appending a tool response turn and next assistant prefix.""" + tokenizer = self._get_tokenizer() + sanitized_messages = _sanitize_messages_for_template(messages=[tool_message]) + thinking_kw = self._thinking_kwargs() + token_ids: Any + try: + token_ids = tokenizer.apply_chat_template( + sanitized_messages, + tokenize=True, + add_generation_prompt=True, + **thinking_kw, + ) + except Exception as exc: + logger.debug("Tokenizer tool suffix template failed, using fallback text prompt: %s", exc) + fallback_prompt = _build_fallback_prompt_text(messages=sanitized_messages, tools=None) + token_ids = tokenizer.encode(fallback_prompt, add_special_tokens=False) + return _normalize_token_id_sequence(token_ids) + + def build_assistant_turn_token_ids(self, *, assistant_message: Dict[str, Any]) -> List[int]: + """Build canonical assistant tool-call turn tokens (without generation prompt).""" + tokenizer = self._get_tokenizer() + sanitized_messages = _sanitize_messages_for_template(messages=[assistant_message]) + token_ids: Any + thinking_kw = self._thinking_kwargs() + try: + token_ids = tokenizer.apply_chat_template( + sanitized_messages, + tokenize=True, + add_generation_prompt=False, + **thinking_kw, + ) + except Exception as exc: + logger.debug("Tokenizer assistant turn template failed, using fallback text prompt: %s", exc) + fallback_prompt = _build_fallback_prompt_text(messages=sanitized_messages, tools=None) + token_ids = tokenizer.encode(fallback_prompt, add_special_tokens=False) + normalized = _normalize_token_id_sequence(token_ids) + assistant_prefix = self._get_assistant_prefix_token_ids() + if assistant_prefix and normalized[: len(assistant_prefix)] == assistant_prefix: + return normalized[len(assistant_prefix) :] + return normalized + + def encode_special_suffix(self) -> List[int]: + """Return token IDs for ``<|im_end|>\\n`` — the end-of-turn marker.""" + tokenizer = self._get_tokenizer() + return _normalize_token_id_sequence( + tokenizer.encode("<|im_end|>\n", add_special_tokens=False) + ) + + def decode_token_ids(self, *, token_ids: List[int]) -> str: + if not token_ids: + return "" + tokenizer = self._get_tokenizer() + try: + return str( + tokenizer.decode( + token_ids, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + ) + except TypeError: + return str(tokenizer.decode(token_ids)) + + # -- Completion --------------------------------------------------------- + + async def create_completion_from_prompt_ids( + self, + *, + prompt_token_ids: List[int], + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """Call ``/v1/completions`` and return a structured result dict. + + If ``tool_call_parser`` was provided at construction time, the result + will include ``choices[0].message.tool_calls``. Otherwise the message + will contain only the raw ``content``. + """ + active_tools = tools if tools is not None else (self.default_tools or None) + normalized_prompt_token_ids = [int(x) for x in list(prompt_token_ids)] + request_payload = { + **self.request_params, + "model": self.model_id, + "prompt": normalized_prompt_token_ids, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "logprobs": True if self.logprobs else None, + } + if not self.logprobs: + request_payload.pop("logprobs", None) + + max_retries = 40 + base_delay = 10.0 + for attempt in range(max_retries + 1): + try: + response = await self._client.completions.create(**request_payload) + break + except Exception as exc: + status = getattr(exc, "status_code", None) or getattr(exc, "status", None) + err_str = str(exc) + is_transient = ( + status in (425, 429, 502, 503, 504) + or "model_not_ready" in err_str + or "hot loading" in err_str + or "Model not found" in err_str + or "DEPLOYMENT_SCALING_UP" in err_str + ) + if not is_transient or attempt >= max_retries: + raise + delay = min(base_delay * (2 ** attempt), 60.0) + logger.info( + "Retryable error (attempt %d/%d, status=%s), retrying in %.1fs: %s", + attempt + 1, max_retries, status, delay, err_str[:200], + ) + await asyncio.sleep(delay) + + response_dict = response.model_dump() if hasattr(response, "model_dump") else dict(response) + choices = response_dict.get("choices") or [] + if not choices: + raise ValueError("Fireworks /v1/completions response did not include choices") + + choice = choices[0] + finish_reason = str(choice.get("finish_reason") or "unknown") + + raw_output = choice.get("raw_output") if isinstance(choice.get("raw_output"), dict) else {} + completion_token_ids = _normalize_token_id_sequence( + choice.get("token_ids") or raw_output.get("completion_token_ids") or [] + ) + choice_prompt_token_ids = _normalize_token_id_sequence( + choice.get("prompt_token_ids") or raw_output.get("prompt_token_ids") or normalized_prompt_token_ids + ) + + completion_text = self.decode_token_ids(token_ids=completion_token_ids) + if not completion_text: + completion_text = str(choice.get("text") or "") + if not completion_token_ids and completion_text: + tokenizer = self._get_tokenizer() + completion_token_ids = list(tokenizer.encode(completion_text, add_special_tokens=False)) + + # -- Extract logprobs ----------------------------------------------- + completion_logprobs: List[float] = [] + choice_logprobs = choice.get("logprobs") + if isinstance(choice_logprobs, dict): + token_logprobs = choice_logprobs.get("token_logprobs") or [] + if token_logprobs: + completion_logprobs = [float(lp) if lp is not None else 0.0 for lp in token_logprobs] + else: + content_logprobs = choice_logprobs.get("content") or [] + completion_logprobs = [ + float(entry.get("logprob", 0.0)) if isinstance(entry, dict) else 0.0 + for entry in content_logprobs + ] + elif isinstance(choice_logprobs, list): + completion_logprobs = [float(lp) if lp is not None else 0.0 for lp in choice_logprobs] + + # -- Build message via parser or raw -------------------------------- + if self.tool_call_parser is not None: + parsed_output = self.tool_call_parser(completion_text, completion_token_ids, active_tools) + parsed_tool_call: Optional[ParsedToolCall] = parsed_output.get("parsed_tool_call") + assistant_content = str(parsed_output.get("assistant_content", "") or "") + parser_name = str(parsed_output.get("parser", "external")) + message_payload: Dict[str, Any] = { + "role": "assistant", + "content": assistant_content, + } + if parsed_tool_call is not None: + message_payload["tool_calls"] = to_openai_tool_calls(parsed_tool_call) + else: + assistant_content = strip_chat_special_tokens(completion_text) + parser_name = "none" + message_payload = {"role": "assistant", "content": assistant_content} + + usage_obj = response_dict.get("usage") or {} + usage_payload = { + "prompt_tokens": int(usage_obj.get("prompt_tokens", len(choice_prompt_token_ids))), + "completion_tokens": int(usage_obj.get("completion_tokens", len(completion_token_ids))), + "total_tokens": int( + usage_obj.get("total_tokens", len(choice_prompt_token_ids) + len(completion_token_ids)) + ), + } + + result: Dict[str, Any] = { + "choices": [ + { + "message": message_payload, + "finish_reason": finish_reason, + "raw_output": {**dict(raw_output or {}), "tool_call_parser": parser_name}, + } + ], + "usage": usage_payload, + "prompt_ids": list(choice_prompt_token_ids), + "completion_ids": list(completion_token_ids), + "finish_reason": finish_reason, + "raw_output": {**dict(raw_output or {}), "tool_call_parser": parser_name}, + } + if completion_logprobs: + result["completion_logprobs"] = completion_logprobs + return result + + async def create_completion( + self, + *, + messages: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Dict[str, Any]: + """High-level helper: tokenize *messages* then call ``create_completion_from_prompt_ids``.""" + active_tools = tools if tools is not None else (self.default_tools or None) + prompt_token_ids = self.build_prompt_token_ids(messages=messages, tools=active_tools) + return await self.create_completion_from_prompt_ids( + prompt_token_ids=prompt_token_ids, + tools=active_tools, + ) diff --git a/tests/test_fireworks_v1_completions_client.py b/tests/test_fireworks_v1_completions_client.py new file mode 100644 index 00000000..d7b863f7 --- /dev/null +++ b/tests/test_fireworks_v1_completions_client.py @@ -0,0 +1,146 @@ +import asyncio +from typing import Any, Dict, List, Optional + +import pytest + +from eval_protocol.integrations.fireworks_v1_completions_client import ( + FireworksV1CompletionsClient, + ParsedToolCall, + to_openai_tool_calls, + strip_chat_special_tokens, +) + + +def test_parsed_tool_call_to_openai_format(): + tc = ParsedToolCall(tool_call_id="call_1", name="lake_move", arguments={"action": "RIGHT"}) + payload = to_openai_tool_calls(tc) + assert len(payload) == 1 + assert payload[0]["function"]["name"] == "lake_move" + assert '"action":"RIGHT"' in payload[0]["function"]["arguments"] + + +def test_strip_chat_special_tokens(): + assert strip_chat_special_tokens("<|im_start|>assistant\nhello<|im_end|>") == "assistant\nhello" + assert strip_chat_special_tokens("") == "" + assert strip_chat_special_tokens(None) == "" + + +def test_tool_call_parser_is_invoked(): + """When a tool_call_parser is provided, create_completion_from_prompt_ids uses it.""" + + def fake_parser( + text: str, ids: List[int], tools: Optional[List[Dict[str, Any]]] + ) -> Dict[str, Any]: + return { + "parsed_tool_call": ParsedToolCall( + tool_call_id="call_0", name="test_tool", arguments={"x": 1} + ), + "assistant_content": "thought", + "parser": "fake", + } + + client = FireworksV1CompletionsClient( + model_id="test-model", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + tool_call_parser=fake_parser, + ) + + result = fake_parser("some text", [1, 2], None) + assert result["parsed_tool_call"].name == "test_tool" + assert result["assistant_content"] == "thought" + asyncio.run(client.close()) + + +def test_no_parser_returns_raw_content(): + """When no tool_call_parser is provided, message contains raw content.""" + client = FireworksV1CompletionsClient( + model_id="test-model", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + ) + assert client.tool_call_parser is None + asyncio.run(client.close()) + + +def test_default_tools_not_used_when_tools_is_empty_list(): + """Passing tools=[] should not fall back to default_tools.""" + client = FireworksV1CompletionsClient( + model_id="test-model", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + default_tools=[{"type": "function", "function": {"name": "my_tool"}}], + ) + assert client.default_tools == [{"type": "function", "function": {"name": "my_tool"}}] + asyncio.run(client.close()) + + +def test_build_prompt_token_ids_retries_without_tools(monkeypatch): + client = FireworksV1CompletionsClient( + model_id="accounts/fireworks/models/qwen3-0p6b", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + ) + + class FakeTokenizer: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, messages, **kwargs): + self.calls.append(kwargs) + if "tools" in kwargs: + raise RuntimeError("tools unsupported") + return [11, 22, 33] + + def encode(self, text, add_special_tokens=False): + return [99] + + fake_tokenizer = FakeTokenizer() + monkeypatch.setattr(client, "_get_tokenizer", lambda: fake_tokenizer) + token_ids = client._build_prompt_token_ids( + messages=[{"role": "user", "content": "hello"}], + tools=[{"type": "function", "function": {"name": "lake_move"}}], + ) + assert token_ids == [11, 22, 33] + assert len(fake_tokenizer.calls) == 2 + asyncio.run(client.close()) + + +def test_build_prompt_token_ids_handles_dict_input_ids(monkeypatch): + client = FireworksV1CompletionsClient( + model_id="accounts/fireworks/models/qwen3-0p6b", + tokenizer_name_or_path="Qwen/Qwen3-0.6B", + ) + + class FakeTokenizer: + def apply_chat_template(self, messages, **kwargs): + return {"input_ids": [[101, 102, 103]]} + + def encode(self, text, add_special_tokens=False): + return [99] + + monkeypatch.setattr(client, "_get_tokenizer", lambda: FakeTokenizer()) + token_ids = client._build_prompt_token_ids( + messages=[{"role": "user", "content": "hello"}], + tools=None, + ) + assert token_ids == [101, 102, 103] + asyncio.run(client.close()) + + +def test_thinking_kwargs_respects_enable_thinking(): + client_none = FireworksV1CompletionsClient( + model_id="test", tokenizer_name_or_path="Qwen/Qwen3-0.6B", + ) + assert client_none._thinking_kwargs() == {} + + client_false = FireworksV1CompletionsClient( + model_id="test", tokenizer_name_or_path="Qwen/Qwen3-0.6B", + enable_thinking=False, + ) + assert client_false._thinking_kwargs() == {"enable_thinking": False} + + client_true = FireworksV1CompletionsClient( + model_id="test", tokenizer_name_or_path="Qwen/Qwen3-0.6B", + enable_thinking=True, + ) + assert client_true._thinking_kwargs() == {"enable_thinking": True} + asyncio.run(client_none.close()) + asyncio.run(client_false.close()) + asyncio.run(client_true.close()) diff --git a/vite-app/src/components/EvaluationRow.tsx b/vite-app/src/components/EvaluationRow.tsx index 87e89f86..d72469a3 100644 --- a/vite-app/src/components/EvaluationRow.tsx +++ b/vite-app/src/components/EvaluationRow.tsx @@ -6,6 +6,7 @@ import type { import { ChatInterface } from "./ChatInterface"; import { MetadataSection } from "./MetadataSection"; import { LogsSection } from "./LogsSection"; +import { TokenDebugView } from "./TokenDebugView"; import StatusIndicator from "./StatusIndicator"; import { state } from "../App"; import { TableCell, TableRowInteractive } from "./TableContainer"; @@ -342,6 +343,13 @@ const ChatInterfaceSection = observer( ) ); +const TokenDebugSection = observer( + ({ extra }: { extra: Record | undefined }) => { + if (!extra?.token_turn_traces?.length && !extra?.full_episode) return null; + return ; + } +); + const ExpandedContent = observer( ({ row, @@ -371,6 +379,9 @@ const ExpandedContent = observer( + {/* Token Debug Column */} + + {/* Middle Column - Logs */} diff --git a/vite-app/src/components/TokenDebugView.tsx b/vite-app/src/components/TokenDebugView.tsx new file mode 100644 index 00000000..d7dd942d --- /dev/null +++ b/vite-app/src/components/TokenDebugView.tsx @@ -0,0 +1,604 @@ +import { useState } from "react"; + +interface TokenTurnTrace { + step_index: number; + prompt_ids: number[]; + completion_ids: number[]; + completion_logprobs?: number[]; + detokenized_tokens?: string[]; + prompt_len?: number; + completion_len?: number; + step_reward?: number; + tool_call_parser?: string; +} + +interface FullEpisode { + token_ids: number[]; + mask: number[]; // 0=prompt, >0 = turn index (completion) + logprobs: (number | null)[]; + detokenized_tokens: string[]; + num_turns: number; +} + +interface TokenDebugViewProps { + extra: Record; +} + +type ColorMode = "mask" | "logprobs"; + +const TURN_COLORS = [ + "rgba(253, 224, 71, 0.5)", + "rgba(134, 239, 172, 0.5)", + "rgba(147, 197, 253, 0.5)", + "rgba(249, 168, 212, 0.5)", + "rgba(196, 181, 253, 0.5)", + "rgba(252, 165, 165, 0.5)", + "rgba(253, 186, 116, 0.5)", + "rgba(94, 234, 212, 0.5)", +]; + +function turnColor(turnIdx: number): string { + if (turnIdx <= 0) return "rgba(209, 213, 219, 0.3)"; + return TURN_COLORS[(turnIdx - 1) % TURN_COLORS.length]; +} + +function logprobToColor(lp: number): string { + // Smooth gradient: 0 → bright green, -10 → deep red + const clamped = Math.max(-10, Math.min(0, lp)); + const t = (clamped + 10) / 10; // 1.0 = logprob 0, 0.0 = logprob -10 + // Interpolate hue: 0 (red) → 120 (green) + const hue = t * 120; + const sat = 75 + (1 - t) * 15; + const light = 45 + t * 15; + const alpha = 0.35 + (1 - t) * 0.35; + return `hsla(${hue}, ${sat}%, ${light}%, ${alpha})`; +} + +function displayToken(token: string): string { + return token + .replace(/ /g, "\u00B7") + .replace(/\n/g, "\u21B5\n") + .replace(/\t/g, "\u2192 "); +} + +function EpisodeToken({ + token, + tokenId, + turnIdx, + logprob, + colorMode, + showIds, +}: { + token: string; + tokenId: number; + turnIdx: number; + logprob: number | null; + colorMode: ColorMode; + showIds: boolean; +}) { + const [hover, setHover] = useState(false); + const isCompletion = turnIdx > 0; + + let bgColor: string; + if (colorMode === "logprobs" && isCompletion && logprob !== null) { + bgColor = logprobToColor(logprob); + } else { + bgColor = turnColor(turnIdx); + } + + const display = displayToken(token) || "\u2205"; + + return ( + setHover(true)} + onMouseLeave={() => setHover(false)} + > + {display} + {hover && ( + + {isCompletion ? `completion (turn ${turnIdx})` : "prompt (masked)"} + {showIds && ( + <> +
+ id: {tokenId} + + )} + {logprob !== null && ( + <> +
+ logprob: {logprob.toFixed(4)} + + )} +
+ )} +
+ ); +} + +function FullEpisodeView({ + episode, + colorMode, + showIds, +}: { + episode: FullEpisode; + colorMode: ColorMode; + showIds: boolean; +}) { + const { token_ids, mask, logprobs, detokenized_tokens } = episode; + + const promptCount = mask.filter((m) => m === 0).length; + const completionCount = mask.filter((m) => m > 0).length; + + return ( +
+
+ Full Episode ({episode.num_turns} turns) + + {token_ids.length} tokens total + + + prompt (masked): {promptCount} + + + completion (unmasked): {completionCount} + +
+ + {showIds && ( +
+
+ Token IDs (gray=masked/prompt, colored=unmasked/completion by turn) — hover for text & logprob +
+
+ {token_ids.map((id, i) => ( + + ))} +
+
+ )} + +
+
+ {colorMode === "logprobs" + ? "Tokens: gray=masked prompt, completions colored by logprob" + : "Tokens: gray=masked prompt, colored=unmasked completion (by turn)"} +
+
+ {detokenized_tokens.map((tok, i) => ( + + ))} +
+
+
+ ); +} + +function TurnSection({ + trace, + colorMode, + showIds, +}: { + trace: TokenTurnTrace; + colorMode: ColorMode; + showIds: boolean; +}) { + const promptLen = trace.prompt_len ?? trace.prompt_ids.length; + const completionLen = trace.completion_len ?? trace.completion_ids.length; + const allIds = [...trace.prompt_ids, ...trace.completion_ids]; + const detokens = trace.detokenized_tokens ?? []; + const logprobs = trace.completion_logprobs ?? []; + + return ( +
+
+ Turn {trace.step_index} + + prompt: {promptLen} | completion: {completionLen} + + {trace.step_reward !== undefined && ( + 0 ? "text-green-600" : trace.step_reward < 0 ? "text-red-600" : "text-gray-600"}`} + > + reward: {trace.step_reward} + + )} + {trace.tool_call_parser && ( + parser: {trace.tool_call_parser} + )} +
+ + {showIds && ( +
+
+ Token IDs +
+
+ {allIds.map((id, i) => ( + + {id} + + ))} +
+
+ )} + +
+
+ {(detokens.length > 0 ? detokens : allIds.map((id) => `[${id}]`)).map( + (tok, i) => { + const isPrompt = i < promptLen; + const lpIdx = i - promptLen; + const lp = + !isPrompt && lpIdx >= 0 && lpIdx < logprobs.length + ? logprobs[lpIdx] + : null; + return ( + + ); + } + )} +
+
+
+ ); +} + +function LogprobLegend() { + const nStops = 20; + const gradientStops = Array.from({ length: nStops }, (_, i) => { + const lp = -10 + (10 * i) / (nStops - 1); + return logprobToColor(lp); + }); + const gradient = `linear-gradient(to right, ${gradientStops.join(", ")})`; + return ( +
+ Logprob: + -10 +
+ 0 +
+ ); +} + +function TurnLegend({ numTurns }: { numTurns: number }) { + return ( +
+ + masked + + {Array.from({ length: Math.min(numTurns, 8) }, (_, i) => ( + + t{i + 1} + + ))} +
+ ); +} + +function TokenIdChip({ + id, + token, + turnIdx, + logprob, + colorMode = "mask", +}: { + id: number; + token: string; + turnIdx: number; + logprob: number | null; + colorMode?: ColorMode; +}) { + const [hover, setHover] = useState(false); + const display = token + .replace(/\n/g, "\\n") + .replace(/\t/g, "\\t"); + + const isCompletion = turnIdx > 0; + const bg = + colorMode === "logprobs" && isCompletion && logprob != null + ? logprobToColor(logprob) + : turnColor(turnIdx); + + return ( + setHover(true)} + onMouseLeave={() => setHover(false)} + > + {id} + {hover && ( + + "{display}" + {logprob != null && ( + <> +
+ logprob: {logprob.toFixed(4)} + + )} +
+ )} +
+ ); +} + +function TextMaskView({ + episode, + showIds, + colorMode, +}: { + episode: FullEpisode; + showIds: boolean; + colorMode: ColorMode; +}) { + const { token_ids, mask, logprobs, detokenized_tokens } = episode; + + const promptTokens = mask.filter((m) => m === 0).length; + const completionTokens = mask.filter((m) => m > 0).length; + + function bgForToken(i: number): string { + const turnIdx = mask[i] ?? 0; + if (colorMode === "logprobs" && turnIdx > 0 && logprobs[i] != null) { + return logprobToColor(logprobs[i]!); + } + return turnColor(turnIdx); + } + + // For mask mode, group consecutive tokens with same mask for cleaner spans + type Segment = { turnIdx: number; text: string; bg: string }; + const segments: Segment[] = []; + for (let i = 0; i < detokenized_tokens.length; i++) { + const turnIdx = mask[i] ?? 0; + const tok = detokenized_tokens[i] ?? ""; + const bg = bgForToken(i); + if ( + colorMode === "mask" && + segments.length > 0 && + segments[segments.length - 1].turnIdx === turnIdx + ) { + segments[segments.length - 1].text += tok; + } else { + segments.push({ turnIdx, text: tok, bg }); + } + } + + return ( +
+
+ Text + Mask ({episode.num_turns} turns) + + masked: {promptTokens} + + + unmasked: {completionTokens} + +
+ + {showIds && ( +
+
+ Token IDs (gray=masked, colored=unmasked by turn) — hover for text & logprob +
+
+ {token_ids.map((id, i) => ( + + ))} +
+
+ )} + +
+ {segments.map((seg, i) => ( + 0 + ? `2px solid ${turnColor(seg.turnIdx).replace("0.5)", "0.9)")}` + : "none", + }} + > + {seg.text} + + ))} +
+
+ ); +} + +type ViewLevel = "text" | "episode" | "turns"; + +export const TokenDebugView = ({ extra }: TokenDebugViewProps) => { + const [colorMode, setColorMode] = useState("mask"); + const [showIds, setShowIds] = useState(false); + const [viewLevel, setViewLevel] = useState("text"); + + const fullEpisode: FullEpisode | null = extra?.full_episode ?? null; + const tokenTurnTraces: TokenTurnTrace[] = extra?.token_turn_traces ?? []; + + if (!fullEpisode && tokenTurnTraces.length === 0) { + return ( +
+ No token data available +
+ ); + } + + const episodeReward = extra?.episode_reward; + const stepRewards: number[] = extra?.step_rewards ?? []; + const numTurns = fullEpisode?.num_turns ?? tokenTurnTraces.length; + + return ( +
+ {/* Header */} +
+
+

Token Debug

+ {episodeReward !== undefined && ( + 0 + ? "bg-green-100 text-green-700" + : episodeReward < 0 + ? "bg-red-100 text-red-700" + : "bg-gray-100 text-gray-700" + }`} + > + reward: {episodeReward} + + )} + {stepRewards.length > 0 && ( + + [{stepRewards.map((r) => r.toFixed(1)).join(", ")}] + + )} +
+
+ {colorMode === "logprobs" ? : } + + {(["mask", "logprobs"] as ColorMode[]).map((m) => ( + + ))} + | + {(["text", "episode", "turns"] as ViewLevel[]).map((v) => ( + + ))} +
+
+ + {/* Content */} +
+ {viewLevel === "text" && fullEpisode ? ( + + ) : viewLevel === "episode" && fullEpisode ? ( + + ) : tokenTurnTraces.length > 0 ? ( + tokenTurnTraces.map((trace, i) => ( + + )) + ) : ( +
+ No token data available for this view +
+ )} +
+
+ ); +};