diff --git a/.github/workflows/python-package-conda.yml b/.github/workflows/python-package-conda.yml new file mode 100644 index 000000000..f3586044a --- /dev/null +++ b/.github/workflows/python-package-conda.yml @@ -0,0 +1,34 @@ +name: Python Package using Conda + +on: [push] + +jobs: + build-linux: + runs-on: ubuntu-latest + strategy: + max-parallel: 5 + + steps: + - uses: actions/checkout@v4 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Add conda to system path + run: | + # $CONDA is an environment variable pointing to the root of the miniconda directory + echo $CONDA/bin >> $GITHUB_PATH + - name: Install dependencies + run: | + conda env update --file environment.yml --name base + - name: Lint with flake8 + run: | + conda install flake8 + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + conda install pytest + pytest diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/causal-conv1d-main.zip b/causal-conv1d-main.zip new file mode 100644 index 000000000..808e377d2 Binary files /dev/null and b/causal-conv1d-main.zip differ diff --git a/config_mamba.py b/config_mamba.py new file mode 100644 index 000000000..646c9e1e8 --- /dev/null +++ b/config_mamba.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass, field + + +@dataclass +class MambaConfig: + + d_model: int = 2560 + d_intermediate: int = 0 + n_layer: int = 64 + vocab_size: int = 50277 + ssm_cfg: dict = field(default_factory=dict) + attn_layer_idx: list = field(default_factory=list) + attn_cfg: dict = field(default_factory=dict) + rms_norm: bool = True + residual_in_fp32: bool = True + fused_add_norm: bool = True + pad_vocab_size_multiple: int = 8 + tie_embeddings: bool = True diff --git a/determinism.py b/determinism.py new file mode 100644 index 000000000..c6066f80d --- /dev/null +++ b/determinism.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import os +import warnings +from packaging import version + +import torch + +try: + import triton + TRITON_VERSION = version.parse(triton.__version__) +except ImportError: + TRITON_VERSION = version.parse("0.0.0") + +TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0") +_autotune_warning_issued = False + +_deterministic_override = None + + +def use_deterministic_mode(): + if _deterministic_override is not None: + return _deterministic_override + env = os.environ.get('MAMBA_DETERMINISTIC') + if env: + return env[0] == '1' + return torch.are_deterministic_algorithms_enabled() + + +def set_deterministic_mode(value): + global _deterministic_override + _deterministic_override = value + + +def _estimate_config_cost(cfg): + """Estimate shared memory cost of a config. Lower is cheaper.""" + block_product = 1 + for key, val in cfg.kwargs.items(): + if key.startswith('BLOCK_SIZE_'): + block_product *= val + return block_product * (getattr(cfg, 'num_stages', 1) or 1) + + +def _filter_configs_by_block_sizes(configs): + """Filter configs by TRITON_AUTOTUNE_BLOCK_SIZE_* env vars.""" + env_filters = {} + for suffix in ('M', 'N', 'K', 'DSTATE'): + env_val = os.environ.get(f"TRITON_AUTOTUNE_BLOCK_SIZE_{suffix}") + if env_val is not None: + env_filters[f'BLOCK_SIZE_{suffix}'] = int(env_val) + if not env_filters: + return None + matching = configs + for key, target in env_filters.items(): + matching = [c for c in matching if c.kwargs.get(key) == target] + return matching[:1] if matching else None + + +def autotune_configs(configs): + """Select autotune configs for deterministic mode. + + Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0, + otherwise auto-selects the cheapest config by block size * stages. + """ + if not configs or not use_deterministic_mode(): + return configs + if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": + return configs + global _autotune_warning_issued + if not _autotune_warning_issued: + _autotune_warning_issued = True + msg = "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." if TRITON_HAS_CACHE_RESULTS else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." + warnings.warn(msg) + filtered = _filter_configs_by_block_sizes(configs) + if filtered: + return filtered + return [min(configs, key=_estimate_config_cost)] + + +def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): + """Allocate buffer for deterministic per-program reductions.""" + if base_shape is None: + return None, 0 + if deterministic: + factory = torch.zeros if zero_init else torch.empty + tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) + return tensor, tensor.stride(-1) + return torch.empty(*base_shape, device=device, dtype=dtype), 0 + + +def finalize_tile_workspace(tensor, deterministic): + if tensor is None: + return None + if deterministic: + tensor = tensor.sum(dim=-1) + return tensor diff --git a/distributed_utils.py b/distributed_utils.py new file mode 100644 index 000000000..74c552796 --- /dev/null +++ b/distributed_utils.py @@ -0,0 +1,144 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.distributed import ProcessGroup + +# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for +# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent +# version of PyTorch. The following 4 lines are for backward compatibility with +# older PyTorch. +if "all_gather_into_tensor" not in dir(torch.distributed): + torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base +if "reduce_scatter_tensor" not in dir(torch.distributed): + torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base + + +# Raw operation, does not support autograd, but does support async +def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + output = torch.empty( + world_size * input_.shape[0], *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.all_gather_into_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + world_size = torch.distributed.get_world_size(process_group) + assert input_.shape[0] % world_size == 0 + output = torch.empty( + input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device + ) + handle = torch.distributed.reduce_scatter_tensor( + output, input_.contiguous(), group=process_group, async_op=async_op + ) + return output, handle + + +# Raw operation, does not support autograd, but does support async +def all_reduce_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False): + input_ = input_.contiguous() + handle = torch.distributed.all_reduce(input_, group=process_group, async_op=async_op) + return input_, handle + + +class AllGatherFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_gather_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +all_gather = AllGatherFunc.apply + + +class ReduceScatterFunc(torch.autograd.Function): + """Reduce scatter the input from the sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = reduce_scatter_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + grad_input, _ = all_gather_raw(grad_output, ctx.process_group) + return grad_input, None + + +# Supports autograd, but does not support async +reduce_scatter = ReduceScatterFunc.apply + + +class AllReduceFunc(torch.autograd.Function): + """Gather the input from sequence parallel region and concatenate.""" + + @staticmethod + def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor: + ctx.process_group = process_group + output, _ = all_reduce_raw(input_, process_group) + return output + + @staticmethod + def backward(ctx, grad_output: Tensor): + return grad_output, None + + +# Supports autograd, but does not support async +all_reduce = AllReduceFunc.apply + + +def sync_shared_params(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _shared_params=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + pamams_shared = { + name: p for name, p in model.named_parameters() if getattr(p, "_shared_params", False) + } + for _, p in sorted(pamams_shared.items()): + with torch.no_grad(): + # Broadcast needs src to be global rank, not group rank + torch.distributed.broadcast( + p, src=torch.distributed.get_global_rank(process_group, 0), group=process_group + ) + + +# Ref: https://github.com/NVIDIA/Megatron-LM/blob/52e636888cccc41e931251c417a7181fc36de926/megatron/optimizer/optimizer.py#L256 +def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: ProcessGroup): + # We want to iterate over parameters with _sequence_parallel=True in the same order, + # as different ranks might have different number of parameters (e.g., only rank 0 has bias). + params_seqparallel = { + name: p for name, p in model.named_parameters() if getattr(p, "_sequence_parallel", False) + } + grads = [p.grad for _, p in sorted(params_seqparallel.items())] + if grads: + with torch.no_grad(): + coalesced = torch._utils._flatten_dense_tensors(grads) + torch.distributed.all_reduce(coalesced, group=process_group) + for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + +def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int: + """Get the dim for the local rank derived from splitting dim on world_size processes. + + The split may not be even across the world_size processes. + """ + multiple = dim // multiple_of + div = multiple // world_size + mod = multiple % world_size + local_multiple = div + int(local_rank < mod) + return local_multiple * multiple_of diff --git a/einops-main.zip b/einops-main.zip new file mode 100644 index 000000000..1f1cd88c2 Binary files /dev/null and b/einops-main.zip differ diff --git a/generation.py b/generation.py new file mode 100644 index 000000000..e4a7a78bf --- /dev/null +++ b/generation.py @@ -0,0 +1,389 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. +import gc +import time +from collections import namedtuple +from dataclasses import dataclass, field +from functools import partial +from typing import Callable, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat +from torch import Tensor +from torch.profiler import ProfilerActivity, profile, record_function +from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer + + +@dataclass +class InferenceParams: + """Inference parameters that are passed to the main model in order + to efficienly calculate and store the context during inference.""" + + max_seqlen: int + max_batch_size: int + seqlen_offset: int = 0 + batch_size_offset: int = 0 + key_value_memory_dict: dict = field(default_factory=dict) + lengths_per_sample: Optional[Tensor] = None + + def reset(self, max_seqlen, max_batch_size): + self.max_seqlen = max_seqlen + self.max_batch_size = max_batch_size + self.seqlen_offset = 0 + if self.lengths_per_sample is not None: + self.lengths_per_sample.zero_() + + +def modify_logits_for_min_p_filtering(logits, min_p): + """Set the logits for none min_p values to -inf. Done in-place.""" + if min_p <= 0.0 or min_p >= 1.0: + return + indices_to_remove = logits < min_p + logits.masked_fill_(indices_to_remove, float("-Inf")) +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L231 +def modify_logits_for_top_k_filtering(logits, top_k): + """Set the logits for none top-k values to -inf. Done in-place.""" + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits.masked_fill_(indices_to_remove, float("-Inf")) + + +# https://github.com/NVIDIA/Megatron-LM/blob/0bb597b42c53355a567aba2a1357cc34b9d99ddd/megatron/text_generation/sampling.py +# https://github.com/huggingface/transformers/blob/a44985b41cfa2de48a5e1de7f1f93b7483da25d1/src/transformers/generation/logits_process.py#L170 +def modify_logits_for_top_p_filtering(logits, top_p): + """Set the logits for none top-p values to -inf. Done in-place.""" + if top_p <= 0.0 or top_p >= 1.0: + return + # First sort and calculate cumulative sum of probabilities. + sorted_logits, sorted_indices = torch.sort(logits, descending=False) + cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - top_p) + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits.masked_fill_(indices_to_remove, float("-inf")) + + +def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): + """Apply repetition penalty. See https://arxiv.org/abs/1909.05858 + logits: (batch_size, vocab_size) + prev_output_tokens: (batch_size, seq_len) + """ + if repetition_penalty == 1.0: + return logits + score = torch.gather(logits, 1, prev_output_tokens) + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) + logits.scatter_(1, prev_output_tokens, score) + return logits + + +def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): + """Sample from top-k logits. + Arguments: + logits: Tensor of shape (batch_size, vocab_size) + """ + if top_k == 1: # Short-circuit for greedy decoding + return logits.argmax(dim=-1) + else: + if top_p > 0.0: + assert top_p <= 1.0, "top-p should be in (0, 1]." + if top_k > 0: + top_k = min(top_k, logits.size(-1)) # Safety check + logits_top, indices = torch.topk(logits, top_k, dim=-1) + if temperature != 1.0: + logits_top /= temperature + modify_logits_for_top_p_filtering(logits_top, top_p) + return indices[ + torch.arange(indices.shape[0], device=indices.device), + torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), + ] + else: + if min_p > 0.0: + logits_top = logits.clone() + max_prob = logits_top[..., 0].item() + min_prob = max_prob * min_p + modify_logits_for_min_p_filtering(logits_top, min_prob) + if temperature != 1.0: + logits_top /= temperature + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) + # Clone so that when we modify for top_p we don't change the original logits + logits_top = logits / temperature if temperature != 1.0 else logits.clone() + modify_logits_for_top_p_filtering(logits_top, top_p) + return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( + dim=-1 + ) + + +@torch.inference_mode() +def decode( + input_ids, + model, + max_length, + top_k=1, + top_p=0.0, + min_p=0.0, + temperature=1.0, + repetition_penalty=1.0, + eos_token_id=None, + teacher_outputs=None, + vocab_size=None, + cg=False, + enable_timing=False, + output_scores=False, + streamer: Optional[TextStreamer] = None +): + """Decoding, either greedy or with top-k or top-p sampling. + If top-k = 0, don't limit the number of candidates (pure sampling). + Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, + then top-p. + We assume that all sequences in the same batch have the same length. + + Arguments: + input_ids: (batch, seq_len) + max_length: int + teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the + logits, the next token is taken from the teacher_outputs. Useful for testing. + Returns: GenerateDecoderOnlyOutput, with the following fields: + sequences: (batch, max_length) + scores: tuples of (batch, vocab_size) + """ + if streamer is not None: + streamer.put(input_ids.cpu()) + + batch_size, seqlen_og = input_ids.shape + teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 + if cg: + if not hasattr(model, "_decoding_cache"): + model._decoding_cache = None + model._decoding_cache = update_graph_cache( + model, + model._decoding_cache, + batch_size, + seqlen_og, + max_length, + ) + inference_params = model._decoding_cache.inference_params + inference_params.reset(max_length, batch_size) + else: + inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) + + def get_logits(input_ids, inference_params): + decoding = inference_params.seqlen_offset > 0 + if decoding: + position_ids = torch.full( + (batch_size, 1), + inference_params.seqlen_offset, + dtype=torch.long, + device=input_ids.device, + ) + else: + position_ids = None + if not cg or not decoding: + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=1, + ).logits.squeeze(dim=1) + else: + logits = model._decoding_cache.run( + input_ids, position_ids, inference_params.seqlen_offset + ).squeeze(dim=1) + return logits[..., :vocab_size] if vocab_size is not None else logits + + def sample_tokens(logits, inference_params): + if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: + token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) + else: + token = teacher_outputs[:, inference_params.seqlen_offset] + # return rearrange(token, "b -> b 1") + return token.unsqueeze(1) + + def should_stop(current_token, inference_params): + if inference_params.seqlen_offset == 0: + return False + if eos_token_id is not None and (current_token == eos_token_id).all(): + return True + if inference_params.seqlen_offset >= max_length - 1: + return True + return False + + start = torch.cuda.Event(enable_timing=enable_timing) + end = torch.cuda.Event(enable_timing=enable_timing) + + if enable_timing: + start.record() + scores, sequences = [], [input_ids] + sequences_cat = input_ids + while not should_stop(sequences[-1], inference_params): + logits = get_logits(sequences[-1], inference_params) + if output_scores: + scores.append(logits.clone()) + inference_params.seqlen_offset += sequences[-1].shape[1] + if repetition_penalty == 1.0: + sampled_tokens = sample_tokens(logits, inference_params) + else: + logits = modify_logit_for_repetition_penalty( + logits, sequences_cat, repetition_penalty + ) + sampled_tokens = sample_tokens(logits, inference_params) + sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) + sequences.append(sampled_tokens) + if streamer is not None: + streamer.put(sampled_tokens.cpu()) + if streamer is not None: + streamer.end() + if enable_timing: + end.record() + torch.cuda.synchronize() + print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") + return GenerateDecoderOnlyOutput(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) + + +class GenerationMixin: + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + raise NotImplementedError + + def generate( + self, + input_ids, + max_length, + top_k=1, + top_p=0.0, + min_p=0.0, + temperature=1.0, + return_dict_in_generate=False, + output_scores=False, + **kwargs, + ): + output = decode( + input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs + ) + if not output_scores: + output.scores = None + return output if return_dict_in_generate else output.sequences + + +@dataclass +class DecodingCGCache: + max_batch_size: int = 0 + max_seqlen: int = 0 + device = None + dtype = None + callables: dict = field(default_factory=dict) + mempool = None + inference_params: Optional[InferenceParams] = None + run: Optional[Callable] = None + + +@torch.inference_mode() +def update_graph_cache( + model, + cache, + batch_size, + seqlen_og, + max_seqlen, + decoding_seqlens=(1,), + dtype=None, + n_warmups=2, +): + if cache is None: + cache = DecodingCGCache() + param_example = next(iter(model.parameters())) + device = param_example.device + if dtype is None: + dtype = param_example.dtype + if ( + (device, dtype) != (cache.device, cache.dtype) + or batch_size > cache.max_batch_size + or max_seqlen > cache.max_seqlen + ): # Invalidate the cache + cache.callables = {} + cache.mempool = None + cache.inference_params = None + gc.collect() + cache.device, cache.dtype = device, dtype + cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen + assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" + inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) + lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) + cache.inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + cache.mempool = torch.cuda.graphs.graph_pool_handle() + for decoding_seqlen in decoding_seqlens: + if (batch_size, decoding_seqlen) not in cache.callables: + cache.callables[batch_size, decoding_seqlen] = capture_graph( + model, + cache.inference_params, + batch_size, + max_seqlen, + decoding_seqlen=decoding_seqlen, + mempool=cache.mempool, + n_warmups=n_warmups, + ) + + def dispatch(input_ids, position_ids, seqlen): + batch_size, decoding_seqlen = input_ids.shape[:2] + return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) + + cache.run = dispatch + cache.inference_params.seqlen_offset = 0 # Reset so it's not confusing + return cache + + +def capture_graph( + model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 +): + device = next(iter(model.parameters())).device + input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) + seqlen_offset_og = inference_params.seqlen_offset + inference_params.seqlen_offset = max_seqlen - decoding_seqlen + inference_params.lengths_per_sample[:] = inference_params.seqlen_offset + + # Warmup before capture + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(n_warmups): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + s.synchronize() + # This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0, + # which requires that graph launch and non-captured launch to not overlap (I think, + # that's how I interpret the documentation). I'm not sure if this is required. + if torch.distributed.is_initialized(): + torch.distributed.barrier() + torch.cuda.current_stream().wait_stream(s) + # Captures the graph + # To allow capture, automatically sets a side stream as the current stream in the context + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=mempool): + logits = model( + input_ids, + position_ids=position_ids, + inference_params=inference_params, + num_last_tokens=decoding_seqlen, + ).logits + + def run(new_input_ids, new_position_ids, seqlen): + inference_params.lengths_per_sample[:] = seqlen + input_ids.copy_(new_input_ids) + position_ids.copy_(new_position_ids) + graph.replay() + return logits.clone() + + inference_params.seqlen_offset = seqlen_offset_og + return run diff --git a/mamba-main.zip b/mamba-main.zip new file mode 100644 index 000000000..fab8933e1 Binary files /dev/null and b/mamba-main.zip differ diff --git a/mamba2-torch-main.zip b/mamba2-torch-main.zip new file mode 100644 index 000000000..2e1bc1df3 Binary files /dev/null and b/mamba2-torch-main.zip differ diff --git a/mixer_seq_simple (1).py b/mixer_seq_simple (1).py new file mode 100644 index 000000000..fae2257a9 --- /dev/null +++ b/mixer_seq_simple (1).py @@ -0,0 +1,309 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. + +import math +from functools import partial +import json +import os +import copy + +from collections import namedtuple + +import torch +import torch.nn as nn + +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mha import MHA +from mamba_ssm.modules.mlp import GatedMLP +from mamba_ssm.modules.block import Block +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf + +try: + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +def create_block( + d_model, + d_intermediate, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + if layer_idx not in attn_layer_idx: + # Create a copy of the config to modify + ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} + ssm_layer = ssm_cfg.pop("layer", "Mamba1") + if ssm_layer not in ["Mamba1", "Mamba2"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + mixer_cls = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + else: + mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + if d_intermediate == 0: + mlp_cls = nn.Identity + else: + mlp_cls = partial( + GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MixerModel(nn.Module): + def __init__( + self, + d_model: int, + n_layer: int, + d_intermediate: int, + vocab_size: int, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + fused_add_norm=False, + residual_in_fp32=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Add, we do: + # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and + # the main branch (output of MLP / Mixer). The model definition is unchanged. + # This is for performance reason: we can fuse add + layer_norm. + self.fused_add_norm = fused_add_norm + if self.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + d_intermediate=d_intermediate, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + d_model, eps=norm_epsilon, **factory_kwargs + ) + + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, inference_params=None, **mixer_kwargs): + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params, **mixer_kwargs + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = layer_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm_f, RMSNorm) + ) + return hidden_states + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + + def __init__( + self, + config: MambaConfig, + initializer_cfg=None, + device=None, + dtype=None, + ) -> None: + self.config = config + d_model = config.d_model + n_layer = config.n_layer + d_intermediate = config.d_intermediate + vocab_size = config.vocab_size + ssm_cfg = config.ssm_cfg + attn_layer_idx = config.attn_layer_idx + attn_cfg = config.attn_cfg + rms_norm = config.rms_norm + residual_in_fp32 = config.residual_in_fp32 + fused_add_norm = config.fused_add_norm + pad_vocab_size_multiple = config.pad_vocab_size_multiple + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + d_intermediate=d_intermediate, + vocab_size=vocab_size, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + rms_norm=rms_norm, + initializer_cfg=initializer_cfg, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + if self.config.tie_embeddings: + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config_data = load_config_hf(pretrained_model_name) + config = MambaConfig(**config_data) + model = cls(config, device=device, dtype=dtype, **kwargs) + model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) + return model + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + os.makedirs(save_directory, exist_ok=True) + + # Save the model's state_dict + model_path = os.path.join(save_directory, 'pytorch_model.bin') + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, 'config.json') + with open(config_path, 'w') as f: + json.dump(self.config.__dict__, f, indent=4) diff --git a/mixer_seq_simple.py b/mixer_seq_simple.py new file mode 100644 index 000000000..fae2257a9 --- /dev/null +++ b/mixer_seq_simple.py @@ -0,0 +1,309 @@ +# Copyright (c) 2023, Albert Gu, Tri Dao. + +import math +from functools import partial +import json +import os +import copy + +from collections import namedtuple + +import torch +import torch.nn as nn + +from mamba_ssm.models.config_mamba import MambaConfig +from mamba_ssm.modules.mamba_simple import Mamba +from mamba_ssm.modules.mamba2 import Mamba2 +from mamba_ssm.modules.mha import MHA +from mamba_ssm.modules.mlp import GatedMLP +from mamba_ssm.modules.block import Block +from mamba_ssm.utils.generation import GenerationMixin +from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf + +try: + from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn +except ImportError: + RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None + + +def create_block( + d_model, + d_intermediate, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon=1e-5, + rms_norm=False, + residual_in_fp32=False, + fused_add_norm=False, + layer_idx=None, + device=None, + dtype=None, +): + if ssm_cfg is None: + ssm_cfg = {} + if attn_layer_idx is None: + attn_layer_idx = [] + if attn_cfg is None: + attn_cfg = {} + factory_kwargs = {"device": device, "dtype": dtype} + if layer_idx not in attn_layer_idx: + # Create a copy of the config to modify + ssm_cfg = copy.deepcopy(ssm_cfg) if ssm_cfg is not None else {} + ssm_layer = ssm_cfg.pop("layer", "Mamba1") + if ssm_layer not in ["Mamba1", "Mamba2"]: + raise ValueError(f"Invalid ssm_layer: {ssm_layer}, only support Mamba1 and Mamba2") + mixer_cls = partial( + Mamba2 if ssm_layer == "Mamba2" else Mamba, + layer_idx=layer_idx, + **ssm_cfg, + **factory_kwargs + ) + else: + mixer_cls = partial(MHA, layer_idx=layer_idx, **attn_cfg, **factory_kwargs) + norm_cls = partial( + nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs + ) + if d_intermediate == 0: + mlp_cls = nn.Identity + else: + mlp_cls = partial( + GatedMLP, hidden_features=d_intermediate, out_features=d_model, **factory_kwargs + ) + block = Block( + d_model, + mixer_cls, + mlp_cls, + norm_cls=norm_cls, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + ) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights( + module, + n_layer, + initializer_range=0.02, # Now only used for embedding layer. + rescale_prenorm_residual=True, + n_residuals_per_layer=1, # Change to 2 if we have MLP +): + if isinstance(module, nn.Linear): + if module.bias is not None: + if not getattr(module.bias, "_no_reinit", False): + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + # Following Pytorch init, except scale by 1/sqrt(2 * n_layer) + # We need to reinit p since this code could be called multiple times + # Having just p *= scale would repeatedly scale it down + nn.init.kaiming_uniform_(p, a=math.sqrt(5)) + with torch.no_grad(): + p /= math.sqrt(n_residuals_per_layer * n_layer) + + +class MixerModel(nn.Module): + def __init__( + self, + d_model: int, + n_layer: int, + d_intermediate: int, + vocab_size: int, + ssm_cfg=None, + attn_layer_idx=None, + attn_cfg=None, + norm_epsilon: float = 1e-5, + rms_norm: bool = False, + initializer_cfg=None, + fused_add_norm=False, + residual_in_fp32=False, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.residual_in_fp32 = residual_in_fp32 + + self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs) + + # We change the order of residual and layer norm: + # Instead of LN -> Attn / MLP -> Add, we do: + # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and + # the main branch (output of MLP / Mixer). The model definition is unchanged. + # This is for performance reason: we can fuse add + layer_norm. + self.fused_add_norm = fused_add_norm + if self.fused_add_norm: + if layer_norm_fn is None or rms_norm_fn is None: + raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels") + + self.layers = nn.ModuleList( + [ + create_block( + d_model, + d_intermediate=d_intermediate, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + norm_epsilon=norm_epsilon, + rms_norm=rms_norm, + residual_in_fp32=residual_in_fp32, + fused_add_norm=fused_add_norm, + layer_idx=i, + **factory_kwargs, + ) + for i in range(n_layer) + ] + ) + + self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)( + d_model, eps=norm_epsilon, **factory_kwargs + ) + + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + n_residuals_per_layer=1 if d_intermediate == 0 else 2, # 2 if we have MLP + ) + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return { + i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + for i, layer in enumerate(self.layers) + } + + def forward(self, input_ids, inference_params=None, **mixer_kwargs): + hidden_states = self.embedding(input_ids) + residual = None + for layer in self.layers: + hidden_states, residual = layer( + hidden_states, residual, inference_params=inference_params, **mixer_kwargs + ) + if not self.fused_add_norm: + residual = (hidden_states + residual) if residual is not None else hidden_states + hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = layer_norm_fn( + hidden_states, + self.norm_f.weight, + self.norm_f.bias, + eps=self.norm_f.eps, + residual=residual, + prenorm=False, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm_f, RMSNorm) + ) + return hidden_states + + +class MambaLMHeadModel(nn.Module, GenerationMixin): + + def __init__( + self, + config: MambaConfig, + initializer_cfg=None, + device=None, + dtype=None, + ) -> None: + self.config = config + d_model = config.d_model + n_layer = config.n_layer + d_intermediate = config.d_intermediate + vocab_size = config.vocab_size + ssm_cfg = config.ssm_cfg + attn_layer_idx = config.attn_layer_idx + attn_cfg = config.attn_cfg + rms_norm = config.rms_norm + residual_in_fp32 = config.residual_in_fp32 + fused_add_norm = config.fused_add_norm + pad_vocab_size_multiple = config.pad_vocab_size_multiple + factory_kwargs = {"device": device, "dtype": dtype} + + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = MixerModel( + d_model=d_model, + n_layer=n_layer, + d_intermediate=d_intermediate, + vocab_size=vocab_size, + ssm_cfg=ssm_cfg, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, + rms_norm=rms_norm, + initializer_cfg=initializer_cfg, + fused_add_norm=fused_add_norm, + residual_in_fp32=residual_in_fp32, + **factory_kwargs, + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply( + partial( + _init_weights, + n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}), + ) + ) + self.tie_weights() + + def tie_weights(self): + if self.config.tie_embeddings: + self.lm_head.weight = self.backbone.embedding.weight + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0, **mixer_kwargs): + """ + "position_ids" is just to be compatible with Transformer generation. We don't use it. + num_last_tokens: if > 0, only return the logits for the last n tokens + """ + hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs) + if num_last_tokens > 0: + hidden_states = hidden_states[:, -num_last_tokens:] + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple("CausalLMOutput", ["logits"]) + return CausalLMOutput(logits=lm_logits) + + @classmethod + def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs): + config_data = load_config_hf(pretrained_model_name) + config = MambaConfig(**config_data) + model = cls(config, device=device, dtype=dtype, **kwargs) + model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype)) + return model + + def save_pretrained(self, save_directory): + """ + Minimal implementation of save_pretrained for MambaLMHeadModel. + Save the model and its configuration file to a directory. + """ + # Ensure save_directory exists + os.makedirs(save_directory, exist_ok=True) + + # Save the model's state_dict + model_path = os.path.join(save_directory, 'pytorch_model.bin') + torch.save(self.state_dict(), model_path) + + # Save the configuration of the model + config_path = os.path.join(save_directory, 'config.json') + with open(config_path, 'w') as f: + json.dump(self.config.__dict__, f, indent=4) diff --git a/rocm6_0.patch b/rocm6_0.patch new file mode 100644 index 000000000..e1fa60d42 --- /dev/null +++ b/rocm6_0.patch @@ -0,0 +1,56 @@ +--- /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h 2023-12-12 20:11:48.000000000 +0000 ++++ rocm_update_files/amd_hip_bf16.h 2024-05-20 17:40:26.983349079 +0000 +@@ -137,7 +137,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_CONV + * \brief Converts float to bfloat16 + */ +-__HOST_DEVICE__ __hip_bfloat16 __float2bfloat16(float f) { ++__HOST_DEVICE__ static inline __hip_bfloat16 __float2bfloat16(float f) { + __hip_bfloat16 ret; + union { + float fp32; +@@ -181,7 +181,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts and moves bfloat162 to float2 + */ +-__HOST_DEVICE__ float2 __bfloat1622float2(const __hip_bfloat162 a) { ++__HOST_DEVICE__ static inline float2 __bfloat1622float2(const __hip_bfloat162 a) { + return float2{__bfloat162float(a.x), __bfloat162float(a.y)}; + } + +@@ -209,7 +209,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Convert double to __hip_bfloat16 + */ +-__HOST_DEVICE__ __hip_bfloat16 __double2bfloat16(const double a) { ++__HOST_DEVICE__ static inline __hip_bfloat16 __double2bfloat16(const double a) { + return __float2bfloat16((float)a); + } + +@@ -217,7 +217,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Convert float2 to __hip_bfloat162 + */ +-__HOST_DEVICE__ __hip_bfloat162 __float22bfloat162_rn(const float2 a) { ++__HOST_DEVICE__ static inline __hip_bfloat162 __float22bfloat162_rn(const float2 a) { + return __hip_bfloat162{__float2bfloat16(a.x), __float2bfloat16(a.y)}; + } + +@@ -247,7 +247,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts high 16 bits of __hip_bfloat162 to float and returns the result + */ +-__HOST_DEVICE__ float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } ++__HOST_DEVICE__ static inline float __high2float(const __hip_bfloat162 a) { return __bfloat162float(a.y); } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV +@@ -275,7 +275,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Converts low 16 bits of __hip_bfloat162 to float and returns the result + */ +-__HOST_DEVICE__ float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } ++__HOST_DEVICE__ static inline float __low2float(const __hip_bfloat162 a) { return __bfloat162float(a.x); } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV diff --git a/ssd_minimal.py b/ssd_minimal.py new file mode 100644 index 000000000..9632ebd43 --- /dev/null +++ b/ssd_minimal.py @@ -0,0 +1,103 @@ +# Copyright (c) 2024, Albert Gu and Tri Dao. +"""Minimal implementation of SSD. + +This is the same as Listing 1 from the paper. +""" + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined + + +def segsum_unstable(x): + """Naive segment sum calculation.""" + T = x.size(-1) + x_cumsum = torch.cumsum(x, dim=-1) + x_segsum = x_cumsum[..., :, None] - x_cumsum[..., None, :] + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + +def segsum(x): + """More stable segment sum calculation.""" + T = x.size(-1) + x = repeat(x, "... d -> ... d e", e=T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + +def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): + """ + Arguments: + X: (batch, length, n_heads, d_head) + A: (batch, length, n_heads) + B: (batch, length, n_heads, d_state) + C: (batch, length, n_heads, d_state) + Return: + Y: (batch, length, n_heads, d_head) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + X, A, B, C = [rearrange(x, "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)] + + A = rearrange(A, "b c l h -> b h c l") + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + L = torch.exp(segsum(A)) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p") + return Y, final_state + + +# Simple test +def test_correctness(): + torch.manual_seed(42) + + ## Dimensions + # Denoted (B, T, Q, D, P) in the paper + batch, seqlen, chunk_size, dim, headdim = 1, 2048, 64, 2048, 64 + nheads = dim // headdim # (H) in the paper + ngroups = 1 # (G) in the paper + dstate = 64 # (N) in the paper + dtype = torch.float32 + device = "cuda" + + x = torch.randn(batch, seqlen, nheads, headdim, dtype=dtype, device=device) + dt = F.softplus(torch.randn(batch, seqlen, nheads, dtype=torch.float32, device=device) - 4).requires_grad_() + A = (-torch.exp(torch.rand(nheads, dtype=torch.float32, device=device))).requires_grad_() + B = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) + C = torch.randn(batch, seqlen, ngroups, dstate, dtype=dtype, device=device) + D = torch.randn(nheads, dtype=dtype, device=device) + + # Comparing fused version and minimal version + y = mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None) + y_min, _ = ssd_minimal_discrete(x*dt.unsqueeze(-1), A*dt, B, C, chunk_size) diff --git a/tensor_parallel.py b/tensor_parallel.py new file mode 100644 index 000000000..2d67b5304 --- /dev/null +++ b/tensor_parallel.py @@ -0,0 +1,296 @@ +# Copyright (c) 2024, Tri Dao. +# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed import ProcessGroup +from mamba_ssm.utils.torch import custom_bwd, custom_fwd + +from einops import rearrange + +from mamba_ssm.distributed.distributed_utils import ( + all_gather_raw, + all_reduce, + all_reduce_raw, + reduce_scatter, + reduce_scatter_raw, +) + + +class ParallelLinearFunc(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): + """ + If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel + with sequence parallelism: we do an all_gather_raw of x before doing the matmul. + """ + ctx.compute_weight_gradient = weight.requires_grad + ctx.process_group = process_group + ctx.sequence_parallel = sequence_parallel + + if torch.is_autocast_enabled(): + x = x.to(dtype=torch.get_autocast_gpu_dtype()) + x = x.contiguous() + if process_group is not None and sequence_parallel: + # We want to kick off the all_gather early, before weight dtype conversion + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + + if torch.is_autocast_enabled(): + weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) + bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None + weight = weight.contiguous() + if process_group is not None and sequence_parallel: + handle_x.wait() + batch_shape, n = total_x.shape[:-1], total_x.shape[-1] + batch_dim = batch_shape.numel() + # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 + output = F.linear(total_x, weight, bias) + if ctx.compute_weight_gradient: + ctx.save_for_backward(x, weight) + else: + ctx.save_for_backward(weight) + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + grad_output = grad_output.contiguous() + process_group = ctx.process_group + sequence_parallel = ctx.sequence_parallel + if ctx.compute_weight_gradient: + x, weight = ctx.saved_tensors + if process_group is not None and sequence_parallel: + total_x, handle_x = all_gather_raw(x, process_group, async_op=True) + else: + total_x = x + else: + (weight,) = ctx.saved_tensors + total_x = None + batch_shape = grad_output.shape[:-1] + batch_dim = batch_shape.numel() + grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) + if ctx.needs_input_grad[0]: + grad_input = F.linear(grad_output, weight.t()) + grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) + if process_group is not None: + reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw + grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) + else: + grad_input = None + if ctx.needs_input_grad[1]: + assert ctx.compute_weight_gradient + if process_group is not None and sequence_parallel: + handle_x.wait() + grad_weight = torch.einsum( + "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1]) + ) + else: + grad_weight = None + grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None + if process_group is not None and ctx.needs_input_grad[0]: + handle_grad_input.wait() + return grad_input, grad_weight, grad_bias, None, None + + +def parallel_linear_func( + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + process_group: Optional[ProcessGroup] = None, + sequence_parallel: bool = True, +): + return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel) + + +class ColumnParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + if out_features % multiple_of: + raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") + multiple = out_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + super().__init__( + in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: + # we do an all_gather of x before doing the matmul. + # If not, then the input is already gathered. + return parallel_linear_func( + x, + self.weight, + self.bias, + process_group=self.process_group, + sequence_parallel=self.sequence_parallel, + ) + + +class RowParallelLinear(nn.Linear): + def __init__( + self, + in_features: int, + out_features: int, + process_group: ProcessGroup, + bias: bool = True, + sequence_parallel=True, + multiple_of=1, + device=None, + dtype=None, + ) -> None: + world_size = torch.distributed.get_world_size(process_group) + rank = torch.distributed.get_rank(process_group) + if in_features % multiple_of: + raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") + multiple = in_features // multiple_of + # We want to split @multiple across world_size, but it could be an uneven split + div = multiple // world_size + mod = multiple % world_size + # The first @mod ranks get @div + 1 copies, the rest get @div copies + local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) + # Only rank 0 will have bias + super().__init__( + local_multiple * multiple_of, + out_features, + bias=bias and rank == 0, + device=device, + dtype=dtype, + ) + self.process_group = process_group + self.sequence_parallel = sequence_parallel + + def forward(self, x): + """ + We're doing Tensor Parallel with sequence parallelism: we do the matmul and then + a reduce_scatter of the result. + """ + out = parallel_linear_func(x, self.weight, self.bias) + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return reduce_fn(out, self.process_group) + + +class VocabParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if num_embeddings % world_size != 0: + raise ValueError( + f"num_embeddings ({num_embeddings}) must be divisible by " + f"world_size ({world_size})" + ) + if world_size > 1 and padding_idx is not None: + raise RuntimeError("ParallelEmbedding does not support padding_idx") + else: + world_size = 1 + super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.process_group is None: + return super().forward(input) + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) + input = input - vocab_start_index + input[input_ids_mask] = 0 + embeddings = super().forward(input) + embeddings[input_ids_mask] = 0.0 + return embeddings + + +class ColumnParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if embedding_dim % world_size != 0: + raise ValueError( + f"embedding_dim ({embedding_dim}) must be divisible by " + f"world_size ({world_size})" + ) + else: + world_size = 1 + super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) + + +class ParallelEmbeddings(nn.Module): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + process_group, + padding_idx=None, + sequence_parallel=True, + device=None, + dtype=None, + ): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.word_embeddings = VocabParallelEmbedding( + vocab_size, + embed_dim, + padding_idx=padding_idx, + process_group=process_group, + **factory_kwargs, + ) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = ColumnParallelEmbedding( + max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + world_size = torch.distributed.get_world_size(self.process_group) + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + if world_size <= 1: + embeddings = embeddings + position_embeddings + else: + partition_dim = self.position_embeddings.embedding_dim + rank = torch.distributed.get_rank(self.process_group) + embeddings[ + ..., rank * partition_dim : (rank + 1) * partition_dim + ] += position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, "b s d -> (b s) d") + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/torch.py b/torch.py new file mode 100644 index 000000000..37df47c8e --- /dev/null +++ b/torch.py @@ -0,0 +1,21 @@ +import torch +from functools import partial +from typing import Callable + +def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): + def decorator(*args, **kwargs): + if cuda_amp_deprecated: + kwargs["device_type"] = "cuda" + return dec(*args, **kwargs) + return decorator + + +if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] + deprecated = True + from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] +else: + deprecated = False + from torch.cuda.amp import custom_fwd, custom_bwd + +custom_fwd = custom_amp_decorator(custom_fwd, deprecated) +custom_bwd = custom_amp_decorator(custom_bwd, deprecated)