Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@
feeding training data.
* [Training Tuple Format](training_tuple.md) — A description of the training
tuple format used in the project. This is an interface between data loader and
model training.
model training.
* [TensorBoard Metrics](tensorboard.md) — Instructions for enabling and
browsing training metrics that are exported during training.

40 changes: 40 additions & 0 deletions docs/tensorboard.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# TensorBoard Metrics

Training can export metrics that are compatible with TensorBoard. To enable
logging, set the `tensorboard_path` field inside the `export` section of the
root configuration. The trainer writes TensorBoard event files into the
specified directory while training runs.

```protobuf
export {
path: "checkpoints"
tensorboard_path: "logs/train"
}
```

## Recorded data

The trainer records the following statistics:

* **Per step**
* Learning rate that was applied for the step.
* Weighted loss value returned by the loss function.
* Unweighted loss components for value, policy, and moves left heads.
* Global gradient norm after clipping.
* **Per epoch**
* Histogram of all model weights together with mean, standard deviation,
minimum, and maximum scalars.
* Configuration scalars: batch size, steps per network, and chunks per
network.

## Viewing the dashboard

Launch TensorBoard and point it to the directory configured above:

```bash
tensorboard --logdir logs/train
```

The dashboard contains grouped scalars for configuration parameters, losses,
and gradient norms, and a histogram of the weight distribution after each
training epoch.
2 changes: 2 additions & 0 deletions proto/export_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ message ExportConfig {
optional string path = 1;
// Training run ID for uploading to training website. Only uploads when set.
optional int32 upload_training_run = 2;
// Directory where TensorBoard event files should be written.
optional string tensorboard_path = 3;
}
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"python-dotenv>=1.1.1",
"requests[socks]>=2.32.5",
"matplotlib>=3.10.6",
"tensorboardx>=2.6.4",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -65,6 +66,10 @@ ignore_missing_imports = true
module = "optax"
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = "tensorboardX"
ignore_missing_imports = true

[tool.pytest.ini_options]
testpaths = ["src"]
python_files = ["test_*.py", "*_test.py"]
Expand Down
20 changes: 8 additions & 12 deletions src/lczero_training/training/overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,15 @@ def run_phase(
) -> None:
nonlocal jit_state
for _ in range(num_steps):
jit_state, (loss, unweighted_losses) = training.train_step(
jit_state, metrics = training.train_step(
optimizer_tx,
jit_state,
train_batch,
)
loss_value, unweighted_host = jax.device_get(
(loss, unweighted_losses)
)
loss_value = float(np.asarray(loss_value))
loss_value = float(np.asarray(metrics["loss"]))
unweighted_host = tree_util.tree_map(
lambda x: float(np.asarray(x)), unweighted_host
lambda x: float(np.asarray(x)),
metrics["unweighted_losses"],
)

eval_loss, eval_unweighted = eval_step(
Expand Down Expand Up @@ -252,17 +250,15 @@ def run_phase(
else:
logger.info("Starting overfit loop for %d steps", num_steps)
for _ in range(num_steps):
jit_state, (loss, unweighted_losses) = training.train_step(
jit_state, metrics = training.train_step(
optimizer_tx,
jit_state,
prepared_batch_a,
)
loss_value, unweighted_host = jax.device_get(
(loss, unweighted_losses)
)
loss_value = float(np.asarray(loss_value))
loss_value = float(np.asarray(metrics["loss"]))
unweighted_host = tree_util.tree_map(
lambda x: float(np.asarray(x)), unweighted_host
lambda x: float(np.asarray(x)),
metrics["unweighted_losses"],
)
step_value = int(
np.asarray(jax.device_get(jit_state.step)).flat[0]
Expand Down
152 changes: 152 additions & 0 deletions src/lczero_training/training/tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
"""Utility helpers for writing TensorBoard summaries during training."""

from dataclasses import dataclass
from typing import Mapping, Optional

import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from jax import tree_util
from tensorboardX import SummaryWriter

from proto.data_loader_config_pb2 import DataLoaderConfig
from proto.training_config_pb2 import ScheduleConfig


@dataclass
class StepMetrics:
"""Metrics collected for a single optimization step."""

step: int
learning_rate: Optional[float]
weighted_loss: float
unweighted_losses: Mapping[str, float]
gradient_norm: Optional[float]


class TensorboardLogger:
"""Writes training metrics and configuration to TensorBoard."""

def __init__(
self,
logdir: Optional[str],
*,
data_loader_config: Optional[DataLoaderConfig] = None,
schedule_config: Optional[ScheduleConfig] = None,
) -> None:
if logdir:
self._writer = SummaryWriter(logdir)
else:
self._writer = None

self._batch_size = _extract_batch_size(data_loader_config)
self._steps_per_network = (
schedule_config.steps_per_network
if schedule_config is not None
else None
)
self._chunks_per_network = (
schedule_config.chunks_per_network
if schedule_config is not None
else None
)

def close(self) -> None:
if self._writer is not None:
self._writer.close()

def log_step(self, metrics: StepMetrics) -> None:
if self._writer is None:
return

if metrics.learning_rate is not None:
self._writer.add_scalar(
"config/learning_rate", metrics.learning_rate, metrics.step
)

self._writer.add_scalar(
"loss/weighted", metrics.weighted_loss, metrics.step
)

for name, value in sorted(metrics.unweighted_losses.items()):
self._writer.add_scalar(
f"loss/unweighted/{name}", value, metrics.step
)

if metrics.gradient_norm is not None:
self._writer.add_scalar(
"gradients/global_norm", metrics.gradient_norm, metrics.step
)

self._writer.flush()

def log_epoch(self, step: int, model_state: nnx.State) -> None:
if self._writer is None:
return

weights = _collect_weights(model_state)
if weights is not None and weights.size:
self._writer.add_histogram("weights/distribution", weights, step)
self._writer.add_scalar(
"weights/mean", float(np.mean(weights)), step
)
self._writer.add_scalar("weights/std", float(np.std(weights)), step)
self._writer.add_scalar("weights/min", float(np.min(weights)), step)
self._writer.add_scalar("weights/max", float(np.max(weights)), step)

if self._batch_size is not None:
self._writer.add_scalar("config/batch_size", self._batch_size, step)
if self._steps_per_network is not None:
self._writer.add_scalar(
"config/steps_per_network", self._steps_per_network, step
)
if self._chunks_per_network is not None:
self._writer.add_scalar(
"config/chunks_per_network", self._chunks_per_network, step
)

self._writer.flush()


def _extract_batch_size(
config: Optional[DataLoaderConfig],
) -> Optional[int]:
if config is None:
return None

for stage in config.stage:
if stage.HasField("tensor_generator"):
generator = stage.tensor_generator
if generator.HasField("batch_size"):
return int(generator.batch_size)
return None


def _collect_weights(model_state: nnx.State) -> Optional[np.ndarray]:
leaves = tree_util.tree_leaves(
model_state, is_leaf=lambda node: hasattr(node, "value")
)
arrays: list[np.ndarray] = []
for leaf in leaves:
array = _leaf_to_array(leaf)
if array is None:
continue
flat = np.asarray(jnp.ravel(array))
if flat.size:
arrays.append(flat)

if not arrays:
return None
return np.concatenate(arrays)


def _leaf_to_array(value: object) -> Optional[jax.Array]:
if isinstance(value, jax.Array):
return value
if isinstance(value, np.ndarray):
return jnp.asarray(value)
maybe_value = getattr(value, "value", None)
if isinstance(maybe_value, (jax.Array, np.ndarray)):
return jnp.asarray(maybe_value)
return None
Loading