From 5edfa6fb5c78c3d00c0c0308d99bbb9209d449a2 Mon Sep 17 00:00:00 2001 From: Alexander Lyashuk Date: Wed, 15 Oct 2025 13:52:34 +0200 Subject: [PATCH] Return dict of metrics from training step --- docs/index.md | 5 +- docs/tensorboard.md | 40 ++++++ proto/export_config.proto | 2 + pyproject.toml | 5 + src/lczero_training/training/overfit.py | 20 ++- src/lczero_training/training/tensorboard.py | 152 ++++++++++++++++++++ src/lczero_training/training/training.py | 85 +++++++++-- uv.lock | 18 ++- 8 files changed, 301 insertions(+), 26 deletions(-) create mode 100644 docs/tensorboard.md create mode 100644 src/lczero_training/training/tensorboard.py diff --git a/docs/index.md b/docs/index.md index 130bb4b3..a696f8c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -6,4 +6,7 @@ feeding training data. * [Training Tuple Format](training_tuple.md) — A description of the training tuple format used in the project. This is an interface between data loader and - model training. \ No newline at end of file + model training. +* [TensorBoard Metrics](tensorboard.md) — Instructions for enabling and + browsing training metrics that are exported during training. + diff --git a/docs/tensorboard.md b/docs/tensorboard.md new file mode 100644 index 00000000..4d4c47e6 --- /dev/null +++ b/docs/tensorboard.md @@ -0,0 +1,40 @@ +# TensorBoard Metrics + +Training can export metrics that are compatible with TensorBoard. To enable +logging, set the `tensorboard_path` field inside the `export` section of the +root configuration. The trainer writes TensorBoard event files into the +specified directory while training runs. + +```protobuf +export { + path: "checkpoints" + tensorboard_path: "logs/train" +} +``` + +## Recorded data + +The trainer records the following statistics: + +* **Per step** + * Learning rate that was applied for the step. + * Weighted loss value returned by the loss function. + * Unweighted loss components for value, policy, and moves left heads. + * Global gradient norm after clipping. +* **Per epoch** + * Histogram of all model weights together with mean, standard deviation, + minimum, and maximum scalars. + * Configuration scalars: batch size, steps per network, and chunks per + network. + +## Viewing the dashboard + +Launch TensorBoard and point it to the directory configured above: + +```bash +tensorboard --logdir logs/train +``` + +The dashboard contains grouped scalars for configuration parameters, losses, +and gradient norms, and a histogram of the weight distribution after each +training epoch. diff --git a/proto/export_config.proto b/proto/export_config.proto index 8a26fdf1..83b87621 100644 --- a/proto/export_config.proto +++ b/proto/export_config.proto @@ -8,4 +8,6 @@ message ExportConfig { optional string path = 1; // Training run ID for uploading to training website. Only uploads when set. optional int32 upload_training_run = 2; + // Directory where TensorBoard event files should be written. + optional string tensorboard_path = 3; } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0814f7b3..fcaa0708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "python-dotenv>=1.1.1", "requests[socks]>=2.32.5", "matplotlib>=3.10.6", + "tensorboardx>=2.6.4", ] [project.optional-dependencies] @@ -65,6 +66,10 @@ ignore_missing_imports = true module = "optax" ignore_missing_imports = true +[[tool.mypy.overrides]] +module = "tensorboardX" +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["src"] python_files = ["test_*.py", "*_test.py"] diff --git a/src/lczero_training/training/overfit.py b/src/lczero_training/training/overfit.py index 5d6998b2..24b91d56 100644 --- a/src/lczero_training/training/overfit.py +++ b/src/lczero_training/training/overfit.py @@ -210,17 +210,15 @@ def run_phase( ) -> None: nonlocal jit_state for _ in range(num_steps): - jit_state, (loss, unweighted_losses) = training.train_step( + jit_state, metrics = training.train_step( optimizer_tx, jit_state, train_batch, ) - loss_value, unweighted_host = jax.device_get( - (loss, unweighted_losses) - ) - loss_value = float(np.asarray(loss_value)) + loss_value = float(np.asarray(metrics["loss"])) unweighted_host = tree_util.tree_map( - lambda x: float(np.asarray(x)), unweighted_host + lambda x: float(np.asarray(x)), + metrics["unweighted_losses"], ) eval_loss, eval_unweighted = eval_step( @@ -252,17 +250,15 @@ def run_phase( else: logger.info("Starting overfit loop for %d steps", num_steps) for _ in range(num_steps): - jit_state, (loss, unweighted_losses) = training.train_step( + jit_state, metrics = training.train_step( optimizer_tx, jit_state, prepared_batch_a, ) - loss_value, unweighted_host = jax.device_get( - (loss, unweighted_losses) - ) - loss_value = float(np.asarray(loss_value)) + loss_value = float(np.asarray(metrics["loss"])) unweighted_host = tree_util.tree_map( - lambda x: float(np.asarray(x)), unweighted_host + lambda x: float(np.asarray(x)), + metrics["unweighted_losses"], ) step_value = int( np.asarray(jax.device_get(jit_state.step)).flat[0] diff --git a/src/lczero_training/training/tensorboard.py b/src/lczero_training/training/tensorboard.py new file mode 100644 index 00000000..7126a2b3 --- /dev/null +++ b/src/lczero_training/training/tensorboard.py @@ -0,0 +1,152 @@ +"""Utility helpers for writing TensorBoard summaries during training.""" + +from dataclasses import dataclass +from typing import Mapping, Optional + +import jax +import jax.numpy as jnp +import numpy as np +from flax import nnx +from jax import tree_util +from tensorboardX import SummaryWriter + +from proto.data_loader_config_pb2 import DataLoaderConfig +from proto.training_config_pb2 import ScheduleConfig + + +@dataclass +class StepMetrics: + """Metrics collected for a single optimization step.""" + + step: int + learning_rate: Optional[float] + weighted_loss: float + unweighted_losses: Mapping[str, float] + gradient_norm: Optional[float] + + +class TensorboardLogger: + """Writes training metrics and configuration to TensorBoard.""" + + def __init__( + self, + logdir: Optional[str], + *, + data_loader_config: Optional[DataLoaderConfig] = None, + schedule_config: Optional[ScheduleConfig] = None, + ) -> None: + if logdir: + self._writer = SummaryWriter(logdir) + else: + self._writer = None + + self._batch_size = _extract_batch_size(data_loader_config) + self._steps_per_network = ( + schedule_config.steps_per_network + if schedule_config is not None + else None + ) + self._chunks_per_network = ( + schedule_config.chunks_per_network + if schedule_config is not None + else None + ) + + def close(self) -> None: + if self._writer is not None: + self._writer.close() + + def log_step(self, metrics: StepMetrics) -> None: + if self._writer is None: + return + + if metrics.learning_rate is not None: + self._writer.add_scalar( + "config/learning_rate", metrics.learning_rate, metrics.step + ) + + self._writer.add_scalar( + "loss/weighted", metrics.weighted_loss, metrics.step + ) + + for name, value in sorted(metrics.unweighted_losses.items()): + self._writer.add_scalar( + f"loss/unweighted/{name}", value, metrics.step + ) + + if metrics.gradient_norm is not None: + self._writer.add_scalar( + "gradients/global_norm", metrics.gradient_norm, metrics.step + ) + + self._writer.flush() + + def log_epoch(self, step: int, model_state: nnx.State) -> None: + if self._writer is None: + return + + weights = _collect_weights(model_state) + if weights is not None and weights.size: + self._writer.add_histogram("weights/distribution", weights, step) + self._writer.add_scalar( + "weights/mean", float(np.mean(weights)), step + ) + self._writer.add_scalar("weights/std", float(np.std(weights)), step) + self._writer.add_scalar("weights/min", float(np.min(weights)), step) + self._writer.add_scalar("weights/max", float(np.max(weights)), step) + + if self._batch_size is not None: + self._writer.add_scalar("config/batch_size", self._batch_size, step) + if self._steps_per_network is not None: + self._writer.add_scalar( + "config/steps_per_network", self._steps_per_network, step + ) + if self._chunks_per_network is not None: + self._writer.add_scalar( + "config/chunks_per_network", self._chunks_per_network, step + ) + + self._writer.flush() + + +def _extract_batch_size( + config: Optional[DataLoaderConfig], +) -> Optional[int]: + if config is None: + return None + + for stage in config.stage: + if stage.HasField("tensor_generator"): + generator = stage.tensor_generator + if generator.HasField("batch_size"): + return int(generator.batch_size) + return None + + +def _collect_weights(model_state: nnx.State) -> Optional[np.ndarray]: + leaves = tree_util.tree_leaves( + model_state, is_leaf=lambda node: hasattr(node, "value") + ) + arrays: list[np.ndarray] = [] + for leaf in leaves: + array = _leaf_to_array(leaf) + if array is None: + continue + flat = np.asarray(jnp.ravel(array)) + if flat.size: + arrays.append(flat) + + if not arrays: + return None + return np.concatenate(arrays) + + +def _leaf_to_array(value: object) -> Optional[jax.Array]: + if isinstance(value, jax.Array): + return value + if isinstance(value, np.ndarray): + return jnp.asarray(value) + maybe_value = getattr(value, "value", None) + if isinstance(maybe_value, (jax.Array, np.ndarray)): + return jnp.asarray(maybe_value) + return None diff --git a/src/lczero_training/training/training.py b/src/lczero_training/training/training.py index b81c6206..44a28fd2 100644 --- a/src/lczero_training/training/training.py +++ b/src/lczero_training/training/training.py @@ -4,7 +4,7 @@ import os import sys from functools import partial -from typing import Callable, Dict, Generator, Tuple, cast +from typing import Callable, Dict, Generator, Optional, Tuple, TypedDict, cast import jax import jax.numpy as jnp @@ -24,8 +24,12 @@ from lczero_training.dataloader import DataLoader, make_dataloader from lczero_training.model.loss_function import LczeroLoss from lczero_training.model.model import LczeroModel -from lczero_training.training.optimizer import make_gradient_transformation +from lczero_training.training.optimizer import ( + make_gradient_transformation, + make_lr_schedule, +) from lczero_training.training.state import JitTrainingState, TrainingState +from lczero_training.training.tensorboard import StepMetrics, TensorboardLogger from proto.root_config_pb2 import RootConfig logger = logging.getLogger(__name__) @@ -38,11 +42,17 @@ def from_dataloader( yield loader.get_next() +class TrainStepMetrics(TypedDict): + loss: jax.Array + unweighted_losses: Dict[str, jax.Array] + grad_norm: jax.Array + + class Training: optimizer_tx: optax.GradientTransformation train_step: Callable[ [optax.GradientTransformation, JitTrainingState, dict], - Tuple[JitTrainingState, Tuple[jax.Array, Dict[str, jax.Array]]], + Tuple[JitTrainingState, TrainStepMetrics], ] def __init__( @@ -76,7 +86,7 @@ def _step( optimizer_tx: optax.GradientTransformation, jit_state: JitTrainingState, batch: dict, - ) -> Tuple[JitTrainingState, Tuple[jax.Array, Dict[str, jax.Array]]]: + ) -> Tuple[JitTrainingState, TrainStepMetrics]: model = nnx.merge(graphdef, jit_state.model_state) def loss_for_grad( @@ -123,12 +133,18 @@ def mean_loss_for_grad( ) mean_unweighted = tree_util.tree_map(jnp.mean, unweighted_losses) - return new_jit_state, (mean_loss, mean_unweighted) + grad_norm = optax.global_norm(mean_grads) + + return new_jit_state, { + "loss": mean_loss, + "unweighted_losses": mean_unweighted, + "grad_norm": grad_norm, + } self.train_step = cast( Callable[ [optax.GradientTransformation, JitTrainingState, dict], - Tuple[JitTrainingState, Tuple[jax.Array, Dict[str, jax.Array]]], + Tuple[JitTrainingState, TrainStepMetrics], ], _step, ) @@ -138,6 +154,9 @@ def run( jit_state: JitTrainingState, datagen: Generator[Tuple[np.ndarray, ...], None, None], num_steps: int, + *, + tensorboard_logger: Optional[TensorboardLogger] = None, + lr_schedule: Optional[Callable[[int], jax.Array]] = None, ) -> JitTrainingState: assert jit_state.opt_state is not None for _ in range(num_steps): @@ -145,7 +164,7 @@ def run( batch = next(datagen) b_inputs, b_policy, b_values, _, b_movesleft = batch logger.info("Fetched batch from dataloader") - jit_state, (loss, unweighted_losses) = self.train_step( + jit_state, metrics = self.train_step( self.optimizer_tx, jit_state, { @@ -155,10 +174,35 @@ def run( "movesleft_targets": b_movesleft, }, ) + loss = metrics["loss"] + unweighted_losses = metrics["unweighted_losses"] + grad_norm = metrics["grad_norm"] logger.info( f"Step {jit_state.step}, Loss: {loss}, Unweighted losses:" f" {unweighted_losses}" ) + if tensorboard_logger is not None: + step_value = int(jit_state.step) + lr_value: Optional[float] = None + if lr_schedule is not None: + lr_value = float(lr_schedule(step_value)) + tensorboard_logger.log_step( + StepMetrics( + step=step_value, + learning_rate=lr_value, + weighted_loss=float(loss), + unweighted_losses={ + name: float(value) + for name, value in unweighted_losses.items() + }, + gradient_norm=float(grad_norm), + ) + ) + if tensorboard_logger is not None: + tensorboard_logger.log_epoch( + step=int(jit_state.step), + model_state=jit_state.model_state, + ) return jit_state @@ -206,16 +250,33 @@ def train(config_filename: str) -> None: config.training.optimizer, max_grad_norm=getattr(config.training, "max_grad_norm", 0.0), ) + lr_schedule = make_lr_schedule(config.training.optimizer) training = Training( optimizer_tx=optimizer_tx, graphdef=model, loss_fn=LczeroLoss(config=config.training.losses), ) - new_state = training.run( - jit_state, - from_dataloader(make_dataloader(config.data_loader)), - config.training.schedule.steps_per_network, - ) + tensorboard_logger: Optional[TensorboardLogger] = None + if config.export.HasField("tensorboard_path"): + data_loader_config = ( + config.data_loader if config.HasField("data_loader") else None + ) + tensorboard_logger = TensorboardLogger( + config.export.tensorboard_path, + data_loader_config=data_loader_config, + schedule_config=config.training.schedule, + ) + try: + new_state = training.run( + jit_state, + from_dataloader(make_dataloader(config.data_loader)), + config.training.schedule.steps_per_network, + tensorboard_logger=tensorboard_logger, + lr_schedule=lr_schedule, + ) + finally: + if tensorboard_logger is not None: + tensorboard_logger.close() if config.export.HasField("path"): date_str = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") diff --git a/uv.lock b/uv.lock index f3a5f020..f6a57d08 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.13'", @@ -778,6 +778,7 @@ dependencies = [ { name = "pytest" }, { name = "python-dotenv" }, { name = "requests", extra = ["socks"] }, + { name = "tensorboardx" }, { name = "textual" }, ] @@ -817,6 +818,7 @@ requires-dist = [ { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, { name = "python-dotenv", specifier = ">=1.1.1" }, { name = "requests", extras = ["socks"], specifier = ">=2.32.5" }, + { name = "tensorboardx", specifier = ">=2.6.4" }, { name = "textual", extras = ["dev"], specifier = ">=0.47.0" }, { name = "typing-extensions", marker = "extra == 'dev'", specifier = ">=4.0.0" }, ] @@ -1960,6 +1962,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e9/44/75a9c9421471a6c4805dbf2356f7c181a29c1879239abab1ea2cc8f38b40/sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2", size = 10235, upload-time = "2024-02-25T23:20:01.196Z" }, ] +[[package]] +name = "tensorboardx" +version = "2.6.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "protobuf" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2b/c5/d4cc6e293fb837aaf9f76dd7745476aeba8ef7ef5146c3b3f9ee375fe7a5/tensorboardx-2.6.4.tar.gz", hash = "sha256:b163ccb7798b31100b9f5fa4d6bc22dad362d7065c2f24b51e50731adde86828", size = 4769801, upload-time = "2025-06-10T22:37:07.419Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/1d/b5d63f1a6b824282b57f7b581810d20b7a28ca951f2d5b59f1eb0782c12b/tensorboardx-2.6.4-py3-none-any.whl", hash = "sha256:5970cf3a1f0a6a6e8b180ccf46f3fe832b8a25a70b86e5a237048a7c0beb18e2", size = 87201, upload-time = "2025-06-10T22:37:05.44Z" }, +] + [[package]] name = "tensorstore" version = "0.1.76"