Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ array-record
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
drjax
flax
gcsfs
google-api-python-client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dill>=0.4.0
distlib>=0.4.0
dm-tree>=0.1.9
docstring-parser>=0.17.0
drjax>=0.1.4
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ dill>=0.4.0
distlib>=0.4.0
dm-tree>=0.1.9
docstring-parser>=0.17.0
drjax>=0.1.4
editdistance>=0.8.1
einops>=0.8.1
einshape>=1.0
Expand Down
1 change: 1 addition & 0 deletions dependencies/requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ array-record
cloud-accelerator-diagnostics
cloud-tpu-diagnostics
datasets
drjax>=0.1.4
flax
gcsfs
google-api-python-client
Expand Down
10 changes: 9 additions & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'

# Parallelism
shard_mode: "auto" # can be either auto or explicit
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
logical_axis_rules: [
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
Expand Down Expand Up @@ -460,6 +460,7 @@ logical_axis_rules: [
['paged_kv_head_dim_size', []],
['dense_layers', []],
['moe_layers', []],
['diloco', 'diloco'],
]
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
Expand All @@ -472,6 +473,7 @@ sharding_tolerance: 0.02
# value to auto-shard based on available slices and devices.
# By default, product of the DCN axes should equal number of slices
# and product of the ICI axes should equal number of devices per slice.
dcn_diloco_parallelism: 1
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
dcn_fsdp_parallelism: 1
dcn_fsdp_transpose_parallelism: 1
Expand All @@ -484,6 +486,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
dcn_pipeline_parallelism: 1
dcn_expert_parallelism: 1
dcn_autoregressive_parallelism: 1 # never recommended
ici_diloco_parallelism: 1
ici_data_parallelism: 1
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
ici_fsdp_transpose_parallelism: 1
Expand Down Expand Up @@ -696,6 +699,11 @@ enable_data_shuffling: True
data_shuffle_seed: 0
init_weights_seed: 0

# DiLoCo params.
diloco_sync_period: 36
diloco_outer_lr: 0.3
diloco_outer_momentum: 0.9

# You may disable clipping by setting gradient_clipping_threshold to zero.
gradient_clipping_threshold: 1.0

Expand Down
23 changes: 23 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,7 @@ class LayoutAndSharding(BaseModel):
class DcnParallelism(BaseModel):
"""Parallelism dimensions across the DCN (Data Center Network)."""

dcn_diloco_parallelism: int = Field(-1, description="DCN axis for Diloco parallelism.")
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
Expand All @@ -752,6 +753,7 @@ class DcnParallelism(BaseModel):
class IciParallelism(BaseModel):
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""

ici_diloco_parallelism: int = Field(-1, description="ICI axis for Diloco parallelism.")
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
Expand Down Expand Up @@ -1000,6 +1002,14 @@ class TrainingLoop(BaseModel):
init_weights_seed: int = Field(0, description="Seed for model weight initialization.")


class DilocoParams(BaseModel):
"""Diloco Hyperparameters"""

diloco_sync_period: int = Field(36, description="Diloco sync period.")
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")


class Optimizer(BaseModel):
"""Configuration for the optimizer and learning rate schedule."""

Expand Down Expand Up @@ -1486,6 +1496,11 @@ class DerivedValues(BaseModel):
description="Effective number of query heads, scaled by `global_parameter_scale`.",
)

num_diloco_replicas: None | int = Field(
None,
description="The number of diloco replicas, derived from ICI and DCN values.",
)

