diff --git a/dependencies/requirements/base_requirements/requirements.txt b/dependencies/requirements/base_requirements/requirements.txt index 582d99c3d7..c40252cfc1 100644 --- a/dependencies/requirements/base_requirements/requirements.txt +++ b/dependencies/requirements/base_requirements/requirements.txt @@ -4,6 +4,7 @@ array-record cloud-accelerator-diagnostics cloud-tpu-diagnostics datasets +drjax flax gcsfs google-api-python-client diff --git a/dependencies/requirements/generated_requirements/cuda12-requirements.txt b/dependencies/requirements/generated_requirements/cuda12-requirements.txt index 00efbc3b1c..9879536ab1 100644 --- a/dependencies/requirements/generated_requirements/cuda12-requirements.txt +++ b/dependencies/requirements/generated_requirements/cuda12-requirements.txt @@ -40,6 +40,7 @@ dill>=0.4.0 distlib>=0.4.0 dm-tree>=0.1.9 docstring-parser>=0.17.0 +drjax>=0.1.4 editdistance>=0.8.1 einops>=0.8.1 einshape>=1.0 diff --git a/dependencies/requirements/generated_requirements/tpu-requirements.txt b/dependencies/requirements/generated_requirements/tpu-requirements.txt index 4569a54438..7cc76ac137 100644 --- a/dependencies/requirements/generated_requirements/tpu-requirements.txt +++ b/dependencies/requirements/generated_requirements/tpu-requirements.txt @@ -40,6 +40,7 @@ dill>=0.4.0 distlib>=0.4.0 dm-tree>=0.1.9 docstring-parser>=0.17.0 +drjax>=0.1.4 editdistance>=0.8.1 einops>=0.8.1 einshape>=1.0 diff --git a/dependencies/requirements/requirements.txt b/dependencies/requirements/requirements.txt index 439e0e3a75..7ae9f9114a 100644 --- a/dependencies/requirements/requirements.txt +++ b/dependencies/requirements/requirements.txt @@ -4,6 +4,7 @@ array-record cloud-accelerator-diagnostics cloud-tpu-diagnostics datasets +drjax>=0.1.4 flax gcsfs google-api-python-client diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 7cc3cb5b1b..0781c34665 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -383,7 +383,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' # Parallelism shard_mode: "auto" # can be either auto or explicit -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] +mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], @@ -460,6 +460,7 @@ logical_axis_rules: [ ['paged_kv_head_dim_size', []], ['dense_layers', []], ['moe_layers', []], + ['diloco', 'diloco'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']] @@ -472,6 +473,7 @@ sharding_tolerance: 0.02 # value to auto-shard based on available slices and devices. # By default, product of the DCN axes should equal number of slices # and product of the ICI axes should equal number of devices per slice. +dcn_diloco_parallelism: 1 dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded dcn_fsdp_parallelism: 1 dcn_fsdp_transpose_parallelism: 1 @@ -484,6 +486,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended dcn_pipeline_parallelism: 1 dcn_expert_parallelism: 1 dcn_autoregressive_parallelism: 1 # never recommended +ici_diloco_parallelism: 1 ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_fsdp_transpose_parallelism: 1 @@ -696,6 +699,11 @@ enable_data_shuffling: True data_shuffle_seed: 0 init_weights_seed: 0 +# DiLoCo params. +diloco_sync_period: 36 +diloco_outer_lr: 0.3 +diloco_outer_momentum: 0.9 + # You may disable clipping by setting gradient_clipping_threshold to zero. gradient_clipping_threshold: 1.0 diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index e37cfdeef0..048b86bc49 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -733,6 +733,7 @@ class LayoutAndSharding(BaseModel): class DcnParallelism(BaseModel): """Parallelism dimensions across the DCN (Data Center Network).""" + dcn_diloco_parallelism: int = Field(-1, description="DCN axis for Diloco parallelism.") dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.") dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.") dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.") @@ -752,6 +753,7 @@ class DcnParallelism(BaseModel): class IciParallelism(BaseModel): """Parallelism dimensions within the ICI (Inter-Chip Interconnect).""" + ici_diloco_parallelism: int = Field(-1, description="ICI axis for Diloco parallelism.") ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.") ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.") ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.") @@ -1000,6 +1002,14 @@ class TrainingLoop(BaseModel): init_weights_seed: int = Field(0, description="Seed for model weight initialization.") +class DilocoParams(BaseModel): + """Diloco Hyperparameters""" + + diloco_sync_period: int = Field(36, description="Diloco sync period.") + diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.") + diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.") + + class Optimizer(BaseModel): """Configuration for the optimizer and learning rate schedule.""" @@ -1486,6 +1496,11 @@ class DerivedValues(BaseModel): description="Effective number of query heads, scaled by `global_parameter_scale`.", ) + num_diloco_replicas: None | int = Field( + None, + description="The number of diloco replicas, derived from ICI and DCN values.", + ) + ici_parallelism: None | list[int] = Field( None, description="Aggregated list of all ICI parallelism values for legacy compatibility.", @@ -1631,6 +1646,7 @@ class MaxTextConfig( # Training, Optimization, and Fine-Tuning RematAndOffload, TrainingLoop, + DilocoParams, Optimizer, AdamW, Muon, @@ -2152,6 +2168,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility. if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage": self.ici_parallelism = [ + self.ici_diloco_parallelism, self.ici_pipeline_parallelism, self.ici_data_parallelism, self.ici_fsdp_parallelism, @@ -2166,6 +2183,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.ici_autoregressive_parallelism, ] self.dcn_parallelism = [ + self.dcn_diloco_parallelism, self.dcn_pipeline_parallelism, self.dcn_data_parallelism, self.dcn_fsdp_parallelism, @@ -2181,6 +2199,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de ] else: ici_map = { + "diloco": self.ici_diloco_parallelism, "data": self.ici_data_parallelism, "stage": self.ici_pipeline_parallelism, "fsdp": self.ici_fsdp_parallelism, @@ -2198,6 +2217,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes] dcn_map = { + "diloco": self.dcn_diloco_parallelism, "data": self.dcn_data_parallelism, "stage": self.dcn_pipeline_parallelism, "fsdp": self.dcn_fsdp_parallelism, @@ -2214,6 +2234,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de } self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes] + # Diloco params + self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism) + # Final string-to-enum conversions if they haven't been coerced by pydantic yet. if isinstance(self.decoder_block, str): self.decoder_block = DecoderBlockType(self.decoder_block.lower()) diff --git a/src/MaxText/diloco.py b/src/MaxText/diloco.py new file mode 100644 index 0000000000..7d137e466e --- /dev/null +++ b/src/MaxText/diloco.py @@ -0,0 +1,201 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An implementation of Distributed Low-Communication (DiLoCo) training. + +This module contains implementations of: + +- DiLoCo: Distributed Low-Communication Training of Language Models + https://arxiv.org/abs/2311.08105 +- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch + https://arxiv.org/abs/2501.18512 +""" + +from collections.abc import Sequence +from typing import Any, Callable + +import drjax +from flax import struct +from flax.training import train_state +import jax +import jax.numpy as jnp +from jaxtyping import Array, Int32, Key, PyTree, UInt32 +import optax + +from MaxText import pyconfig + +Batch = Any +Params = PyTree +Metrics = PyTree +OptState = optax.OptState +InnerOptStates = optax.OptState +PRNGKey = Key[Array, ""] | UInt32[Array, "2"] +Step = Int32[Array, ""] + + +class DiLoCoTrainState(struct.PyTreeNode): + """The state of the DiLoCo training process. + + Attributes: + inner_state: A `flax.training.train_state.TrainState` of the state for each + step of the inner optimization. All arrays are expected to have a leading + dimension with size of the number of diloco replicas so that training + steps can be mapped over this dimension. + outer_params: A PyTree of the global model weights. These will mimic a + sub-PyTree in `inner_state`, which rank-1 shape. + outer_opt_state: The state for the outer Nesterov momentum optimizer. + step: The step counter of the training process. + """ + + inner_state: train_state.TrainState + outer_params: Params + outer_opt_state: OptState + step: Step + + +def reshape_first_axis_with_diloco(num_diloco_replicas: int, pytree: PyTree) -> PyTree: + """Reshapes the first dimension of each array in the PyTree to include a DiLoCo axis. + + This function takes a a batch of data represented as a PyTree + and reshapes the leading dimension of each array within it. The purpose is + to introduce a new 'diloco' axis, which is used for distributing data + across DiLoCo replicas. + + Args: + num_diloco_replicas: The number of DiLoCo replicas. This determines the + size of the new leading dimension. + pytree: The input PyTree, where each array is expected to have a batch + dimension as its first axis. + + Returns: + A new PyTree with the same structure as the input, but with each array's + first dimension reshaped to `(num_diloco_replicas, original_batch_dim // num_diloco_replicas, ...)`. + The sharding specification is also updated to include the 'diloco' axis. + """ + + def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str]] = ()) -> jax.sharding.PartitionSpec: + if tuple(*pspec)[0] == "diloco": + # pull out diloco axis if already present + return jax.sharding.PartitionSpec("diloco", (*pspec[0][1:],), (*pspec[1:],)) + return jax.sharding.PartitionSpec("diloco", *pspec) + + def reshape_for_diloco(arr): + batch_dim, *example_shape = arr.shape + diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape) + s = arr.sharding + s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec)) + return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s) + + return jax.tree.map(reshape_for_diloco, pytree) + + +def build_diloco_state( + config: "pyconfig.HyperParameters", + initialize_state: Callable[[], train_state.TrainState], +) -> tuple[DiLoCoTrainState, PyTree]: + """Given a non-DiLoCo train state, construct a DiLoCo training state.""" + outer_optimizer = optax.sgd( + config.diloco_outer_lr, + momentum=config.diloco_outer_momentum, + nesterov=True, + ) + + @drjax.program(placements={"diloco": config.num_diloco_replicas}) + def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: + state = initialize_state() + # Inner state must be broadcast across clients. + inner_state = drjax.broadcast(state) + # Outer state retains a single copy of the model parameters and optimizer state. + outer_params = state.params + outer_opt_state = outer_optimizer.init(outer_params) + outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) + return ( + DiLoCoTrainState( + inner_state=inner_state, outer_params=outer_params, outer_opt_state=outer_opt_state, step=state.step + ), + outer_opt_state_sharding, + ) + + return init_diloco_state() + + +def build_diloco_train_step( + config: pyconfig.HyperParameters, + train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]], +) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]: + """Convert a local state and train step into DiLoCo-compatible versions. + + This is an implementation of the original (non-streaming) DiLoCo algorithm + which syncs all model parameters across the replicas every + `config.diloco_sync_period` steps, treating the difference accumulated over + non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on + the "global" model. + + Args: + config: The config used to set up training. + train_step: A local train step. This will be executed independently within + each replica. + """ + outer_optimizer = optax.sgd( + config.diloco_outer_lr, + momentum=config.diloco_outer_momentum, + nesterov=True, + ) + + def synchronize(state): + # Calculate the delta between the current replica's state and the global + # state (since last synchronization). + broadcast_outer_params = drjax.broadcast(state.outer_params) + model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params) + # Treat the average delta as the outer optimizer's gradient and apply to + # the global (outer) model params. + averaged_pseudo_grad = drjax.reduce_mean(model_delta) + updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params) + new_outer_params = optax.apply_updates(state.outer_params, updates) + # Replace inner model params with the new global model params. + # NOTE: inner optimizer state is retained despite the change in parameters, + # see section 6.1 in https://arxiv.org/pdf/2311.08105. + new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state) + return state.replace( + outer_params=new_outer_params, + outer_opt_state=new_opt_state, + inner_state=new_inner_state, + ) + + def typed_reduce_mean(in_tree): + total = drjax.reduce_sum(in_tree) + avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total) + return avg + + @drjax.program(placements={"diloco": config.num_diloco_replicas}) + def diloco_train_step(state, batch, prng): + # Broadcast the RNG across replicas. + broadcast_rng = drjax.broadcast(prng) + inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng)) + avg_metrics = typed_reduce_mean(metrics) + state = state.replace( + inner_state=inner_state, + step=inner_state.step[0], + ) + # Either synchronize the model, or no-op, depending on whether the current + # step falls on the synchronization period. + state = jax.lax.cond( + inner_state.step[0] % config.diloco_sync_period == 0, + synchronize, + lambda x: x, # no-op + state, + ) + return state, avg_metrics + + return diloco_train_step diff --git a/tests/diloco_test.py b/tests/diloco_test.py new file mode 100644 index 0000000000..2922a3fa67 --- /dev/null +++ b/tests/diloco_test.py @@ -0,0 +1,290 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the DiLoCo implementation in diloco.py""" + + +import os +from collections.abc import Mapping +import dataclasses +import unittest +from typing import Any + +import chex +from flax.experimental import nnx +from flax.training import train_state +import jax +import jax.numpy as jnp +import jax.sharding +import numpy as np +import optax +import pytest + +from MaxText import diloco +from MaxText.pyconfig import initialize_pydantic +from MaxText.globals import MAXTEXT_REPO_ROOT + +_BASE_CONFIG_PATH = os.path.join(MAXTEXT_REPO_ROOT, "src", "MaxText", "configs", "base.yml") + + +class SimpleNNXModel(nnx.Module): + """A simple state for testing a minimal model.""" + + def __init__(self, *, rngs: nnx.Rngs): + self.dense = nnx.Linear( + 2, + 1, + kernel_init=nnx.initializers.constant(jnp.asarray([[2.0], [1.0]])), + bias_init=nnx.initializers.ones_init(), + rngs=rngs, + ) + + def __call__(self, x): + return self.dense(x) + + +@dataclasses.dataclass +class _TestConfig: + """A fake config for test.""" + + keys: Mapping[str, Any] + + +class DiLoCoTest(unittest.TestCase): + + @pytest.mark.tpu_only + def test_diloco_training_simulation_with_mesh(self): + """Runs a simulation of DiLoCo training on a mesh and asserts correctness.""" + num_replicas = 2 + num_steps = 4 + + devices = jax.devices() + if len(devices) < num_replicas: + self.skipTest(f"Test requires {num_replicas} devices, but only {len(devices)} are available.") + + mesh_devices = np.array(devices[:num_replicas]).reshape(1, num_replicas) + mesh = jax.sharding.Mesh(mesh_devices, axis_names=("data", "diloco")) + + # test_config = pyconfig.HyperParameters( + # pydantic_config=_TestConfig( + # keys={ + # "num_diloco_replicas": num_replicas, + # "diloco_outer_momentum": 0.9, + # "diloco_outer_lr": 1.0, + # "diloco_sync_period": num_steps - 1, + # } + # ) + # ) + + test_config = initialize_pydantic( + [ + "", + _BASE_CONFIG_PATH, + f"dcn_diloco_parallelism={num_replicas}", + "ici_diloco_parallelism=1", + "diloco_outer_momentum=0.9", + "diloco_outer_lr=1.0", + f"diloco_sync_period={num_steps-1}", + ] + ) + + with mesh: + tx = optax.sgd(learning_rate=0.1) + rngs = nnx.Rngs(params=jax.random.key(seed=42)) + model = SimpleNNXModel(rngs=rngs) + graphdef, params = nnx.split(model) + + def nnx_apply_fn(params, inputs): + model_replica = nnx.merge(graphdef, params) + return model_replica(inputs) + + # 2. Vmap this new wrapper function + vmapped_apply = jax.vmap(nnx_apply_fn, in_axes=(None, 0)) + + def _test_train_step(state: train_state.TrainState, batch, prng_key: diloco.PRNGKey): + """A simple MSE loss train step to enable numerics testing.""" + del prng_key + + def loss_fn(params, batch): + inputs, labels = batch + logits = vmapped_apply(params, inputs) + residual = logits - labels + sq_residual = jnp.square(residual) + msq_residual = jnp.mean(sq_residual) + return msq_residual + + loss, grad = jax.value_and_grad(loss_fn)(state.params, batch) + return state.apply_gradients(grads=grad), loss + + initial_test_state = train_state.TrainState.create( + apply_fn=vmapped_apply, + params=params, + tx=tx, + ) + + diloco_test_state, _ = diloco.build_diloco_state(test_config, lambda: initial_test_state) + chex.assert_equal(diloco_test_state.step, 0) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + diloco_train_step = diloco.build_diloco_train_step(test_config, _test_train_step) + inputs = jnp.array( + [ + [[0.0, 1.0], [1.0, 0.0]], # First replica inputs. + [[1.0, 0.0], [0.0, 1.0]], # Second replica inputs. + ] + ) + labels = jnp.array( + [ + [[1.0], [2.0]], # First replica labels. + [[2.0], [3.0]], # Second replica labels. + ] + ) + + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "diloco")) + inputs = jax.device_put(inputs, sharding) + labels = jax.device_put(labels, sharding) + + # Run the first step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[2], [1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[2], [1]] + [1])) ^ 2 ) + # = mean( ([[1], [2]] - [[2], [3]]) ^ 2 ) + # = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] ) + # = 1.0 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[2], [1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[2], [1]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[3], [2]]) ^ 2 ) + # = mean( ([-1, 1]) ^ 2 ) = mean( [1, 1] ) + # = 1.0 + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 1.0) + chex.assert_equal(loss, 1.0) + # Assert no updates to the global model yet (no synchronization) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + # Run the second step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.9], [0.9]] + # Bias: b = [0.8] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.9], [0.9]] + [0.8])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.7], [2.7]]) ^ 2 ) + # = mean( ([-0.7, 0.7]) ^ 2 ) = mean( [0.49, 0.49] ) + # = 0.49 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.9], [1.1]] + # Bias: b = [1] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.9], [1.1]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.9], [2.1]]) ^ 2 ) + # = mean( ([-0.9, 0.9]) ^ 2 ) = mean( [0.81, 0.81] ) + # = 0.81 + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 2.0) + chex.assert_trees_all_close(loss, 0.65) + # Assert no updates to the global model yet (no synchronization) + chex.assert_trees_all_equal(diloco_test_state.outer_params, initial_test_state.params) + + # Run the third step, which synchronizes afterwards. + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.83], [0.83]] + # Bias: b = [0.66] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.83], [0.83]] + [0.66])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.49], [2.49]]) ^ 2 ) + # = mean( ([-0.49, 0.49]) ^ 2 ) = mean( [0.2401, 0.2401] ) + # = 0.2401 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.81], [1.19]] + # Bias: b = [1.] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.81], [1.19]] + [1])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.81], [2.19]]) ^ 2 ) + # = mean( ([-0.81, 0.81]) ^ 2 ) = mean( [0.6561, 0.6561] ) + # = 0.6561 + # + # After these are averaged, the model differences are computed to create a + # pseudo-gradient update to the outer_params and applied via a momentum + # based outer optimizer. + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 3.0) + chex.assert_trees_all_close(loss, 0.4481) + # Assert that inner and outer parameters are all equal now that + # synchronization has happened. + chex.assert_trees_all_equal( + diloco_test_state.outer_params, + jax.tree.map(lambda arr: arr[0, ...], diloco_test_state.inner_state.params), + ) + chex.assert_trees_all_equal( + diloco_test_state.outer_params, + jax.tree.map(lambda arr: arr[1, ...], diloco_test_state.inner_state.params), + ) + + # Run the fourth step (no synchronization). + # Replica 0: + # Data: [[0, 1], [1, 0]] + # Labels: [[1], [2]] + # Weights: w = [[1.5345], [1.0494]] + # Bias: b = [0.5839] + # Loss = mean((y - pred)^2) = + # = mean( ([[1], [2]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[1], [2]] - ([[0, 1], [1, 0]] . [[1.5345], [1.0494]]] + [0.5839])) ^ 2 ) + # = mean( ([[1], [2]] - [[1.6333], [2.1184]]) ^ 2 ) + # = mean( ([-0.6333, 0.1184]) ^ 2 ) = mean( [0.4010, 0.0140] ) + # ~ 0.2075 + # + # Replica 1: + # Data: [[1, 0], [0, 1]] + # Labels: [[2], [3]] + # Weights: w = [[1.5345], [1.0494]] + # Bias: b = [0.5839] + # Loss = mean((y - pred)^2) = + # = mean( ([[2], [3]] - (x . w + b)) ^ 2 ) ) + # = mean( ([[2], [3]] - ([[1, 0], [0, 1]] . [[1.5345], [1.0494]] + [0.5839])) ^ 2 ) + # = mean( ([[2], [3]] - [[2.1184], [1.6333]]) ^ 2 ) + # = mean( ([-0.1184, 1.3667]) ^ 2 ) = mean( [0.0140, 1.8678] ) + # ~ 0.94 + step_three_outer_params = diloco_test_state.outer_params + diloco_test_state, loss = diloco_train_step(diloco_test_state, (inputs, labels), jax.random.key(seed=42)) + chex.assert_equal(diloco_test_state.step, 4.0) + chex.assert_trees_all_close(loss, 0.574244) + # Assert no updates to the global model since previous step (no + # synchronization). + chex.assert_trees_all_equal(diloco_test_state.outer_params, step_three_outer_params)