Skip to content
Open

X #849

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions .github/workflows/python-package-conda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
name: Python Package using Conda

on: [push]

jobs:
build-linux:
runs-on: ubuntu-latest
strategy:
max-parallel: 5

steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Add conda to system path
run: |
# $CONDA is an environment variable pointing to the root of the miniconda directory
echo $CONDA/bin >> $GITHUB_PATH
- name: Install dependencies
run: |
conda env update --file environment.yml --name base
- name: Lint with flake8
run: |
conda install flake8
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
conda install pytest
pytest
Empty file added __init__.py
Empty file.
Binary file added causal-conv1d-main.zip
Binary file not shown.
18 changes: 18 additions & 0 deletions config_mamba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from dataclasses import dataclass, field


@dataclass
class MambaConfig:

d_model: int = 2560
d_intermediate: int = 0
n_layer: int = 64
vocab_size: int = 50277
ssm_cfg: dict = field(default_factory=dict)
attn_layer_idx: list = field(default_factory=list)
attn_cfg: dict = field(default_factory=dict)
rms_norm: bool = True
residual_in_fp32: bool = True
fused_add_norm: bool = True
pad_vocab_size_multiple: int = 8
tie_embeddings: bool = True
96 changes: 96 additions & 0 deletions determinism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) 2024, Tri Dao, Albert Gu.

import os
import warnings
from packaging import version

import torch

try:
import triton
TRITON_VERSION = version.parse(triton.__version__)
except ImportError:
TRITON_VERSION = version.parse("0.0.0")

TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0")
_autotune_warning_issued = False

_deterministic_override = None


def use_deterministic_mode():
if _deterministic_override is not None:
return _deterministic_override
env = os.environ.get('MAMBA_DETERMINISTIC')
if env:
return env[0] == '1'
return torch.are_deterministic_algorithms_enabled()


def set_deterministic_mode(value):
global _deterministic_override
_deterministic_override = value


def _estimate_config_cost(cfg):
"""Estimate shared memory cost of a config. Lower is cheaper."""
block_product = 1
for key, val in cfg.kwargs.items():
if key.startswith('BLOCK_SIZE_'):
block_product *= val
return block_product * (getattr(cfg, 'num_stages', 1) or 1)


def _filter_configs_by_block_sizes(configs):
"""Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars."""
env_filters = {}
for suffix in ('M', 'N', 'K', 'DSTATE'):
env_val = os.environ.get(f"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}")
if env_val is not None:
env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val)
if not env_filters:
return None
matching = configs
for key, target in env_filters.items():
matching = [c for c in matching if c.kwargs.get(key) == target]
return matching[:1] if matching else None


def autotune_configs(configs):
"""Select autotune configs for deterministic mode.

Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0,
otherwise auto-selects the cheapest config by block size * stages.
"""
if not configs or not use_deterministic_mode():
return configs
if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1":
return configs
global _autotune_warning_issued
if not _autotune_warning_issued:
_autotune_warning_issued = True
msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." if TRITON_HAS_CACHE_RESULTS else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning."
warnings.warn(msg)
filtered = _filter_configs_by_block_sizes(configs)
if filtered:
return filtered
return [min(configs, key=_estimate_config_cost)]


def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True):
"""Allocate buffer for deterministic per-program reductions."""
if base_shape is None:
return None, 0
if deterministic:
factory = torch.zeros if zero_init else torch.empty
tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype)
return tensor, tensor.stride(-1)
return torch.empty(*base_shape, device=device, dtype=dtype), 0


def finalize_tile_workspace(tensor, deterministic):
if tensor is None:
return None
if deterministic:
tensor = tensor.sum(dim=-1)
return tensor
144 changes: 144 additions & 0 deletions distributed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from typing import Optional

import torch
from torch import Tensor
from torch.distributed import ProcessGroup

# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
# version of PyTorch. The following 4 lines are for backward compatibility with
# older PyTorch.
if "all_gather_into_tensor" not in dir(torch.distributed):
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
if "reduce_scatter_tensor" not in dir(torch.distributed):
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base


# Raw operation, does not support autograd, but does support async
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
output = torch.empty(
world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.all_gather_into_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle


# Raw operation, does not support autograd, but does support async
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)
return output, handle


# Raw operation, does not support autograd, but does support async
def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
input_ = input_.contiguous()
handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op)
return input_, handle


class AllGatherFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_gather_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
return grad_input, None


# Supports autograd, but does not support async
all_gather = AllGatherFunc.apply


class ReduceScatterFunc(torch.autograd.Function):
"""Reduce scatter the input from the sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = reduce_scatter_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
return grad_input, None


# Supports autograd, but does not support async
reduce_scatter = ReduceScatterFunc.apply


class AllReduceFunc(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatenate."""

@staticmethod
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
ctx.process_group = process_group
output, _ = all_reduce_raw(input_, process_group)
return output

@staticmethod
def backward(ctx, grad_output: Tensor):
return grad_output, None


# Supports autograd, but does not support async
all_reduce = AllReduceFunc.apply


def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _shared_params=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
pamams_shared = {
name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False)
}
for _, p in sorted(pamams_shared.items()):
with torch.no_grad():
# Broadcast needs src to be global rank, not group rank
torch.distributed.broadcast(
p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group
)


# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256
def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup):
# We want to iterate over parameters with _sequence_parallel=True in the same order,
# as different ranks might have different number of parameters (e.g., only rank 0 has bias).
params_seqparallel = {
name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False)
}
grads = [p.grad for _, p in sorted(params_seqparallel.items())]
if grads:
with torch.no_grad():
coalesced = torch._utils._flatten_dense_tensors(grads)
torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced)


def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.

The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
Binary file added einops-main.zip
Binary file not shown.
Loading