ici_parallelism: None | list[int] = Field(
None,
description="Aggregated list of all ICI parallelism values for legacy compatibility.",
Expand Down Expand Up @@ -1631,6 +1646,7 @@ class MaxTextConfig(
# Training, Optimization, and Fine-Tuning
RematAndOffload,
TrainingLoop,
DilocoParams,
Optimizer,
AdamW,
Muon,
Expand Down Expand Up @@ -2152,6 +2168,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
self.ici_parallelism = [
self.ici_diloco_parallelism,
self.ici_pipeline_parallelism,
self.ici_data_parallelism,
self.ici_fsdp_parallelism,
Expand All @@ -2166,6 +2183,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.ici_autoregressive_parallelism,
]
self.dcn_parallelism = [
self.dcn_diloco_parallelism,
self.dcn_pipeline_parallelism,
self.dcn_data_parallelism,
self.dcn_fsdp_parallelism,
Expand All @@ -2181,6 +2199,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
]
else:
ici_map = {
"diloco": self.ici_diloco_parallelism,
"data": self.ici_data_parallelism,
"stage": self.ici_pipeline_parallelism,
"fsdp": self.ici_fsdp_parallelism,
Expand All @@ -2198,6 +2217,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]

dcn_map = {
"diloco": self.dcn_diloco_parallelism,
"data": self.dcn_data_parallelism,
"stage": self.dcn_pipeline_parallelism,
"fsdp": self.dcn_fsdp_parallelism,
Expand All @@ -2214,6 +2234,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
}
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]

# Diloco params
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)

# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
if isinstance(self.decoder_block, str):
self.decoder_block = DecoderBlockType(self.decoder_block.lower())
Expand Down
201 changes: 201 additions & 0 deletions src/MaxText/diloco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""An implementation of Distributed Low-Communication (DiLoCo) training.

This module contains implementations of:

- DiLoCo: Distributed Low-Communication Training of Language Models
https://arxiv.org/abs/2311.08105
- Streaming DiLoCo with overlapping communication: Towards a Distributed Free Lunch
https://arxiv.org/abs/2501.18512
"""

from collections.abc import Sequence
from typing import Any, Callable

import drjax
from flax import struct
from flax.training import train_state
import jax
import jax.numpy as jnp
from jaxtyping import Array, Int32, Key, PyTree, UInt32
import optax

from MaxText import pyconfig

Batch = Any
Params = PyTree
Metrics = PyTree
OptState = optax.OptState
InnerOptStates = optax.OptState
PRNGKey = Key[Array, ""] | UInt32[Array, "2"]
Step = Int32[Array, ""]


class DiLoCoTrainState(struct.PyTreeNode):
"""The state of the DiLoCo training process.

Attributes:
inner_state: A `flax.training.train_state.TrainState` of the state for each
step of the inner optimization. All arrays are expected to have a leading
dimension with size of the number of diloco replicas so that training
steps can be mapped over this dimension.
outer_params: A PyTree of the global model weights. These will mimic a
sub-PyTree in `inner_state`, which rank-1 shape.
outer_opt_state: The state for the outer Nesterov momentum optimizer.
step: The step counter of the training process.
"""

inner_state: train_state.TrainState
outer_params: Params
outer_opt_state: OptState
step: Step


def reshape_first_axis_with_diloco(num_diloco_replicas: int, pytree: PyTree) -> PyTree:
"""Reshapes the first dimension of each array in the PyTree to include a DiLoCo axis.

This function takes a a batch of data represented as a PyTree
and reshapes the leading dimension of each array within it. The purpose is
to introduce a new 'diloco' axis, which is used for distributing data
across DiLoCo replicas.

Args:
num_diloco_replicas: The number of DiLoCo replicas. This determines the
size of the new leading dimension.
pytree: The input PyTree, where each array is expected to have a batch
dimension as its first axis.

