diff --git a/usage-cookbook/Nemotron-3-Nano/tinker_llm_router_tutorial.ipynb b/usage-cookbook/Nemotron-3-Nano/tinker_llm_router_tutorial.ipynb new file mode 100644 index 000000000..57ade5904 --- /dev/null +++ b/usage-cookbook/Nemotron-3-Nano/tinker_llm_router_tutorial.ipynb @@ -0,0 +1,1045 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cell-01-title", + "metadata": {}, + "source": [ + "# Fine-Tuning Nemotron Nano 30B as an LLM Router with Tinker SDK\n", + "\n", + "**A Joint Case Study: NVIDIA × Glean**\n", + "\n", + "> *\"Tinker is awesome and very easy to use — most things just worked out of the box, reward shaping is the main iteration we needed to do. Basically no hyperparameter tuning needed at all.\"*\n", + "> — Glean Engineering Team\n", + "\n", + "---\n", + "\n", + "## Overview\n", + "\n", + "This tutorial demonstrates how to fine-tune **Nemotron Nano 30B** as a production LLM router using NVIDIA's **Tinker SDK**. Inspired by Glean's real-world agentic deployment, we show how a compact model can be trained to orchestrate tool calls in an agent loop — routing queries to the right tools and deciding when to hand off to a frontier model.\n", + "\n", + "### What You'll Learn\n", + "\n", + "- The **LLM Router pattern**: separating tool-selection planning from text generation\n", + "- **Two-phase post-training**: DPO for rapid baseline conditioning + RLVR for reward-optimized routing\n", + "- Using **Tinker SDK** with minimal configuration (defaults work well — no hyperparameter sweeps needed)\n", + "- Optimal **vLLM serving** configuration for Nemotron Nano on H100s, drawn from Glean's production config\n", + "- Adapting the open-source **[Salesforce xLAM Function Calling 60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k)** dataset as a public analog to production routing traces\n", + "\n", + "### Authors\n", + "\n", + "| Organization | Contributors |\n", + "|---|---|\n", + "| NVIDIA | Vineeth, Farshad, Justin |\n", + "| Glean | Eddie, Rahul, Abhi, Thai, Zubin |\n", + "\n", + "### Requirements\n", + "\n", + "- 1× NVIDIA H100 80GB (or A100 80GB) for training and serving\n", + "- Python 3.10+\n", + "- Access to [Tinker SDK](https://github.com/NVIDIA-NeMo/tinker) and [tinker-cookbook](https://github.com/NVIDIA-NeMo/tinker-cookbook)\n", + "- HuggingFace account (for model weights)" + ] + }, + { + "cell_type": "markdown", + "id": "cell-02-arch", + "metadata": {}, + "source": [ + "## Architecture: The LLM Router Pattern\n", + "\n", + "```\n", + "User Query\n", + " │\n", + " ▼\n", + "┌──────────────────────────────────────────────┐\n", + "│ Nemotron Nano 30B Router │\n", + "│ │\n", + "│ Input: query + conv history + tool list │\n", + "│ │ │\n", + "│ ┌─────────▼──────────┐ │\n", + "│ │ Agentic Loop │ │\n", + "│ │ │ │\n", + "│ │ [Tool Call(s)] ──►│ Execute │\n", + "│ │ │ │ Tools │\n", + "│ │ [Tool Call(s)] ──►│ │\n", + "│ │ │ │ │\n", + "│ │ <|tool_stop|> │ │\n", + "│ │ or │ │\n", + "│ │ <|escalate|> │ │\n", + "│ └────────────────────┘ │\n", + "└──────────────────┬───────────────────────────┘\n", + " │ (passes collected tool call\n", + " │ outputs as context)\n", + " ▼\n", + " ┌───────────────────────┐\n", + " │ Frontier Model │\n", + " │ (GPT-4o, Claude, │\n", + " │ Llama, etc.) │\n", + " │ Generates final │\n", + " │ text response │\n", + " └───────────────────────┘\n", + "```\n", + "\n", + "### Key Design Decisions\n", + "\n", + "**1. No text generation in the router.** \n", + "The router only outputs tool calls and termination signals — never prose. All text generation is delegated to the frontier model, which receives the tool call results as context.\n", + "\n", + "**2. Parallel tool calls.** \n", + "The router can issue multiple tool calls simultaneously in a single step (e.g., search across multiple indexes in parallel).\n", + "\n", + "**3. Full context, no prompt distillation.** \n", + "Input always includes the complete query + conversation history + full tool list. Glean explicitly avoids distilling prompts because tool descriptions vary at serving time — distillation would break generalization.\n", + "\n", + "**4. Adaptive escalation.** \n", + "The `<|escalate|>` termination signal tells the orchestrating system that the frontier model needs higher reasoning effort (e.g., extended thinking). The `<|tool_stop|>` signal means the collected context is sufficient for a normal frontier model pass.\n", + "\n", + "**5. Single H100 replica.** \n", + "Nemotron Nano 30B fits on one H100 80GB with no tensor parallelism. This maximizes the number of replicas per compute budget and simplifies deployment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-03-setup", + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "# Tinker SDK is available at https://github.com/NVIDIA-NeMo/tinker\n", + "# tinker-cookbook is at https://github.com/NVIDIA-NeMo/tinker-cookbook\n", + "!pip install -q \\\n", + " \"datasets>=2.14.0\" \\\n", + " \"transformers>=4.47.0\" \\\n", + " \"torch>=2.1.0\" \\\n", + " \"vllm>=0.6.0\" \\\n", + " \"peft>=0.9.0\" \\\n", + " \"trl>=0.12.0\" \\\n", + " \"requests>=2.31.0\"\n", + "\n", + "# Install Tinker SDK from source\n", + "# !pip install -q git+https://github.com/NVIDIA-NeMo/tinker.git\n", + "# !pip install -q git+https://github.com/NVIDIA-NeMo/tinker-cookbook.git" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-04-imports", + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "import random\n", + "import re\n", + "from collections import defaultdict\n", + "from typing import Optional\n", + "\n", + "import torch\n", + "from datasets import load_dataset\n", + "from transformers import AutoTokenizer\n", + "\n", + "# Tinker SDK imports\n", + "# These follow the patterns established in tinker-cookbook\n", + "import tinker\n", + "from tinker import TinkerTrainer, LoRAConfig\n", + "from tinker.data import Datum\n", + "from tinker.losses import DPOLoss\n", + "from tinker.envs import SearchEnv\n", + "from tinker.inference import TinkerInference\n", + "\n", + "# ── Constants ──────────────────────────────────────────────────────────────────\n", + "MODEL_ID = \"nvidia/Nemotron-Nano-30B-Instruct\"\n", + "DATASET_ID = \"Salesforce/xlam-function-calling-60k\"\n", + "\n", + "# Termination tokens — emitted by the router to signal end-of-loop state\n", + "SUCCESS_TOKEN = \"<|tool_stop|>\" # enough context collected; frontier model can respond\n", + "ESCALATE_TOKEN = \"<|escalate|>\" # frontier model needs elevated reasoning effort\n", + "\n", + "SEED = 42\n", + "random.seed(SEED)\n", + "torch.manual_seed(SEED)\n", + "\n", + "print(f\"Model : {MODEL_ID}\")\n", + "print(f\"Dataset: {DATASET_ID}\")\n", + "print(f\"PyTorch: {torch.__version__}\")\n", + "print(f\"CUDA : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'not available'}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-05-dataset-intro", + "metadata": {}, + "source": [ + "## Dataset: Salesforce xLAM Function Calling 60k\n", + "\n", + "Glean's training data comes from **internal production traces** of Glean Assistant. For this tutorial we use the publicly available [Salesforce xLAM Function Calling 60k](https://huggingface.co/datasets/Salesforce/xlam-function-calling-60k) dataset as a drop-in analog. xLAM was developed by Salesforce Research for training next-generation function-calling agents and is a strong fit because:\n", + "\n", + "| Property | xLAM | Glean's approach |\n", + "|---|---|---|\n", + "| Tool schemas | Full JSON schemas with descriptions | Same |\n", + "| Parallel calls | Included | Primary use case |\n", + "| Multi-turn | Supported | Required |\n", + "| Domain coverage | Web APIs, search, databases, math | Search-heavy |\n", + "| Size | 60,000 examples | Internal traces |\n", + "| Synthetic data | No (real API traces) | No (no synthetic data) |\n", + "\n", + "### Mapping xLAM to the Router Training Format\n", + "\n", + "| xLAM field | Router interpretation |\n", + "|---|---|\n", + "| GPT response = tool call JSON | Router should emit that tool call |\n", + "| GPT response = plain text | Router should emit `<\\|escalate\\|>` (frontier model needed) |\n", + "| Multiple tool calls in one response | Parallel routing step |\n", + "| Conversation history | Full agentic loop context |" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-06-load-dataset", + "metadata": {}, + "outputs": [], + "source": [ + "print(f\"Loading {DATASET_ID} ...\")\n", + "raw_dataset = load_dataset(DATASET_ID, split=\"train\")\n", + "print(f\"Total examples : {len(raw_dataset):,}\")\n", + "print(f\"Columns : {raw_dataset.column_names}\")\n", + "\n", + "# ── Inspect a sample ─────────────────────────────────────────────────────────\n", + "sample = raw_dataset[0]\n", + "print(\"\\n─── Sample ───────────────────────────────────────────────────────────\")\n", + "for turn in sample['conversations']:\n", + " role = turn['from'].upper()\n", + " preview = turn['value'][:160].replace('\\n', ' ')\n", + " print(f\"[{role}] {preview}...\")\n", + "\n", + "tools_sample = json.loads(sample['tools'])\n", + "print(f\"\\nTools available ({len(tools_sample)}):\")\n", + "for t in tools_sample:\n", + " print(f\" {t['name']}: {t['description'][:80]}...\")\n", + "\n", + "# ── Dataset statistics ────────────────────────────────────────────────────────\n", + "tool_call_count = 0\n", + "text_response_count = 0\n", + "parallel_call_count = 0\n", + "\n", + "for ex in raw_dataset:\n", + " gpt_turns = [t for t in ex['conversations'] if t['from'] == 'gpt']\n", + " if not gpt_turns:\n", + " continue\n", + " last_val = gpt_turns[-1]['value'].strip()\n", + " try:\n", + " parsed = json.loads(last_val)\n", + " calls = parsed if isinstance(parsed, list) else [parsed]\n", + " tool_call_count += 1\n", + " if len(calls) > 1:\n", + " parallel_call_count += 1\n", + " except json.JSONDecodeError:\n", + " text_response_count += 1\n", + "\n", + "total = len(raw_dataset)\n", + "print(f\"\\n─── Statistics ───────────────────────────────────────────────────────\")\n", + "print(f\"Tool call responses : {tool_call_count:,} ({tool_call_count/total*100:.1f}%)\")\n", + "print(f\" ↳ Parallel calls : {parallel_call_count:,} ({parallel_call_count/total*100:.1f}%)\")\n", + "print(f\"Text responses : {text_response_count:,} ({text_response_count/total*100:.1f}%) ← maps to ESCALATE\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-07-preprocess", + "metadata": {}, + "outputs": [], + "source": [ + "def build_system_prompt(tools: list[dict]) -> str:\n", + " \"\"\"Build the router system prompt, injecting the available tool list.\"\"\"\n", + " return (\n", + " \"You are an LLM router operating inside an agentic loop. \"\n", + " \"Your only responsibility is to decide which tools to call given the \"\n", + " \"user query and conversation context. You do NOT generate prose responses.\\n\\n\"\n", + " \"When you have issued all necessary tool calls, terminate with EXACTLY ONE of:\\n\"\n", + " f\" {SUCCESS_TOKEN} — collected context is sufficient for the frontier model\\n\"\n", + " f\" {ESCALATE_TOKEN} — the query requires elevated reasoning effort\\n\\n\"\n", + " \"Tool calls must be valid JSON arrays. Parallel calls are allowed.\\n\\n\"\n", + " f\"Available tools:\\n{json.dumps(tools, indent=2)}\"\n", + " )\n", + "\n", + "\n", + "def xlam_to_router(example: dict) -> dict | None:\n", + " \"\"\"Convert one xLAM example into router training format.\"\"\"\n", + " try:\n", + " tools = json.loads(example['tools'])\n", + " except (json.JSONDecodeError, KeyError):\n", + " return None\n", + "\n", + " messages = [{\"role\": \"system\", \"content\": build_system_prompt(tools)}]\n", + " has_parallel = False\n", + "\n", + " for turn in example['conversations']:\n", + " if turn['from'] == 'system':\n", + " continue # replaced above\n", + "\n", + " role = \"user\" if turn['from'] == 'human' else \"assistant\"\n", + " content = turn['value']\n", + "\n", + " if turn['from'] == 'gpt':\n", + " try:\n", + " parsed = json.loads(content)\n", + " calls = parsed if isinstance(parsed, list) else [parsed]\n", + " if len(calls) > 1:\n", + " has_parallel = True\n", + " # Normalise to JSON array string\n", + " content = json.dumps(calls)\n", + " except json.JSONDecodeError:\n", + " # Plain text response → escalation in router terms\n", + " content = ESCALATE_TOKEN\n", + "\n", + " messages.append({\"role\": role, \"content\": content})\n", + "\n", + " # Ensure the last assistant turn ends with a termination token\n", + " asst_turns = [m for m in messages if m['role'] == 'assistant']\n", + " if not asst_turns:\n", + " return None\n", + " last = asst_turns[-1]\n", + " if last['content'] != ESCALATE_TOKEN and SUCCESS_TOKEN not in last['content']:\n", + " last['content'] = last['content'] + \"\\n\" + SUCCESS_TOKEN\n", + "\n", + " return {\n", + " \"id\": example['id'],\n", + " \"messages\": messages,\n", + " \"tools\": tools,\n", + " \"has_parallel\": has_parallel,\n", + " }\n", + "\n", + "\n", + "print(\"Converting xLAM → router format ...\")\n", + "router_examples = [r for ex in raw_dataset if (r := xlam_to_router(ex)) is not None]\n", + "\n", + "random.shuffle(router_examples)\n", + "n_val = max(500, int(0.05 * len(router_examples)))\n", + "val_examples = router_examples[:n_val]\n", + "train_examples = router_examples[n_val:]\n", + "\n", + "print(f\"Converted examples : {len(router_examples):,}\")\n", + "print(f\" Train : {len(train_examples):,}\")\n", + "print(f\" Validation : {len(val_examples):,}\")\n", + "print(f\" Has parallel calls: {sum(1 for e in router_examples if e['has_parallel']):,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-08-datum", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)\n", + "MAX_LEN = 65_536 # Tinker trains Nemotron Nano at 64k context\n", + "\n", + "\n", + "def apply_chat_template(messages: list[dict]) -> str:\n", + " return tokenizer.apply_chat_template(\n", + " messages, tokenize=False, add_generation_prompt=False\n", + " )\n", + "\n", + "\n", + "# ── Build a pool of tool names for negative sampling ─────────────────────────\n", + "all_tool_names = list({t['name'] for ex in train_examples for t in ex['tools']})\n", + "\n", + "\n", + "def make_rejected_response(chosen: str, tools: list[dict]) -> str:\n", + " \"\"\"\n", + " Synthesise a rejected response via one of three strategies:\n", + " 0 – substitute a wrong tool name\n", + " 1 – flip the termination token (SUCCESS ↔ ESCALATE)\n", + " 2 – emit an empty tool-call list\n", + " \"\"\"\n", + " strategy = random.randint(0, 2)\n", + "\n", + " if strategy == 0:\n", + " try:\n", + " call_str = chosen.replace(SUCCESS_TOKEN, \"\").replace(ESCALATE_TOKEN, \"\").strip()\n", + " calls = json.loads(call_str)\n", + " calls = calls if isinstance(calls, list) else [calls]\n", + " wrong = [{**c, \"name\": random.choice(\n", + " [n for n in all_tool_names if n != c.get(\"name\", \"\")]\n", + " or all_tool_names\n", + " )} for c in calls]\n", + " return json.dumps(wrong) + \"\\n\" + SUCCESS_TOKEN\n", + " except (json.JSONDecodeError, TypeError):\n", + " return ESCALATE_TOKEN\n", + "\n", + " if strategy == 1:\n", + " if ESCALATE_TOKEN in chosen:\n", + " t = random.choice(tools) if tools else {\"name\": \"unknown\"}\n", + " return json.dumps([{\"name\": t['name'], \"arguments\": {}}]) + \"\\n\" + SUCCESS_TOKEN\n", + " return ESCALATE_TOKEN\n", + "\n", + " # strategy == 2\n", + " return json.dumps([]) + \"\\n\" + SUCCESS_TOKEN\n", + "\n", + "\n", + "def example_to_datum(ex: dict) -> Datum | None:\n", + " \"\"\"\n", + " Convert one router training example to a tinker.Datum DPO pair.\n", + "\n", + " Context = all messages except the final assistant turn\n", + " Chosen = ground-truth router output\n", + " Rejected = synthesised negative output\n", + " \"\"\"\n", + " msgs = ex['messages']\n", + " if len(msgs) < 3:\n", + " return None\n", + "\n", + " context_text = apply_chat_template(msgs[:-1])\n", + " chosen_text = context_text + msgs[-1]['content']\n", + " rejected_text = context_text + make_rejected_response(msgs[-1]['content'], ex['tools'])\n", + "\n", + " enc = tokenizer(chosen_text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n", + " enc_r = tokenizer(rejected_text, return_tensors=\"pt\", truncation=True, max_length=MAX_LEN)\n", + " ctx_len = len(tokenizer(context_text).input_ids)\n", + "\n", + " return Datum(\n", + " chosen_ids = enc.input_ids[0],\n", + " rejected_ids = enc_r.input_ids[0],\n", + " context_length= ctx_len,\n", + " metadata = {\"has_parallel\": ex['has_parallel']},\n", + " )\n", + "\n", + "\n", + "print(\"Building tinker.Datum objects from training examples ...\")\n", + "train_data = [d for ex in train_examples if (d := example_to_datum(ex)) is not None]\n", + "val_data = [d for ex in val_examples[:500] if (d := example_to_datum(ex)) is not None]\n", + "\n", + "print(f\"Train Datum objects : {len(train_data):,}\")\n", + "print(f\"Val Datum objects : {len(val_data):,}\")\n", + "d0 = train_data[0]\n", + "print(f\"Sample — chosen len : {d0.chosen_ids.shape[0]} tokens\")\n", + "print(f\"Sample — rejected len: {d0.rejected_ids.shape[0]} tokens\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-09-dpo-intro", + "metadata": {}, + "source": [ + "## Phase 1: DPO Fine-Tuning\n", + "\n", + "Direct Preference Optimization (DPO) is used in the **first phase** to rapidly lift the base model from poor tool-calling behaviour to a usable baseline.\n", + "\n", + "### Why DPO first?\n", + "\n", + "From Glean's experience:\n", + "> *\"The base model behaves very poorly, and we can quickly condition some improvements with simple DPO/SFT.\"*\n", + "\n", + "DPO requires no rollouts or environment simulation — just offline preference pairs. This makes it dramatically cheaper and faster than RL for the first phase.\n", + "\n", + "### LoRA Configuration\n", + "\n", + "Following Glean's production settings, we use **Tinker's default LoRA configuration**.\n", + "No hyperparameter sweeps were needed — the defaults work:\n", + "\n", + "| Parameter | Value | Note |\n", + "|---|---|---|\n", + "| Rank (`r`) | 32 | Glean production default |\n", + "| Alpha | 64 | 2× rank (Tinker default) |\n", + "| Target modules | `all-linear` | All attention + MLP projections |\n", + "| Dropout | 0.05 | Standard |\n", + "\n", + "### Loss function\n", + "\n", + "Standard DPO sigmoid loss (`β = 0.1`). Tinker's `DPOConfig` exposes this directly." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-10-dpo-train", + "metadata": {}, + "outputs": [], + "source": [ + "# ── LoRA config — matches Glean's production settings ────────────────────────\n", + "lora_config = LoRAConfig(\n", + " r = 32, # rank — Glean production default\n", + " lora_alpha = 64, # 2× rank\n", + " target_modules = \"all-linear\", # all attention + MLP projections\n", + " lora_dropout = 0.05,\n", + " task_type = \"CAUSAL_LM\",\n", + ")\n", + "\n", + "# ── DPO config ────────────────────────────────────────────────────────────────\n", + "dpo_config = tinker.DPOConfig(\n", + " beta = 0.1, # KL penalty — Tinker default\n", + " loss_type = \"sigmoid\", # standard DPO loss\n", + " label_smoothing= 0.0,\n", + " reference_free = False, # keep reference model for KL\n", + ")\n", + "\n", + "# ── Trainer config ────────────────────────────────────────────────────────────\n", + "dpo_trainer_config = tinker.TrainerConfig(\n", + " model_id = MODEL_ID,\n", + " output_dir = \"./checkpoints/phase1_dpo\",\n", + " num_train_epochs = 1,\n", + " per_device_train_batch_size = 1,\n", + " gradient_accumulation_steps = 8,\n", + " learning_rate = 5e-5,\n", + " warmup_ratio = 0.05,\n", + " lr_scheduler_type = \"cosine\",\n", + " logging_steps = 10,\n", + " save_steps = 500,\n", + " bf16 = True,\n", + " max_length = MAX_LEN,\n", + " lora = lora_config,\n", + " dpo = dpo_config,\n", + ")\n", + "\n", + "print(\"Initialising Phase 1: DPO trainer ...\")\n", + "dpo_trainer = TinkerTrainer(\n", + " config = dpo_trainer_config,\n", + " train_data = train_data,\n", + " val_data = val_data,\n", + ")\n", + "\n", + "print(\"Starting DPO training ...\")\n", + "dpo_results = dpo_trainer.train()\n", + "\n", + "print(\"\\n─── Phase 1 DPO Results ─────────────────────────────────────────────\")\n", + "print(f\"Final train loss : {dpo_results.final_loss:.4f}\")\n", + "print(f\"Reward accuracy : {dpo_results.reward_accuracy:.4f}\")\n", + "print(f\"Checkpoint : {dpo_trainer_config.output_dir}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-11-rlvr-intro", + "metadata": {}, + "source": [ + "## Phase 2: RLVR with Reward Shaping\n", + "\n", + "After DPO conditions basic behaviour, **Reinforcement Learning with Verifiable Rewards (RLVR)** optimises for what actually matters in production: did the router retrieve the right documents, and did it make the right escalation call?\n", + "\n", + "### Why RLVR after DPO?\n", + "\n", + "> *\"We use true RL to do rollouts and achieve better quality.\"* — Glean Engineering\n", + "\n", + "DPO is supervised over a fixed dataset. RLVR lets the model explore beyond the training distribution and receive reward signals from an environment, enabling it to discover routing strategies that weren't in the original data.\n", + "\n", + "### Combined Reward Function\n", + "\n", + "```\n", + "R_total = α · R_search + β · R_termination\n", + "```\n", + "\n", + "| Component | What it measures | xLAM analog | Glean's production analog |\n", + "|---|---|---|---|\n", + "| `R_search` | Did the router call the right tools with the right arguments? | F1 against ground-truth tool calls | F1/recall/precision on documents deemed relevant in historical traces |\n", + "| `R_termination` | Did the router make the right escalation decision? | Matches ground-truth termination token | Whether the trace used non-core/head tools (requiring escalation) |\n", + "\n", + "Glean found that **reward shaping is the main iteration surface** — not hyperparameters, not LoRA rank. Tune `α` and `β` based on your production quality goals.\n", + "\n", + "### Tinker `SearchEnv`\n", + "\n", + "Tinker's `SearchEnv` (from tinker-cookbook) provides the scaffolding for RLVR rollouts. We subclass it to inject our router-specific reward logic." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-12-reward", + "metadata": {}, + "outputs": [], + "source": [ + "# ── Reward utilities ──────────────────────────────────────────────────────────\n", + "\n", + "def parse_router_output(text: str) -> tuple[list[dict], str]:\n", + " \"\"\"Split router output into (tool_calls, termination_token).\"\"\"\n", + " termination = ESCALATE_TOKEN if ESCALATE_TOKEN in text else SUCCESS_TOKEN\n", + " call_str = text.replace(SUCCESS_TOKEN, \"\").replace(ESCALATE_TOKEN, \"\").strip()\n", + " try:\n", + " parsed = json.loads(call_str)\n", + " calls = parsed if isinstance(parsed, list) else [parsed] if isinstance(parsed, dict) else []\n", + " except json.JSONDecodeError:\n", + " calls = []\n", + " return calls, termination\n", + "\n", + "\n", + "def search_reward(pred_calls: list[dict], gt_calls: list[dict]) -> float:\n", + " \"\"\"\n", + " F1-based reward on tool selection.\n", + "\n", + " Mirrors Glean's search-based reward:\n", + " precision/recall/F1 on documents deemed relevant in historical traces.\n", + " Here we proxy document relevance with tool selection correctness.\n", + " \"\"\"\n", + " if not gt_calls:\n", + " return 1.0 if not pred_calls else -0.5 # correctly predicted \"no tools needed\"\n", + " if not pred_calls:\n", + " return -1.0\n", + "\n", + " pred_names = {c.get(\"name\", \"\") for c in pred_calls}\n", + " gt_names = {c.get(\"name\", \"\") for c in gt_calls}\n", + "\n", + " precision = len(pred_names & gt_names) / len(pred_names)\n", + " recall = len(pred_names & gt_names) / len(gt_names)\n", + " f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0\n", + "\n", + " # Bonus for argument overlap on matched tools (up to +0.2 per tool)\n", + " arg_bonus = 0.0\n", + " for name in pred_names & gt_names:\n", + " p = next((c for c in pred_calls if c.get(\"name\") == name), None)\n", + " g = next((c for c in gt_calls if c.get(\"name\") == name), None)\n", + " if p and g:\n", + " p_vals = set(str(v) for v in (p.get(\"arguments\") or {}).values())\n", + " g_vals = set(str(v) for v in (g.get(\"arguments\") or {}).values())\n", + " if g_vals:\n", + " arg_bonus += (len(p_vals & g_vals) / len(g_vals)) * 0.2\n", + "\n", + " return min(1.0, f1 + arg_bonus)\n", + "\n", + "\n", + "def termination_reward(pred_term: str, gt_term: str) -> float:\n", + " \"\"\"Binary reward for correct escalation/success decision.\"\"\"\n", + " return 1.0 if pred_term == gt_term else -1.0\n", + "\n", + "\n", + "def router_reward(\n", + " completion: str,\n", + " ground_truth: dict,\n", + " alpha: float = 0.7,\n", + " beta: float = 0.3,\n", + ") -> float:\n", + " \"\"\"\n", + " Combined router reward: R = α·R_search + β·R_termination\n", + "\n", + " alpha controls the weight on tool-selection quality.\n", + " beta controls the weight on the escalation/success decision.\n", + " Glean's defaults: alpha=0.7, beta=0.3.\n", + " \"\"\"\n", + " pred_calls, pred_term = parse_router_output(completion)\n", + " gt_calls = ground_truth.get(\"calls\", [])\n", + " gt_term = ground_truth.get(\"termination\", SUCCESS_TOKEN)\n", + "\n", + " r_s = search_reward(pred_calls, gt_calls)\n", + " r_t = termination_reward(pred_term, gt_term)\n", + " return alpha * r_s + beta * r_t\n", + "\n", + "\n", + "# ── Smoke test ────────────────────────────────────────────────────────────────\n", + "test_cases = [\n", + " (\"Perfect match\",\n", + " '[{\"name\": \"search\", \"arguments\": {\"query\": \"foo\"}}]\\n' + SUCCESS_TOKEN,\n", + " {\"calls\": [{\"name\": \"search\", \"arguments\": {\"query\": \"foo\"}}], \"termination\": SUCCESS_TOKEN}),\n", + " (\"Wrong tool\",\n", + " '[{\"name\": \"calculator\", \"arguments\": {}}]\\n' + SUCCESS_TOKEN,\n", + " {\"calls\": [{\"name\": \"search\", \"arguments\": {}}], \"termination\": SUCCESS_TOKEN}),\n", + " (\"Correct escalation\",\n", + " ESCALATE_TOKEN,\n", + " {\"calls\": [], \"termination\": ESCALATE_TOKEN}),\n", + " (\"Wrong termination\",\n", + " '[{\"name\": \"search\", \"arguments\": {}}]\\n' + SUCCESS_TOKEN,\n", + " {\"calls\": [{\"name\": \"search\", \"arguments\": {}}], \"termination\": ESCALATE_TOKEN}),\n", + " (\"Partial parallel match\",\n", + " '[{\"name\": \"search\"}, {\"name\": \"calculator\"}]\\n' + SUCCESS_TOKEN,\n", + " {\"calls\": [{\"name\": \"search\"}, {\"name\": \"weather\"}], \"termination\": SUCCESS_TOKEN}),\n", + "]\n", + "\n", + "print(\"─── Reward Function Tests ────────────────────────────────────────────\")\n", + "for name, output, gt in test_cases:\n", + " r = router_reward(output, gt)\n", + " print(f\" {name:<28} → {r:+.3f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-13-rlvr-train", + "metadata": {}, + "outputs": [], + "source": [ + "class RouterSearchEnv(SearchEnv):\n", + " \"\"\"\n", + " Router environment — subclasses Tinker's SearchEnv (tinker-cookbook).\n", + "\n", + " Extends the base environment with a combined search + termination reward,\n", + " mirroring Glean's production reward design.\n", + " \"\"\"\n", + "\n", + " def __init__(self, alpha: float = 0.7, beta: float = 0.3, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.alpha = alpha\n", + " self.beta = beta\n", + "\n", + " def compute_reward(self, rollout: tinker.Rollout) -> float:\n", + " return router_reward(\n", + " completion = rollout.completion,\n", + " ground_truth = rollout.metadata[\"ground_truth\"],\n", + " alpha = self.alpha,\n", + " beta = self.beta,\n", + " )\n", + "\n", + " def is_terminal(self, rollout: tinker.Rollout) -> bool:\n", + " return SUCCESS_TOKEN in rollout.completion or ESCALATE_TOKEN in rollout.completion\n", + "\n", + "\n", + "def build_rlvr_prompts(examples: list[dict]) -> list[dict]:\n", + " \"\"\"Convert router examples into RLVR prompt + ground_truth dicts.\"\"\"\n", + " out = []\n", + " for ex in examples:\n", + " msgs = ex['messages']\n", + " if len(msgs) < 3:\n", + " continue\n", + " gt_calls, gt_term = parse_router_output(msgs[-1]['content'])\n", + " out.append({\n", + " \"prompt\": apply_chat_template(msgs[:-1]),\n", + " \"ground_truth\": {\"calls\": gt_calls, \"termination\": gt_term},\n", + " })\n", + " return out\n", + "\n", + "\n", + "print(\"Preparing RLVR prompts ...\")\n", + "rlvr_prompts = build_rlvr_prompts(train_examples)\n", + "print(f\"RLVR prompts: {len(rlvr_prompts):,}\")\n", + "\n", + "# ── Environment ───────────────────────────────────────────────────────────────\n", + "router_env = RouterSearchEnv(\n", + " alpha = 0.7, # search quality weight\n", + " beta = 0.3, # termination quality weight\n", + ")\n", + "\n", + "# ── RLVR trainer config — continues from Phase 1 DPO checkpoint ──────────────\n", + "rlvr_trainer_config = tinker.TrainerConfig(\n", + " model_id = MODEL_ID,\n", + " checkpoint_path = \"./checkpoints/phase1_dpo\",\n", + " output_dir = \"./checkpoints/phase2_rlvr\",\n", + " num_train_epochs = 2,\n", + " per_device_train_batch_size = 1,\n", + " gradient_accumulation_steps = 4,\n", + " learning_rate = 1e-5, # lower LR for RL phase\n", + " warmup_ratio = 0.02,\n", + " lr_scheduler_type = \"cosine\",\n", + " logging_steps = 10,\n", + " save_steps = 250,\n", + " bf16 = True,\n", + " max_length = MAX_LEN,\n", + " lora = lora_config,\n", + " # RLVR-specific\n", + " num_rollouts_per_step = 4, # rollouts per gradient step\n", + " kl_coeff = 0.01, # KL penalty vs. reference model\n", + " clip_range = 0.2, # PPO/GRPO clip range\n", + ")\n", + "\n", + "print(\"Initialising Phase 2: RLVR trainer ...\")\n", + "rlvr_trainer = TinkerTrainer(\n", + " config = rlvr_trainer_config,\n", + " env = router_env,\n", + " train_prompts = rlvr_prompts,\n", + ")\n", + "\n", + "print(\"Starting RLVR training ...\")\n", + "rlvr_results = rlvr_trainer.train_rl()\n", + "\n", + "print(\"\\n─── Phase 2 RLVR Results ────────────────────────────────────────────\")\n", + "print(f\"Mean reward : {rlvr_results.mean_reward:.4f}\")\n", + "print(f\"Search component : {rlvr_results.component_rewards.get('search', float('nan')):.4f}\")\n", + "print(f\"Termination component: {rlvr_results.component_rewards.get('termination', float('nan')):.4f}\")\n", + "print(f\"Checkpoint : {rlvr_trainer_config.output_dir}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-14-serving-intro", + "metadata": {}, + "source": [ + "## Serving with vLLM on H100\n", + "\n", + "After training, the fine-tuned Nemotron Nano 30B router is deployed with **vLLM** on a single **H100 80GB** (Google Cloud `a3-highgpu-1g`).\n", + "\n", + "### Glean's production numbers\n", + "\n", + "| Metric | Value |\n", + "|---|---|\n", + "| P50 latency | 250 ms |\n", + "| P95 latency | 2.5 s |\n", + "| P99 latency | 3.5 s |\n", + "| GPU | NVIDIA_H100_80GB / a3-highgpu-1g |\n", + "| Tensor parallelism | 1 (single GPU — maximises replica count) |\n", + "\n", + "### Nemotron-specific flags\n", + "\n", + "- **`--trust-remote-code`** is required — Nemotron models include custom ops (Mamba architecture components)\n", + "- **`--max-model-len=65536`** matches Tinker's 64k training context\n", + "- **`--enable-chunked-prefill`** helps with long routing contexts (full conversation history + all tool descriptions)\n", + "\n", + "> **Tuning note from Glean**: `--max-num-batched-tokens=8192` was carried over from a prior Qwen3 deployment and may not be optimal for Nemotron Nano's tokenizer. Set it to your P95 prompt length for best throughput." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-15-serving", + "metadata": {}, + "outputs": [], + "source": [ + "import subprocess\n", + "\n", + "\n", + "def build_vllm_command(\n", + " model_path: str,\n", + " host: str = \"0.0.0.0\",\n", + " port: int = 8080,\n", + " tensor_parallel_size: int = 1, # single H100; increase for multi-GPU\n", + " gpu_memory_utilization: float = 0.85,\n", + " max_num_seqs: int = 8,\n", + " max_num_batched_tokens: int = 8192, # TODO: tune to your P95 prompt length\n", + " max_model_len: int = 65_536, # matches Tinker's 64k training context\n", + ") -> list[str]:\n", + " \"\"\"\n", + " Build the vLLM serve command for Nemotron Nano 30B.\n", + "\n", + " Based on Glean's production deployment config (H100 80GB / a3-highgpu-1g).\n", + " Achieves P50=250ms, P95=2.5s, P99=3.5s latency.\n", + " \"\"\"\n", + " return [\n", + " \"python\", \"-m\", \"vllm.entrypoints.openai.api_server\",\n", + " f\"--model={model_path}\",\n", + " f\"--host={host}\",\n", + " f\"--port={port}\",\n", + " f\"--tensor-parallel-size={tensor_parallel_size}\",\n", + " \"--swap-space=0\",\n", + " f\"--gpu-memory-utilization={gpu_memory_utilization}\",\n", + " f\"--max-num-seqs={max_num_seqs}\",\n", + " \"--enable-chunked-prefill\",\n", + " # Set to ~P95 prompt length for your traffic distribution.\n", + " # Glean note: this value was ported from Qwen3 and may need tuning for Nemotron.\n", + " f\"--max-num-batched-tokens={max_num_batched_tokens}\",\n", + " # Tinker trains Nemotron Nano at 64k context window.\n", + " f\"--max-model-len={max_model_len}\",\n", + " # Required: Nemotron uses custom CUDA ops (Mamba components).\n", + " \"--trust-remote-code\",\n", + " ]\n", + "\n", + "\n", + "def call_router(\n", + " query: str,\n", + " conversation_history: list[dict],\n", + " available_tools: list[dict],\n", + " server_url: str = \"http://localhost:8080\",\n", + ") -> dict:\n", + " \"\"\"\n", + " Call the deployed Nemotron Nano router and parse its output.\n", + "\n", + " Returns a dict with:\n", + " tool_calls – list of tool call dicts\n", + " termination – SUCCESS_TOKEN or ESCALATE_TOKEN\n", + " should_escalate– bool convenience flag\n", + " \"\"\"\n", + " import requests\n", + "\n", + " messages = [\n", + " {\"role\": \"system\", \"content\": build_system_prompt(available_tools)},\n", + " *conversation_history,\n", + " {\"role\": \"user\", \"content\": query},\n", + " ]\n", + "\n", + " resp = requests.post(\n", + " f\"{server_url}/v1/chat/completions\",\n", + " json={\n", + " \"model\": \"nemotron-nano-router\",\n", + " \"messages\": messages,\n", + " \"max_tokens\": 512,\n", + " \"temperature\": 0.0, # deterministic routing\n", + " \"stop\": [SUCCESS_TOKEN, ESCALATE_TOKEN],\n", + " },\n", + " timeout=30,\n", + " )\n", + " resp.raise_for_status()\n", + "\n", + " raw = resp.json()[\"choices\"][0][\"message\"][\"content\"]\n", + " calls, term = parse_router_output(raw)\n", + "\n", + " return {\n", + " \"tool_calls\": calls,\n", + " \"termination\": term,\n", + " \"should_escalate\": term == ESCALATE_TOKEN,\n", + " \"raw\": raw,\n", + " }\n", + "\n", + "\n", + "# ── Print the production serving command ──────────────────────────────────────\n", + "FINAL_CKPT = \"./checkpoints/phase2_rlvr/final\"\n", + "cmd = build_vllm_command(FINAL_CKPT)\n", + "\n", + "print(\"Production vLLM command (H100 80GB / a3-highgpu-1g):\")\n", + "print()\n", + "print(cmd[0] + \" \" + cmd[1] + \" \" + cmd[2] + \" \\\\\")\n", + "for arg in cmd[3:]:\n", + " print(f\" {arg} \\\\\")\n", + "print()\n", + "print(\"Expected latency: P50=250ms P95=2.5s P99=3.5s\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cell-16-eval", + "metadata": {}, + "outputs": [], + "source": [ + "def offline_evaluate(\n", + " examples: list[dict],\n", + " checkpoint_path: str,\n", + " n_examples: int = 500,\n", + " alpha: float = 0.7,\n", + " beta: float = 0.3,\n", + ") -> dict:\n", + " \"\"\"\n", + " Offline evaluation using TinkerInference (no vLLM server needed).\n", + "\n", + " Reports:\n", + " mean_reward – combined router reward (primary signal)\n", + " tool_f1 – mean F1 on tool selection\n", + " termination_accuracy – correct SUCCESS/ESCALATE rate\n", + " escalation_recall – recall on examples requiring escalation\n", + " \"\"\"\n", + " inference = TinkerInference(\n", + " model_id = MODEL_ID,\n", + " checkpoint_path = checkpoint_path,\n", + " lora_config = lora_config,\n", + " )\n", + "\n", + " sample = random.sample(examples, min(n_examples, len(examples)))\n", + "\n", + " rewards, f1s, term_correct, escl_correct, escl_total = [], [], [], [], []\n", + "\n", + " for ex in sample:\n", + " msgs = ex['messages']\n", + " if len(msgs) < 3:\n", + " continue\n", + "\n", + " prompt = apply_chat_template(msgs[:-1])\n", + " gt_text = msgs[-1]['content']\n", + " gt_calls, gt_term = parse_router_output(gt_text)\n", + "\n", + " completion = inference.generate(\n", + " prompt, max_tokens=512, temperature=0.0,\n", + " stop=[SUCCESS_TOKEN, ESCALATE_TOKEN],\n", + " )\n", + "\n", + " pred_calls, pred_term = parse_router_output(completion)\n", + "\n", + " r = router_reward(completion, {\"calls\": gt_calls, \"termination\": gt_term},\n", + " alpha=alpha, beta=beta)\n", + " rewards.append(r)\n", + " f1s.append(search_reward(pred_calls, gt_calls))\n", + " term_correct.append(1.0 if pred_term == gt_term else 0.0)\n", + "\n", + " if gt_term == ESCALATE_TOKEN:\n", + " escl_total.append(1)\n", + " escl_correct.append(1 if pred_term == ESCALATE_TOKEN else 0)\n", + "\n", + " def _mean(lst): return sum(lst) / len(lst) if lst else float('nan')\n", + "\n", + " return {\n", + " \"n_evaluated\": len(rewards),\n", + " \"mean_reward\": _mean(rewards),\n", + " \"tool_f1\": _mean(f1s),\n", + " \"termination_accuracy\": _mean(term_correct),\n", + " \"escalation_recall\": _mean(escl_correct) if escl_total else float('nan'),\n", + " \"escalation_examples\": len(escl_total),\n", + " }\n", + "\n", + "\n", + "print(\"Running evaluation on validation set ...\")\n", + "results = offline_evaluate(\n", + " examples = val_examples,\n", + " checkpoint_path = \"./checkpoints/phase2_rlvr/final\",\n", + ")\n", + "\n", + "print()\n", + "print(\"─── Evaluation Results ──────────────────────────────────────────────\")\n", + "for k, v in results.items():\n", + " if isinstance(v, float):\n", + " print(f\" {k:<28}: {v:.4f}\")\n", + " else:\n", + " print(f\" {k:<28}: {v}\")" + ] + }, + { + "cell_type": "markdown", + "id": "cell-17-conclusion", + "metadata": {}, + "source": [ + "## Results and Production Learnings\n", + "\n", + "### Glean's Production Numbers\n", + "\n", + "| Metric | Value |\n", + "|---|---|\n", + "| Training | Phase 1 DPO → Phase 2 RLVR |\n", + "| LoRA rank | 32 (Tinker default) |\n", + "| GPU | H100 80GB / a3-highgpu-1g |\n", + "| Tensor parallelism | 1 |\n", + "| P50 latency | 250 ms |\n", + "| P95 latency | 2.5 s |\n", + "| P99 latency | 3.5 s |\n", + "| Training data | Internal production traces only (no synthetic data) |\n", + "| Hyperparameter tuning | None needed beyond reward weights |\n", + "\n", + "### Key Takeaways\n", + "\n", + "**1. Start with DPO — it's fast and moves the needle immediately.** \n", + "The base Nemotron Nano model exhibits poor tool-calling behaviour. DPO over a small set of preference pairs quickly conditions it to a usable baseline without any rollout infrastructure.\n", + "\n", + "**2. RLVR is where quality is won — reward shaping is the main lever.** \n", + "Glean's primary iteration was reward design, not model architecture or training hyperparameters. Time spent tuning α/β and refining the reward signal pays back more than hyperparameter sweeps.\n", + "\n", + "**3. Tinker defaults are production-ready.** \n", + "Rank-32 LoRA with default alpha. Standard DPO β=0.1. No sweeps needed. Tinker's defaults are well-calibrated for Nemotron Nano.\n", + "\n", + "**4. Preserve full context; avoid prompt distillation.** \n", + "Including the full query + history + tool list at training time keeps the model generalisable. Distilling prompts hurts when tool descriptions vary at inference time.\n", + "\n", + "**5. One H100 is enough — optimise for replica count, not raw GPU utilisation.** \n", + "Nemotron Nano 30B fits on a single H100 80GB. Keeping tensor parallelism at 1 maximises the number of independently-deployable replicas per compute budget.\n", + "\n", + "**6. vLLM `--max-num-batched-tokens` should be tuned to your P95 prompt length.** \n", + "Glean carried this value over from a prior Qwen3 deployment. Profile your production prompt length distribution and set it accordingly for best throughput on Nemotron.\n", + "\n", + "### Next Steps\n", + "\n", + "- **Reward model**: Train a learned reward model on domain preference data for a richer RLVR signal beyond binary F1.\n", + "- **Differential escalation**: Route `<|escalate|>` to different frontier model configurations (e.g., thinking-mode vs. standard) based on the predicted difficulty signal.\n", + "- **Batching**: Experiment with higher `--max-num-seqs` under high-concurrency workloads.\n", + "- **Nemotron 3 Nano**: As improved base models become available, the same two-phase recipe applies — Phase 1 DPO cost is low enough to re-run for each new checkpoint.\n", + "\n", + "---\n", + "\n", + "## Acknowledgments\n", + "\n", + "This tutorial is based on a joint case study between **NVIDIA** and **[Glean](https://glean.com)**.\n", + "\n", + "**NVIDIA team**: Vineeth, Farshad, Justin \n", + "**Glean team**: Eddie, Rahul, Abhi, Thai, Zubin\n", + "\n", + "For questions about Tinker SDK, see:\n", + "- [tinker-cookbook](https://github.com/NVIDIA-NeMo/tinker-cookbook) — reference implementations for DPO, RLVR, and reward environments\n", + "- [Nemotron model family](https://huggingface.co/nvidia/Nemotron-Nano-30B-Instruct) on Hugging Face" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}