Returns:
A new PyTree with the same structure as the input, but with each array's
first dimension reshaped to `(num_diloco_replicas, original_batch_dim // num_diloco_replicas, ...)`.
The sharding specification is also updated to include the 'diloco' axis.
"""

def extend_pspec(pspec: jax.sharding.PartitionSpec | Sequence[str | Sequence[str]] = ()) -> jax.sharding.PartitionSpec:
if tuple(*pspec)[0] == "diloco":
# pull out diloco axis if already present
return jax.sharding.PartitionSpec("diloco", (*pspec[0][1:],), (*pspec[1:],))
return jax.sharding.PartitionSpec("diloco", *pspec)

def reshape_for_diloco(arr):
batch_dim, *example_shape = arr.shape
diloco_shape = (num_diloco_replicas, batch_dim // num_diloco_replicas, *example_shape)
s = arr.sharding
s = jax.sharding.NamedSharding(mesh=s.mesh, spec=extend_pspec(s.spec))
return jax.lax.with_sharding_constraint(jnp.reshape(arr, shape=diloco_shape), s)

return jax.tree.map(reshape_for_diloco, pytree)


def build_diloco_state(
config: "pyconfig.HyperParameters",
initialize_state: Callable[[], train_state.TrainState],
) -> tuple[DiLoCoTrainState, PyTree]:
"""Given a non-DiLoCo train state, construct a DiLoCo training state."""
outer_optimizer = optax.sgd(
config.diloco_outer_lr,
momentum=config.diloco_outer_momentum,
nesterov=True,
)

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]:
state = initialize_state()
# Inner state must be broadcast across clients.
inner_state = drjax.broadcast(state)
# Outer state retains a single copy of the model parameters and optimizer state.
outer_params = state.params
outer_opt_state = outer_optimizer.init(outer_params)
outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state)
return (
DiLoCoTrainState(
inner_state=inner_state, outer_params=outer_params, outer_opt_state=outer_opt_state, step=state.step
),
outer_opt_state_sharding,
)

return init_diloco_state()


def build_diloco_train_step(
config: pyconfig.HyperParameters,
train_step: Callable[[train_state.TrainState, Batch, PRNGKey], tuple[train_state.TrainState, Metrics]],
) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]:
"""Convert a local state and train step into DiLoCo-compatible versions.

This is an implementation of the original (non-streaming) DiLoCo algorithm
which syncs all model parameters across the replicas every
`config.diloco_sync_period` steps, treating the difference accumulated over
non-sync steps as a pseudo gradient and applying SGD with Nesterov momentum on
the "global" model.

Args:
config: The config used to set up training.
train_step: A local train step. This will be executed independently within
each replica.
"""
outer_optimizer = optax.sgd(
config.diloco_outer_lr,
momentum=config.diloco_outer_momentum,
nesterov=True,
)

def synchronize(state):
# Calculate the delta between the current replica's state and the global
# state (since last synchronization).
broadcast_outer_params = drjax.broadcast(state.outer_params)
model_delta = jax.tree.map(lambda x, y: y - x, state.inner_state.params, broadcast_outer_params)
# Treat the average delta as the outer optimizer's gradient and apply to
# the global (outer) model params.
averaged_pseudo_grad = drjax.reduce_mean(model_delta)
updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.outer_params)
new_outer_params = optax.apply_updates(state.outer_params, updates)
# Replace inner model params with the new global model params.
# NOTE: inner optimizer state is retained despite the change in parameters,
# see section 6.1 in https://arxiv.org/pdf/2311.08105.
new_inner_state = drjax.map_fn(lambda state: state.replace(params=new_outer_params), state.inner_state)
return state.replace(
outer_params=new_outer_params,
outer_opt_state=new_opt_state,
inner_state=new_inner_state,
)

def typed_reduce_mean(in_tree):
total = drjax.reduce_sum(in_tree)
avg = jax.tree.map(lambda x: (x / config.num_diloco_replicas).astype(x.dtype), total)
return avg

@drjax.program(placements={"diloco": config.num_diloco_replicas})
def diloco_train_step(state, batch, prng):
# Broadcast the RNG across replicas.
broadcast_rng = drjax.broadcast(prng)
inner_state, metrics = drjax.map_fn(train_step, (state.inner_state, batch, broadcast_rng))
avg_metrics = typed_reduce_mean(metrics)
state = state.replace(
inner_state=inner_state,
step=inner_state.step[0],
)
# Either synchronize the model, or no-op, depending on whether the current
# step falls on the synchronization period.
state = jax.lax.cond(
inner_state.step[0] % config.diloco_sync_period == 0,
synchronize,
lambda x: x, # no-op
state,
)
return state, avg_metrics

return diloco_train_step
Loading
Loading