diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..abe93c90 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,32 @@ +name: CI Pipeline + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + name: Run Tests + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install system build tools + run: | + sudo apt-get update + sudo apt-get install -y build-essential python3-dev + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - name: Install tox + run: uv tool install tox --with tox-uv + + - name: Run unit tests + run: tox \ No newline at end of file diff --git a/azure-pipelines.yml b/azure-pipelines.yml new file mode 100644 index 00000000..f8a2f6de --- /dev/null +++ b/azure-pipelines.yml @@ -0,0 +1,21 @@ +# Starter pipeline +# Start with a minimal pipeline that you can customize to build and deploy your code. +# Add steps that build, run tests, deploy, and more: +# https://aka.ms/yaml + +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- script: | + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/main; + conda tos accept --override-channels --channel https://repo.anaconda.com/pkgs/r; + pip install tox + pip install tox-conda + displayName: 'Install tox' +- script: | + tox + displayName: 'Run unit tests' diff --git a/benchmark/README.md b/benchmark/README.md new file mode 100644 index 00000000..1cd09a68 --- /dev/null +++ b/benchmark/README.md @@ -0,0 +1,357 @@ +# Ring Attention Performance Benchmarks + +This directory contains a unified performance benchmarking framework for all Ring Attention variants, built using a shared architecture that eliminates code duplication and provides consistent interfaces. + +## ๐Ÿ—๏ธ Architecture + +The benchmark framework consists of: + +### Core Framework +- **`benchmark_base.py`**: Shared benchmark framework extending the test framework +- **Configuration System**: Unified configuration management via `../tests/customized_ops/ring_attn/configs.py` + +### Attention Implementations +- **`benchmark_ring_attn.py`**: Standard Ring Attention benchmarks +- **`benchmark_ring_attn_varlen.py`**: Variable Length Ring Attention benchmarks +- **`benchmark_zigzag_attn.py`**: Zigzag Ring Attention benchmarks (causal-only) + +## ๐Ÿš€ Quick Start + +### 1. List Available Configurations + +```bash +cd benchmark + +# List configurations for any benchmark variant +python benchmark_ring_attn_varlen.py --list-configs +python benchmark_ring_attn.py --list-configs +python benchmark_zigzag_attn.py --list-configs +``` + +### 2. Run Basic Benchmarks + +```bash +# Ring Attention Variable Length +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium + +# Standard Ring Attention +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small + +# Zigzag Ring Attention (causal-only) +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config tiny +``` + +### 3. Advanced Usage + +```bash +# Custom timing parameters +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup --warmup-runs 5 --timing-runs 10 + +# Detailed profiling +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config large --timing-method profiler + +# Custom configurations (legacy support) +torchrun --nproc_per_node=2 benchmark_ring_attn.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 +``` + +## ๐Ÿ“‹ Available Configurations + +The benchmark framework uses a comprehensive configuration system with predefined configurations for different testing scenarios. + +### Configuration Categories + +#### Small Configs (Quick Testing) +- **`tiny`**: 2ร—8ร—64, seq=1024, tokens=1K, bf16 [Causal] +- **`small`**: 4ร—12ร—128, seq=4096, tokens=4K, bf16 [Causal] +- **`small_fp16`**: 4ร—12ร—128, seq=4096, tokens=4K, fp16 [Non-causal] +- **`small_window`**: 4ร—12ร—128, seq=4096, tokens=4K, bf16 [Causal] [Window=512,0] + +#### Medium Configs (Standard Testing) +- **`medium`**: 4ร—24ร—128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_large_head`**: 4ร—12ร—256, seq=8192, tokens=8K, bf16 [Non-causal] +- **`medium_many_heads`**: 4ร—32ร—128, seq=8192, tokens=8K, bf16 [Causal] +- **`medium_fp16`**: 4ร—24ร—128, seq=8192, tokens=8K, fp16 [Causal] +- **`medium_window`**: 4ร—24ร—128, seq=8192, tokens=8K, bf16 [Causal] [Window=512,0] + +#### Large Configs (Performance Testing) +- **`large`**: 4ร—32ร—128, seq=16384, tokens=16K, bf16 [Causal] +- **`large_seq`**: 4ร—24ร—128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_head`**: 4ร—24ร—256, seq=16384, tokens=16K, bf16 [Non-causal] +- **`xlarge`**: 8ร—32ร—128, seq=32768, tokens=32K, bf16 [Causal] +- **`large_window`**: 4ร—32ร—128, seq=16384, tokens=16K, bf16 [Causal] [Window=512,0] + +#### GQA Configs (Grouped Query Attention) +- **`qwen3_235b_a22b`**: 2ร—64ร—64, seq=16384, tokens=16K, bf16 (GQA 64โ†’4) [Causal] +- **`qwen3_30b_a3b`**: 4ร—32ร—64, seq=16384, tokens=16K, bf16 (GQA 32โ†’4) [Causal] +- **`qwen3_4b`**: 4ร—32ร—80, seq=16384, tokens=16K, bf16 (GQA 32โ†’4) [Causal] +- **`qwen3_32b`**: 2ร—64ร—128, seq=16384, tokens=16K, bf16 (GQA 64โ†’8) [Causal] +- **`qwen3_14b`**: 4ร—40ร—128, seq=16384, tokens=16K, bf16 (GQA 40โ†’8) [Causal] + +#### Zigzag Configs (Causal-Only) +- **`zigzag_tiny`**: 2ร—8ร—64, seq=1024, tokens=1K, bf16 [Causal] +- **`zigzag_small`**: 4ร—12ร—128, seq=4096, tokens=4K, bf16 [Causal] +- **`zigzag_medium`**: 4ร—24ร—128, seq=8192, tokens=8K, bf16 [Causal] +- **`zigzag_large`**: 4ร—32ร—128, seq=16384, tokens=16K, bf16 [Causal] +- **`zigzag_fp16`**: 4ร—12ร—128, seq=4096, tokens=4K, fp16 [Causal] +- **`zigzag_gqa`**: 4ร—32ร—128, seq=8192, tokens=8K, bf16 (GQA 32โ†’8) [Causal] + +### Default Configuration Sets +- **Correctness Testing**: `["tiny", "small", "medium"]` +- **Performance Testing**: `["medium", "large"]` +- **Multi-GPU Testing**: `["small", "medium"]` +- **GQA Testing**: `["qwen3_4b", "qwen3_14b", "qwen3_32b"]` +- **Zigzag Testing**: `["zigzag_tiny", "zigzag_small", "zigzag_medium"]` + +## ๐Ÿ”ง Features + +### Unified Framework +- **Shared Base Class**: All benchmarks extend `RingAttnBenchmarkBase` for consistency +- **Code Reuse**: Leverages test framework components (`test_base.py`, `runner_base.py`) +- **Consistent Interface**: Same command-line options across all attention variants + +### Multiple Timing Methods +- **`simple`**: Basic CUDA timing measurements (fastest) +- **`warmup`**: Multiple runs with warm-up (recommended for accurate results) +- **`profiler`**: torch.profiler with detailed kernel analysis + +### Comprehensive Metrics +- **Performance**: Forward/backward timing, throughput (tokens/sec) +- **Scalability**: Speedup analysis, parallel efficiency +- **Memory**: GPU memory usage tracking +- **Comparative**: Single vs. parallel mode analysis + +### Configuration Support +- **Predefined Configs**: 20+ predefined configurations covering different scales +- **Legacy Parameters**: Backward compatibility with custom parameters +- **Attention Variants**: Support for standard, variable-length, and zigzag attention +- **GQA Support**: Grouped Query Attention configurations based on Qwen models + +## ๐Ÿงช Usage Examples + +### Basic Performance Testing +```bash +# Quick benchmarks with different attention types +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny --timing-method simple +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config small --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --dtype fp16 +``` + +### Comparative Analysis +```bash +# Compare different attention mechanisms on same config +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config medium --timing-method warmup +torchrun --nproc_per_node=2 benchmark_zigzag_attn.py --config medium --timing-method warmup +``` + +### Advanced Profiling +```bash +# Detailed profiler analysis +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config large --timing-method profiler + +# Custom timing parameters for high precision +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --timing-method warmup --warmup-runs 10 --timing-runs 20 +``` + +### GQA Performance Testing +```bash +# Test Grouped Query Attention configurations +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config qwen3_4b --timing-method warmup +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config qwen3_14b --timing-method warmup +``` + +### Legacy Support (Custom Parameters) +```bash +# Override specific parameters while using predefined base +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config medium --seqlen 16384 --nheads 32 + +# Full custom configuration +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --seqlen 8192 --nheads 16 --head-dim 128 --batch-size 4 --dtype bf16 +``` + +## ๐Ÿ“ˆ Output Interpretation + +The benchmark framework provides comprehensive performance analysis: + +### Performance Metrics +``` +================================================================================ +RING ATTENTION VARIABLE LENGTH PERFORMANCE BENCHMARK (WARMUP METHOD) +Configuration: medium - medium + Sequence length: 8192 + Batch size: 4 + Heads: 24 + Head dim: 128 + Data type: bf16 + World size: 2 GPUs + Total tokens: 8,192 + (Warmup runs: 3, Timing runs: 5) +================================================================================ +Single Mode: + Forward time: 0.001234 seconds + Backward time: 0.002345 seconds + Total time: 0.003579 seconds + Throughput: 2288764 tokens/sec + +Parallel Mode: + Forward time: 0.000987 seconds + Backward time: 0.001654 seconds + Total time: 0.002641 seconds + Throughput: 3102234 tokens/sec + +Speedup: + Forward speedup: 1.25x + Backward speedup: 1.42x + Total speedup: 1.35x + Throughput improvement: 1.35x + +Efficiency: + Theoretical speedup: 2x + Actual speedup: 1.35x + Parallel efficiency: 67.7% +================================================================================ +``` + +### Key Metrics Explained +- **Forward/Backward Time**: Separate timing for forward and backward passes +- **Throughput**: Tokens processed per second (higher = better) +- **Speedup**: Performance ratio vs single GPU (higher = better) +- **Parallel Efficiency**: Actual speedup / theoretical speedup (closer to 100% = better) + +### Profiler Output (when using `--timing-method profiler`) +When using the profiler method, you get additional detailed analysis: +- Kernel-level timing breakdown +- Memory bandwidth utilization +- CUDA kernel execution patterns +- Optimization recommendations + +## ๐ŸŽฏ Attention Variant Characteristics + +### Ring Attention (`benchmark_ring_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: General purpose attention for standard transformer models +- **Constraints**: Supports both causal and non-causal attention, sliding windows + +### Ring Attention Variable Length (`benchmark_ring_attn_varlen.py`) +- **Format**: Packed format `[total_tokens, num_heads, head_dim]` with `cu_seqlens` +- **Use Case**: Optimized for variable-length sequences, eliminates padding waste +- **Constraints**: Supports causal/non-causal attention, sliding windows + +### Zigzag Attention (`benchmark_zigzag_attn.py`) +- **Format**: Standard batch format `[batch_size, seq_len, num_heads, head_dim]` +- **Use Case**: Specialized for causal attention with optimized communication pattern +- **Constraints**: **Only supports causal=True and window_size=(-1,-1)** + +## ๐Ÿ”— Integration with Test Framework + +The benchmark framework is tightly integrated with the correctness test framework: + +### Shared Components +- **Configuration System**: Same `configs.py` used for both correctness and performance testing +- **Base Classes**: Reuses `RingAttnRunnerBase` from `runner_base.py` +- **Distributed Setup**: Shared GPU detection and distributed initialization +- **Error Handling**: Consistent tolerance and validation logic + +### Workflow Integration +```bash +# 1. Run correctness tests first +cd /path/to/MagicCube +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py --config tiny + +# 2. Then run performance benchmarks +cd benchmark +torchrun --nproc_per_node=2 benchmark_ring_attn_varlen.py --config tiny +``` + +## โš ๏ธ Requirements & Setup + +### System Requirements +- **Multi-GPU Setup**: Most benchmarks require 2+ GPUs (use `torchrun --nproc_per_node=N`) +- **GPU Memory**: Large configs may require high-memory GPUs (A100, H100 recommended) +- **CUDA**: Compatible CUDA installation (11.8+ recommended) +- **Python Environment**: PyTorch with NCCL support for distributed training + +### Optional Components +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance (auto-detected) +- **Flash Attention**: Required for base attention implementations +- **InfiniBand**: Recommended for multi-node setups (reduces communication latency) + +### Environment Setup +```bash +# From MagicCube root directory +cd benchmark + +# Verify imports work correctly +python -c " +from benchmark_base import RingAttnBenchmarkBase +print('โœ“ Benchmark framework ready') +" + +# Test configuration system +python benchmark_ring_attn_varlen.py --list-configs +``` + +## ๐Ÿšจ Troubleshooting + +### Common Issues + +#### GPU/Memory Issues +```bash +# OOM errors: Use smaller configs or reduce batch size +torchrun --nproc_per_node=2 benchmark_ring_attn.py --config tiny # Instead of large + +# Insufficient GPUs: Check available GPUs +python -c "import torch; print(f'Available GPUs: {torch.cuda.device_count()}')" +``` + +#### Import/Path Issues +```bash +# Import errors: Ensure running from correct directory +cd /path/to/MagicCube/benchmark +python benchmark_ring_attn.py --help + +# Configuration import errors +python -c " +import sys, os +sys.path.insert(0, '../tests/customized_ops/ring_attn') +from configs import get_config +print('โœ“ Config system working') +" +``` + +#### Distributed Training Issues +```bash +# NCCL errors: Check GPU compatibility and CUDA setup +export NCCL_DEBUG=INFO # For detailed NCCL debugging + +# Port conflicts: Use different port +torchrun --master_port=29501 --nproc_per_node=2 benchmark_ring_attn.py --config tiny +``` + +### Performance Debugging +```bash +# Test basic functionality without distributed training +CUDA_VISIBLE_DEVICES=0 python -c " +from benchmark_ring_attn import RingAttnBenchmark +print('โœ“ Benchmark classes load correctly') +" + +# Verify attention implementations work +cd ../tests/customized_ops/ring_attn +pytest test_ring_attn.py::TestRingAttn::test_ring_attn_tiny -v +``` + +**Note**: Actual efficiency depends on hardware, network, and system configuration. + +## ๐Ÿ“š Related Documentation + +### Core Documentation +- **Ring Attention Implementation**: `../nnscaler/customized_ops/ring_attention/README.md` +- **Test Framework**: `../tests/customized_ops/ring_attn/README.md` +- **Development Guide**: `../dev_docs/README_refactoring.md` +- **Testing Results**: `../dev_docs/benchmark_testing_results.md` + +--- + +**For implementation details**: See `../nnscaler/customized_ops/ring_attention/` +**For correctness testing**: See `../tests/customized_ops/ring_attn/` \ No newline at end of file diff --git a/benchmark/benchmark_base.py b/benchmark/benchmark_base.py new file mode 100644 index 00000000..4226d2ed --- /dev/null +++ b/benchmark/benchmark_base.py @@ -0,0 +1,426 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base benchmark framework for ring attention performance tests. +This module extends the test framework to support performance benchmarking. +""" + +import os +import sys +import time +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple, Callable + +import torch +import torch.distributed as dist +from torch.profiler import profile, ProfilerActivity + +# Add tests directory to path to import test framework +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) + +from runner_base import RingAttnRunnerBase +from configs import get_config, get_configs_by_category, DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmarkBase(RingAttnRunnerBase): + """Base class for ring attention performance benchmarks""" + + def __init__(self): + super().__init__() + self.timing_method = "warmup" + self.warmup_runs = 3 + self.timing_runs = 5 + + @abstractmethod + def get_benchmark_name(self) -> str: + """Return the benchmark name for display""" + pass + + def run_timing_with_warmup(self, forward_fn: Callable, backward_fn: Callable, + warmup_runs: int = None, timing_runs: int = None) -> Tuple[float, float, Any]: + """Run timing with warm-up runs to get accurate measurements.""" + warmup_runs = warmup_runs or self.warmup_runs + timing_runs = timing_runs or self.timing_runs + + # Warm-up runs + for _ in range(warmup_runs): + torch.cuda.synchronize() + output = forward_fn() + torch.cuda.synchronize() + backward_fn(output) + torch.cuda.synchronize() + + # Timing runs + forward_times = [] + backward_times = [] + + for _ in range(timing_runs): + # Forward timing + torch.cuda.synchronize() + start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - start + forward_times.append(forward_time) + + # Backward timing + torch.cuda.synchronize() + start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - start + backward_times.append(backward_time) + + # Return average times + avg_forward = sum(forward_times) / len(forward_times) + avg_backward = sum(backward_times) / len(backward_times) + return avg_forward, avg_backward, output + + def run_timing_with_profiler(self, forward_fn: Callable, backward_fn: Callable, + rank_id: int = 0) -> Tuple[float, float, Any]: + """Run timing using torch.profiler for detailed analysis.""" + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + # Run profiler with timing + torch.cuda.synchronize() + + with profile(activities=activities, record_shapes=True, with_stack=True) as prof: + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_end = time.perf_counter() + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_end = time.perf_counter() + + torch.cuda.synchronize() + + # Calculate timing from our measurements + forward_time = forward_end - forward_start + backward_time = backward_end - backward_start + + if rank_id == 0: + self._print_profiler_results(prof) + + return forward_time, backward_time, output + + def run_timing_simple(self, forward_fn: Callable, backward_fn: Callable) -> Tuple[float, float, Any]: + """Run simple timing without warmup or profiling.""" + torch.cuda.synchronize() + forward_start = time.perf_counter() + output = forward_fn() + torch.cuda.synchronize() + forward_time = time.perf_counter() - forward_start + + torch.cuda.synchronize() + backward_start = time.perf_counter() + backward_fn(output) + torch.cuda.synchronize() + backward_time = time.perf_counter() - backward_start + + return forward_time, backward_time, output + + def _print_profiler_results(self, prof): + """Print profiler results with fallback for different PyTorch versions.""" + print("\n" + "="*60) + print("TORCH PROFILER RESULTS") + print("="*60) + + try: + # Try the most common sorting options + events = prof.key_averages() + table_str = events.table(sort_by="self_cuda_time_total", row_limit=20) + print(table_str) + except Exception as e1: + try: + table_str = events.table(sort_by="cuda_time_total", row_limit=20) + print(table_str) + except Exception as e2: + try: + table_str = events.table(sort_by="self_cpu_time_total", row_limit=20) + print(table_str) + except Exception as e3: + print(f"Warning: Could not generate profiler table due to API differences") + print(f"Errors: {e1}, {e2}, {e3}") + + # Fallback: print basic event info + print("Available profiler events:") + for i, event in enumerate(events): + if i >= 10: # Limit output + break + try: + print(f" {event.key}: CPU time = {getattr(event, 'cpu_time_total', 'N/A')} us") + except: + print(f" {event.key}: [timing info unavailable]") + + print("="*60 + "\n") + + def create_timing_functions(self, inputs, config, dout_tensor): + """Create timing functions for single and parallel execution.""" + # Single mode functions + def single_forward(): + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + # Run single GPU reference + output, grad_tensors = self.run_single_gpu_reference(single_inputs, config) + return output, (single_inputs, grad_tensors) + + def single_backward(outputs): + output, (single_inputs, grad_tensors) = outputs + output.backward(dout_tensor) + return dout_tensor + + # Parallel mode functions + model = self.create_test_module(config) + dummy_args = self.get_dummy_forward_args(inputs) + + from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + world_size = dist.get_world_size() + + parallel_model = parallelize( + model, + dummy_forward_args=dummy_args, + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + def parallel_forward(): + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + output = parallel_model(**para_inputs) + return output, para_inputs + + def parallel_backward(outputs): + output, para_inputs = outputs + output.backward(dout_tensor) + parallel_model.sync_grad() + return dout_tensor + + return single_forward, single_backward, parallel_forward, parallel_backward + + def calculate_throughput_metrics(self, config, forward_time: float, backward_time: float) -> Dict[str, float]: + """Calculate throughput and efficiency metrics.""" + total_time = forward_time + backward_time + + # Calculate total tokens processed + if hasattr(config, 'total_tokens'): + total_tokens = config.total_tokens + else: + total_tokens = config.batch_size * config.max_seqlen + + throughput = total_tokens / total_time if total_time > 0 else 0 + + return { + 'total_tokens': total_tokens, + 'throughput_tokens_per_sec': throughput, + 'total_time': total_time, + 'forward_time': forward_time, + 'backward_time': backward_time + } + + def print_benchmark_results(self, config_name: str, config, dtype: str, + single_metrics: Dict[str, float], + parallel_metrics: Dict[str, float], + world_size: int, rank_id: int): + """Print comprehensive benchmark results.""" + if rank_id != 0: + return + + print("\n" + "="*80) + print(f"{self.get_benchmark_name().upper()} PERFORMANCE BENCHMARK ({self.timing_method.upper()} METHOD)") + print(f"Configuration: {config_name} - {config.name}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Batch size: {config.batch_size}") + print(f" Heads: {config.num_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size} GPUs") + print(f" Total tokens: {single_metrics['total_tokens']:,}") + + if self.timing_method == "warmup": + print(f" (Warmup runs: {self.warmup_runs}, Timing runs: {self.timing_runs})") + print("="*80) + + # Timing results + print(f"Single Mode:") + print(f" Forward time: {single_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {single_metrics['backward_time']:.6f} seconds") + print(f" Total time: {single_metrics['total_time']:.6f} seconds") + print(f" Throughput: {single_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + print(f"\nParallel Mode:") + print(f" Forward time: {parallel_metrics['forward_time']:.6f} seconds") + print(f" Backward time: {parallel_metrics['backward_time']:.6f} seconds") + print(f" Total time: {parallel_metrics['total_time']:.6f} seconds") + print(f" Throughput: {parallel_metrics['throughput_tokens_per_sec']:.0f} tokens/sec") + + # Speedup calculations + forward_speedup = single_metrics['forward_time'] / parallel_metrics['forward_time'] if parallel_metrics['forward_time'] > 0 else 0 + backward_speedup = single_metrics['backward_time'] / parallel_metrics['backward_time'] if parallel_metrics['backward_time'] > 0 else 0 + total_speedup = single_metrics['total_time'] / parallel_metrics['total_time'] if parallel_metrics['total_time'] > 0 else 0 + throughput_improvement = parallel_metrics['throughput_tokens_per_sec'] / single_metrics['throughput_tokens_per_sec'] if single_metrics['throughput_tokens_per_sec'] > 0 else 0 + + print(f"\nSpeedup:") + print(f" Forward speedup: {forward_speedup:.2f}x") + print(f" Backward speedup: {backward_speedup:.2f}x") + print(f" Total speedup: {total_speedup:.2f}x") + print(f" Throughput improvement: {throughput_improvement:.2f}x") + + # Efficiency metrics + theoretical_speedup = world_size + efficiency = total_speedup / theoretical_speedup * 100 if theoretical_speedup > 0 else 0 + print(f"\nEfficiency:") + print(f" Theoretical speedup: {theoretical_speedup:.0f}x") + print(f" Actual speedup: {total_speedup:.2f}x") + print(f" Parallel efficiency: {efficiency:.1f}%") + print("="*80 + "\n") + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Run performance benchmark for the specific attention implementation.""" + # Setup timing parameters + self.timing_method = timing_method + self.warmup_runs = warmup_runs + self.timing_runs = timing_runs + + # Initialize distributed environment + world_size, rank = self.initialize_distributed() + rank_id = dist.get_rank() + + # Get configuration + config = get_config(config_name) if config_name else self._create_legacy_config(**legacy_kwargs) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Running {self.get_benchmark_name()} performance benchmark...") + print(f"Configuration: {config.name if hasattr(config, 'name') else 'custom'}") + + # Prepare inputs + device = torch.device(f"cuda:{rank_id}") + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Pre-generate dout tensor for timing consistency + with torch.no_grad(): + dummy_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + dummy_inputs[k] = v.detach() + else: + dummy_inputs[k] = v + dummy_out, _ = self.run_single_gpu_reference(dummy_inputs, config) + dout_tensor = torch.randn_like(dummy_out, device=device, dtype=torch_dtype) + dist.broadcast(dout_tensor, src=0) + + # Create timing functions + single_forward, single_backward, parallel_forward, parallel_backward = self.create_timing_functions( + inputs, config, dout_tensor + ) + + if rank_id == 0: + print(f"Running performance benchmark using {timing_method} method...", end="") + + # Run timing based on method + if timing_method == "profiler": + single_forward_time, single_backward_time, _ = self.run_timing_with_profiler( + single_forward, single_backward, rank_id + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_profiler( + parallel_forward, parallel_backward, rank_id + ) + elif timing_method == "warmup": + single_forward_time, single_backward_time, _ = self.run_timing_with_warmup( + single_forward, single_backward, warmup_runs, timing_runs + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_with_warmup( + parallel_forward, parallel_backward, warmup_runs, timing_runs + ) + else: # simple + single_forward_time, single_backward_time, _ = self.run_timing_simple( + single_forward, single_backward + ) + parallel_forward_time, parallel_backward_time, _ = self.run_timing_simple( + parallel_forward, parallel_backward + ) + + if rank_id == 0: + print(" Done!") + + # Calculate metrics and print results + single_metrics = self.calculate_throughput_metrics(config, single_forward_time, single_backward_time) + parallel_metrics = self.calculate_throughput_metrics(config, parallel_forward_time, parallel_backward_time) + + self.print_benchmark_results( + config_name or "custom", config, dtype, + single_metrics, parallel_metrics, world_size, rank_id + ) + + # Cleanup + dist.destroy_process_group() + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters.""" + class LegacyConfig: + def __init__(self, **kwargs): + self.name = "legacy_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Add other default attributes as needed + + return LegacyConfig(**kwargs) + + def list_configurations(self): + """List all available configurations for benchmarking.""" + print("Available Ring Attention Configurations:") + print("=" * 50) + + for category in ["small", "medium", "large", "gqa"]: + print(f"\n{category.upper()} CONFIGS:") + configs = get_configs_by_category(category) + if configs: + for name, config in configs.items(): + tokens_k = config.total_tokens // 1000 + gqa_info = f" (GQA {config.num_heads}->{config.num_kv_heads})" if config.is_gqa else "" + causal_info = " [Causal]" if config.causal else " [Non-causal]" + window_info = f" [Window={config.window_size[0]},{config.window_size[1]}]" if config.window_size != (-1, -1) else "" + print(f" {name:20s} - {config.batch_size}x{config.num_heads}x{config.head_dim}, seq={config.max_seqlen}, tokens={tokens_k}K, {config.dtype}{gqa_info}{causal_info}{window_info}") + else: + print(" No configurations in this category") + + print(f"\nDEFAULT PERFORMANCE CONFIGS: {DEFAULT_PERFORMANCE_CONFIGS}") + print(f"\nUsage: Use --config to specify a configuration") \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn.py b/benchmark/benchmark_ring_attn.py new file mode 100644 index 00000000..50929729 --- /dev/null +++ b/benchmark/benchmark_ring_attn.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for standard Ring Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def function_name(self) -> str: + return "ring_attn" + + def get_benchmark_name(self) -> str: + return "Ring Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for standard ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + return wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for standard ring attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + output = wrap_ring_attn_func( + q, k, v, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for standard ring attention.""" + class LegacyRingAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_ring_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + return LegacyRingAttnConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_ring_attn_varlen.py b/benchmark/benchmark_ring_attn_varlen.py new file mode 100644 index 00000000..97c4c6fc --- /dev/null +++ b/benchmark/benchmark_ring_attn_varlen.py @@ -0,0 +1,222 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import ring attention implementation +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS + + +class RingAttnVarlenBenchmark(RingAttnBenchmarkBase): + """Benchmark for Ring Attention Variable Length""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def function_name(self) -> str: + return "ring_attn_varlen" + + def get_benchmark_name(self) -> str: + return "Ring Attention Variable Length" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for variable length ring attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + return wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for variable length sequence attention.""" + set_seed(42) + + # Get cu_seqlens from config or create default + if hasattr(config, 'cu_seqlens'): + cu_seqlens = config.cu_seqlens + else: + # Create default variable length sequences + seqlen = config.max_seqlen + cu_seqlens = [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen] + + cu_seqlens_tensor = torch.tensor(cu_seqlens, dtype=torch.int32, device=device) + total_tokens = cu_seqlens[-1] + + # Create input tensors + q = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + k = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + v = torch.randn(total_tokens, config.num_heads, config.head_dim, device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + cu_seqlens_q = inputs['cu_seqlens_q'] + cu_seqlens_k = inputs['cu_seqlens_k'] + + output = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=getattr(config, 'causal', True), + window_size=getattr(config, 'window_size', (-1, -1)) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for varlen.""" + class LegacyVarlenConfig: + def __init__(self, **kwargs): + self.name = "legacy_varlen_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.dtype = "bf16" + self.causal = True + self.window_size = (-1, -1) + + # Create variable length sequences + seqlen = self.max_seqlen + self.cu_seqlens = kwargs.get('cu_seqlens', [0, seqlen // 8, seqlen // 4, seqlen // 2, seqlen]) + self.total_tokens = self.cu_seqlens[-1] + + return LegacyVarlenConfig(**kwargs) + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Ring Attention Variable Length Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Total sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (number of sequences) (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = RingAttnVarlenBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/benchmark/benchmark_zigzag_attn.py b/benchmark/benchmark_zigzag_attn.py new file mode 100644 index 00000000..94e99521 --- /dev/null +++ b/benchmark/benchmark_zigzag_attn.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag Attention Performance Benchmark +Uses the shared benchmark framework to reduce code duplication. +""" + +import argparse +import sys +import os +import torch + +# Import the benchmark base class +from benchmark_base import RingAttnBenchmarkBase + +# Import zigzag attention implementation +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func +from nnscaler.customized_ops.ring_attention.core.utils import set_seed + +# Import test configuration (via the base class path setup) +tests_dir = os.path.join(os.path.dirname(__file__), "../tests/customized_ops/ring_attn") +sys.path.insert(0, tests_dir) +from configs import DEFAULT_PERFORMANCE_CONFIGS, ZIGZAG_CONFIGS + + +class ZigzagAttnBenchmark(RingAttnBenchmarkBase): + """Benchmark for Zigzag Attention""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func' + + @property + def function_name(self) -> str: + return "zigzag_attn" + + def get_benchmark_name(self) -> str: + return "Zigzag Attention" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention.""" + class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, q, k, v): + # Zigzag attention only supports causal=True and window_size=(-1,-1) + return wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return TestModule() + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors for zigzag attention.""" + set_seed(42) + + # Create input tensors with standard batch format + q = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + k = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + v = torch.randn(config.batch_size, config.max_seqlen, config.num_heads, config.head_dim, + device=device, dtype=torch_dtype) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation.""" + q, k, v = inputs['q'], inputs['k'], inputs['v'] + + # Zigzag attention constraints + output = wrap_zigzag_attn_func( + q, k, v, + causal=True, + window_size=(-1, -1) + ) + + return output, [q, k, v] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization.""" + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + return dummy_args + + def _create_legacy_config(self, **kwargs): + """Create a legacy configuration from individual parameters for zigzag attention.""" + class LegacyZigzagAttnConfig: + def __init__(self, **kwargs): + self.name = "legacy_zigzag_attn_custom" + self.max_seqlen = kwargs.get('seqlen', 16384) + self.num_heads = kwargs.get('nheads', 24) + self.head_dim = kwargs.get('head_dim', 128) + self.batch_size = kwargs.get('batch_size', 4) + self.total_tokens = self.batch_size * self.max_seqlen + self.dtype = "bf16" + # Zigzag attention constraints + self.causal = True + self.window_size = (-1, -1) + + return LegacyZigzagAttnConfig(**kwargs) + + def run_performance_benchmark(self, config_name: str = None, dtype: str = "bf16", + timing_method: str = "warmup", warmup_runs: int = 3, + timing_runs: int = 5, **legacy_kwargs): + """Override to validate zigzag attention constraints.""" + # Validate configuration for zigzag constraints + if config_name: + from configs import get_config + config = get_config(config_name) + if not config.causal: + print(f"WARNING: Config '{config_name}' has causal=False, but zigzag attention requires causal=True") + print("Proceeding with causal=True for zigzag attention...") + if config.window_size != (-1, -1): + print(f"WARNING: Config '{config_name}' has window_size={config.window_size}, but zigzag attention requires (-1, -1)") + print("Proceeding with window_size=(-1, -1) for zigzag attention...") + + # Call parent implementation + super().run_performance_benchmark( + config_name=config_name, dtype=dtype, timing_method=timing_method, + warmup_runs=warmup_runs, timing_runs=timing_runs, **legacy_kwargs + ) + + def list_configurations(self): + """List configurations suitable for zigzag attention.""" + print("Available Zigzag Attention Configurations:") + print("=" * 50) + print("NOTE: Zigzag attention only supports causal=True and window_size=(-1,-1)") + print("Configurations listed below will be automatically adjusted for these constraints.\n") + + # Call parent method but with zigzag-specific note + super().list_configurations() + + print(f"\nZIGZAG-SPECIFIC CONFIGS: {list(ZIGZAG_CONFIGS.keys())}") + print("These configs are specifically designed for zigzag attention.") + + +def main(): + """Main entry point for the benchmark.""" + parser = argparse.ArgumentParser(description="Zigzag Attention Performance Benchmark") + parser.add_argument( + "--config", + type=str, + default=None, + help="Predefined configuration name. Use --list-configs to see available options.", + ) + parser.add_argument( + "--list-configs", + action="store_true", + help="List all available predefined configurations", + ) + parser.add_argument( + "--dtype", + type=str, + default="bf16", + choices=["fp16", "bf16"], + help="Data type for inputs", + ) + # Legacy parameters for custom configurations + parser.add_argument( + "--seqlen", + type=int, + default=None, + help="Sequence length (overridden by --config)", + ) + parser.add_argument( + "--nheads", + type=int, + default=None, + help="Number of attention heads (overridden by --config)", + ) + parser.add_argument( + "--head-dim", + type=int, + default=None, + help="Head dimension (overridden by --config)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=None, + help="Batch size (overridden by --config)", + ) + # Timing parameters + parser.add_argument( + "--timing-method", + type=str, + default="warmup", + choices=["simple", "profiler", "warmup"], + help="Timing method: simple (basic timing), profiler (torch.profiler with detailed analysis), warmup (recommended: warm-up + multiple runs)", + ) + parser.add_argument( + "--warmup-runs", + type=int, + default=3, + help="Number of warm-up runs before timing (for warmup method)", + ) + parser.add_argument( + "--timing-runs", + type=int, + default=5, + help="Number of timing runs to average (for warmup method)", + ) + + args = parser.parse_args() + + # Create benchmark instance + benchmark = ZigzagAttnBenchmark() + + if args.list_configs: + benchmark.list_configurations() + else: + benchmark.run_performance_benchmark( + config_name=args.config, + dtype=args.dtype, + timing_method=args.timing_method, + warmup_runs=args.warmup_runs, + timing_runs=args.timing_runs, + # Legacy parameters + seqlen=args.seqlen, + nheads=args.nheads, + head_dim=args.head_dim, + batch_size=args.batch_size, + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/dev.md b/dev.md index f42fcb04..0960bde9 100644 --- a/dev.md +++ b/dev.md @@ -92,4 +92,4 @@ Another trick is, if you want to step into pakcage source code, you can add the ### Write Unit Tests 1. If you need to use torchrun, please refer to `unit_test/launch_torchrun.py`, and you can find examples in `unit_tests/runtime/test_runtime_collectives.py`. Please note that `torchrun` is very slow, you should reduce its usage as possible. 2. If you want to mock up any functions/methods, please use pytest-mock. -3. **NOTE**: The name of test files and test functions must start with `test_` +3. **NOTE**: The name of test files and test functions must start with `test_` \ No newline at end of file diff --git a/docs/source/einops.md b/docs/source/einops.md new file mode 100644 index 00000000..e8103add --- /dev/null +++ b/docs/source/einops.md @@ -0,0 +1,38 @@ +# einops Support in NnScaler +================================= + +Tracing einops Functions are challenging due to their dynamic nature and heavy reliance on string-based patterns and runtime shape manipulations. It is challenging to statically analyze and trace these operations accurately, because tracing doesn't work well with complex python logic (e.g. string parsing, dynamic shape computations, loops, etc) involved in einops functions. + +To make things easier, we skip tracing the internal logic of einops functions and directly use the resolved transformation recipes. + +This is done by skipping tracing internal einops function: `_prepare_transformation_recipe`. In future, if einops changes their internal implementation, we may need to update our patching logic accordingly. + +For nnscaler, we may skip more functions in the future if needed. For exmaple, `_reconstruct_from_shape_uncached` and `_reconstruct_from_shape` are also candidates for skipping tracing, but currently we haven't found issues without skipping them. Once we find issues related to them, we will skip tracing them as well. + +As a result, when you use einops functions in your model, we can't guarantee that the traced recipe will be valid when their parameters are changed (e.g. input shapes or pattern strings. `compute_config.constant_folding=False` doesn't help here). + +Currently we haven't encountered problems in our tests, but it's still possible in some corner cases. If you encounter any problems, please report an issue to us. + +Here is an example of using einops in a model with NnScaler: + +```python +import torch +import torch.nn as nn +import einops +from nnscaler import nnscaler, ComputeConfig + +class EinopsModel(nn.Module): + def __init__(self): + ... + + def forward(self, x): + # this is good, because the pattern and the input shape is static (h/w/c are fixed) + x = einops.rearrange(x, 'b (h w c) -> b c h w', h=4, w=4, c=1) + ... + y = ... + # this depends on y + # although dependence maintains properly if you set `compute_config.constant_folding=False`, + # This can be changed in future. So be cautious when using such patterns. + x = einops.rearrange(x, 'b c h w -> b (h w c)', b=y) + ... +``` diff --git a/docs/source/installation.rst b/docs/source/installation.rst index 5e5deb7d..2e26a9af 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -12,7 +12,7 @@ The wheel package is hosted on `GitHub release `_. + +If you are familiar with Azure stuffs, you can follow DevOps' guide to set up the repository. + +Or if you prefer the simpler way, download the ``.whl`` file in the "Files" section of the website, +and install it locally: + +:: + + python -m pip install nnscaler-*.whl + +********** +Quickstart +********** + +The next step depends on your choice of the training framework. + +- **No framework**: if you write your own training code and do not use a framework, + see :ref:`Parallelize API` section. +- **Fairseq**: if you use fairseq, see :ref:`Fairseq` section. +- **Lightning**: TODO + +.. _Parallelize API: + +Parallelize API +=============== + +TODO: write a hello world example, assigned to Zhe Liu + +If you write your own training code, you can use the *parallelize* API to make your model parallel: + +.. code-block:: python + + import torch + from nnscaler import parallelize, ComputeConfig, build_optimizer + + class LLM(torch.nn.Module): + def __init__(self, ...): + ... + def forward(self, x): + ... + + llm_sample_input = ... # dummpy input will be used to do tracing + pas_policy = ... # the PAS policy, you can use autodist pas + compute_config = ComputeConfig( + plan_ngpus=..., + runtime_ngpus=..., + use_zero=..., + ..., + ) # compute environment config + ParallelizedLLM = parallelize( + LLM, + {'x': llm_sample_input}, + pas_policy, + compute_config, + ) + +Example +------- + +An example of the parallelize API is provided in the repo: +`train.py `_ + +You can download and try it: :: + + torchrun --nproc_per_node=4 --nnodes=1 train.py + +Documentation +------------- + +If the example works for you, you can now follow the documentation to parallelize your model: +:doc:`parallel_module` + +.. _Fairseq: + +Fairseq (To be retired) +======================= + +.. TODO: + + nnScaler provides `fairseq integration `_. + + TODO: refine the example (and its doc), assigned to Youshan Miao + + TODO (long term): write an example using unmodified fairseq + + Installation + ------------ + + To use fairseq, clone the fork and install it: :: + + python -m pip uninstall fairseq + + git clone https://msrasrg.visualstudio.com/SuperScaler/_git/Fairseq + cd Fairseq + python -m pip install -e . + + Example + ------- + + Follow the example + `here `_. + diff --git a/examples/llama/create_mini_model.py b/examples/llama/create_mini_model.py index 1151771e..514ac3ac 100644 --- a/examples/llama/create_mini_model.py +++ b/examples/llama/create_mini_model.py @@ -3,9 +3,11 @@ import argparse from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +import torch def main(args): + torch.manual_seed(0) # Ensure deterministic initialization config = AutoConfig.from_pretrained(args.model_id) config.num_hidden_layers = 4 config.use_cache = False diff --git a/examples/warmup_scheduler.py b/examples/warmup_scheduler.py new file mode 100644 index 00000000..54e8aa7f --- /dev/null +++ b/examples/warmup_scheduler.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import math +from torch.optim.lr_scheduler import LRScheduler, Optimizer, _warn_get_lr_called_within_step + + +class WarmupCosineAnnealingLR(LRScheduler): + r""" + torch.optim.lr_scheduler.CosineAnnealingLR with warmup. + + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_steps (int): Number of warmup steps. + T_max (int): Maximum number of iterations. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + verbose (bool | str): If ``True``, prints a message to stdout for + each update. Default: ``False``. + + .. deprecated:: 2.2 + ``verbose`` is deprecated. Please use ``get_last_lr()`` to access the + learning rate. + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_steps: int, + T_max: int, + eta_min=0.0, + last_epoch=-1, + ): # noqa: D107 + self.warmup_steps = warmup_steps + self.T_max = T_max - warmup_steps + 1 + self.eta_min = eta_min + super().__init__(optimizer, last_epoch) + + def get_lr(self): + """Retrieve the learning rate of each parameter group.""" + _warn_get_lr_called_within_step(self) + + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + elif last_epoch_wo_warmup == 0: + return [base_lr for base_lr in self.base_lrs] + elif self._step_count == 1 and last_epoch_wo_warmup > 0: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos((last_epoch_wo_warmup) * math.pi / self.T_max)) + / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + elif (last_epoch_wo_warmup - 1 - self.T_max) % (2 * self.T_max) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + return [ + (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / (1 + math.cos(math.pi * (last_epoch_wo_warmup - 1) / self.T_max)) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self): + last_epoch_wo_warmup = self.last_epoch - self.warmup_steps + 1 + if last_epoch_wo_warmup < 0: + return [base_lr * (self.last_epoch + 1) / self.warmup_steps for base_lr in self.base_lrs] + else: + return [ + self.eta_min + + (base_lr - self.eta_min) + * (1 + math.cos(math.pi * last_epoch_wo_warmup / self.T_max)) + / 2 + for base_lr in self.base_lrs + ] diff --git a/nnscaler/__init__.py b/nnscaler/__init__.py index 2bf5867a..b3a18165 100644 --- a/nnscaler/__init__.py +++ b/nnscaler/__init__.py @@ -16,6 +16,8 @@ broadcast_weights, load_sharded_state_dict, sync_grad_when, + trimmed_broadcast_merged_state_dict, + load_merged_state_dict_from_rank, ) from nnscaler.graph.parser.register import register_op from nnscaler.runtime.function.function import ( @@ -24,6 +26,12 @@ no_constant_folding, fold_constant, ) +from nnscaler.runtime.f16_optimizer import MixedPrecisionAdam, MixedPrecisionAdamW +from nnscaler.runtime.hybrid_optimizer import HybridLRScheduler, HybridOptimizer +from nnscaler.utils import ( + mark_dynamic, + get_dynamic, +) def init(): diff --git a/nnscaler/autodist/autodist_config.py b/nnscaler/autodist/autodist_config.py index 796d35e2..de3eb351 100644 --- a/nnscaler/autodist/autodist_config.py +++ b/nnscaler/autodist/autodist_config.py @@ -72,6 +72,12 @@ class AutoDistConfig: `x.module1` will match `x.module1` but not `y.module1`. Due to constraint of the tracer, you can pass `ROOT` to recompute_modules if you want the whole module to be recomputed. + - recompute_ratio ('float`, *optional*, defaults to `1.0`): + When `recompute_modules` only contains one name (excluding `ROOT`), this specify the ratio of modules + to be recomputed. For example, if `module1` is specified in `recompute_modules` and `recompute_ratio` is `0.8`, + only 80% of `module1` instances will be recomputed. + If there are multiple module names in `recompute_modules`, this field will be ignored and all specified modules + will be recomputed. - memory_constraint (`float`, *optional*, defaults to `32`): The memory constraint in each device in GB. - memory_granularity (`int`, *optional*, defaults to `1`): @@ -115,6 +121,10 @@ class AutoDistConfig: `transient_mem_size = opt_transient_coef * (1st_largest_infer_mem + 2nd_largest_infer_mem)`. This formula is useful in many cases, but it may be too strict when some operators consume or generate a large tensor (>= 4GB). In this case, you can set `transient_mem_coef` to a smaller value to relax the constraint. + - disable_shared_param_constraint (`bool`, *optional*, defaults to `False`): + Whether to disable the shared parameter constraint in spmd solver. When a parameter is shared by multiple modules, + the spmd solver will force the parameter to be replicated to complicated adapter generation. However, user can disable + it and provide customized partition constraints for those shared parameters. """ def __init__(self, @@ -133,6 +143,7 @@ def __init__(self, mesh_row=1, mesh_col=1, recompute_modules='', + recompute_ratio=1.0, memory_constraint=32, memory_granularity=1, micro_batch_size=1, @@ -150,6 +161,7 @@ def __init__(self, solver='dp', parallel_profile=True, transient_mem_coef=2, + disable_shared_param_constraint=False, **kwargs): self.pc_path = partition_constraints_path self.profile_dir = profile_dir @@ -166,6 +178,7 @@ def __init__(self, self.is_train = is_train self.mesh_desc = MeshDesc(mesh_row, mesh_col) self.recompute_modules = recompute_modules + self.recompute_ratio = recompute_ratio # from GB to Byte self.memory_constraint = int(memory_constraint * 1024 * 1024 * 1024) self.memory_granularity = memory_granularity @@ -192,6 +205,7 @@ def __init__(self, self.solver = 'dp' self.parallel_profile = parallel_profile self.transient_mem_coef = transient_mem_coef + self.disable_shared_param_constraint = disable_shared_param_constraint ignored_keys = list(kwargs.keys()) if ignored_keys: @@ -244,7 +258,7 @@ def _validate_config(self): scale_factor = self.world_size // self.mesh_desc.ngpus if scale_factor % self.zero_ngroups != 0: raise ValueError( - f'world size {self.world_size} must be divisible by zero num groups {self.zero_ngroups}' + f'scale_factor {scale_factor} must be divisible by zero num groups {self.zero_ngroups}' ) if not self.solver in [ diff --git a/nnscaler/autodist/cost_database.py b/nnscaler/autodist/cost_database.py index 2e003359..97c446c3 100644 --- a/nnscaler/autodist/cost_database.py +++ b/nnscaler/autodist/cost_database.py @@ -119,27 +119,37 @@ def _load_comm_data(profile_dir: Path, plan_ngpus: int) -> Dict[str, Dict[str, L load intra_2.json, intra_4.json, and intra_8.json from the profile directory. If any of the files is not found, we will use the default data as well. ''' - def loader(path: Path): + def loader(path: Path, strict: bool): if not os.path.exists(path): return False, None info = {} dev = 2 + prev_info = None while dev <= plan_ngpus: fname = f'intra_{dev}.json' if not (path / fname).exists(): - return False, None - with open(path / fname, 'r') as f: - info[fname] = json.load(f) + if strict or prev_info is None: + return False, None + else: + content = prev_info + _logger.warning(f'{dev} devices communication profile data not found, using previous data') + else: + with open(path / fname, 'r') as f: + content = json.load(f) + prev_info = content + info[fname] = content dev *= 2 return True, info comm_path = profile_dir / 'comm' - success, comm_info = loader(comm_path) + success, comm_info = loader(comm_path, strict=True) if not success: + # When communication profile data is not found, use the default data. If the input `plan_ngpus` is greater + # than the devices in the profile data, the data with largest device count (16 for mi200) will be used. This + # is helpful when user wants to generate a distributed plan spanning over multiple nodes. _logger.warning(f'Communication profile data not found, using default data at {_DEFAULT_COMM_DATA_PATH}') - success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH)) - if not success: - raise RuntimeError(f'Communication profile data is not compatible with plan_ngpus {plan_ngpus}') + success, comm_info = loader(Path(_DEFAULT_COMM_DATA_PATH), strict=False) + assert success, f'Failed to load default communication profile data from {_DEFAULT_COMM_DATA_PATH}, please check nnscaler\'s installation' return comm_info @@ -337,10 +347,14 @@ def query_single_mem(self, obj, memory_type, round=True) -> int: from .op_partition import OpPartition from .cube_operator import CubeOperator if isinstance(obj, OpPartition): - masks = self.gen_masks(obj.operator) + query_obj = obj.operator else: assert isinstance(obj, CubeOperator) - masks = self.gen_masks(obj) + query_obj = obj + try: + masks = self.gen_masks(query_obj) + except Exception as e: + raise RuntimeError(f"Failed to generate masks for {query_obj} with {self.query_profiled_metrics(query_obj)}: {e}") if memory_type == 'full_weight' and isinstance(obj, OpPartition): profiled_metrics = self.query_profiled_metrics(obj.operator) else: diff --git a/nnscaler/autodist/cube_operator.py b/nnscaler/autodist/cube_operator.py index f25f1ab5..b7d3c2dc 100644 --- a/nnscaler/autodist/cube_operator.py +++ b/nnscaler/autodist/cube_operator.py @@ -108,7 +108,7 @@ def collect_anno_info(self): for idx_dim, dim_anno in enumerate(shape_anno.dims): for idx_id, identifier in enumerate(dim_anno.identifiers): reduce_type = dim_anno.reduces[idx_id] - if reduce_type != DimAnno.ReduceType.Freeze: + if reduce_type != DimAnno.ReduceType.Freeze and self.ir_cell.input(idx_shape).dim_tracks[idx_dim].is_constant: self.parallelable_dims.add(identifier) if reduce_type == DimAnno.ReduceType.Sum: self._has_sum_dim = True diff --git a/nnscaler/autodist/dp_solver.cpp b/nnscaler/autodist/dp_solver.cpp index f0b8e6d7..3b8bc015 100644 --- a/nnscaler/autodist/dp_solver.cpp +++ b/nnscaler/autodist/dp_solver.cpp @@ -67,7 +67,7 @@ void ThreadPool::waitFinished() { cv_finished.wait(lock, [this]() { return tasks.empty() && (busy == 0); }); } -const int MAX_CONCURRENCY = std::thread::hardware_concurrency(); +int MAX_CONCURRENCY = std::thread::hardware_concurrency(); ThreadPool pool(MAX_CONCURRENCY); std::vector> split_work(int num, int base) { @@ -118,6 +118,11 @@ class DPSolver { queries.clear(); id2node.clear(); search_results.clear(); + if (verbose) { + MAX_CONCURRENCY = 1; + std::cout << "set MAX_CONCURRENCY to 1 for verbose mode" + << std::endl; + } } void add_interval(int start, int end) { @@ -230,6 +235,31 @@ class DPSolver { } } + int encode_ir(const std::vector> &cur_ir) { + int val = 0; + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + val += cur_ir[j].second; + if (j + 1 < cur_ir.size()) { + val *= id2node[cur_ir[j + 1].first]->p_num; + } + } + return val; + } + + void print_ir(const std::vector> &cur_ir) { + for (std::size_t j = 0; j < cur_ir.size(); ++j) { + std::cout << "(" << cur_ir[j].first << ", " << cur_ir[j].second << ") "; + } + std::cout << std::endl; + } + + void print_states(DPNode *dp_node) { + for (std::size_t i = 0; i < dp_node->state.size(); ++i) { + UnitDPState state = dp_node->state[i]; + std::cout << "state " << i << ": " << state.to_string() << std::endl; + } + } + // lazy build edge void buildInEdges(DPNode *dp_node) { if (!dp_node->in_edges.empty()) { @@ -361,15 +391,14 @@ class DPSolver { break; } } + bool need_add_pre_node = false; if (!find_pre_id) { Node *pre_node = id2node[node->id - 1]; if (pre_node->father_id != node->father_id) { - // do nothing, means the pre_node's output is not used - // we select the 1st partition of the pre_node - // need to be careful when the graph has multiple outputs if (!has_found_follow && !follow_candidates.empty()) { cur_ir.push_back(*follow_candidates.rbegin()); } + need_add_pre_node = true; } else if (pre_node->father_id == pre_node->id) { assert(follow_candidates.rbegin()->first == pre_node->id); cur_ir.push_back(*follow_candidates.rbegin()); @@ -391,15 +420,36 @@ class DPSolver { } } std::sort(cur_ir.begin(), cur_ir.end()); - val = 0; - for (std::size_t j = 0; j < cur_ir.size(); ++j) { - val += cur_ir[j].second; - if (j + 1 < cur_ir.size()) { - val *= id2node[cur_ir[j + 1].first]->p_num; + if (verbose) { + std::cout << "need_add_pre_node: " << need_add_pre_node << std::endl; + } + if (need_add_pre_node) { + // means the pre_node's output is not used by later nodes, + // so we need to enumerate all the partition states of pre_node + if (verbose) { + std::cout << "p_num " << id2node[node->id - 1]->p_num << std::endl; + } + for (int pred_p = 0; pred_p < id2node[node->id - 1]->p_num; + ++pred_p) { + cur_ir.push_back(std::make_pair(node->id - 1, pred_p)); + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); + } + cur_ir.pop_back(); + } + } else { + int val = encode_ir(cur_ir); + dp_node->in_edges.push_back( + std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); + if (verbose) { + print_ir(cur_ir); + print_states(id2node[node->id - 1]->dp_nodes[val]); } } - dp_node->in_edges.push_back( - std::make_pair(id2node[node->id - 1]->dp_nodes[val], cost)); } } @@ -414,6 +464,9 @@ class DPSolver { return; } + if (verbose) { + std::cout << "before update, cur_p " << cur_p << std::endl; + } // storing edges takes space, so we build edges when needed buildInEdges(dp_node); if (dp_node->in_edges.empty()) { @@ -468,6 +521,10 @@ class DPSolver { } } } + if (verbose) { + std::cout << "after update" << std::endl; + print_states(dp_node); + } } void do_dp(int start_level, int end_level) { diff --git a/nnscaler/autodist/model_graph.py b/nnscaler/autodist/model_graph.py index 44be8ca2..7c9eec82 100644 --- a/nnscaler/autodist/model_graph.py +++ b/nnscaler/autodist/model_graph.py @@ -718,6 +718,9 @@ def fetch_module(scope_node: ScopeNode, prefix: List[str]): modules = [self.scope_tree_root] else: modules = fetch_module(self.scope_tree_root, []) + if len(recompute_modules) == 1 and self.autodist_config.recompute_ratio < 1.0: + boundary = max(1, int(len(modules) * self.autodist_config.recompute_ratio)) + modules = modules[:boundary] train_mem = 0 for module in modules: diff --git a/nnscaler/autodist/spmd_solver.py b/nnscaler/autodist/spmd_solver.py index 1a1e0ec8..608e1fd2 100644 --- a/nnscaler/autodist/spmd_solver.py +++ b/nnscaler/autodist/spmd_solver.py @@ -261,6 +261,9 @@ def should_force_replica(operator: CubeOperator) -> bool: if len(consumers) == 1: continue _logger.info(f'find shared parameter {param} in {consumers}') + if self.autodist_config.disable_shared_param_constraint: + _logger.info(f'disable shared parameter constraint for {param}') + continue for consumer in consumers: if not isinstance(consumer, IRDimops): # always replicate non-dimops @@ -358,10 +361,15 @@ def is_valid_partition(operator: CubeOperator, p_ids: List[Any], if not selected_pc.replica_allowed: return False else: - allowed_pids = [ - operator.pos2dim_id(pos) - for pos in selected_pc.allowed_partition_dims - ] + allowed_pids = list() + for pos in selected_pc.allowed_partition_dims: + # When allowed dims in provided partition constraints are not correct generate warning + # If there is no valid partitions for the operator, the solver will throw exception later. + try: + cur_allowed_pid = operator.pos2dim_id(pos) + allowed_pids.append(cur_allowed_pid) + except Exception as e: + _logger.warning(f"Failed to get allowed partition id for {selected_pc}'s {pos}: {e}") if u not in allowed_pids: return False @@ -681,11 +689,10 @@ def calc_partition_cost(self, op_idx: int, partition_idx: int): bw_comm_time = 0 intra_time = micro_batch_num * (fw_comm_time + bw_comm_time) # double check the follow chain - if self.get_father_id(op_idx) == self.get_father_id( - producer) and intra_time == 0: - if src_p.operator.ir_cell.mirror is not None: - if self.p_fathers[op_idx][ - partition_idx] != self.p_fathers[producer][k]: + # if `intra_time` (forward + backward) is 0, we assume both partitions are in the same follow chain + if self.get_father_id(op_idx) == self.get_father_id(producer) and intra_time == 0: + if src_p.operator.ir_cell.mirror is not None and tgt_p.operator.ir_cell.mirror is not None: + if self.p_fathers[op_idx][partition_idx] != self.p_fathers[producer][k]: _logger.warning( f'Unexpected comm cost, set to inf: {src_p.ir_cell} to {tgt_p.ir_cell}' ) diff --git a/nnscaler/cli/__init__.py b/nnscaler/cli/__init__.py index 958e874f..d218f6f9 100644 --- a/nnscaler/cli/__init__.py +++ b/nnscaler/cli/__init__.py @@ -17,4 +17,6 @@ AggregatedOutputs, ) +from nnscaler.cli.serialization import register_format + from nnscaler.parallel import ComputeConfig diff --git a/nnscaler/cli/arg_parser.py b/nnscaler/cli/arg_parser.py index f5caab0d..1adf6b9c 100644 --- a/nnscaler/cli/arg_parser.py +++ b/nnscaler/cli/arg_parser.py @@ -3,6 +3,7 @@ import os import copy +import logging from typing import List, Optional, Tuple, Dict, Any, Union from dataclasses import dataclass, field, is_dataclass, asdict @@ -16,6 +17,7 @@ except ImportError: UnionType = None # for python < 3.10 +logger = logging.getLogger(__name__) _TYPE_KEY = '__type' _VALUE_TYPE_KEY = '__value_type' @@ -390,6 +392,8 @@ def _deserialize_object(value, value_type): else: raise ValueError(f"Failed to deserialize {value} to {value_type}") if _is_primitive_type(value_type): + if callable(value): + logger.warning(f'{value} is callable, converting to {value_type} may not work as expected.') return value_type(value) except Exception as ex: raise ValueError(f"Failed to deserialize {value} to {value_type}") from ex diff --git a/nnscaler/cli/checkpoint.py b/nnscaler/cli/checkpoint.py new file mode 100644 index 00000000..4d080620 --- /dev/null +++ b/nnscaler/cli/checkpoint.py @@ -0,0 +1,163 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This script provides functionality to convert a merged checkpoint or directory containing per-rank checkpoints into sharded checkpoints +suitable for distributed training with multiple GPUs. +Run this script with: + python -m nnscaler.cli.checkpoint distribute -f +where is the path to the merged checkpoint file or directory containing per-rank checkpoints, +and is the directory to save the sharded checkpoints. + +This script only for command line. +""" + +import logging +import os +import sys +from pathlib import Path + +import torch.distributed + +import nnscaler +from nnscaler.cli.trainer import Trainer, TrainerArgs +from nnscaler.parallel import _trim_module_merged_state_dict, _trim_optimizer_merged_state_dict + + +logger = logging.getLogger(__name__) + + +def _patch_distributed(): + groups = {} + + def is_initialized(): + return bool(groups) + + torch.distributed.is_initialized = is_initialized + + def init_process_group(*args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + groups[None] = list(range(world_size)) + + def get_rank(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + try: + return groups[group].index(int(os.environ['RANK'])) + except ValueError: + return -1 + + def get_world_size(group=None): + if group not in groups: + raise ValueError(f"Unknown group: {group}") + return len(groups[group]) + + def new_group(ranks=None, *args, **kwargs): + world_size = int(os.environ['WORLD_SIZE']) + if ranks is None or len(ranks) == world_size: + return + group_id = tuple(sorted(ranks)) + if group_id in groups: + return group_id + groups[group_id] = ranks + return group_id + + torch.distributed.get_rank = get_rank + torch.distributed.get_world_size = get_world_size + torch.distributed.init_process_group = init_process_group + torch.distributed.destroy_process_group = lambda: None + torch.distributed.new_group = new_group + torch.distributed.barrier = lambda *args, **kwargs: None + torch.distributed.all_gather = lambda *args, **kwargs: None + torch.distributed.broadcast_object_list = lambda *args, **kwargs: None + + +def _trim_merged_checkpoint(train_args: TrainerArgs, merged_state_dict, rank: int): + os.environ['RANK'] = str(rank) + os.environ['LOCAL_RANK'] = '0' + os.environ['WORLD_SIZE'] = str(train_args.compute_config.runtime_ngpus) + os.environ['GROUP_RANK'] = str(rank) + os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['TORCHELASTIC_RUN_ID'] = '0' # fake torchrun env + + sharded_state_dict = {k: v for k, v in merged_state_dict.items()} + + trainer = Trainer(train_args=train_args) + # enforce run mode to load module and optimizer + trainer.train_args.run_mode = 'run' + trainer._setup() + + sharded_state_dict['model'] = _trim_module_merged_state_dict( + trainer.model, merged_state_dict['model'], + device='cpu' + ) + sharded_state_dict['optimizer'] = _trim_optimizer_merged_state_dict( + trainer.model, trainer.optimizer._extra_state, merged_state_dict['optimizer'], + device='cpu' + ) + sharded_state_dict['train_args'] = train_args.to_dict() + sharded_state_dict['train_args'].setdefault('checkpoint', {})['save_type'] = 'sharded' + # discard rng_states for merged state dict + sharded_state_dict.pop('rng_states', None) + if 'dataloader' in sharded_state_dict and sharded_state_dict['dataloader'] is not None: + # keep dataloader state only when all ranks have the same state + dataloader_states = sharded_state_dict['dataloader'] + if all(dataloader_states[i] == dataloader_states[0] for i in range(1, len(dataloader_states))): + sharded_state_dict['dataloader'] = dataloader_states[0] + else: + sharded_state_dict.pop('dataloader') + + # make it sharded checkpoint + for module_path, m in trainer.model.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, nnscaler.ParallelModule): + m._add_extra_state(sharded_state_dict['model'], prefix) + return sharded_state_dict + + +def _distribute_checkpoint(train_args: TrainerArgs, from_: str, to_: str): + nnscaler.utils.set_default_logger_level(level=logging.INFO) + _patch_distributed() + resume_from = Path(from_) + save_to = Path(to_) + save_to.mkdir(parents=True, exist_ok=True) + checkpointer = train_args.create_checkpointer() + + if resume_from.is_file(): + state_dict = checkpointer.load(resume_from, device='cpu') + if convert_fn := train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + ckpt_files = checkpointer.list_checkpoints(resume_from) + rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} + if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): + raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") + state_dict = Trainer._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=checkpointer) + + for i in range(train_args.compute_config.runtime_ngpus): + sharded_state_dict = _trim_merged_checkpoint(train_args, state_dict, i) + checkpointer.save_for_rank(sharded_state_dict, save_to, i) + + checkpointer.flush() + + +if __name__ == '__main__': + argv = sys.argv[1:] + if len(argv) == 0: + raise ValueError("No command specified. Expected `distribute -f `") + if argv[0] == 'distribute': + if len(argv) < 5: + raise ValueError("Not enough arguments. Expected at least `distribute -f `") + from_ = argv[1] + to_ = argv[2] + train_args = TrainerArgs.from_cli(argv[3:]) + # never broadcast generated files. + train_args.broadcast_strategy = 'none' + train_args.checkpoint.resume_from = None + _distribute_checkpoint(train_args, from_, to_) + else: + raise ValueError(f"Unknown command: {argv[0]}") +else: + # we have patched too many things. + # please run this script with `python -m nnscaler.cli.checkpoint` + raise ImportError("checkpoint.py should be run as a script.") diff --git a/nnscaler/cli/mixed_module.py b/nnscaler/cli/mixed_module.py index d7354617..ef4268ca 100644 --- a/nnscaler/cli/mixed_module.py +++ b/nnscaler/cli/mixed_module.py @@ -19,6 +19,7 @@ TrainerArgs, PrecisionMixin, PolicyMixin, ModuleParallelizeConfig, ComputeConfig, load_type ) +from .serialization import Checkpointer logger = logging.getLogger(__name__) @@ -41,7 +42,8 @@ class ModuleParallelizeConfigAdapter(PrecisionMixin, PolicyMixin): def __init__( self, trainer_args: TrainerArgs, parallel_module: Optional[ModuleParallelizeConfig] = None, - tracing_weights: Optional[dict[str, Any]] = None + tracing_weights: Optional[dict[str, Any]] = None, + checkpointer: Optional[Checkpointer] = None, ): """ Args: @@ -52,6 +54,7 @@ def __init__( self.trainer_args = trainer_args self.parallel_module = parallel_module self.tracing_weights = tracing_weights + self.checkpointer = checkpointer or Checkpointer() # we don't want to load the tracing weights every time # It should be loaded only once outside, and passed to the adapter @@ -132,10 +135,10 @@ def load_tracing_weights(self) -> Optional[dict[str, Any]]: # try to reuse the weights from the tracing weights tracing_weights = self.tracing_weights if self.tracing_from_weights and tracing_weights is None: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) else: if self.tracing_from_weights: - tracing_weights = torch.load(self.tracing_from_weights) + tracing_weights = self.checkpointer.load(self.tracing_from_weights) elif self.parallel_module.tracing_from_weights_prefix: leading_key = self.parallel_module.tracing_from_weights_prefix + '.' tracing_weights = {} @@ -166,21 +169,29 @@ def create_model(self, module_args: Optional[tuple[tuple, dict]]=None) -> torch. def create_dummy_forward_args(self, dummy_input) -> dict[str, Any]: if self.parallel_module: - return self.fix_input( + forward_args = self.fix_input( self.parallel_module.create_dummy_forward_args(self.trainer_args) ) - - # forward args of whole model - arg_names = list( - inspect.signature( - inspect.unwrap(getattr(self.model_type, 'forward')) - ).parameters.keys() - ) - return {arg_names[1]: self.fix_input(dummy_input)} # arg_names[0] is self + if self.parallel_module.forward_args_post_process_fn: + forward_args = self.parallel_module.forward_args_post_process_fn(self.trainer_args, forward_args) + return forward_args + else: + # forward args of whole model + arg_names = list( + inspect.signature( + inspect.unwrap(getattr(self.model_type, 'forward')) + ).parameters.keys() + ) + # dummy input is already fixed and post processed by trainer + forward_args = {arg_names[1]: dummy_input} # arg_names[0] is self + return forward_args def resolve_compute_config(self): compute_config = copy.deepcopy(self.compute_config) - compute_config.pas_config['__pas_name'] = self.pas_policy + compute_config.pas_config['__pas_name'] = \ + self.pas_policy \ + if not callable(self.pas_policy) \ + else f'{self.pas_policy.__module__}.{self.pas_policy.__qualname__}' # autodist configs compute_config.pas_config['update_freq'] = self.trainer_args.update_freq compute_config.pas_config['use_bf16'] = self.param_dtype == torch.bfloat16 @@ -197,6 +208,7 @@ def resolve_compute_config(self): def parallelize(self, dummy_input: Optional[dict[str, Any]] = None, *, load_module: bool = True, + build_buckets: bool = True, module_args: Optional[tuple[tuple, dict]] = None ): pmodel_class = nnscaler.parallelize( @@ -212,7 +224,7 @@ def parallelize(self, load_module=load_module, ) if load_module: - return pmodel_class() + return pmodel_class(build_buckets=build_buckets) return pmodel_class @@ -279,24 +291,32 @@ def parameters_for_calc_gnorm(self): return model -def parallelize_model(trainer_args: TrainerArgs, dummy_input: dict[str, Any], load_module: bool): +def parallelize_model( + trainer_args: TrainerArgs, + dummy_input: dict[str, Any], + load_module: bool, + build_buckets: bool, + checkpointer: Checkpointer +): tracing_weights = None + checkpointer = checkpointer or Checkpointer() if trainer_args.tracing_from_weights: - tracing_weights = torch.load(trainer_args.tracing_from_weights) + tracing_weights = checkpointer.load(trainer_args.tracing_from_weights) def _new_adapter(parallel_module=None): return ModuleParallelizeConfigAdapter( trainer_args, parallel_module, - tracing_weights=tracing_weights + tracing_weights=tracing_weights, + checkpointer=checkpointer, ) if not trainer_args.model.parallel_modules: # parallelize the whole model - return _new_adapter().parallelize(dummy_input, load_module=load_module) + return _new_adapter().parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets) if not load_module and all(pm.args is not None for pm in trainer_args.model.parallel_modules): for m in trainer_args.model.parallel_modules: - _new_adapter(m).parallelize(dummy_input, load_module=False) + _new_adapter(m).parallelize(dummy_input, load_module=False, build_buckets=build_buckets) return parallel_sub_modules = { @@ -346,7 +366,7 @@ def __parallel__new__(cls, *args, **kwargs): # This is a trade-off to make sure the parallelized module is consistent. # Maybe we can use torch.distributed.broadcast to sync the random state in all devices. with fork_rng(): - return adapter.parallelize(dummy_input, load_module=load_module, module_args=(args, kwargs)) + return adapter.parallelize(dummy_input, load_module=load_module, build_buckets=build_buckets, module_args=(args, kwargs)) finally: _patch_new() diff --git a/nnscaler/cli/serialization.py b/nnscaler/cli/serialization.py new file mode 100644 index 00000000..3fb281f9 --- /dev/null +++ b/nnscaler/cli/serialization.py @@ -0,0 +1,483 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, Callable, Protocol, Type +from pathlib import Path +import shutil +import time +import logging + +import torch + +from nnscaler.runtime.serialization import load, save + + +logger = logging.getLogger(__name__) + + +class _LoadProc(Protocol): + def __call__(self, f: str | Path, *, device='cpu') -> Any: ... + + +class _SaveProc(Protocol): + def __call__(self, obj: Any, f: str | Path) -> None: ... + + +class CheckpointFormat(Protocol): + """ + A placeholder class for new serialization formats. + """ + name: str + suffix: str + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + ... + + @classmethod + def save(cls, obj: Any, f: str | Path) -> None: + ... + + +class SerializationRunner(Protocol): + name: str + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + ... + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + ... + + def flush(self) -> None: + """ + Flushes any pending operations for saving. + Loading operations are assumed to be synchronous. + """ + ... + + +class _DefaultSerializationRunner: + name: str = '' + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return load_func(f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + save_func(obj, f) + + def flush(self) -> None: + pass + + +def make_hybrid_serialization_runner( + load_serializer: Type[SerializationRunner], + save_serializer: Type[SerializationRunner] +) -> Type[SerializationRunner]: + """ + Creates a hybrid serialization runner that uses different runners for loading and saving. + """ + class HybridSerializationRunner(SerializationRunner): + name = f"{load_serializer.name}:{save_serializer.name}" + + def __init__(self, load_args=None, save_args=None): + self._load_runner = load_serializer(**(load_args or {})) + self._save_runner = save_serializer(**(save_args or {})) + + def run_load( + self, + load_func: _LoadProc, + f: str | Path, + *, + device='cpu' + ) -> Any: + return self._load_runner.run_load(load_func, f, device=device) + + def run_save( + self, + save_func: _SaveProc, + obj: Any, + f: str | Path + ) -> None: + self._save_runner.run_save(save_func, obj, f) + + def flush(self) -> None: + self._save_runner.flush() + + return HybridSerializationRunner + + +def _torch_load(f: str | Path, *, device='cpu') -> Any: + return torch.load(f, map_location=device, weights_only=False) + + +def _torch_save(obj: Any, f: str | Path) -> None: + torch.save(obj, f) + + +class Checkpointer: + # the format of the checkpoint file + # keys: epoch, step, rank + # currently it is not configurable + # TODO: make it configurable + CHECKPOINT_FILE_NAME_FORMAT: str = '{rank}{suffix}' + CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/' + CHECKPOINT_FILE_NAME_FORMAT + CHECKPOINT_LAST_DIR_NAME: str = 'last' + CHECKPOINT_BEST_DIR_NAME: str = 'best' + CHECKPOINT_MERGED_FILE_NAME: str = 'merged{suffix}' + CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}{suffix}' + CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}{suffix}' + NAME_MAP: dict[str, str] = { + 'pt': '.ckpt', + 'safetensors': '.safetensors' + } + SUFFIX_MAP: dict[str, str] = {v: k for k, v in NAME_MAP.items()} + # will use torch.load and torch.save for other suffixes + SUFFIX_HANDLERS: dict[str, tuple[_LoadProc, _SaveProc]] = { + '.safetensors': (load, save), + } + REGISTERED_RUNNERS: dict[str, Type[SerializationRunner]] = { + '': _DefaultSerializationRunner, + } + + def __init__(self, format: str = 'pt', serializer: str = None, serializer_args: dict[str, Any] = None): + """ + Args: + format (`str`, *optional*, defaults to `"pt"`): + The checkpoint format to use. Builtin formats are: + - `"pt"`: PyTorch checkpoint format. + - `"safetensors"`: Safetensors format. + serializer (`str`, *optional*): + The serialization runner to use. Builtin runners are: + - `""` (empty string): Default runner that directly uses the load and save functions. + You can also specify a hybrid runner by using the format `load_serializer:save_serializer`, + e.g., `"split:async"`. + serializer_args (`dict`, *optional*): + args for the serialization runner. + """ + if format not in self.NAME_MAP: + raise ValueError(f"Unsupported checkpoint format: {format}") + self.format = format + self.suffix = self.NAME_MAP[format] + + self.runner: SerializationRunner + serializer = serializer or '' + + if ':' in serializer: + parts = serializer.split(':') + if len(parts) != 2: + raise ValueError(f"Invalid hybrid serialization runner: {serializer}") + load_serializer_name = parts[0] + save_serializer_name = parts[1] + if load_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {load_serializer_name}") + if save_serializer_name not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {save_serializer_name}") + load_serializer_type = self.REGISTERED_RUNNERS[load_serializer_name] + save_serializer_type = self.REGISTERED_RUNNERS[save_serializer_name] + runner_cls = make_hybrid_serialization_runner( + load_serializer_type, + save_serializer_type + ) + else: + if serializer not in self.REGISTERED_RUNNERS: + raise ValueError(f"Unsupported serialization runner: {serializer}") + runner_cls = self.REGISTERED_RUNNERS[serializer] + + self.runner = runner_cls(**(serializer_args or {})) + + def get_checkpoint_file_path(self, epoch: int, step: int, rank: int) -> str: + return self.CHECKPOINT_FILE_FORMAT.format(epoch=epoch, step=step, rank=rank, suffix=self.suffix) + + def get_last_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_LAST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_best_checkpoint_file_path(self, rank: int) -> str: + return self.CHECKPOINT_BEST_FILE_FORMAT.format(rank=rank, suffix=self.suffix) + + def get_merged_checkpoint_file_name(self) -> str: + return self.CHECKPOINT_MERGED_FILE_NAME.format(suffix=self.suffix) + + def get_last_dir_name(self) -> str: + return self.CHECKPOINT_LAST_DIR_NAME + + def get_best_dir_name(self) -> str: + return self.CHECKPOINT_BEST_DIR_NAME + + def load(self, f: str | Path, *, device='cpu') -> Any: + """ + Loads a checkpoint file + + Args: + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be loaded as safetensors file. + otherwise, it will be loaded as a PyTorch checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. + """ + suffix = Path(f).suffix + if suffix in self.SUFFIX_HANDLERS: + load_func, _ = self.SUFFIX_HANDLERS[suffix] + else: + load_func = _torch_load + + return self.runner.run_load(load_func, f, device=device) + + def save(self, obj: Any, f: str | Path) -> None: + """ + Saves a checkpoint file + + Args: + obj (`Any`): + The object to save. + f: filename of the checkpoint file. + if the suffix is .safetensors, it will be saved as safetensors file. + otherwise, it will be saved as a PyTorch checkpoint file. + """ + suffix = Path(f).suffix + if suffix in self.SUFFIX_HANDLERS: + _, save_func = self.SUFFIX_HANDLERS[suffix] + else: + save_func = _torch_save + + self.runner.run_save(save_func, obj, f) + + def load_for_rank(self, dir: str | Path, rank: int, device='cpu') -> Any: + """ + Loads a checkpoint file for a specific rank + + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to load. + device (`str`, `int`, *optional*): + The device on which you want the tensors. + """ + for suffix in self.NAME_MAP.values(): + f = Path(dir) / f"{rank}{suffix}" + if f.exists(): + return self.load(f, device=device) + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {dir}") + + def save_for_rank(self, obj: Any, dir: str | Path, rank: int) -> None: + """ + Saves a checkpoint file for a specific rank + + Args: + obj (`Any`): + The object to save. + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to save. + """ + f = Path(dir) / self.CHECKPOINT_FILE_NAME_FORMAT.format(rank=rank, suffix=self.suffix) + self.save(obj, f) + + def remove_for_rank(self, dir: str | Path, rank: int) -> None: + """ + Removes a checkpoint file for a specific rank. + Args: + dir (`str`): + The directory where the checkpoint files are stored. + rank (`int`): + The rank of the checkpoint file to remove. + """ + suffixes = set(list(self.NAME_MAP.values()) + [self.suffix]) + for suffix in suffixes: + f = Path(dir) / f"{rank}{suffix}" + f.unlink(missing_ok=True) + for extra_file in Path(dir).glob(f"{rank}{suffix}.*"): + extra_file.unlink(missing_ok=True) + + def copy_for_rank(self, src: str | Path, dst: str | Path, rank: int, symlink: bool = False) -> None: + """ + Copies a checkpoint file for a specific rank from one directory to another. + Args: + src (`str`): + The source directory where the checkpoint files are stored. + dst (`str`): + The destination directory where the checkpoint files will be copied. + rank (`int`): + The rank of the checkpoint file to copy. + symlink (`bool`, *optional*, defaults to `False`): + Whether to create a symbolic link instead of copying the file. + """ + self.remove_for_rank(dst, rank) + + src = Path(src).resolve() + dst = Path(dst).resolve() + dst.mkdir(parents=True, exist_ok=True) + + src_f = Path(src) / f"{rank}{self.suffix}" + dst_f = Path(dst) / f"{rank}{self.suffix}" + + if not src_f.exists(): + raise FileNotFoundError(f"No checkpoint file found for rank {rank} in directory {src}") + + if symlink: + # this restricts symlink creation within the same directory + # so we can create relative symlink safely + if src.parent != dst.parent: + raise ValueError("Cannot create symlink when source and destination are not in the same directory.") + + if symlink: + self._create_symlink_with_retry(Path('..') / src.name / src_f.name, dst_f) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + self._create_symlink_with_retry(Path('..') / src.name / extra_file.name, dst_extra_file) + else: + shutil.copy2(src_f, dst_f) + for extra_file in src.glob(f"{rank}{self.suffix}.*"): + dst_extra_file = Path(dst) / extra_file.name + shutil.copy2(extra_file, dst_extra_file) + + @classmethod + def _create_symlink_with_retry(cls, src: str | Path, dst: str | Path) -> None: + dst = Path(dst) + dst.unlink(missing_ok=True) + + # deletion in blobfuse is not immediate sometimes + # so we retry until success + while True: + try: + dst.symlink_to(Path(src)) + break + except FileExistsError: + logger.warning(f"Creating symlink {dst} failed. Retrying...") + dst.unlink(missing_ok=True) + time.sleep(0.1) + + logger.info(f"Symlink {dst} created.") + + def list_checkpoints(self, dir: str | Path) -> list[Path]: + """ + List the main checkpoint files in a directory + Args: + dir (`str`): + The directory where the checkpoint files are stored. + Returns: + (`list[Path]`): + The list of checkpoint files in the directory. + """ + p = Path(dir) + files = [] + for suffix in self.NAME_MAP.values(): + fs = list(p.glob(f"*{suffix}")) + if fs: + if files: + raise ValueError(f"Mixed checkpoint file formats in directory {dir}") + else: + files.extend(fs) + return files + + def flush(self) -> None: + """ + Flushes any pending operations. + """ + self.runner.flush() + + @classmethod + def get_format(cls, suffix: str) -> str: + """ + Gets the format name from the suffix. + """ + suffix = '.' + suffix.lstrip('.') + if suffix not in Checkpointer.SUFFIX_MAP: + raise ValueError(f"Unsupported checkpoint suffix: {suffix}") + return Checkpointer.SUFFIX_MAP[suffix] + + +def register_format(format: Type[CheckpointFormat]) -> None: + """ + Registers a new serialization format. + """ + suffix = '.' + format.suffix.lstrip('.') + Checkpointer.NAME_MAP[format.name] = suffix + Checkpointer.SUFFIX_MAP[suffix] = format.name + Checkpointer.SUFFIX_HANDLERS[suffix] = (format.load, format.save) + + +def register_serialization_runner(runner: Type[SerializationRunner]) -> None: + """ + Register a new serialization runner, which can intercept the load and save process. + For example, file redirection, chunking, asynchronous IO or other logic. + + Please note if you create extra files during saving, + you must make sure + 1. the suffix of the main checkpoint file must match registered formats. + 2. the name of extra files should start with the main checkpoint file name + '.', + but the suffix should not conflict with registered formats. + + For example, if the input checkpoint file is `model.ckpt`, + you must create a file called 'model.ckpt', + and you can use extra file names like 'model.ckpt.1', 'model.ckpt.meta', 'model.ckpt.opt' etc. + """ + if ':' in runner.name: + raise ValueError("Serialization runner name cannot contain ':'") + Checkpointer.REGISTERED_RUNNERS[runner.name] = runner + + +def convert_format( + src: str | Path, + dst: str | Path, + *, + src_serializer: str = None, + src_serializer_args: dict = None, + dst_serializer: str = None, + dst_serializer_args: dict = None, + device: str = 'cpu' +) -> None: + """ + Converts a checkpoint file from one format to another. + + Args: + src (`str` or `Path`): + The input checkpoint file. + dst (`str` or `Path`): + The output checkpoint file. + src_serializer (`str`, *optional*): + The serialization runner of the input checkpoint file. + src_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the input checkpoint file. + dst_serializer (`str`, *optional*): + The serialization runner of the output checkpoint file. + dst_serializer_args (`dict`, *optional*): + The arguments for the serialization runner of the output checkpoint file. + device (`str`, *optional*, defaults to `"cpu"`): + The device on which you want the tensors. + """ + src_format = Checkpointer.get_format(Path(src).suffix) + dst_format = Checkpointer.get_format(Path(dst).suffix) + + if src_format == dst_format and src_serializer == dst_serializer: + raise ValueError("Input and output formats and serializers are the same, no conversion needed.") + + src_checkpointer = Checkpointer(format=src_format, serializer=src_serializer, serializer_args=src_serializer_args) + dst_checkpointer = Checkpointer(format=dst_format, serializer=dst_serializer, serializer_args=dst_serializer_args) + + obj = src_checkpointer.load(src, device=device) + dst_checkpointer.save(obj, dst) + dst_checkpointer.flush() diff --git a/nnscaler/cli/train_hook.py b/nnscaler/cli/train_hook.py index 76abeb8b..afccf2da 100644 --- a/nnscaler/cli/train_hook.py +++ b/nnscaler/cli/train_hook.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. from typing import Any, Dict, List, TYPE_CHECKING, Literal, TypedDict, Optional +from pathlib import Path import torch @@ -213,6 +214,18 @@ def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> checkpoint: the checkpoint to be saved """ + def on_expire_checkpoint(self, trainer: 'Trainer', step: int, checkpoint_dir: Path) -> None: + """ + Called before expiring (deleting) checkpoint. + If you want to do something before a checkpoint is deleted, you can do it here. + + Note: only local-rank 0 will call this hook. + + Args: + step: the overall training step of the checkpoint to be expired + checkpoint_dir: the directory that holds the checkpoint to be expired + """ + class AggregatedTrainHook(TrainHook): def __init__(self, hooks: List[TrainHook]): @@ -333,3 +346,35 @@ def after_load_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) def on_save_checkpoint(self, trainer: 'Trainer', checkpoint: Dict[str, Any]) -> None: for hook in self.hooks: hook.on_save_checkpoint(trainer, checkpoint) + + def on_expire_checkpoint(self, trainer: 'Trainer', step: int, checkpoint_dir: Path) -> None: + for hook in self.hooks: + hook.on_expire_checkpoint(trainer, step, checkpoint_dir) + + +class TrainHookHost: + def _get_hook_objects(self) -> List[Any]: + """ + Return a list of objects that can be hooks (but not necessarily hooks) + """ + ... + + def get_hooks(self) -> List[TrainHook]: + """ + Return a list of TrainHook objects + """ + hooks = {} + visited = set() + def _get_hooks(obj): + if id(obj) in visited: + return + visited.add(id(obj)) + + if isinstance(obj, TrainHook): + hooks[id(obj)] = obj + if isinstance(obj, TrainHookHost): + for o in obj._get_hook_objects(): + _get_hooks(o) + + _get_hooks(self) + return list(hooks.values()) diff --git a/nnscaler/cli/trainer.py b/nnscaler/cli/trainer.py index d167b95f..828498d7 100644 --- a/nnscaler/cli/trainer.py +++ b/nnscaler/cli/trainer.py @@ -22,28 +22,18 @@ from tqdm import tqdm import nnscaler -from nnscaler.utils import enforce_zero_num_worker, is_running_distributed +from nnscaler.runtime.device import DeviceGroup +from nnscaler.utils import broadcast_mixed_data, is_running_distributed from .trainer_args import AggregatedOutputs, TrainerArgs, fix_input -from .train_hook import AggregatedTrainHook, TrainHook +from .train_hook import AggregatedTrainHook, TrainHook, TrainHookHost from .mixed_module import parallelize_model, mixin_module +from .serialization import Checkpointer logger = logging.getLogger(__name__) -# the format of the checkpoint file -# keys: epoch, step, rank -# currently it is not configurable -# TODO: make it configurable -CHECKPOINT_FILE_FORMAT: str = '{epoch:04d}-{step:04d}/{rank}.ckpt' -CHECKPOINT_LAST_DIR_NAME: str = 'last' -CHECKPOINT_BEST_DIR_NAME: str = 'best' -CHECKPOINT_MERGED_FILE_NAME: str = 'merged.ckpt' -CHECKPOINT_LAST_FILE_FORMAT: str = 'last/{rank}.ckpt' -CHECKPOINT_BEST_FILE_FORMAT: str = 'best/{rank}.ckpt' - - @dataclass class TrainStatus: best_loss = float('inf') @@ -102,6 +92,7 @@ def __init__(self, self.max_train_steps = None self.loggers = [] self.hook = None + self.checkpointer = None # RNG states pending resume; reset to None after resuming self.rng_states_from_resume: dict[str, torch.Tensor] | None = None @@ -111,6 +102,8 @@ def run(self): if not self.train_args.compile_mode: self._train() finally: + if self.checkpointer: + self.checkpointer.flush() for stage in ['train', 'val', 'test']: if self.dataloader[stage] is not None and (close_fn := getattr(self.dataloader[stage], 'close', None)): close_fn() @@ -126,28 +119,16 @@ def run(self): def _fix_input(self, input): return fix_input(input, self.train_args.input_dtype) - def _load_dummy_input(self): - if dummy_sample_gen_fn := self.train_args.resolved_dummy_sample_gen_fn: - return dummy_sample_gen_fn(self.train_args) - - with enforce_zero_num_worker(DataLoader): - dataset = self.train_args.create_dataset('train') - dataloader = self.train_args.create_dataloader('train', dataset) - assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." - value = next(iter(dataloader)) - if close_fn := getattr(dataloader, 'close', None): - close_fn() - return value - def _setup(self): if is_running_distributed(): nnscaler.init() - if torch.distributed.get_rank() == 0: + if DeviceGroup().local_rank == 0: logging.getLogger().setLevel(logging.INFO) else: logging.getLogger().setLevel(logging.WARNING) self.train_args.init_env(self) + self.checkpointer = self.train_args.create_checkpointer() # make sure all ranks are synchronized after init_env if is_running_distributed(): @@ -156,10 +137,14 @@ def _setup(self): compile_only = self.train_args.compile_mode # load a dummy input from training dataset - self.dummy_input = self._load_dummy_input() - self.dummy_input = self._fix_input(self.dummy_input) + self.dummy_input = self.train_args.dummy_input - pmodel = parallelize_model(self.train_args, self.dummy_input, load_module=not compile_only) + pmodel = parallelize_model( + self.train_args, self.dummy_input, + load_module=not compile_only, + build_buckets=not self.train_args.should_delay_bucket_building(), + checkpointer=self.checkpointer, + ) if compile_only: return @@ -186,6 +171,17 @@ def _setup(self): self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE')) self.local_rank = int(os.environ.get('LOCAL_RANK')) self.node_rank = int(os.environ.get('GROUP_RANK')) + assert self.rank // self.local_world_size == self.node_rank + self.local_ranks = list( + range( + self.node_rank * self.local_world_size, + (self.node_rank + 1) * self.local_world_size + ) + ) + self.local_rank0 = self.local_ranks[0] + # create local process groups + for local_rank0 in range(0, self.world_size, self.local_world_size): + DeviceGroup().get_group(list(range(local_rank0, local_rank0 + self.local_world_size))) self.total_train_steps_per_epoch = len(self.dataloader['train']) // self.train_args.update_freq if len(self.dataloader['train']) % self.train_args.update_freq != 0: @@ -214,8 +210,9 @@ def _setup(self): # (see `train_args.optimizer.grad_reduction`` handling in `train_epoch`). # This is useful to avoid overflow when the gradients are large. def reducer_pre_hook(reducer, grad): - grad.div_(self.train_args.scaling_factor) + grad.div_(self.train_args.optimizer.grad_reduce_divisor or self.train_args.scaling_factor) self.optimizer.register_reducer_pre_hook(reducer_pre_hook) + # Currently we never pass `last_epoch` to its constructor self.lr_scheduler = self.train_args.create_lr_scheduler(self.optimizer) self.loggers = self.train_args.create_loggers() @@ -223,9 +220,20 @@ def reducer_pre_hook(reducer, grad): self.model, self.optimizer, self.lr_scheduler, + self.checkpointer, ] + component_hooks = [] + for component in supported_hook_components: + if isinstance(component, TrainHook): + component_hooks.append(component) + if isinstance(component, TrainHookHost): + component_hooks.extend(component.get_hooks()) + + # dedup hooks + component_hooks = list({id(hook): hook for hook in component_hooks}.values()) + self.hook = AggregatedTrainHook( - [x for x in supported_hook_components if isinstance(x, TrainHook)] + component_hooks + [self.train_args.create_hook()] ) @@ -235,8 +243,19 @@ def reducer_pre_hook(reducer, grad): self.hook.after_setup(self) @classmethod - def _merge_checkpoint(cls, checkpoint_files: List[str]): - state_dicts = [torch.load(f, map_location='cpu', weights_only=False) for f in checkpoint_files] + def _merge_checkpoint(cls, checkpoint_files: List[str], + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + ): + checkpointer = checkpointer or Checkpointer() + state_dicts = [] + for f in checkpoint_files: + state_dict = checkpointer.load(f) + if model_only: + # we pop optimizer state to save cpu memory + state_dict.pop('optimizer', None) + state_dicts.append(state_dict) for i in range(1, len(state_dicts)): if state_dicts[i]['train_args'] != state_dicts[0]['train_args']: raise ValueError(f"train_args in {checkpoint_files[i]} is different from {checkpoint_files[0]}") @@ -245,14 +264,16 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): module_state_dict, opt_state_dict = nnscaler.merge_state_dicts( [s['model'] for s in state_dicts], - [s['optimizer'] for s in state_dicts] + [s['optimizer'] for s in state_dicts] if not model_only else None, ) + if model_only: + return {'model': module_state_dict} train_args = copy.deepcopy(state_dicts[0]['train_args']) train_args['checkpoint']['save_type'] = 'merged' global_keys = { 'model', 'optimizer', 'train_args', - 'train_status', 'lr_scheduler', 'rank' + 'train_status', 'lr_scheduler', 'rank', 'nnscaler' } # for extra keys (including `dataloader` and `rng_states`), we will not merge them. # Intead we will collect them from all state_dicts @@ -271,77 +292,55 @@ def _merge_checkpoint(cls, checkpoint_files: List[str]): 'lr_scheduler': state_dicts[0].get('lr_scheduler', None), 'train_status': state_dicts[0]['train_status'], 'train_args': train_args, + 'nnscaler': state_dicts[0]['nnscaler'], **extra_keys, } return merged_state_dict - def _broadcast_merged_state_dict(self, state_dict: Dict[str, Any]): + def _broadcast_merged_state_dict( + self, + state_dict: Dict[str, Any], + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + ): """ Broadcast the merged state dict to all ranks. - We can't broadcast the whole state_dict at once, because it may be too large, and leads to OOM. - Here we will break the model and optimizer state_dict into smaller pieces and broadcast them one by one. - Please note we use `torch.distributed.broadcast_object_list` to broadcast the state_dict (including tensors inside). """ + dst_ranks = dst_ranks or list(range(torch.distributed.get_world_size())) + if src_rank not in dst_ranks or self.rank not in dst_ranks: + raise ValueError(f"src_rank and current rank must be in dst_ranks: {dst_ranks}") + pg = DeviceGroup().get_group(dst_ranks) - def _broadcast_keys(sdict: Dict[str, Any], set_keys=True): - if self.rank == 0: - state_keys = list(sdict.keys()) - else: - state_keys = None - state_key_list = [state_keys] - torch.distributed.broadcast_object_list(state_key_list, src=0) - state_keys = state_key_list[0] - if set_keys and self.rank != 0: - for key in state_keys: - sdict[key] = {} # assume the values are empty dicts - return state_keys - - def _broadcast_value(sdict, key): - if self.rank == 0: - value_list = [sdict[key]] - else: - value_list = [None] - torch.distributed.broadcast_object_list(value_list, src=0) - if self.rank != 0: - sdict[key] = value_list[0] - - def _broadcast_values(sdict, keys): - for key in keys: - _broadcast_value(sdict, key) - - if self.rank == 0: + if self.rank == src_rank: if state_dict is None: raise ValueError("state_dict should not be None in rank 0 when broadcasting") else: if state_dict is not None: raise ValueError("state_dict should be None in other ranks when broadcasting") - state_dict = {} - - state_keys = _broadcast_keys(state_dict) - - for skey in state_keys: - logger.info(f"Broadcasting {skey}.") - if skey == 'optimizer': - opt_keys = _broadcast_keys(state_dict['optimizer']) - opt_keys_without_state = [ - k for k in opt_keys if k != 'state' - ] - _broadcast_values(state_dict['optimizer'], opt_keys_without_state) - idxs = _broadcast_keys(state_dict['optimizer']['state']) - for idx in idxs: - idx_keys = _broadcast_keys(state_dict['optimizer']['state'][idx]) - _broadcast_values(state_dict['optimizer']['state'][idx], idx_keys) - elif skey == 'model': - model_keys = _broadcast_keys(state_dict['model']) - _broadcast_values(state_dict['model'], model_keys) - else: - _broadcast_value(state_dict, skey) - return state_dict + + return broadcast_mixed_data(state_dict, src_rank=src_rank, group=pg, device='cpu') @classmethod - def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str): - merged_state_dict = cls._merge_checkpoint(checkpoint_files) - torch.save(merged_state_dict, output_file) + def merge_checkpoint(cls, checkpoint_files: List[str], output_file: str, + *, + model_only: bool = False, + checkpointer: Optional[Checkpointer] = None, + serializer: Optional[str] = None, + serializer_args: Optional[dict[str, Any]] = None, + ): + if checkpointer is not None: + if serializer is not None or serializer_args is not None: + raise ValueError("serializer and serializer_args should not be specified when checkpointer is given") + else: + checkpointer = Checkpointer(serializer=serializer, serializer_args=serializer_args) + + merged_state_dict = cls._merge_checkpoint( + checkpoint_files, + model_only=model_only, + checkpointer=checkpointer, + ) + checkpointer.save(merged_state_dict, output_file) + checkpointer.flush() def _log_finalize(self): for logger in self.loggers: @@ -361,13 +360,21 @@ def _load_checkpoint(self): if not resume_from: return logger.info(f"Resuming from {resume_from}") + trimmed_broadcast_required = False + load_from_merged = False + if resume_from.is_file(): - resume_from = resume_from # when we load from merged checkpoint - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) - if convert_fn := self.train_args.checkpoint.resolved_convert_fn: - state_dict = convert_fn(state_dict) + # when we load from merged checkpoint + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory or self.local_rank == 0: + state_dict = self.checkpointer.load(resume_from) + if convert_fn := self.train_args.checkpoint.resolved_convert_fn: + state_dict = convert_fn(state_dict) + else: + state_dict = None else: - ckpt_files = list(resume_from.glob('*.ckpt')) + ckpt_files = self.checkpointer.list_checkpoints(resume_from) rank_ckpt_files = {int(f.stem): f for f in ckpt_files if f.stem.isdigit()} if set(rank_ckpt_files.keys()) != set(range(len(rank_ckpt_files))): raise ValueError(f"Checkpoint files in {resume_from} are not complete: {rank_ckpt_files.keys()}") @@ -378,24 +385,59 @@ def _load_checkpoint(self): if len(rank_ckpt_files) != self.world_size or self.train_args.checkpoint.resume_from.with_merged: # merge the checkpoint files from all ranks and broadcast to all ranks torch.distributed.barrier() - if self.rank == 0: + if self.local_rank == 0: logger.info(f"Merging checkpoint files from {resume_from}") - state_dict = self._merge_checkpoint(list(rank_ckpt_files.values())) + state_dict = self._merge_checkpoint(list(rank_ckpt_files.values()), checkpointer=self.checkpointer) else: state_dict = None - logger.info(f"Broadcasting merged checkpoint to all ranks.") - state_dict = self._broadcast_merged_state_dict(state_dict) - logger.info(f"Broadcasted merged checkpoint to all ranks.") + + load_from_merged = True + trimmed_broadcast_required = self.train_args.checkpoint.resume_from.save_memory + if not self.train_args.checkpoint.resume_from.save_memory: + logger.info(f"Broadcasting merged checkpoint to all ranks.") + state_dict = self._broadcast_merged_state_dict( + state_dict, src_rank=self.local_rank0, dst_ranks=self.local_ranks + ) + logger.info(f"Broadcasted merged checkpoint to all ranks.") else: - resume_from = resume_from / f'{self.rank}.ckpt' - state_dict = torch.load(resume_from, map_location='cpu', weights_only=False) + state_dict = self.checkpointer.load_for_rank(resume_from, self.rank) + if state_dict['train_args']['compute_config'] != asdict(self.train_args.compute_config): + logger.warning( + f"compute_config is changed, and loading checkpoint may fail. " + f"If it fails, please try with merged checkpoint." + ) + + if trimmed_broadcast_required: + logger.info("Broadcasting trimmed checkpoint to all ranks.") + state_dict = state_dict or {} + state_dict['model'], state_dict['optimizer'] = nnscaler.trimmed_broadcast_merged_state_dict( + self.model, + state_dict['model'] if self.local_rank == 0 else None, + self.optimizer, + state_dict['optimizer'] if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + remaining_state_dict = self._broadcast_merged_state_dict( + {k: v for k, v in state_dict.items() if k not in ('model', 'optimizer')} + if self.local_rank == 0 else None, + src_rank=self.local_rank0, + dst_ranks=self.local_ranks, + ) + if self.local_rank != 0: + state_dict.update(remaining_state_dict) + logger.info("Broadcasted trimmed checkpoint to all ranks.") + + # trimmed checkpoint is sharded + ckpt_save_type = 'sharded' + else: + # if it is not a well-formed state_dict (from third party) + # we will treat it as a merged state_dict + ckpt_save_type = state_dict.get('train_args', {}) \ + .get('checkpoint', {}) \ + .get('save_type', 'merged') self.hook.on_load_checkpoint(self, state_dict) - # if it is not a well-formed state_dict (from third party) - # we will treat it as a merged state_dict - ckpt_save_type = state_dict.get('train_args', {}) \ - .get('checkpoint', {}) \ - .get('save_type', 'merged') if ckpt_save_type == 'merged': # it is a merged state dict nnscaler.load_merged_state_dict( @@ -420,10 +462,11 @@ def _load_checkpoint(self): raise ValueError("lr_scheduler is not set in the current trainer") if self.lr_scheduler: self.lr_scheduler.load_state_dict(state_dict['lr_scheduler']) + if 'dataloader' in state_dict and state_dict['dataloader'] is not None: if not self._is_resumable_dataloader(): raise ValueError("dataloader is not resumable, but checkpoint contains dataloader state") - if ckpt_save_type == 'merged': + if load_from_merged: dataloader_states = state_dict['dataloader'] # only load dataloader state when all ranks have the same state # TODO: is this reasonable? @@ -443,7 +486,7 @@ def _load_checkpoint(self): self.train_status = TrainStatus(**state_dict['train_status']) # we don't resume rng states when loading merged checkpoint, - if ckpt_save_type != 'merged': + if not load_from_merged: self.rng_states_from_resume = state_dict.get('rng_states') # resumed in _global_batch_iterator() else: logger.warning("RNG states are not resumed when loading merged checkpoint.") @@ -514,7 +557,7 @@ def _save_checkpoint(self, loss): current_epoch -= 1 if checkpoint_config.save_type == 'sharded': - model_state_dict= self.model.state_dict() + model_state_dict = self.model.state_dict() optimizer_state_dict = self.optimizer.state_dict() elif checkpoint_config.save_type == 'deduped': model_state_dict, optimizer_state_dict = nnscaler.deduped_state_dict( @@ -541,50 +584,37 @@ def _save_checkpoint(self, loss): self.hook.on_save_checkpoint(self, state_dict) - ckpt_file = save_dir / CHECKPOINT_FILE_FORMAT.format( + ckpt_file = save_dir / self.checkpointer.get_checkpoint_file_path( epoch=current_epoch, step=self.train_status.finished_train_steps, rank=self.rank, ) logger.info(f"Saving checkpoint to {str(ckpt_file.parent)}") ckpt_file.parent.mkdir(parents=True, exist_ok=True) - torch.save(state_dict, ckpt_file) + self.checkpointer.save(state_dict, ckpt_file) # save last if checkpoint_config.save_last: logger.info(f"Saving checkpoint as the last checkpoint.") - last_file = save_dir / CHECKPOINT_LAST_FILE_FORMAT.format( - rank=self.rank + + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_last_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - last_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # remove the old symlink or file - if last_file.is_symlink() or last_file.exists(): - last_file.unlink() - # symblink as relative path - last_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # last_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, last_file) # save best if checkpoint_config.save_best and loss <= self.train_status.best_loss: logger.info(f"Best loss updated: {self.train_status.best_loss:.3f} -> {loss:.3f}") logger.info(f"Saving checkpoint as the best checkpoint.") - best_file = save_dir / CHECKPOINT_BEST_FILE_FORMAT.format( - epoch=current_epoch, - step=self.train_status.finished_train_steps, - rank=self.rank, + + self.checkpointer.copy_for_rank( + ckpt_file.parent, + save_dir / self.checkpointer.get_best_dir_name(), + self.rank, + checkpoint_config.symlink_best_and_last ) - best_file.parent.mkdir(parents=True, exist_ok=True) - if checkpoint_config.symlink_best_and_last: - # symblink as relative path - if best_file.is_symlink() or best_file.exists(): - best_file.unlink() - best_file.symlink_to(Path('..') / ckpt_file.parent.name / ckpt_file.name) - # best_file.symlink_to(ckpt_file) - else: - shutil.copy(ckpt_file, best_file) torch.distributed.barrier() # remove old checkpoints @@ -597,6 +627,14 @@ def _save_checkpoint(self, loss): torch.distributed.barrier() + @classmethod + def _get_dependent_dirs(cls, ckpt_dir): + target_dirs = set() + for p in Path(ckpt_dir).glob('*'): + if p.is_symlink(): + target_dirs.add(p.resolve().parent.name) + return target_dirs + def _expire_checkpoints(self): if not self.train_args.checkpoint.keep_last_n_checkpoints: # keep all return @@ -604,36 +642,46 @@ def _expire_checkpoints(self): save_dir = Path(self.train_args.checkpoint.save_dir) checkpoints = [ p.name for p in save_dir.glob('*') - if p.is_dir() and p.name not in [CHECKPOINT_BEST_DIR_NAME, CHECKPOINT_LAST_DIR_NAME] + if p.is_dir() and p.name not in [ + self.checkpointer.get_best_dir_name(), + self.checkpointer.get_last_dir_name() + ] ] if len(checkpoints) <= self.train_args.checkpoint.keep_last_n_checkpoints: return # (step, ckpt_name) pairs checkpoint_info = [(int(p.split('-')[1]), p) for p in checkpoints] + # map from ckpt_name to step + checkpoint_info_map = {p[1]: p[0] for p in checkpoint_info} checkpoint_info.sort() expire_list = [c[1] for c in checkpoint_info[:-self.train_args.checkpoint.keep_last_n_checkpoints]] - best_ckpt = save_dir / CHECKPOINT_BEST_DIR_NAME - last_ckpt = save_dir / CHECKPOINT_LAST_DIR_NAME + best_ckpt = save_dir / self.checkpointer.get_best_dir_name() + last_ckpt = save_dir / self.checkpointer.get_last_dir_name() for ckpt_dir in [best_ckpt, last_ckpt]: if not ckpt_dir.exists(): continue - for p in ckpt_dir.glob('*.ckpt'): - if p.is_symlink(): - ckpt_name = p.resolve().parent.name - if ckpt_name in expire_list: - expire_list.remove(ckpt_name) - logger.info('Keep old checkpoint `%s` because it is symbol linked in best or last.', ckpt_name) - break # just check the first file is enough + for ckpt_name in self._get_dependent_dirs(ckpt_dir): + if ckpt_name in expire_list: + expire_list.remove(ckpt_name) + logger.info('Keep old checkpoint `%s` because it is symbol linked in best or last.', ckpt_name) for ckpt_name in expire_list: logger.info('Removing old checkpoint: %s', ckpt_name) - shutil.rmtree(save_dir / ckpt_name) + self.hook.on_expire_checkpoint(self, checkpoint_info_map[ckpt_name], save_dir / ckpt_name) + try: + shutil.rmtree(save_dir / ckpt_name) + except FileNotFoundError: + # may have been removed by other processes (when the storage is shared) + pass + except Exception as e: + logger.warning('Error when expiring checkpoint `%s`: %s. Will try later.', ckpt_name, e) def _global_batch_iterator(self, num_skip_first=0, stage='train'): if stage == 'train': if self.dataloader_resumed or num_skip_first == 0: + logger.info(f'Trainer resumes dataloader directly.') # if the checkpoint stops at the end of an epoch, # the rng states must be resumed before creating iterator # because `DataLoader.__iter__()` uses the rng (dunno why), @@ -641,6 +689,7 @@ def _global_batch_iterator(self, num_skip_first=0, stage='train'): self._try_resume_rng_states() it = iter(self.dataloader[stage]) else: # dry run until reach the desired batch. + logger.info(f'Trainer try to resume dataloader for {stage} stage with {num_skip_first}.') it = iter(self.dataloader[stage]) for _ in range(num_skip_first * self.train_args.update_freq): _sample = next(it) @@ -773,7 +822,7 @@ def _train(self): torch.cuda.reset_peak_memory_stats() if self.train_status.finished_train_steps >= self.max_train_steps: - logger.info(f"Training is skipped: already done.") + logger.info(f"Training is skipped: already done, finished_train_steps={self.train_status.finished_train_steps} >= max_train_steps={self.max_train_steps}.") return start_epoch = self.train_status.finished_train_steps // self.total_train_steps_per_epoch @@ -948,7 +997,7 @@ def _train_epoch(self, epoch: int) -> None: self.hook.after_sync_grad(self) # scale gradients - multiplier = self.train_args.scaling_factor + multiplier = self.train_args.optimizer.grad_reduce_divisor or self.train_args.scaling_factor if self.train_args.optimizer.grad_reduction == 'sum': # do nothing. `multiplier` is already correct pass @@ -977,7 +1026,7 @@ def _train_epoch(self, epoch: int) -> None: step_stat.gnorm = step_stat.gnorm.item() # update parameters - step_stat.lr = self.optimizer.param_groups[0]['lr'] + step_stat.lr = self.optimizer.param_groups[0]['lr'] # only log the first group's lr self.hook.before_optimizer_step(self) self.optimizer.step() self.hook.after_optimizer_step(self) diff --git a/nnscaler/cli/trainer_args.py b/nnscaler/cli/trainer_args.py index 2fb99daa..bfa86cf7 100644 --- a/nnscaler/cli/trainer_args.py +++ b/nnscaler/cli/trainer_args.py @@ -3,7 +3,7 @@ from dataclasses import asdict, dataclass, field, replace import importlib -from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Union, TypeVar +from typing import Any, Callable, Dict, List, Literal, Optional, TYPE_CHECKING, Type, Union, TypeVar from typing_extensions import get_args from pathlib import Path import logging @@ -15,12 +15,12 @@ import torch import torch.utils import torch.utils.data -import torch.utils.data.dataloader +from torch.utils.data.dataloader import DataLoader import yaml import torch import nnscaler -from nnscaler.utils import fields, transform_recursively, load_type +from nnscaler.utils import enforce_zero_num_worker, fields, fn_field, transform_recursively, load_type, copy_dynamic from nnscaler.parallel import ComputeConfig, build_optimizer, ReuseType, BroadcastGenFilesStrategy, _PREDEFINED_POLICIES from nnscaler.runtime.module import ParallelModule @@ -32,6 +32,7 @@ ) from .loggers.logger_base import LoggerBase from .train_hook import TrainHook +from .serialization import Checkpointer if TYPE_CHECKING: from .trainer import Trainer @@ -123,9 +124,9 @@ def fix_input(input, input_dtype=None): return tuple(fix_input(v, input_dtype) for v in input) elif isinstance(input, torch.Tensor): if input.is_floating_point() and input_dtype is not None: - return input.to(input_dtype).cuda() + return copy_dynamic(input, input.to(input_dtype).cuda()) else: - return input.cuda() + return copy_dynamic(input, input.cuda()) return input @@ -238,9 +239,17 @@ class ModuleParallelizeConfig: # we can parallelize submodules instead of creating whole model. # This is useful sometimes. args: Optional[Dict[str, Any]] = None - # the full qualified name of the function to generate dummy forward args - # Its type should be `Callable[[TrainerArgs],Dict[str, Any]]` - forward_args_gen_fn: str = None + # the full qualified name of the function to generate dummy inputs for forward + # Its type should be `Callable[[TrainerArgs], dict[str, Any]]` + # where the output dict is the kwargs for forward function of the module + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + forward_args_gen_fn: Optional[Callable[['TrainerArgs'], dict[str, Any]]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy inputs for forward + # Note the tensors in the inputs have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, mark some dims of tensors as dynamic + # (you can do it in `forward_args_gen_fn` as well) + forward_args_post_process_fn: Optional[Callable[['TrainerArgs', dict[str, Any]], dict[str, Any]]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -289,8 +298,7 @@ def create_model(self, trainer_args: 'TrainerArgs', module_args: Optional[tuple[ return self.model_type(*args, **kwargs) def create_dummy_forward_args(self, trainer_args: 'TrainerArgs') -> dict[str, Any]: - forward_args_gen_fn = load_type(self.forward_args_gen_fn) - return forward_args_gen_fn(trainer_args) + return self.forward_args_gen_fn(trainer_args) @dataclass @@ -314,6 +322,7 @@ class OptimizerConfig: args: Dict[str, Any] = field(default_factory=dict) clip_gnorm: float = 0.0 + param_clss_fn: Optional[Callable[[str], Any]] = fn_field(default=None) # loss reduction method # mean: average the loss over all micro-batches # sum: sum the loss of all micro-batches @@ -328,6 +337,13 @@ class OptimizerConfig: # per-token-mean: average the gradients over all tokens # you must specify `aggregate_outputs_fn` and return the number of tokens grad_reduction: str = 'mean' + # the divisor applied to gradients before all-reduce. If not set, the default + # divisor is `runtime_ngpus / plan_ngpus`. We divide the gradients to avoid overflow. + # However, if the gradients are in high precision or the user has known the range of + # the gradients, he/she can set a smaller divisor to improve the accuracy. Note that + # the gradients will be recovered by multiplying the divisor after all-reduce and before + # optimizer step. + grad_reduce_divisor: Optional[float] = None # the function to aggregate the outputs from all micro-batches # inputs: (list of local outputs, torch group) # output: AggregateOutputs @@ -403,6 +419,28 @@ class ResumeOptions: # `None` means will load the sharded checkpoint files if the world size is not changed. # and will load merged checkpoint if the world size is changed. with_merged: Optional[bool] = None + # If the memory is limited, we can save memory by only loading merged state dict in GPU 0 of each node + # and broadcast trimmed state dict to other ranks in the same node + # although this will be slower + # Only used when resuming from a merged checkpoint. + save_memory: bool = True + + +@dataclass +class SerializerOptions: + # the serialization runner to be used + # It should be a name of registered SerializationRunners + name: str = '' + + # the full qualified name of the function to create the serialization runner + # Currently we do not support this way + # to make sure all serialization runners are registered and can be used in other places + # (like nnscaler.cli.Trainer.merge_checkpoint) + # type: str = None + + # arguments for the serialization runner + # Note You should be able to load for any arguments + args: Dict[str, Any] = field(default_factory=dict) @dataclass @@ -410,6 +448,19 @@ class CheckpointConfig: save_dir: str = './checkpoints' no_save: bool = False + # `"pt"`: PyTorch native format + # `"safetensors"`: Safetensors format + # You can also register new formats via `nnscaler.cli.serialization.register_format` + # or specify a custom format here by providing a CheckpointFormat subclass + format: str = 'pt' + + # the serialization runner to be used + # It should be a name of registered SerializationRunners + # If None, the default serializer will be used + serializer: Optional[SerializerOptions] = field(default=None, metadata={ + 'normalize': lambda x: {'name': x} if isinstance(x, str) else x + }) + # `"sharded"`: Each rank saves its shard of weights and optimizer states to a file. The checkpoint is # a folder with as many files as the world size. # `"deduped"`: Each rank saves its deduped shard of weights and optimizer states to a file. The checkpoint is @@ -452,6 +503,13 @@ def resolved_convert_fn(self) -> Optional[Callable[[Dict[str, Any]], Dict[str, A return load_type(self.resume_from.convert_fn) def __post_init__(self): + # backward compatibility + if isinstance(self.resume_from, str): + self.resume_from = ResumeOptions(checkpoint=self.resume_from) + + if isinstance(self.serializer, str): + self.serializer = SerializerOptions(name=self.serializer) + if self.resume_from and self.resume_from.checkpoint: if self.resume_from.checkpoint in ['last', 'best']: if not self.save_dir: @@ -468,6 +526,12 @@ def __post_init__(self): if not self.save_dir: raise ValueError("save_dir is required") + if self.format not in Checkpointer.NAME_MAP: + raise ValueError(f"Invalid format {self.format}") + + if self.serializer and self.serializer.name not in Checkpointer.REGISTERED_RUNNERS: + raise ValueError(f"Invalid Serialization runner {self.serializer.name}") + if self.every_n_epochs is not None and self.every_n_train_steps is not None: raise ValueError("Cannot specify both every_n_epochs and every_n_train_steps") if self.every_n_epochs is None and self.every_n_train_steps is None: @@ -551,6 +615,7 @@ class HookMapConfig: on_load_checkpoint: str = None after_load_checkpoint: str = None on_save_checkpoint: str = None + on_expire_checkpoint: str = None class ArgsTrainHook(TrainHook): @@ -595,9 +660,16 @@ class TrainerArgs(PrecisionMixin, PolicyMixin): # compile: compile the model but not training # run: compile and run the model run_mode: str = 'run' - # the full qualified name of the function to generate dummy sample for forward + # the full qualified name of the function to generate dummy sample # Its type should be `Callable[[TrainerArgs], Any]` - dummy_sample_gen_fn: str = None + # The tensors in the sample will be moved to GPU and converted to input_dtype by trainer. + dummy_sample_gen_fn: Optional[Callable[['TrainerArgs'], Any]] = fn_field(default=None) + # the full qualified name of the function to post process the dummy sample + # Note the tensors in the sample have been moved to GPU and converted to input_dtype + # But you can still further process the sample, + # for example, you can use this function to mark some dims of tensors as dynamic + # when you don't use `dummy_sample_gen_fn` or don't handle dynamic dims in it, + dummy_sample_post_process_fn: Optional[Callable[['TrainerArgs', Any], Any]] = fn_field(default=None) # the model state dict file for tracing. # It is only used in tracing to serve as the initial state dict of the model. tracing_from_weights: str = None @@ -723,6 +795,10 @@ def __post_init__(self): ) self._vars = self.create_kwarg(self.vars) + # will be initialized lazily + # because it is heavy, and may not be used in some cases + # and it looks weird to initialize it eagerly in __post_init__ + self._dummy_input = None @classmethod def from_cli(cls, argv: List[str]) -> 'TrainerArgs': @@ -811,12 +887,6 @@ def resolved_aggregate_outputs_fn(self): return None return load_type(self.optimizer.aggregate_outputs_fn) - @property - def resolved_dummy_sample_gen_fn(self): - if not self.dummy_sample_gen_fn: - return None - return load_type(self.dummy_sample_gen_fn) - @property def scaling_factor(self): return self.compute_config.runtime_ngpus // self.compute_config.plan_ngpus @@ -846,7 +916,7 @@ def init_env(self, trainer: 'Trainer'): init_env_fn = load_type(self.init_env_fn) init_env_fn(trainer) - def get_resolved_var(self, fqn: str) -> Any: + def get_resolved_var(self, fqn: str, *, default: Any = None) -> Any: """ Get a resolved variable from the vars dictionary. The fqn is a full qualified name of the variable, e.g. 'x.y.z'. @@ -855,18 +925,47 @@ def get_resolved_var(self, fqn: str) -> Any: var = self._vars for part in parts: if part not in var: - raise ValueError(f"Variable {fqn} not found in vars") + return default var = var[part] return var + @property + def dummy_input(self): + if self._dummy_input is None: + self._dummy_input = self._load_dummy_input() + self._dummy_input = fix_input(self._dummy_input, self.input_dtype) + if self.dummy_sample_post_process_fn: + self._dummy_input = self.dummy_sample_post_process_fn(self, self._dummy_input) + return self._dummy_input + + def _load_dummy_input(self): + if self.dummy_sample_gen_fn: + return self.dummy_sample_gen_fn(self) + + with enforce_zero_num_worker(DataLoader): + dataset = self.create_dataset('train') + dataloader = self.create_dataloader('train', dataset) + assert dataloader.num_workers == 0, "The dataloader must have `num_workers=0`." + value = next(iter(dataloader)) + if close_fn := getattr(dataloader, 'close', None): + close_fn() + return value + def create_model(self) -> torch.nn.Module: kwargs = self.create_kwarg(self.model.args) return self.model_type(**kwargs) + def should_delay_bucket_building(self) -> bool: + return self.optimizer.param_clss_fn is not None + def create_parallel_optimizer(self, parallel_model: torch.nn.Module): kwargs = self.create_kwarg(self.optimizer.args) optimizer_class = load_type(self.optimizer.type) - return build_optimizer(parallel_model, optimizer_class, self.compute_config, **kwargs) + return build_optimizer( + parallel_model, optimizer_class, self.compute_config, + self.optimizer.param_clss_fn, + **kwargs + ) def create_dataset(self, stage='train'): dataset_args = getattr(self.dataset, f'{stage}_args') @@ -947,3 +1046,12 @@ def create_hook(self) -> TrainHook: return ArgsTrainHook(hook_config) else: raise ValueError(f"Invalid hook_config {hook_config}") + + def create_checkpointer(self) -> Checkpointer: + if self.checkpoint.serializer: + return Checkpointer( + self.checkpoint.format, + self.checkpoint.serializer.name, + self.checkpoint.serializer.args + ) + return Checkpointer(self.checkpoint.format) diff --git a/nnscaler/codegen/emit.py b/nnscaler/codegen/emit.py index fc197b71..3f6e8652 100644 --- a/nnscaler/codegen/emit.py +++ b/nnscaler/codegen/emit.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import inspect from typing import Generator, Iterable, List, Any, Optional, Tuple, Dict import logging @@ -35,7 +36,10 @@ def __repr__(self): return self.name -def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: +def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, +) -> Any: """ Return repr-able value of a tensor or value. For tensor, return IRValue({prefix}{tensor.name}_{tensor.tid}) @@ -44,6 +48,7 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: Args: val (Any): tensor or non-tensor value prefix_attr (str): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs Returns: the val that can be repr safely """ @@ -51,20 +56,22 @@ def _safe_repr_value(val: Any, prefix_attr: Optional[str] = None) -> Any: return val if isinstance(val, IRObject): tensor_name = val.name + if strip_star: + tensor_name = tensor_name.lstrip('*') tensor_name = tensor_name.replace('.', '_') name = '_'.join([tensor_name, str(val.tid)]) if prefix_attr is not None and val.is_attr(): name = prefix_attr + name return IRValue(name) elif isinstance(val, slice): - return slice(_safe_repr_value(val.start, prefix_attr), _safe_repr_value(val.stop, prefix_attr), _safe_repr_value(val.step, prefix_attr)) + return slice(_safe_repr_value(val.start, prefix_attr, strip_star=strip_star), _safe_repr_value(val.stop, prefix_attr, strip_star=strip_star), _safe_repr_value(val.step, prefix_attr, strip_star=strip_star)) elif isinstance(val, dict): - return {_safe_repr_value(k, prefix_attr): _safe_repr_value(v, prefix_attr) for k, v in val.items()} + return {_safe_repr_value(k, prefix_attr, strip_star=strip_star): _safe_repr_value(v, prefix_attr, strip_star=strip_star) for k, v in val.items()} elif isinstance(val, list): - return [_safe_repr_value(v, prefix_attr) for v in val] + return [_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val] elif isinstance(val, tuple): # TODO: support subclasses of tuple, like torch.Size? - return tuple(_safe_repr_value(v, prefix_attr) for v in val) + return tuple(_safe_repr_value(v, prefix_attr, strip_star=strip_star) for v in val) elif isinstance(val, (int, str, bool, float, type(None), bytes, type(Ellipsis), torch.dtype)): return val elif isinstance(val, torch.device): @@ -89,7 +96,10 @@ class CodeEmission: def node_name(self, node: IRCell) -> str: return f"{node.name}{node.cid}" - def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: + def tensor_name(self, val: Any, prefix_attr: Optional[str] = None, + *, + strip_star: bool = True, + ) -> str: """ Return representation of a value or a tensor. For tensor, return the {prefix}{tensor.name}_{tensor.tid} @@ -98,10 +108,13 @@ def tensor_name(self, val: Any, prefix_attr: Optional[str] = None) -> str: Args: val (Any): tensor or non-tensor value prefix_attr (Optional[str]): prefix to the tensor name if the tensor is an attribute + strip_star (bool): whether to strip leading * for *args and **kwargs + You should set it to False when you want to generate code for + function arguments. Returns: representation of the val in str """ - return repr(_safe_repr_value(val, prefix_attr)) + return repr(_safe_repr_value(val, prefix_attr, strip_star=strip_star)) def complex_name(self, val: Any, prefix_attr: Optional[str]=None) -> str: """ @@ -225,8 +238,35 @@ def emit_fnode(self, node: IRFwOperation, runtime_devid: int, plan_ndevs: int, r emit_rule = self._emit_rules.map(signature) body = emit_rule(node, inputs, kwargs, runtime_devid, plan_ndevs, runtime_ndevs) + def _to_tuple_str(names: List[str]) -> str: + if len(names) == 1: + return f'({names[0]}, )' + return '(' + ', '.join(names) + ')' + + def _insert_hook(outputs=None, is_pre: bool=False, output_len: int = 0): + hook = node.pre_hook if is_pre else node.post_hook + if not hook: + return + module_path = inspect.getmodule(hook).__name__ + fsig = f'{module_path}.{hook.__name__}' + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + codes.append( + f'{fsig}(self, ' + + repr(node.hook_meta) + ', ' + + f"{_to_tuple_str(inputs)}, " + + f"dict({', '.join(kw_pairs)})" + + ('' if is_pre else ', ' + outputs) + + ')' + ) + + _insert_hook(is_pre=True) + if len(node.outputs()) == 0: codes.append(body) + _insert_hook(is_pre=False, outputs='None') else: irobj_path = {} def r(t, current_path): @@ -245,8 +285,12 @@ def r(t, current_path): if all(len(x) == 1 for x in irobj_path.values()): # if all IRObjects are leafs, we can directly assign the output outputs = [self.tensor_name(t) for t in node.outputs()] - outputs = ', '.join(outputs) - codes.append(f'{outputs} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) else: outputs = [] im_outputs = [] @@ -258,7 +302,12 @@ def r(t, current_path): im_ouptut = self.tensor_name(IRObject('im_output')) im_outputs.append(im_ouptut) outputs.append(im_ouptut) - codes.append(f'{", ".join(outputs)} = {body}') + outputs_str = ', '.join(outputs) + codes.append(f'{outputs_str} = {body}') + _insert_hook( + outputs=outputs_str if len(node.outputs()) == 1 else _to_tuple_str(outputs), + is_pre=False + ) for t, path in irobj_path.items(): if len(path) == 1: # immediate output, skip diff --git a/nnscaler/codegen/module/module.py b/nnscaler/codegen/module/module.py index 32cd75fe..c3bc8c34 100644 --- a/nnscaler/codegen/module/module.py +++ b/nnscaler/codegen/module/module.py @@ -8,8 +8,9 @@ import torch import numpy as np import inspect +import pickle -from nnscaler.ir.cten import IRCell +from nnscaler.ir.cten import IRCell, IRTensor from nnscaler.ir.tensor import IRFullTensor, IRSubTensor from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir.adapter import IRWeightReducer, IRAdapter @@ -23,7 +24,7 @@ from nnscaler.execplan.execplan import ExeReuseCell from nnscaler.codegen.syntax.symtable import SymbolTable -from nnscaler.codegen.syntax.blocks import ClassBlock, FunctionBlock, Block +from nnscaler.codegen.syntax.blocks import ClassBlock, ForBlock, FunctionBlock, Block from nnscaler.codegen.emit import FuncEmission from nnscaler.codegen.module.autograd import AutogradAdapterCodeGen @@ -126,6 +127,7 @@ def __init__( 'from pathlib import Path', 'import torch', 'import torch.utils.checkpoint as ckpt', 'import nnscaler', 'import nnscaler.flags', + 'import nnscaler.runtime.function', 'import _operator', 'from numpy import inf', 'import builtins', '', f'runtime_version = {runtime_version!r}', '', '' ] @@ -138,6 +140,18 @@ def __init__( # self.init_code.append('@torch.jit.script') self.init_code.append(op_impl) self.init_code += [''] + + # hooks + hook_imports = set() + for node in execplan.graph.select(ntype=IRFwOperation): + if node.pre_hook is not None: + hook_imports.add(inspect.getmodule(node.pre_hook).__name__) + if node.post_hook is not None: + hook_imports.add(inspect.getmodule(node.post_hook).__name__) + for modname in hook_imports: + self.init_code.append(f'import {modname}') + self.init_code += [''] + # module init code self.model_init_statements: List[str] = list() # module method bodies for forward computations, e.g. Segments, Adapters. @@ -317,7 +331,8 @@ def gen( *, as_parallel_module: bool = False, end2end_mode: bool = False, - forward_args: Optional[Dict[str, Any]] = None + forward_args: Optional[Dict[str, Any]] = None, + outfile_attr_meta_map: Optional[str] = None, ) -> str: """ Generate model implementation code based on the given graph. @@ -406,6 +421,7 @@ def forward(self, x, y=None, z=None): This is used only in parallel module. forward_args (Dict[str, Any]): argument names and their default values of forward function, if None, use node inputs. This is used only in parallel module. + outfile_attr_meta_map (str): output file path for parameter mapping. None if don't save Returns: generated code @@ -451,6 +467,7 @@ def forward(self, x, y=None, z=None): if k not in param_first_used_pos: param_first_used_pos[k] = (i, v) + attr_meta_map = {} # emit code for node in sequence: if isinstance(node, IRSegment): @@ -472,7 +489,7 @@ def forward(self, x, y=None, z=None): # emit node tensor declaration into `__init__` # typically it's about the `nn.Parameter` - self.init_attributes(node) + attr_meta_map.update(self.init_attributes(node)) # emit node code # codes : List[str] @@ -483,11 +500,15 @@ def forward(self, x, y=None, z=None): for t in node.inputs(): if isinstance(t, IRSubTensor): if not t.is_attr(): - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) else: - args.append(self.tensor_name(t)) + args.append(self.tensor_name(t, strip_star=False)) node_args.append(args) + if outfile_attr_meta_map: + with open(outfile_attr_meta_map, 'wb') as f: + pickle.dump(attr_meta_map, f) + # generate full code with ClassBlock( class_name='GenModel', @@ -499,6 +520,7 @@ def forward(self, x, y=None, z=None): if as_parallel_module: cb.insert_body(f'rank = {device}') # save rank in class level + cb.insert_body(f'world_size = {self.runtime_ndevs}') # save world size in class level # async_op, max_bucket_size_bytes and zero_use_reduce_scatter # parameters are for testing purpose # and will not expose to user @@ -506,15 +528,17 @@ def forward(self, x, y=None, z=None): args=[ 'self', 'init_params=True', - '*', + 'build_buckets=True', + '*args', f'async_op={CompileFlag.async_reducer}', f'max_bucket_size_bytes={CompileFlag.max_reducer_bucket}', f'zero_use_reduce_scatter={CompileFlag.zero_use_reduce_scatter}', + f'**kwargs', ] ) as ib: ib.insert_body(self.model_init_statements) ib.insert_body('') - ib.insert_body('self._post_init(init_params)') + ib.insert_body('self._post_init(init_params, build_buckets)') else: with FunctionBlock(func_name='__init__', args=['self']) as ib: ib.insert_body(self.model_init_statements) @@ -528,7 +552,10 @@ def forward(self, x, y=None, z=None): if isinstance(node, IRSegment): segment_idxs.append(idx) - with FunctionBlock(func_name=name, args=input_args) as fb: + saved_tensors_hooks_needed = isinstance(node, IRSegment) and CompileFlag.use_zero > 1 + func_name = name + '_impl' if saved_tensors_hooks_needed else name + + with FunctionBlock(func_name=func_name, args=input_args) as fb: fb.insert_body(forward_code) # generate output outputs = [self.tensor_name(t) for t in node.outputs()] @@ -541,6 +568,16 @@ def forward(self, x, y=None, z=None): cb.insert_body('@torch.jit.script_method') cb.insert_body(fb.code) + if saved_tensors_hooks_needed: + with FunctionBlock(func_name=name, args=input_args) as fb: + # call segment under save_params_hooks context + save_context_code = f'with self.save_params_hooks():' + with Block(save_context_code) as cblock: + cblock.insert_body(f'return self.{func_name}({", ".join(node_args[idx])})') + fb.insert_body(cblock.code) + cb.insert_body('') + cb.insert_body(fb.code) + if as_parallel_module: if not segment_idxs: raise RuntimeError("The graph has no segment, forward code cannot be generated.") @@ -627,8 +664,11 @@ def _get_resolved_arg(arg_name, default_value): outputs = self.return_name(node.outputs(), skip_attr=True) call_code = f'{outputs} = self.{self.node_name(node)}({", ".join(inputs)})' # be sure the user doesn't specify unused args. - for unused_arg in unused_args: - fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') + # but sometimes this can cause issues + # (for example, the value is used in an `if` condition in the original forward function), + # so we disable it for now. + # for unused_arg in unused_args: + # fb.insert_body(f'if {unused_arg} is not None: raise ValueError("{unused_arg} is not used in graph tracing, so it must be None when running forward.")') fb.insert_body(call_code) return_code = f'return {self.return_name_complex(self.execplan.graph.outputs())}' fb.insert_body(return_code) @@ -644,6 +684,11 @@ def emit_comm_groups(self): - `model_init_statements` """ sign = 'self.init_group(ranks={ranks})' + # create single rank communication group + self.model_init_statements.append('# single rank communication groups') + with ForBlock(var='rank', iters=f'range({self.runtime_ndevs})') as fb: + fb.insert_body(sign.format(ranks='[rank]')) + self.model_init_statements.extend(fb.code) # create communication group self.model_init_statements.append('# communication groups') for ranks in self.comm_groups: @@ -651,7 +696,7 @@ def emit_comm_groups(self): self.model_init_statements.append(code) self.model_init_statements.append(' ') - def init_attributes(self, node: IRCell): + def init_attributes(self, node: IRCell) -> dict[str, dict[str, Any]]: """ Emit tensor declaration code @@ -660,10 +705,18 @@ def init_attributes(self, node: IRCell): This method also populates `self.symbols : SymbolTable` to record the names of the variables for the tensors ever encountered. + + Returns: + dict[str, dict[str, Any]]: A mapping of tensor names to their attributes. """ + attr_meta_map = {} + self._init_attributes(node, attr_meta_map) + return attr_meta_map + + def _init_attributes(self, node: IRCell, attr_meta_map: Dict[str, Any]): psign = "self.register_parameter('{name}', torch.nn.Parameter(torch.empty({shape}, dtype={dtype})))" bsign = "self.register_buffer('{name}', torch.empty({shape}, dtype={dtype}), persistent={persistent})" - map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {full_shape}, {slicers}, {val_chunks})" + map_sign = "self.add_full_map('{attr}', {tid}, {is_param}, '{orig_name}', {shape}, {slicers}, {val_chunks})" if not isinstance(node, IRSegment): for itensor in node.inputs(): name = self.tensor_name(itensor, prefix_attr='self.') @@ -691,14 +744,24 @@ def init_attributes(self, node: IRCell): assert len(slicers) == 1 and slicers[0] == slice(0, 1), f"Unexpected slicers {slicers} for scalar tensor." slicers = '...' # Ellipsis slicer for scalar tensor, x[...] is equivalent to x val_chunks = itensor.valmap[1] - code = map_sign.format( - attr=self.tensor_name(itensor), + attr_name = self.tensor_name(itensor) + attr_props = dict( tid=itensor.parent.tid, is_param=itensor.is_param(), orig_name=itensor.parent.name, - full_shape=tuple(itensor.parent.origin_shape), - slicers=str(slicers), - val_chunks=val_chunks + shape=tuple(itensor.parent.origin_shape), # full tensor shape + slicers=slicers, + val_chunks=val_chunks, + ) + attr_meta_map[attr_name] = dict( + **attr_props, + dtype=itensor.dtype, + sub_shape=tuple(itensor.shape) + ) + + code = map_sign.format( + attr=attr_name, + **attr_props ) self.model_init_statements.append(code) self.model_init_statements.append('') @@ -710,7 +773,7 @@ def init_attributes(self, node: IRCell): self.symbols.create(self.tensor_name(output, prefix_attr='self.')) else: for sub_node in node.nodes(): - self.init_attributes(sub_node) + self._init_attributes(sub_node, attr_meta_map) return def init_reducer(self, @@ -874,12 +937,64 @@ def emit_context_manager(node: IRCell): code = "with " + ", ".join(ctx_managers) + ":" return code - def emit_node(node): + def emit_node(node, node_idx): node_code = [] # execute if isinstance(node, IRFwOperation): + param_inputs = [ + self.tensor_name(t, prefix_attr='self.') for t in node.iobjs() + if isinstance(t, IRTensor) and t.is_param() + ] + + # for multiref node under zero3, we need to clone the params to avoid in-place modification issue + if param_inputs and CompileFlag.use_zero > 1 and node.name == 'multiref': + _logger.warning(f'Node {node} is a multiref node with param inputs under ZeRO-3, ' + f'we set clone_level=1 to avoid in-place modification issue.') + node.kwargs['clone_level'] = 1 + code = self.emit_fnode(node, runtime_devid=runtime_devid, plan_ndevs=len(self.devices), runtime_ndevs=self.runtime_ndevs, prefix_attr='self.') - node_code += code + + if not param_inputs or CompileFlag.use_zero <= 1: + node_code += code + else: + activation_inputs = [ + self.tensor_name(t) for t in node.iobjs() + if isinstance(t, IRTensor) and not t.is_attr() and t.requires_grad + ] + activation_outputs = [ + self.tensor_name(t) for t in node.oobjs() + if isinstance(t, IRTensor) and t.requires_grad + ] + + # insert param prefetch before each fnode for zero3 + for t in param_inputs: + node_code.append(f'self.prefetch_param({t})') + # The backward hook here is not reliable, + # 1. there can be no activation input requiring grad, + # 2. some inputs may not be used. + # so, to maximize the chance of triggering backward hook + # let's hook to every input requiring grad + # We also add evict logic in AccumulateGrad hook in bucket implementation, + # which can make sure params are evicted after backward use. + for q in activation_inputs: + node_code.append(f'{q} = self.backward_postevict_param({q}, {t}, {node_idx})') + + node_code += code + + # insert zero param release after each fnode + for t in param_inputs: + node_code.append(f'self.postevict_param({t})') + + # insert backward hook for activation outputs to fetch params in backward + for t in activation_outputs: + # we don't know which activation output will be used in backward + # (DCE may not work 100% correctly), + # so we add hook to all activation outputs for all input params + for p in param_inputs: + node_code.append( + f'{t} = self.backward_prefetch_param({t}, {p}, {node_idx})' + ) + elif isinstance(node, IRAdapter): # for adapters inside an IRSegment, we don't apply async communication to it # as it is mostly in critical path. @@ -905,15 +1020,15 @@ def insert_codes_under_ctx(ctx_code, codes): node_codes = [] current_context_manager_code = "" current_codes = [] - for node in nodes: + for node_idx, node in enumerate(nodes): if has_op_context_info(node): new_context_manager_code = emit_context_manager(node) if current_context_manager_code != new_context_manager_code: node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - current_codes = emit_node(node) + current_codes = emit_node(node, node_idx) current_context_manager_code = new_context_manager_code else: - current_codes.extend(emit_node(node)) + current_codes.extend(emit_node(node, node_idx)) else: # Node without op context infortmation means it is inserted by nnscaler, not convert from original fx graph, # for example, multiref node and adapter node, currently for nodes inserted by nnscaler we have the following assumption: @@ -967,7 +1082,7 @@ def insert_codes_under_ctx(ctx_code, codes): # # TODO: all inserted nodes should have its op context field. node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) - node_codes += emit_node(node) + node_codes += emit_node(node, node_idx) current_codes = [] node_codes += insert_codes_under_ctx(current_context_manager_code, current_codes) diff --git a/nnscaler/customized_ops/__init__.py b/nnscaler/customized_ops/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/nnscaler/customized_ops/ring_attention/README.md b/nnscaler/customized_ops/ring_attention/README.md new file mode 100644 index 00000000..38fbec5f --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/README.md @@ -0,0 +1,219 @@ +# Ring Attention Implementation + +High-performance ring attention mechanisms for nnscaler, supporting multiple attention variants and distributed training. + +## ๐Ÿ“– Overview + +This module implements multiple efficient attention mechanisms designed to distribute computation evenly in long sequence processing: + +- **Ring Attention**: Standard ring attention supporting arbitrary sequence lengths +- **Ring Attention Variable Length**: Variable-length sequence optimized ring attention +- **Zigzag Attention**: Zigzag pattern ring attention optimized for causal attention + +All implementations are deeply integrated with nnscaler's parallel computing framework, supporting automatic distributed training. + +## ๐Ÿ—๏ธ Architecture Design + +``` +nnscaler/customized_ops/ring_attention/ +โ”œโ”€โ”€ __init__.py # Package import interface +โ”œโ”€โ”€ ring_attn.py # Standard ring attention +โ”œโ”€โ”€ ring_attn_varlen.py # Variable length ring attention +โ”œโ”€โ”€ zigzag_attn.py # Zigzag ring attention +โ”œโ”€โ”€ varlen_utils.py # Variable length utility functions +โ””โ”€โ”€ core/ # Core implementations + โ”œโ”€โ”€ ring_attn_implementation.py # Standard ring attention core + โ”œโ”€โ”€ ring_attn_varlen_implementation.py # Variable length core implementation + โ”œโ”€โ”€ zigzag_attn_implementation.py # Zigzag attention core implementation + โ””โ”€โ”€ utils.py # Common utility functions +``` + +## ๐Ÿš€ Quick Start + +### Standard Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +# Basic usage +output = wrap_ring_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size, (-1,-1) means global attention + softmax_scale=None, # Softmax scale factor, defaults to 1/sqrt(head_dim) + dropout_p=0.0 # Dropout probability +) +``` + +### Variable Length Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + +# Variable length sequence attention +output = wrap_ring_attn_varlen_func( + q, # [total_tokens, num_heads, head_dim] + k, # [total_tokens, num_heads, head_dim] + v, # [total_tokens, num_heads, head_dim] + cu_seqlens_q, # Cumulative sequence lengths [batch_size + 1] + cu_seqlens_k, # Cumulative sequence lengths [batch_size + 1] + bias=None, # Optional attention bias + causal=True, # Causal attention mask + window_size=(-1, -1), # Sliding window size + softmax_scale=None, # Softmax scale factor + dropout_p=0.0 # Dropout probability +) +``` + +### Zigzag Ring Attention + +```python +from nnscaler.customized_ops.ring_attention import wrap_zigzag_attn_func + +# Zigzag attention (causal attention only) +output = wrap_zigzag_attn_func( + q, # [batch_size, seq_len, num_heads, head_dim] + k, # [batch_size, seq_len, num_heads, head_dim] + v, # [batch_size, seq_len, num_heads, head_dim] + causal=True, # Must be True + window_size=(-1, -1), # Must be (-1, -1), sliding window not supported + softmax_scale=None, + dropout_p=0.0 +) +``` + +## ๐Ÿ”ง Core Features + +### Performance Optimization +- **Flash Attention integration**: Efficient implementation based on flash_attn +- **TransformerEngine support**: Automatic detection and usage of TE 2.2.0+ +- **CUDA kernel optimization**: GPU-optimized low-level implementations +- **Distributed friendly**: Seamless integration with torch.distributed + +### Flexible Configuration +- **Attention patterns**: Support for causal and non-causal attention +- **Sliding window**: Configurable local attention windows +- **GQA support**: Grouped Query Attention optimization +- **Custom scaling**: Flexible softmax scaling strategies + +## ๐Ÿงฎ Algorithm Principles + +### Ring Attention Mechanism + +Ring Attention decomposes attention computation into multiple blocks: + +1. **Sequence chunking**: Divide long sequences into blocks distributed across devices +2. **Ring communication**: Devices pass key/value blocks by all-gather and reduce-scatter +3. **Incremental computation**: Each device computes attention with received key/value blocks + +### Variable Length Optimization + +Special optimizations for variable length sequences: + +```python +# Cumulative sequence length example +cu_seqlens = [0, 128, 256, 512] # 3 sequences with lengths 128, 128, 256 +# Corresponding token tensor shape: [512, num_heads, head_dim] +``` + +### Zigzag Pattern + +Zigzag Attention uses a special communication pattern for higher efficiency in causal attention scenarios: + +- **Causal constraint**: Only supports causal=True cases +- **Optimized communication**: Ring communication optimized for causal masks +- **Memory friendly**: Further reduces unnecessary computation and communication + +## ๐Ÿ”— nnscaler Integration + +### Automatic Parallelization + +```python +from nnscaler.parallel import parallelize, ComputeConfig +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + +class AttentionModel(torch.nn.Module): + def forward(self, q, k, v): + return wrap_ring_attn_func(q, k, v, causal=True) + +# nnscaler automatically handles distribution +config = ComputeConfig( + plan_ngpus=4, + runtime_ngpus=4 +) +parallel_model = parallelize(model, config=config) +``` + +### Computation Graph Optimization + +nnscaler automatically provides: +- **Communication optimization**: Minimize inter-device communication overhead +- **Memory planning**: Optimize memory usage patterns +- **Operator fusion**: Fuse with other operators for optimization +- **Gradient synchronization**: Automatic gradient communication in backward pass + +## ๐Ÿงช Testing Framework + +Comprehensive test coverage ensures implementation correctness and performance: + +```bash +# Run all attention tests +pytest tests/customized_ops/ring_attn/ -v + +# Specific attention variant tests +pytest tests/customized_ops/ring_attn/test_ring_attn.py -v +pytest tests/customized_ops/ring_attn/test_ring_attn_varlen.py -v +pytest tests/customized_ops/ring_attn/test_zigzag_attn.py -v +``` + +### Test Types + +- **Correctness tests**: Compare outputs with standard attention +- **Multi-GPU scalability**: Behavior validation across different device counts +- **GQA compatibility**: Grouped Query Attention correctness +- **Sliding window**: Local attention pattern validation +- **Edge cases**: Stability testing under extreme conditions + +## ๐Ÿ› ๏ธ Development Guide + +### Adding New Attention Variants + +1. **Core implementation**: Add implementation file in `core/` directory +2. **Wrapper function**: Create corresponding wrap function +3. **Test coverage**: Add comprehensive test cases +4. **Documentation**: Update README and API documentation + +### Performance Optimization Tips + +- **TransformerEngine**: Install TE 2.2.0+ for optimal performance +- **CUDA version**: Use CUDA 11.8+ for latest optimizations +- **Memory configuration**: Adjust batch size and sequence length based on GPU memory +- **Communication optimization**: Use InfiniBand networks to reduce communication latency + +## ๐Ÿšจ Known Limitations + +### Ring Attention +- **alibi_slopes**: ALiBi positional encoding not currently supported +- **return_attn_probs**: Returning attention weights not supported + +### Zigzag Attention +- **causal**: Only supports causal attention (causal=True) +- **window_size**: Sliding window not supported (must be (-1,-1)) + +### General Limitations +- **Dynamic shapes**: Sequence length cannot change dynamically during training +- **Mixed precision**: May require special handling in certain configurations + +## ๐Ÿ“š References + +- **Ring Attention Paper**: [Ring Attention with Blockwise Transformers](https://arxiv.org/abs/2310.01889) +- **Flash Attention**: [FlashAttention: Fast and Memory-Efficient Exact Attention](https://arxiv.org/abs/2205.14135) +- **Llama3 Paper**: [The Llama3 Herd of Models](https://arxiv.org/pdf/2407.21783) +- **nnscaler Documentation**: [nnscaler Parallel Computing Framework](https://github.com/microsoft/nnscaler) +- **TransformerEngine**: [NVIDIA TransformerEngine](https://github.com/NVIDIA/TransformerEngine) + +--- + +**Note**: This implementation is optimized for large-scale distributed training. For single-GPU scenarios, standard Flash Attention is recommended for optimal performance. \ No newline at end of file diff --git a/nnscaler/customized_ops/ring_attention/__init__.py b/nnscaler/customized_ops/ring_attention/__init__.py new file mode 100644 index 00000000..e54f5bc1 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .ring_attn_varlen import wrap_ring_attn_varlen_func + +from .zigzag_attn import wrap_zigzag_attn_func + +from .ring_attn import wrap_ring_attn_func \ No newline at end of file diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py new file mode 100644 index 00000000..39a3885d --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_implementation.py @@ -0,0 +1,326 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import shuffle_input, recover_output, GlobalMemoryBuffer, get_default_args, all_gather, reduce_scatter + + +_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer() + + +def ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + up_q = q[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + + up_out, up_lse = forward(up_q, up_k, up_v, causal) + + down_q = q[:, block_len:] + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_out, down_lse = forward(down_q, down_k, down_v, causal) + + out = torch.cat([up_out, down_out], dim=1) + return out, up_lse, down_lse + + +def ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + up_lse, + down_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + block_len = q.size(1) // 2 + curr_rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + keep_idx = 2 * curr_rank + dual_rank = world_size - curr_rank - 1 + dual_send_idx = 2 * dual_rank + 1 + up_rank = min(keep_idx, dual_send_idx) + down_rank = max(keep_idx, dual_send_idx) + + dq = torch.zeros_like(q) + dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_dk") + dk_buffer.zero_() + dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_dv") + dv_buffer.zero_() + + up_q = q[:, :block_len] + up_out = out[:, :block_len] + up_dout = dout[:, :block_len] + if causal: + up_k = k[:, :(up_rank + 1) * block_len] + up_v = v[:, :(up_rank + 1) * block_len] + up_dk = dk_buffer[:, :(up_rank + 1) * block_len] + up_dv = dv_buffer[:, :(up_rank + 1) * block_len] + else: + up_k, up_v = k, v + up_dk, up_dv = dk_buffer, dv_buffer + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": up_dout, + "q": up_q, + "k": up_k, + "v": up_v, + "out": up_out, + "softmax_lse": up_lse, + "dq": dq[:, :block_len], + "dk": up_dk, + "dv": up_dv, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + down_q = q[:, block_len:] + down_out = out[:, block_len:] + down_dout = dout[:, block_len:] + # TODO: optimize the buffer allocation + down_dk_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(k.size(), k.dtype, "bwd_down_dk") + down_dk_buffer.zero_() + down_dv_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(v.size(), v.dtype, "bwd_down_dv") + down_dv_buffer.zero_() + if causal: + down_k = k[:, :(down_rank + 1) * block_len] + down_v = v[:, :(down_rank + 1) * block_len] + down_dk = down_dk_buffer[:, :(down_rank + 1) * block_len] + down_dv = down_dv_buffer[:, :(down_rank + 1) * block_len] + else: + down_k, down_v = k, v + down_dk, down_dv = down_dk_buffer, down_dv_buffer + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": down_dout, + "q": down_q, + "k": down_k, + "v": down_v, + "out": down_out, + "softmax_lse": down_lse, + "dq": dq[:, block_len:], + "dk": down_dk, + "dv": down_dv, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + dk_buffer.add_(down_dk_buffer) + dv_buffer.add_(down_dv_buffer) + + bsz = q.size(0) + if bsz == 1: + dim_size = list(k.size()) + dim_size[1] = dim_size[1] // world_size + dk = torch.empty(dim_size, dtype=k.dtype, device=k.device) + dv = torch.empty(dim_size, dtype=v.dtype, device=v.device) + dist.reduce_scatter_tensor(dk, dk_buffer, group=process_group) + dist.reduce_scatter_tensor(dv, dv_buffer, group=process_group) + else: + dk = reduce_scatter(dk_buffer, dim=1, process_group=process_group) + dv = reduce_scatter(dv_buffer, dim=1, process_group=process_group) + + return dq, dk, dv + + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, all gather k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, reduce scatter dk, dv +''' +class RingFlashAttnFunc(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + assert alibi_slopes is None + + bsz = q.size(0) + q = shuffle_input(to_send=q, process_group=group) + k = k.contiguous() + v = v.contiguous() + world_size = dist.get_world_size(group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + # torch.distributed._all_gather_base function requires that the k and v tensors are contiguous. + torch.distributed.all_gather_into_tensor(k_buffer, k, group=group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=group) + else: + k_buffer = all_gather(k, dim=1, process_group=group) + v_buffer = all_gather(v, dim=1, process_group=group) + + out, up_lse, down_lse = ring_flash_attn_forward( + group, + q, + k_buffer, + v_buffer, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, up_lse, down_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, up_lse, down_lse = ctx.saved_tensors + bsz = q.size(0) + world_size = dist.get_world_size(ctx.group) + dim_size = list(k.size()) + dim_size[1] = dim_size[1] * world_size + if bsz == 1: + k_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, k.dtype, "fwd_k") + v_buffer = _GLOBAL_MEMORY_BUFFER.get_tensor(dim_size, v.dtype, "fwd_v") + torch.distributed.all_gather_into_tensor(k_buffer, k, group=ctx.group) + torch.distributed.all_gather_into_tensor(v_buffer, v, group=ctx.group) + else: + k_buffer = all_gather(k, dim=1, process_group=ctx.group) + v_buffer = all_gather(v, dim=1, process_group=ctx.group) + + dq, dk, dv = ring_flash_attn_backward( + ctx.group, + dout, + q, + k_buffer, + v_buffer, + out, + up_lse, + down_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None diff --git a/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py new file mode 100644 index 00000000..709ffe09 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/ring_attn_varlen_implementation.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: Most of code is copied from project https://github.com/zhuzilin/ring-flash-attention + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import ( + _flash_attn_varlen_forward, + _flash_attn_varlen_backward, +) +from .utils import get_default_args, AllGatherComm as Comm + + +def llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens: torch.Tensor, causal: bool, rank: int, world_size: int +): + """ + Args: + cu_seqlens: torch.Tensor, the cu_seqlens of all the sequences across the ring process group. + + Returns: + cu_seqlens_q: torch.Tensor, the cu_seqlens of the q slice for this rank. + cu_seqlens_k: torch.Tensor, the cu_seqlens of the k slice that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + local_k_slice: slice, the slice of the k that the local q need. Note + that this may be longer than `total_seq_len // world_size`. + """ + total_length = cu_seqlens[-1].item() + assert total_length % world_size == 0, cu_seqlens + length_per_rank = total_length // world_size + left = torch.searchsorted(cu_seqlens, rank * length_per_rank) + right = torch.searchsorted(cu_seqlens, (rank + 1) * length_per_rank) + + # after this, cu_seqlens[left:right + 1] contains all the sequence for this rank + if cu_seqlens[left] != rank * length_per_rank: + left -= 1 + left = left.item() + right = right.item() + + # q is always the same. just calculate the cu_seqlens for the local slice + cu_seqlens_q = cu_seqlens[left : right + 1].clone() + cu_seqlens_q -= rank * length_per_rank + cu_seqlens_q[0] = 0 + cu_seqlens_q[-1] = length_per_rank + + cu_seqlens_k = cu_seqlens[left : right + 1].clone() + if causal: + # when causal, we hope + # - the last k seq is of the same length as the last q seq + slice_right = (rank + 1) * length_per_rank + cu_seqlens_k[-1] = slice_right + else: + # when not causal, we hope + # - the last k is full seq + slice_right = cu_seqlens[right].item() + + slice_left = cu_seqlens[left].item() + cu_seqlens_k -= slice_left + + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + local_k_slice = slice(slice_left, slice_right) + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, local_k_slice + + +def llama3_flash_attn_varlen_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + out_list = [] + lse_list = [] + + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + kv_buffer_copy = torch.empty_like(kv_buffer) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm = Comm(process_group) + + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + q_i = q[:, i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + + params = get_default_args(_flash_attn_varlen_forward).copy() + params.update( + { + "q": q_i, + "k": k_i, + "v": v_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": cur_alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_varlen_forward(**params) + if len(outputs) == 8: + out, _, _, _, _, lse, _, _ = outputs + else: + assert len(outputs) == 4 + out, lse, _, _ = outputs + out_list.append(out) + lse_list.append(lse) + + out = torch.cat(out_list, dim=1) + lse = torch.cat(lse_list, dim=-2) + return out, lse + + +def llama3_flash_attn_varlen_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + nheads = q.shape[1] + total_k, nheads_k, head_dim = k.shape + assert nheads_k % heads_k_stride == 0 + + world_size = dist.get_world_size(process_group) + kv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + kv_buffer_copy = torch.empty_like(kv_buffer) + + dkv_buffer = torch.empty( + (2, total_k * world_size, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + if heads_k_stride != nheads_k: + kv_contiguous_buffer = torch.empty( + (2, total_k, heads_k_stride, head_dim), + dtype=k.dtype, + device=k.device, + ) + + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + + comm = Comm(process_group) + + k_0 = k[:, :heads_k_stride].contiguous() + v_0 = v[:, :heads_k_stride].contiguous() + comm.all_gather(kv_buffer_copy[0], k_0) + comm.all_gather(kv_buffer_copy[1], v_0) + + for i in range(0, nheads_k, heads_k_stride): + dkv_buffer.zero_() + + q_slice = slice( + i * nheads // nheads_k, (i + heads_k_stride) * nheads // nheads_k + ) + q_i = q[:, q_slice] + dout_i = dout[:, q_slice] + out_i = out[:, q_slice] + dq_i = dq[:, q_slice] + if softmax_lse.dim() == 3: + lse_i = softmax_lse[:, q_slice].contiguous() + else: + lse_i = softmax_lse[q_slice] + + comm.wait() + kv_buffer, kv_buffer_copy = kv_buffer_copy, kv_buffer + + if i < nheads_k - heads_k_stride: + # all_gather the next kv slice + kv_slice_left = i + heads_k_stride + kv_slice_right = kv_slice_left + heads_k_stride + send_k = k[:, kv_slice_left:kv_slice_right].contiguous() + send_v = v[:, kv_slice_left:kv_slice_right].contiguous() + comm.all_gather(kv_buffer_copy[0], send_k) + comm.all_gather(kv_buffer_copy[1], send_v) + + k_i = kv_buffer[0][local_k_slice] + v_i = kv_buffer[1][local_k_slice] + dk_i = dkv_buffer[0][local_k_slice] + dv_i = dkv_buffer[1][local_k_slice] + + if alibi_slopes is None: + cur_alibi_slopes = None + else: + cur_alibi_slopes = alibi_slopes[i * nheads // nheads_k : (i + heads_k_stride) * nheads // nheads_k] + + params = get_default_args(_flash_attn_varlen_backward).copy() + params.update( + { + "dout": dout_i, + "q": q_i, + "k": k_i, + "v": v_i, + "out": out_i, + "softmax_lse": lse_i, + "dq": dq_i, + "dk": dk_i, + "dv": dv_i, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": cur_alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_varlen_backward(**params) + + if heads_k_stride != nheads_k: + # reduce_scatter needs contiguous buffer + dk_i = kv_contiguous_buffer[0] + dv_i = kv_contiguous_buffer[1] + else: + dk_i = dk + dv_i = dv + + dist.reduce_scatter_tensor(dk_i, dkv_buffer[0], group=process_group) + dist.reduce_scatter_tensor(dv_i, dkv_buffer[1], group=process_group) + + if heads_k_stride != nheads_k: + dk[:, i : i + heads_k_stride] = dk_i + dv[:, i : i + heads_k_stride] = dv_i + + return dq, dk, dv + + +class Llama3FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = llama3_flash_attn_varlen_forward( + group, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k) + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.heads_k_stride = heads_k_stride + ctx.local_k_slice = local_k_slice + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + dq, dk, dv = llama3_flash_attn_varlen_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.heads_k_stride, + ctx.local_k_slice, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + return (dq, dk, dv) + (None,) * 15 + + +def llama3_flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def llama3_flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return Llama3FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + heads_k_stride, + local_k_slice, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/nnscaler/customized_ops/ring_attention/core/utils.py b/nnscaler/customized_ops/ring_attention/core/utils.py new file mode 100644 index 00000000..6345b35b --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/utils.py @@ -0,0 +1,343 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This logger implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + +from typing import Optional, Tuple +from functools import reduce +import operator +import inspect +from functools import cache +import random + +import torch +import torch.distributed as dist + + +def set_seed(rank, seed=42): + seed = rank + seed + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def log(msg, a, rank0_only=False): + world_size = dist.get_world_size() + rank = dist.get_rank() + if rank0_only: + if rank == 0: + print( + f"{msg}: " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + return + + for i in range(world_size): + if i == rank: + if rank == 0: + print(f"{msg}:") + print( + f"[{rank}] " + f"max {a.abs().max().item()}, " + f"mean {a.abs().mean().item()}", + flush=True, + ) + dist.barrier() + + +def gen_head_anno(query_states, key_states, value_states, head_pos=2): + if query_states.shape[head_pos] != key_states.shape[head_pos]: + assert query_states.shape[head_pos] % key_states.shape[head_pos] == 0 + group_size = query_states.shape[head_pos] // key_states.shape[head_pos] + assert query_states.shape[head_pos] == value_states.shape[head_pos] * group_size + q_anno = f'(group_num {group_size})' + kv_anno = 'group_num' + else: + q_anno = kv_anno = 'num_heads' + return q_anno, kv_anno + + +# copied from project https://github.com/zhuzilin/ring-flash-attention +@cache +def _get_default_args(func): + spec = inspect.getfullargspec(func) + defaults = spec.defaults if spec.defaults is not None else () + padded_defaults = (None,) * (len(spec.args) - len(defaults)) + defaults + args = dict(zip(spec.args, padded_defaults)) + if "softcap" in args: + args["softcap"] = 0.0 + return args + + +def get_default_args(func): + if inspect.isfunction(func): + return _get_default_args(func) + else: + # Use the origin _init_fn in CustomOpDef + return _get_default_args(func._init_fn) + + +class AllGatherComm: + def __init__(self, group=None) -> None: + self.group = group + self.handles = [] + + def all_gather(self, output_tensor: torch.Tensor, input_tensor: torch.Tensor): + handle = dist.all_gather_into_tensor( + output_tensor, input_tensor, group=self.group, async_op=True + ) + self.handles.append(handle) + + def wait(self): + for handle in self.handles: + handle.wait() + self.handles = [] + + +# copy from megatron/core/utils.py +class GlobalMemoryBuffer: + """Global buffer to avoid dynamic memory allocations. + Caller should ensure that buffers of the same name + are not used concurrently.""" + + def __init__(self): + self.buffer = {} + + def get_tensor(self, tensor_shape, dtype, name): + required_len = reduce(operator.mul, tensor_shape, 1) + if ( + self.buffer.get((name, dtype), None) is None + or self.buffer[(name, dtype)].numel() < required_len + ): + self.buffer[(name, dtype)] = torch.empty( + required_len, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False + ) + + return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape) + + +def update_out_and_lse( + out: Optional[torch.Tensor], + lse: Optional[torch.Tensor], + block_out: torch.Tensor, + block_lse: torch.Tensor, + slice_=None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + @torch.jit.script + def _update_out_and_lse( + out: torch.Tensor, + lse: torch.Tensor, + block_out: torch.Tensor, + block_lse: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + block_out = block_out.to(torch.float32) + block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + + new_lse = lse + torch.log(1 + torch.exp(block_lse - lse)) + + out = torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out + + lse = new_lse + return out, lse + + if out is None: + if slice_ is not None: + raise RuntimeError("first update_out_and_lse should not pass slice_ args") + out = block_out.to(torch.float32) + lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1) + elif slice_ is not None: + slice_out, slice_lse = out[slice_], lse[slice_] + slice_out, slice_lse = _update_out_and_lse( + slice_out, slice_lse, block_out, block_lse + ) + out[slice_], lse[slice_] = slice_out, slice_lse + else: + out, lse = _update_out_and_lse(out, lse, block_out, block_lse) + return out, lse + + +class RingComm: + def __init__(self, process_group: dist.ProcessGroup): + self._process_group = process_group + self._ops = [] + self.rank = dist.get_rank(self._process_group) + self.world_size = dist.get_world_size(self._process_group) + self._reqs = None + + parts = self.world_size // 2 + self.ring_list = [] + for i in range(parts): + self.ring_list.extend([i, self.world_size - i - 1]) + + self.revert_rank = self.ring_list.index(self.rank) + + offset = ((dist.get_rank() // self.world_size) * self.world_size) + self.send_rank = self.ring_list[(self.revert_rank + 1) % self.world_size] + offset + self.recv_rank = self.ring_list[(self.revert_rank - 1) % self.world_size] + offset + + def send_recv( + self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if recv_tensor is None: + res = torch.empty_like(to_send) + else: + res = recv_tensor + + send_op = dist.P2POp( + dist.isend, to_send, self.send_rank, group=self._process_group + ) + recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group) + self._ops.append(send_op) + self._ops.append(recv_op) + return res + + def commit(self): + if self._reqs is not None: + raise RuntimeError("commit called twice") + self._reqs = dist.batch_isend_irecv(self._ops) + + def wait(self): + if self._reqs is None: + raise RuntimeError("wait called before commit") + for req in self._reqs: + req.wait() + self._reqs = None + self._ops = [] + + +def shuffle_input(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + # We must use outplace, otherwise it will raise error at backward due to inplace operations. + # We can not change to_send directly and create a new tensor to store the result. + to_send_f = torch.zeros_like(to_send) + + # assume the input sequence length is 8, and computation runs on 4 GPUs + # the seq is represented as [0 1 2 3 4 5 6 7], world size is 4 + # the input status before `shuffle_input` is + # - gpu A: [0 1] + # - gpu B: [2 3] + # - gpu C: [4 5] + # - gpu D: [6 7] + # the value of `to_send_slice` is + # - gpu A: [1] + # - gpu B: [3] + # - gpu C: [5] + # - gpu D: [7] + block_seq_len = to_send.shape[1] // 2 + to_send_slice = to_send[:, block_seq_len:].contiguous() + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + res = torch.zeros_like(to_send_slice) + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + # rank src_rank + # 0 3 + # 1 2 + # 2 1 + # 3 0 + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: # D: 6 7, -> 1 6 + to_send_f[:, block_seq_len:] = to_send[:, :block_seq_len] + to_send_f[:, :block_seq_len, ...] = res + else: # A: 0 1, -> 0 7 + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len] + to_send_f[:, block_seq_len:, ...] = res + # after shuffle, the status of `to_send_f` + # GPU A: [0 7] + # GPU B: [2 5] + # GPU C: [3 4] + # GPU D: [1 6] + + return to_send_f + + +def recover_output(to_send: torch.Tensor, + process_group: dist.ProcessGroup = None): + + if not to_send.is_contiguous(): + to_send = to_send.contiguous() + + to_send_f = torch.zeros_like(to_send) + + block_seq_len = to_send.shape[1] // 2 + + rank = dist.get_rank(process_group) + world_size = dist.get_world_size(process_group) + + if rank >= world_size // 2: + to_send_slice = to_send[:, :block_seq_len, ...].contiguous() + else: + to_send_slice = to_send[:, block_seq_len:, ...].contiguous() + res = torch.zeros_like(to_send_slice) + + assert to_send_slice.is_contiguous() + assert res.is_contiguous() + + _ops = [] + offset = ((dist.get_rank() // world_size) * world_size) + src_rank = (world_size - rank - 1) % world_size + offset + send_op = dist.P2POp( + dist.isend, to_send_slice, src_rank, group=process_group + ) + recv_op = dist.P2POp( + dist.irecv, res, src_rank, group=process_group) + + _ops.append(send_op) + _ops.append(recv_op) + + response = dist.batch_isend_irecv(_ops) + for resp in response: + resp.wait() + + if rank >= world_size // 2: + to_send_f[:, :block_seq_len] = to_send[:, block_seq_len:, ...] + to_send_f[:, block_seq_len:] = res + else: + to_send_f[:, :block_seq_len] = to_send[:, :block_seq_len, ...] + to_send_f[:, block_seq_len:] = res + + return to_send_f.contiguous() + + +def all_gather(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + tensor = tensor.contiguous() if not tensor.is_contiguous() else tensor + world_size = dist.get_world_size(process_group) + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + tensor_list[torch.distributed.get_rank(process_group)] = tensor.data + torch.distributed.all_gather(tensor_list, tensor, group=process_group) + otensor = torch.concat(tuple(tensor_list), dim=dim) + return otensor + + +def reduce_scatter(tensor: torch.Tensor, dim: int, process_group: dist.ProcessGroup): + world_size = dist.get_world_size(process_group) + itensors = list(tensor.chunk(world_size, dim)) + for idx, t in enumerate(itensors): + itensors[idx] = t.contiguous() if not t.is_contiguous() else t + otensor = torch.empty_like(itensors[0], requires_grad=False) + torch.distributed.reduce_scatter(otensor, itensors, group=process_group) + return otensor diff --git a/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py new file mode 100644 index 00000000..7a643d59 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/core/zigzag_attn_implementation.py @@ -0,0 +1,516 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Credits: This implementation is inspired by project https://github.com/zhuzilin/ring-flash-attention + +import torch +import torch.distributed as dist +from flash_attn.flash_attn_interface import _flash_attn_forward, _flash_attn_backward +from .utils import RingComm, update_out_and_lse, shuffle_input, recover_output, get_default_args + +''' +Assume we have 4 GPUs A, B, C, D. +The sequence is represented as [0 1 2 3 4 5 6 7]. + +The P2P communication ring is A -> D -> B -> C -> A +The initial status of the attention computation is +X +X X +X X X +X X X X +X X X X X +X X X X X X +X X X X X X X +X X X X X X X X +Note: +- the computation in the diagonal is `causal=True` +- the computation in the off-diagonal is `causal=False` +We consider a `X` with `causal=True` as a unit computation block. +In this example, there are 4 steps. Each device is responsible for 2 unit computation blocks in each step. + +q status is same across all steps (q is not transmitted): +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] + +Step 0, kv status: +GPU A: [0 7] +GPU B: [2 5] +GPU C: [3 4] +GPU D: [1 6] +Computation status: +A +X D +X X B +X X X C +X X X C C +X X B X X B +X D X X X X D +A X X X X X X A + +Step 1, kv status: +GPU A: [3 4] +GPU B: [1 6] +GPU C: [2 5] +GPU D: [0 7] +Computation status: +X +D X +X B X +X X C X +X X C X X +X B X X X X +D X X X X X X +X X X A A X X X + +Step 2, kv status: +GPU A: [2 5] +GPU B: [0 7] +GPU C: [1 6] +GPU D: [3 4] +Computation status: +X +X X +B X X +X C X X +X C X X X +B X X X X X +X X X D D X X +X X A X X A X X + +Step 3, kv status: +GPU A: [1 6] +GPU B: [3 4] +GPU C: [0 7] +GPU D: [2 5] +Computation status: +X +X X +X X X +C X X X +C X X X X +X X X B B X +X X D X X D X +X A X X X X A X + +From this example, we can conclude the key insight of zigzag ring flash attention is: +- split the sequence into fine-grained blocks to achieve balance across steps and gpus +- schedule the computation in a zigzag pattern to minimize the communication overhead + +To be more specific, if the sequence length is L=4n, the total computation cost of flash attention +with causal=True is 1/2 L^2 = 8n^2. Each device needs to compute 4n. Each step needs to compute 2. + +Computation task assigned for each GPU: + +GPU 0: (0, 4n-1) +GPU 1: (2, 4n-3) +... +GPU n-1: (2n-2, 2n+1) +GPU n: (2n-1, 2n) +GPU n+1: (2n-3, 2n+2) +... +GPU 2n-1: (1, 4n-2) + +Dependence of kv (required kv range) for each device: +GPU 0: [0, 4n-1] +GPU 1: [0, 4n-3] +... +GPU n-1: [0, 2n+1] +GPU n: [0, 2n] +GPU n+1: [0, 2n+2] +... +GPU 2n-1: [0, 4n-2] + +In general, if there are 2n GPUs, the ring is 0 -> 2n-1 -> 1 -> 2n-2 -> ... -> n -> n+1 -> 0 + +For each device, the 2n steps is divided into 3 parts: +1. compute the local attention with `causal=True` +2. if current step is less or equal to its relative rank in the ring, select the first half + of the received kv to compute the attention with `causal=False`. In the example above, each + device computes to `left` of its corresponding rows in the status matrix. +3. if current step is greater than its relative rank in the ring, select the second half of + local q and full received kv to compute the attention with `causal=False`. In the example + above, each device fills the remaining part of its lower row in the status matrix. +''' + +def zigzag_ring_flash_attn_forward( + process_group, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): + assert causal == True, "zigzag ring is meaningless for causal=False" + comm = RingComm(process_group) + + block_seq_len = q.shape[1] // 2 + q1 = q[:, block_seq_len:] + + out = None + lse = None + next_k, next_v = None, None + + def forward(q, k, v, causal): + params = get_default_args(_flash_attn_forward).copy() + params.update( + { + "q": q, + "k": k, + "v": v, + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "return_softmax": True and dropout_p > 0, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + outputs = _flash_attn_forward(**params) + if len(outputs) == 8: + block_out, _, _, _, _, block_lse, _, _ = outputs + else: + assert len(outputs) == 4 + block_out, block_lse, _, _ = outputs + return block_out, block_lse + + for step in range(comm.world_size): + if step + 1 != comm.world_size: + next_k: torch.Tensor = comm.send_recv(k) + next_v: torch.Tensor = comm.send_recv(v) + comm.commit() + + if step == 0: + block_out, block_lse = forward(q, k, v, causal=True) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + elif step <= comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + block_out, block_lse = forward(q, k0, v0, causal=False) + out, lse = update_out_and_lse(out, lse, block_out, block_lse) + else: + block_out, block_lse = forward(q1, k, v, causal=False) + out, lse = update_out_and_lse( + out, + lse, + block_out, + block_lse, + slice_=(slice(None), slice(block_seq_len, None)), + ) + + if step + 1 != comm.world_size: + comm.wait() + k = next_k + v = next_v + + out = out.to(q.dtype) + lse = lse.squeeze(dim=-1).transpose(1, 2) + return out, lse + + +''' +In the backward pass, we assume q, k, v and out are saved in the shuffled order. +In addition, the backward pass requires a shuffled dout as input and generates +a shuffled dq, dk, dv as output. Note that out is a sum of all step outputs, so +we can directly pass dout to each step's backward block to compute the local gradient +according to the differiential chain rule. + +Similar to the forward pass, in the backward pass, the 2n steps are divided into 3 parts. + +Different from the forward pass, we need to communicate the gradient of kv in a ring as well. +To be more specific, each device calculates the local gradients of dq, dk, dv. In the following +steps, dq will be accumulated in the initial device, while dk and dv will be transmitted to the +next consumer device, then accumulated in the consumer device. In the end, the dk and dv will be +transmitted back to the initial device. + +In addition, to be compatible with the flash-attn's interface and reduce the precision loss, +we will accumulate and transmit the gradients in float32. They will be converted back to the +original dtype at the end of the backward pass. +''' +def zigzag_ring_flash_attn_backward( + process_group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale, + dropout_p=0, + causal=True, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, +): # pragma: no cover + assert causal == True, "zigzag ring is meaningless for causal=False" + kv_comm = RingComm(process_group) + d_kv_comm = RingComm(process_group) + dq, dk, dv = None, None, None + next_dk, next_dv = None, None + next_k, next_v = None, None + dk_comm_buffer, dv_comm_buffer = None, None + + dout1 = dout.chunk(2, dim=1)[1] + q1 = q.chunk(2, dim=1)[1] + out1 = out.chunk(2, dim=1)[1] + softmax_lse1 = softmax_lse.chunk(2, dim=2)[1].contiguous() + block_seq_len = q.shape[1] // 2 + + # repeatly allocating buffer may be slow... + dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) + dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) + dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) + + def backward(dout, q, k, v, out, softmax_lse, causal): + seqlen_q = q.shape[1] + seqlen_kv = k.shape[1] + + params = get_default_args(_flash_attn_backward).copy() + params.update( + { + "dout": dout, + "q": q, + "k": k, + "v": v, + "out": out, + "softmax_lse": softmax_lse, + "dq": dq_buffer[:, :seqlen_q], + "dk": dk_buffer[:, :seqlen_kv], + "dv": dv_buffer[:, :seqlen_kv], + "dropout_p": dropout_p, + "softmax_scale": softmax_scale, + "causal": causal, + "alibi_slopes": alibi_slopes, + "deterministic": deterministic, + } + ) + if "window_size" in params: + params.update({"window_size": window_size}) + else: + params.update( + { + "window_size_left": window_size[0], + "window_size_right": window_size[1], + } + ) + _flash_attn_backward(**params) + + for step in range(kv_comm.world_size): + if step + 1 != kv_comm.world_size: + next_k = kv_comm.send_recv(k) + next_v = kv_comm.send_recv(v) + kv_comm.commit() + + if step == 0: + backward(dout, q, k, v, out, softmax_lse, causal=True) + dq = dq_buffer.to(torch.float32) + dk = dk_buffer.to(torch.float32) + dv = dv_buffer.to(torch.float32) + else: + if step <= kv_comm.revert_rank: + k0 = k[:, :block_seq_len] + v0 = v[:, :block_seq_len] + backward(dout, q, k0, v0, out, softmax_lse, causal=False) + dq += dq_buffer + else: + backward(dout1, q1, k, v, out1, softmax_lse1, causal=False) + # always use the first half in dq_buffer. + dq[:, block_seq_len:] += dq_buffer[:, :block_seq_len] + + d_kv_comm.wait() + dk_comm_buffer, dv_comm_buffer = dk, dv + dk, dv = next_dk, next_dv + + if step <= kv_comm.revert_rank: + dk[:, :block_seq_len] += dk_buffer[:, :block_seq_len] + dv[:, :block_seq_len] += dv_buffer[:, :block_seq_len] + else: + dk += dk_buffer + dv += dv_buffer + + if step + 1 != kv_comm.world_size: + kv_comm.wait() + k = next_k + v = next_v + + next_dk = d_kv_comm.send_recv(dk, dk_comm_buffer) + next_dv = d_kv_comm.send_recv(dv, dv_comm_buffer) + d_kv_comm.commit() + + d_kv_comm.wait() + + return dq.to(q.dtype), next_dk.to(q.dtype), next_dv.to(q.dtype) + +''' +In nnscaler, sequence are stored in the initial order, e.g., [0 1 2 3 4 5 6 7]. +However, zigzag ring flash attention requires the sequence to be in the order of [0 7 2 5 3 4 1 6]. +As a result: +- in forward, we need to shuffle q, k, v and recover the out +- in backward, we need to shuffle dout and recover the dq, dk, dv +''' +class ZigZagRingFlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_softmax, + group, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + + assert alibi_slopes is None + q = shuffle_input(to_send=q, process_group=group) + k = shuffle_input(to_send=k, process_group=group) + v = shuffle_input(to_send=v, process_group=group) + k = k.contiguous() + v = v.contiguous() + out, softmax_lse = zigzag_ring_flash_attn_forward( + group, + q, + k, + v, + softmax_scale=softmax_scale, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=False, + ) + # this should be out_padded + ctx.save_for_backward(q, k, v, out, softmax_lse) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.group = group + out = recover_output(out, process_group=group) + return out if not return_softmax else (out, softmax_lse, None) + + @staticmethod + def backward(ctx, dout, *args): # pragma: no cover + dout = shuffle_input(to_send=dout, process_group=ctx.group) + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = zigzag_ring_flash_attn_backward( + ctx.group, + dout, + q, + k, + v, + out, + softmax_lse, + softmax_scale=ctx.softmax_scale, + dropout_p=ctx.dropout_p, + causal=ctx.causal, + window_size=ctx.window_size, + alibi_slopes=ctx.alibi_slopes, + deterministic=ctx.deterministic, + ) + dq = recover_output(dq, ctx.group) + dk = recover_output(dk, ctx.group) + dv = recover_output(dv, ctx.group) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def zigzag_ring_flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) + + +def zigzag_ring_flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + group=None, +): + return ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + group, + ) diff --git a/nnscaler/customized_ops/ring_attention/ring_attn.py b/nnscaler/customized_ops/ring_attention/ring_attn.py new file mode 100644 index 00000000..e7a8a4b8 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/ring_attn.py @@ -0,0 +1,113 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict +from torch import Tensor + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from .core.ring_attn_implementation import RingFlashAttnFunc +from .core.utils import gen_head_anno +from flash_attn import flash_attn_func + +from nnscaler.runtime.device import DeviceGroup + + +def wrap_ring_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the ring_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + assert alibi_slopes is None, "alibi_slopes is not supported in ring_attn_func" + assert return_attn_probs is False, "return_attn_probs is not supported in ring_attn_func" + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, softmax_scale=softmax_scale, causal=causal, window_size=window_size,) + return output + + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = RingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ) + + return output + + +def emit_ring(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_ring)(wrap_ring_attn_func) diff --git a/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py new file mode 100644 index 00000000..bb9ff54b --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/ring_attn_varlen.py @@ -0,0 +1,308 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict, Optional +import torch +from torch import Tensor +import torch.distributed as dist +import warnings + +from nnscaler.graph.parser.register import register_op +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir import IRTensor +from nnscaler.runtime.device import DeviceGroup +from flash_attn import flash_attn_varlen_func +from .core.ring_attn_varlen_implementation import llama3_flash_attn_prepare_cu_seqlens, llama3_flash_attn_varlen_func +from .core.utils import gen_head_anno +from .varlen_utils import shuffle_varlen, unshuffle_varlen + +# Try to import TransformerEngine with version check +_HAS_TRANSFORMER_ENGINE = False +_TE_VERSION_OK = False +attn_forward_func_with_cp = None + +try: + import transformer_engine + _HAS_TRANSFORMER_ENGINE = True + + # Check version - require 2.2.0+ + try: + from packaging import version + te_version = version.parse(transformer_engine.__version__) + required_version = version.parse("2.2.0") + _TE_VERSION_OK = te_version >= required_version + + if _TE_VERSION_OK: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + try: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + except ImportError: + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + else: + warnings.warn( + f"TransformerEngine version {transformer_engine.__version__} is too old. " + f"Require 2.2.0+. Falling back to basic ring attention implementation." + ) + except ImportError: + # packaging not available, try to import anyway + try: + # Try different import paths for different versions + try: + # For v2.5.0+ + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import attn_forward_func_with_cp + except ImportError: + # For v2.2.0-v2.4.x + from transformer_engine.pytorch.attention import attn_forward_func_with_cp + _TE_VERSION_OK = True + except (ImportError, AttributeError): + warnings.warn( + "TransformerEngine attention module not available or incompatible. " + "Falling back to basic ring attention implementation." + ) + +except ImportError: + warnings.warn( + "TransformerEngine not found. Falling back to basic ring attention implementation. " + "For better performance with context parallelism, install TransformerEngine 2.2.0+." + ) + + +def get_transformer_engine_info() -> Dict[str, any]: + """Get information about TransformerEngine availability and version.""" + return { + "has_transformer_engine": _HAS_TRANSFORMER_ENGINE, + "version_ok": _TE_VERSION_OK, + "has_cp_function": attn_forward_func_with_cp is not None, + "version": getattr(transformer_engine, "__version__", None) if _HAS_TRANSFORMER_ENGINE else None, + "required_version": "2.2.0+", + } + + +def print_transformer_engine_status(): + """Print TransformerEngine status for debugging.""" + info = get_transformer_engine_info() + print("TransformerEngine Status:") + print(f" - Available: {info['has_transformer_engine']}") + if info['has_transformer_engine']: + print(f" - Version: {info['version']}") + print(f" - Version OK (>= 2.2.0): {info['version_ok']}") + print(f" - CP Function Available: {info['has_cp_function']}") + else: + print(f" - Required Version: {info['required_version']}") + print(f" - Will use TE CP: {info['has_transformer_engine'] and info['version_ok'] and info['has_cp_function']}") + + +def wrap_ring_attn_varlen_func( + q: Tensor, + k: Tensor, + v: Tensor, + cu_seqlens_q: Tensor, + cu_seqlens_k: Tensor, + alibi_slopes: Tensor, + dropout_p: float = 0.0, + softmax_scale: Tensor = None, + causal: bool = False, + window_size: Tuple[int] = (-1, -1), + deterministic: bool = False, + return_attn_probs: bool = False, + process_group: Tuple[int] = None, +): + ''' + wrap the ring_attn_varlen_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_varlen_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + assert not return_attn_probs, "return_attn_probs is not supported in ring-attention" + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + if process_group is None or len(process_group) == 1: + output = flash_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + ) + return output + + assert len(q.shape) == 3, "q must have shape [total_q, qh, dim]" + assert len(k.shape) == 3, "k must have shape [total_k, kh, dim]" + assert len(v.shape) == 3, "v must have shape [total_k, vh, dim]" + total_q, qheads, qdim = q.shape + total_k, kheads, kdim = k.shape + total_v, vheads, vdim = v.shape + assert total_q == total_k == total_v, "total_q, total_k and total_v must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + local_rank = dist.get_rank(local_process_group) + local_world_size = dist.get_world_size(local_process_group) + assert local_world_size == len(process_group), "local_world_size should be the same with process_group size" + + if local_process_group is None: + local_process_group = dist.group.WORLD + + if window_size == (-1, -1): + # Use TransformerEngine with context parallelism if available and version is OK + if _HAS_TRANSFORMER_ENGINE and _TE_VERSION_OK and attn_forward_func_with_cp is not None: + shuffled_q = shuffle_varlen(q, cu_seqlens_q, process_group, local_process_group) + shuffled_k = shuffle_varlen(k, cu_seqlens_k, process_group, local_process_group) + shuffled_v = shuffle_varlen(v, cu_seqlens_k, process_group, local_process_group) + + te_cu_seqlens_q = cu_seqlens_q.clone() + te_cu_seqlens_k = cu_seqlens_k.clone() + te_cu_seqlens_q = torch.cat( + [ + te_cu_seqlens_q, + torch.tensor([cu_seqlens_q[-1].item()], dtype=te_cu_seqlens_q.dtype, device=te_cu_seqlens_q.device) + ] + ) + te_cu_seqlens_k = torch.cat( + [ + te_cu_seqlens_k, + torch.tensor([cu_seqlens_k[-1].item()], dtype=te_cu_seqlens_k.dtype, device=te_cu_seqlens_k.device) + ] + ) + shuffled_output = attn_forward_func_with_cp( + True, + shuffled_q, + shuffled_k, + shuffled_v, + te_cu_seqlens_q, + te_cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + te_cu_seqlens_q, + te_cu_seqlens_k, + dropout_p, + local_process_group, + process_group, + # TODO: optimize the stream usage + torch.cuda.current_stream(), + "p2p", # "all_gather" version cannot work with thd format + qkv_format="thd", + attn_mask_type="padding_causal" if causal else "padding", + ) + output = unshuffle_varlen(shuffled_output, cu_seqlens_q, process_group, local_process_group) + return output + else: + # Fallback to basic ring attention implementation + warnings.warn( + "TransformerEngine not available or version incompatible. " + "Using basic ring attention implementation which may be slower." + ) + + ( + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + local_k_slice, + ) = llama3_flash_attn_prepare_cu_seqlens( + cu_seqlens_q, + causal=causal, + rank=local_rank, + world_size=local_world_size, + ) + + output = llama3_flash_attn_varlen_func( + q, + k, + v, + local_cu_seqlens_q, + local_cu_seqlens_k, + local_max_seqlen_q, + local_max_seqlen_k, + heads_k_stride=1, + local_k_slice=local_k_slice, + dropout_p=dropout_p, + causal=causal, + window_size=window_size, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=False, + group=local_process_group, + ) + + return output + + +def emit_ring(node: IRDimops, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate ring_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + remainder = runtime_devid % plan_ndevs + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [(i, f // s) for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + if partition_dims[0][0] == 0: # partition on sequence dim + # the synchronization should occur across scaleunits + num = partition_dims[0][1] + scale_unit_dev_ids = [local_rank + offset for local_rank in range(remainder // num * num, (remainder // num + 1) * num)] + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0][0] == 1: + # partition the head dim, use local flash_attn_func + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, cu_seqlens_q, cu_seqlens_k, alibi_slopes, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states, head_pos=1) + if isinstance(alibi_slopes, IRTensor): + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, {q_anno} -> l {q_anno} vd^' + else: + return f'l {q_anno} hd^, l {kv_anno} hd^, l {kv_anno} vd^, e^, e^, ? -> l {q_anno} vd^' + + +def input_gen_fn(node: IRDimops): + inputs = [] + device = torch.cuda.current_device() + seqlen = node.inputs()[0].shape[0] + for i, t in enumerate(node.inputs()): + if i < 3: # query, key, value + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + elif i in [3, 4]: # cu_seqlens + inputs.append(torch.Tensor([0, seqlen]).to(torch.int32).to(device)) + elif i == 5: # optional alibi_slopes + if isinstance(t, IRTensor): + inputs.append(torch.randn(t.shape, dtype=t.dtype, device=device, requires_grad=t.requires_grad)) + else: + inputs.append(None) + else: # other kwargs, use defaults + break + return tuple(inputs) + + +register_op(flash_attention_anno, emit_fn=emit_ring, input_gen_fn=input_gen_fn)(wrap_ring_attn_varlen_func) diff --git a/nnscaler/customized_ops/ring_attention/varlen_utils.py b/nnscaler/customized_ops/ring_attention/varlen_utils.py new file mode 100644 index 00000000..bdd1f127 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/varlen_utils.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Utilities for variable-length sequence processing in ring attention. +Contains shuffle and unshuffle functions for context parallel processing. +""" + +from typing import List +import torch +from torch import Tensor +import torch.distributed as dist +from nnscaler.runtime.adapter.nn import allgather_reducescatter + + +def shuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Shuffle tensor data for variable-length sequences in context parallel processing. + + Args: + t: Input tensor to shuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Shuffled tensor + """ + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) + + +def unshuffle_varlen(t: Tensor, cu_seqlens_padded: Tensor, cp_ranks: List[int], cp_group: dist.ProcessGroup) -> Tensor: + """ + Unshuffle tensor data to restore original variable-length sequence order. + This is the reverse operation of shuffle_varlen. + + Args: + t: Shuffled tensor to unshuffle (local portion from each rank) + cu_seqlens_padded: Cumulative sequence lengths (global) + cp_ranks: List of ranks in the context parallel group + cp_group: Process group for context parallel communication + + Returns: + Unshuffled tensor (local portion for each rank) + """ + # reverse operation of shuffle_varlen + cp_size = torch.distributed.get_world_size(group=cp_group) + assert cp_size > 1, "cp_size should be greater than 1" + cp_rank = torch.distributed.get_rank(group=cp_group) + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + sum_len = cu_seqlens_padded[-1].item() + + def process_tensor(val): + if val is None: + return val + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + cp_rank_slices = [] + for rank in range(cp_size): + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (rank * slice_size), + seq_start + ((rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - rank) * slice_size), + device=val.device, + ) + ) + perm = torch.cat(cp_rank_slices) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(sum_len, device=val.device) + + # Create a tensor to hold the unshuffled result + unshuffled = val.index_select(current_seq_dim, inv_perm) + local_tensor = torch.chunk(unshuffled, cp_size, dim=current_seq_dim)[cp_rank] + return local_tensor + + full_tensor = allgather_reducescatter(t, 0, cp_ranks) + return process_tensor(full_tensor) diff --git a/nnscaler/customized_ops/ring_attention/zigzag_attn.py b/nnscaler/customized_ops/ring_attention/zigzag_attn.py new file mode 100644 index 00000000..2373f9d0 --- /dev/null +++ b/nnscaler/customized_ops/ring_attention/zigzag_attn.py @@ -0,0 +1,117 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Tuple, List, Dict +import torch +from torch import Tensor +import torch.distributed + +from nnscaler.graph.parser.register import register_op +from nnscaler.ir.operator import IRFwOperation +from .core.zigzag_attn_implementation import ZigZagRingFlashAttnFunc +from .core.utils import gen_head_anno +from flash_attn import flash_attn_func + +import torch.distributed as dist +from nnscaler.runtime.device import DeviceGroup + + +def wrap_zigzag_attn_func(q: Tensor, k: Tensor, v: Tensor, softmax_scale: Tensor=None, + dropout_p: float=0.0, causal: bool=True, window_size: Tuple[int]=(-1, -1), + alibi_slopes: Tensor=None, deterministic: bool=False, + return_attn_probs: bool=False, + process_group: Tuple[int]=None) -> Tensor: + ''' + wrap the zigzag_attn_func to support the distributed training in nnScaler. + most of the arguments are the same as the original flash_attn_func. + `process_group` should be none in the user code since nnScaler accepts the + program defined for the single device and will automatically generate the + required communications. + ''' + + assert window_size == (-1, -1), "window_size is not supported in zigzag-attention" + assert not return_attn_probs, "return_attn_probs is not supported in zigzag-attention" + assert alibi_slopes is None, "alibi_slopes is not supported in zigzag-attention" + + if process_group is None or len(process_group) == 1: + # there is an additional checker for the `softmax_scale`, which is equivalent + # to the behavior of the original flash_attn_func. + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + output = flash_attn_func(q, k, v, 0.0, softmax_scale, causal) + return output + + assert causal == True, "zigzag_ring is meaningless for causal=False" + assert len(q.shape) == 4, "q must have shape [bs, ql, qh, dim]" + assert len(k.shape) == 4, "k must have shape [bs, kl, kh, dim]" + assert len(v.shape) == 4, "v must have shape [bs, vl, vh, dim]" + qbsz, qlen, qheads, qdim = q.shape + kbsz, klen, kheads, kdim = k.shape + vbsz, vlen, vheads, vdim = v.shape + assert qbsz == kbsz == vbsz, "batch size must be the same" + assert qlen == klen == vlen, "sequence length must be the same" + assert kheads == vheads, "number of k and v heads must be the same" + assert qheads % kheads == 0, "number of q heads must be a multiple of k heads" + assert qdim == kdim == vdim, "dimension must be the same" + + local_process_group = DeviceGroup().get_group(process_group) + + output = ZigZagRingFlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + alibi_slopes, + deterministic, + return_attn_probs, + local_process_group, + ).contiguous() + + return output + +def emit_zigzag(node: IRFwOperation, args: List[str], kwargs: Dict[str, str], runtime_devid: int, plan_ndevs: int, runtime_ndevs: int) -> str: + """Special rule to generate zigzag_attn node""" + + signature = node.signature + + offset = (runtime_devid // plan_ndevs) * plan_ndevs + scale_unit_dev_ids = [local_rank + offset for local_rank in range(plan_ndevs)] + + kw_pairs = list() + for key, val in kwargs.items(): + code = f'{key}={val}' + kw_pairs.append(code) + + sub_input = node.inputs()[0] + full_input = sub_input.parent + partition_dims = [i for i, (s, f) in enumerate(zip(sub_input.shape, full_input.shape)) if s != f] + assert len(partition_dims) <= 1, f"support no more than one partition dim, but got {partition_dims}" + if not partition_dims: + kw_pairs.append("process_group=None") + else: + # if the 'process_group' is None, we will use the local attention (flash_attn_func) + if partition_dims[0] == 0: # partition on batch dim + # partition the bsz dim, use local flash_attn_func + kw_pairs.append("process_group=None") + elif partition_dims[0] == 1: # partition on sequence dim + # the synchronization should occur across scaleunits + kw_pairs.append(f"process_group={scale_unit_dev_ids}") + elif partition_dims[0] == 2: + # partition on num_head dim + kw_pairs.append("process_group=None") + else: + raise ValueError(f'unsupported partition dim: {partition_dims[0]}') + + args = ", ".join(list(args) + kw_pairs) + return f"{signature}({args})" + + +def flash_attention_anno(query_states, key_states, value_states, *args, **kwargs) -> str: + q_anno, kv_anno = gen_head_anno(query_states, key_states, value_states) + return f'b l {q_anno} hd^, b l {kv_anno} hd^, b l {kv_anno} vd^ -> b l {q_anno} vd^' + + +register_op(flash_attention_anno, emit_fn=emit_zigzag)(wrap_zigzag_attn_func) diff --git a/nnscaler/flags.py b/nnscaler/flags.py index 77333987..af903b91 100644 --- a/nnscaler/flags.py +++ b/nnscaler/flags.py @@ -47,7 +47,7 @@ class CompileFlag: # use zero optimization on optimizer status. # to cooperate with zero, user needs to call `model.parameters_for_optimizer()` # to get parameters for optimizer, and `model.gather_params()` after `optimizer.step()` - use_zero = _to_bool('USE_ZERO') + use_zero = _to_int('USE_ZERO') # use async communication to overlap gradient synchronization and backward computation async_reducer = _to_bool('ASYNC_REDUCER') # use async reducer # maximal reducer weight bytes for one allreduce (only effective for async): diff --git a/nnscaler/graph/function/dimops.py b/nnscaler/graph/function/dimops.py index 333b01ad..827c3ed2 100644 --- a/nnscaler/graph/function/dimops.py +++ b/nnscaler/graph/function/dimops.py @@ -72,7 +72,7 @@ import logging from itertools import dropwhile -from nnscaler.ir.cten import IRTensor, IRObject +from nnscaler.ir.cten import IRTensor, IRObject, ValueTrack from nnscaler.ir.operator import IRFwOperation @@ -753,7 +753,7 @@ def ianno(self, index: int) -> ShapeAnno: @return dim_annos ShapeAnno: a tuple that each element is a dimension annotation """ assert index < len(self.inputs()), "index out of boudary" - return tuple(self._iannos[index]) + return self._iannos[index] def oanno(self, index: int) -> ShapeAnno: """! @@ -853,7 +853,7 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict op_anno.reset_identifiers() identifier_values: Dict[str, int] = dict() - for ashape, itensor in zip(op_anno.inputs(), inputs): + for idx, (ashape, itensor) in enumerate(zip(op_anno.inputs(), inputs)): if not isinstance(itensor, IRTensor) or ashape.ignore: continue if ashape.ndims != len(itensor.shape): @@ -861,7 +861,12 @@ def align(self, signature, inputs: List[IRTensor], op_anno: OpAnno, kwargs: Dict for adim, dimlen in zip(ashape.dims, itensor.shape): if len(adim.identifiers) == 1: if adim.identifiers[0] in identifier_values and identifier_values[adim.identifiers[0]] != dimlen: - raise RuntimeError(f'the exist identifier value {identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}') + error_msg = ( + f"at {signature} with {op_anno} the exist identifier {adim.identifiers[0]} value " + f"{identifier_values[adim.identifiers[0]]} is not equal to the new value {dimlen}, " + f"error idx {idx}, input tensors {inputs}" + ) + raise RuntimeError(error_msg) identifier_values[adim.identifiers[0]] = dimlen # check dimension consistency diff --git a/nnscaler/graph/function/function.py b/nnscaler/graph/function/function.py index c083ec9c..f6af51ae 100644 --- a/nnscaler/graph/function/function.py +++ b/nnscaler/graph/function/function.py @@ -34,7 +34,7 @@ import logging from collections.abc import Iterable -from nnscaler.ir.cten import IRTensor, IRObject, IR +from nnscaler.ir.cten import IRTensor, IRObject, IR, ValueTrack from nnscaler.ir.tensor import IRSubTensor, IRFullTensor from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import DimopSplit, ShapeAnno, OpAnno, IRDimops, TransformRule @@ -150,6 +150,16 @@ def Accum(*inputs, signature = None): return IRDimops(Cat, 'accum', signature, [anno], inputs) +def Dot(input, tensor, *, out=None, signature = None): + """ + torch.dot(input, tensor, *, out=None) -> Tensor + """ + assert out is None + signature = 'torch.dot' + annos = ['k+, k+ -> 1',] + return IRDimops(Dot, 'dot', signature, annos, [input, tensor]) + + def Linear(input, weight, bias=None, signature = None): signature = 'torch.nn.functional.linear' assert isinstance(input, IRTensor) and isinstance(weight, IRTensor) @@ -195,6 +205,7 @@ def CubeEinSum(*operands, equation=None, signature = None): anno = f'{lhs} -> {rhs}' return IRDimops(CubeEinSum, 'einsum', signature, [anno], operands, equation=equation) + def EinSum(equation: str, *operands, signature = None): return CubeEinSum(*operands, equation=equation, signature=signature) @@ -259,7 +270,21 @@ def CubeArange(start: Union[int, IRObject], end: Union[int, IRObject], step: Uni size = (math.ceil((end_val-start_val)/step_val),) anno, rules = _get_creator_anno_rules( tuple(dim.value if isinstance(dim, IRObject) else dim for dim in size), False) - return IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(size) + if not isinstance(start, IRObject) and start == 0 \ + and not isinstance(step, IRObject) and step == 1 \ + and isinstance(end, IRObject): + # a special case for arange(0, end), which is very common in practice + # we can directly use end's value track + output.dim_tracks = [end.value_track] + else: + output.dim_tracks = [ValueTrack.new([start, end, step])] + ret = IRDimops(CubeArange, 'arange', signature, [anno], [], rules, **kwargs) + ret.set_output(0, output) + return ret def Arange(*args, start=None, end=None, step=None, out=None, dtype=None, layout=None, @@ -352,13 +377,38 @@ def creation_function_size_check(op_name, size, *arg_size) -> Tuple[Union[int, I raise ValueError(f"get illegal input size={size}, arg_size={arg_size} in {op_name}") # convert scalar to shape (1,) tensor, nnscaler don't support empty shape [] now. if len(size_val) == 0: - _logger.warn(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") + _logger.warning(f"detect tensor creation function {op_name} create a scalar, force it to create a shape [1] tensor instead") size = (1,) else: raise ValueError(f"get unknown input type size={size} in {op_name}") return size +def creation_function_dim_track(resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> list[ValueTrack]: + if isinstance(resolved_size, IRObject): + assert isinstance(resolved_size.value, (tuple, list)) + # all dims dependent on resolved_size + return [ValueTrack.new([resolved_size]) for _ in resolved_size.value] + + dim_tracks = [] + for dim in resolved_size: + if isinstance(dim, IRObject): + dim_tracks.append(ValueTrack.new([dim])) + else: + # no dim dependency when dim is not IRObject + dim_tracks.append(ValueTrack.new([])) + return dim_tracks + + +def creation_function_set_dim_tracks(op: IRDimops, resolved_size: Union[IRObject, tuple[Union[int, IRObject]]]) -> IRDimops: + # Output will be replaced in Parser, + # Here we just pass the value tracks out + output = IRFullTensor(_unwrap_value(resolved_size)) + output.dim_tracks = creation_function_dim_track(resolved_size) + op.set_output(0, output) + return op + + def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False, memory_format=None, signature=None): """ @@ -374,7 +424,10 @@ def Empty(size, *arg_size, out=None, dtype=None, layout=None, device=None, requi kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Empty, 'empty', signature, [anno], [], rules, **kwargs), + size + ) def Zeros(size, *arg_size, out=None, dtype=None, layout=None, @@ -390,7 +443,10 @@ def Zeros(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.zeros', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Zeros, 'zeros', signature, [anno], [], rules, **kwargs), + size + ) def Ones(size, *arg_size, out=None, dtype=None, layout=None, @@ -406,7 +462,10 @@ def Ones(size, *arg_size, out=None, dtype=None, layout=None, size = creation_function_size_check('torch.ones', size, *arg_size) kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Ones, 'ones', signature, [anno], [], rules, **kwargs), + size + ) def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -424,7 +483,10 @@ def Rand(size, *arg_size, out=None, dtype=None, layout=None, device=None, requir kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Rand, 'rand', signature, [anno], [], rules, **kwargs), + size + ) def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, device=None, requires_grad=False, @@ -442,7 +504,10 @@ def Randn(size, *arg_size, generator=None, out=None, dtype=None, layout=None, de kwargs = {'size': size, 'requires_grad': requires_grad, 'dtype': dtype, 'pin_memory': pin_memory} anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs) + return creation_function_set_dim_tracks( + IRDimops(Randn, 'randn', signature, [anno], [], rules, **kwargs), + size + ) def Full(size, fill_value, *, out=None, dtype=None, layout=None, @@ -457,8 +522,11 @@ def Full(size, fill_value, *, out=None, dtype=None, layout=None, signature = 'nnscaler.runtime.function.full' size = creation_function_size_check('torch.full', size) anno, rules = _get_creator_anno_rules(_unwrap_value(size), True) - return IRDimops(Full, 'full', signature, [anno], [], rules, - size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad) + return creation_function_set_dim_tracks( + IRDimops(Full, 'full', signature, [anno], [], rules, + size=size, fill_value=fill_value, dtype=dtype, requires_grad=requires_grad), + size + ) def NewTensor(data, *, dtype=None, device=None, @@ -492,6 +560,22 @@ def NewTensor(data, *, dtype=None, device=None, return IRDimops(NewTensor, 'tensor', signature, [anno], [], rules, **kwargs) +def Eye(n: int, m: Optional[int] = None, *, dtype=None, device=None, + requires_grad=False, signature=None): + """ + torch.eye(n, m=None, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) โ†’ Tensor + """ + dtype = dtype if dtype is not None else torch.get_default_dtype() + creation_function_args_check('torch.eye', dtype=dtype, device=device) + + signature = 'nnscaler.runtime.function.eye' + if m is None: + m = n + kwargs = {'n': n, 'm': m, 'requires_grad': requires_grad, 'dtype': dtype} + anno, rules = _get_creator_anno_rules((_unwrap_value(n), _unwrap_value(m)), False) + return IRDimops(Eye, 'eye', signature, [anno], [], rules, **kwargs) + + def _handle_broadcast(lhs: IRTensor, rhs: IRTensor) -> Tuple[List[str]]: """Create shape annotations for element wise operator following broadcastable rules: https://pytorch.org/docs/stable/notes/broadcasting.html @@ -1809,14 +1893,18 @@ def CubeStack(*tensors, dim=0, signature=None): assert all(isinstance(tensor, IRTensor) for tensor in tensors), f'but got {tensors}' assert isinstance(dim, int), f"but not {dim}" signature = 'nnscaler.runtime.function.stack' - iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] - oanno = [None for i in range(len(tensors[0].shape) + 1)] - oanno[dim] = f'{len(tensors)}^' - offset = 0 - for i in range(len(oanno)): - if oanno[i] is None: - oanno[i] = copy.copy(iannos[-1][offset]) - offset += 1 + if tensors[0].is_scalar_tensor(): + iannos = ['1' for _ in tensors] + oanno = [f'{len(tensors)}'] + else: + iannos = [ShapeAnno.create_shape_str(t.shape) for t in tensors] + oanno = [None for i in range(len(tensors[0].shape) + 1)] + oanno[dim] = f'{len(tensors)}' + offset = 0 + for i in range(len(oanno)): + if oanno[i] is None: + oanno[i] = copy.copy(iannos[-1][offset]) + offset += 1 anno = OpAnno.create_op_str(iannos, [oanno]) return IRDimops(CubeStack, 'stack', signature, [anno], tensors, dim=dim) @@ -1834,7 +1922,7 @@ def Stack(tensors, dim=0, out=None, signature = None): return CubeStack(*tensors, dim=dim, signature=signature) -def Chunk(input, chunks, dim=0, signature = None): +def Chunk(input: IRTensor, chunks, dim=0, signature = None): """ torch.chunk(input, chunks, dim=0) """ @@ -1845,7 +1933,18 @@ def Chunk(input, chunks, dim=0, signature = None): for oanno in oannos: oanno[dim] = str(input.shape[dim] // chunks) anno = OpAnno.create_op_str(iannos, oannos) - return IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + ret = IRDimops(Chunk, 'chunk', signature, [anno], [input], chunks=chunks, dim=dim) + + # set proper value tracks for outputs + output_shape = list(input.shape) + output_shape[dim] = input.shape[dim] // chunks + dim_vt = ValueTrack.new([chunks, input.dim_tracks[dim]]) + for d in range(chunks): + output = IRFullTensor(output_shape) + output.set_dim_track(dim, dim_vt) + ret.set_output(d, output) + + return ret def Select(input, dim, index, signature = None): @@ -2340,12 +2439,15 @@ def Size(tensor, dim=None, signature = None) -> Union[List[int], IRPyFunc]: torch.Tensor.size(tensor, dim=None) """ assert isinstance(tensor, IRTensor) - val = tensor.shape[dim] if isinstance(dim, int) else tensor.shape - assert val is not None + if isinstance(dim, int): + val = IRObject(name='size', value=tensor.shape[dim], value_track=tensor.dim_tracks[dim]) + else: + val = tuple(IRObject('size', value=s, value_track=t) for s, t in zip(tensor.shape, tensor.dim_tracks)) + if dim is None: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)]) + return IRPyFunc(signature, [tensor], [val]) else: - return IRPyFunc(signature, [tensor], [IRObject(name='size', value=val)], dim=dim) + return IRPyFunc(signature, [tensor], [val], dim=dim) def Dim(tensor, signature=None) -> Union[List[int], IRPyFunc]: @@ -2602,7 +2704,7 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], if isinstance(obj, IRTensor): if name == 'shape': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" - shape = IRObject('shape', value=obj.shape) + shape = tuple(IRObject('shape', value=s, value_track=t) for s, t in zip(obj.shape, obj.dim_tracks)) return IRPyFunc(signature, [instance, field], [shape]) if name == 'dtype': assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" @@ -2616,6 +2718,10 @@ def GetAttr(instance: object, field: str, signature = None) -> Union[List[int], assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" _logger.warning("getattr of 'layout' will always return torch.strided") return torch.strided + if name == 'T': + assert isinstance(obj, IRFullTensor), f"type {type(obj)} is not supported" + assert len(obj.shape) == 2, "only 2-dim tensor support .T operation" + return Transpose(obj, 0, 1, signature='torch.transpose') if isinstance(obj, torch.finfo): return getattr(obj, name) return IRPyFunc(signature, [instance, field], [IRObject.missing]) @@ -3391,10 +3497,14 @@ def Item(input, signature = None): """ torch.Tensor.item() """ - # set output to IRObject.missing, + # set output value to IRObject.missing_value, # because the output is unknown here. # It will be filled with real value in parser. - return IRPyFunc(signature, inputs=[input], outputs=[IRObject.missing], constant_foldable=False) + return IRPyFunc( + signature, inputs=[input], + outputs=[IRObject('item', value=IRObject.missing_value, is_constant=False)], + constant_foldable=False + ) def DictKeys(o: Union[Dict, IRObject], signature=None): @@ -3426,7 +3536,7 @@ def DictValues(o: Union[Dict, IRObject], signature=None): def DictItems(o: Union[Dict, IRObject], signature=None): - signature = 'nnscaler.runtime.function.dict_values' + signature = 'nnscaler.runtime.function.dict_items' if not isinstance(o, dict) and not (isinstance(o, IRObject) and isinstance(o.value, dict)): raise ValueError(f'the input should be a dict or an IRObject with dict value, but get {o}') diff --git a/nnscaler/graph/graph.py b/nnscaler/graph/graph.py index 550b21f1..6cf7372f 100644 --- a/nnscaler/graph/graph.py +++ b/nnscaler/graph/graph.py @@ -65,11 +65,10 @@ def __call__(self, *args): """ return self.forward(*args) - def forward(self, *args: Tuple[IRObject]) -> Union[IRTensor, Tuple[IRTensor]]: + def forward(self, *args: IRObject) -> Union[IRTensor, Tuple[IRTensor]]: """Forward the IRGraph to add model nodes into program. - Args: - args (Tuple[IRObject]): input IRObjects + args (Tuple[IRObject, ...]): input IRObjects Returns: Any: output that can be nested structure of IRObjects @@ -288,6 +287,7 @@ def use_dataloader_input(self): # IRDataOperation. Since we already know the output of the dataloader, # we don't need to set the value for it. ir_root_obj = IRObject(name='dataloader', value=None, is_constant=False) + ir_root_obj.value_track.with_no_dep() data_op = IRDataOperation(ir_root_obj, self.inputs()) # add the data operation to the graph, which will use `next` to get data. self.insert(data_op, 0) @@ -1212,7 +1212,8 @@ def checksum(self, strict: bool = True) -> str: def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_node: Union[IRFwOperation, IRDataOperation]): """ Copy meta information from src_node to dest_node. - Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device'] + Current copy fields: ['recompute', 'comment', 'op_context', 'module_stack', 'device', + 'hook_meta', 'pre_hook', 'post_hook'] """ if isinstance(src_node, IRFwOperation): dest_node.recompute = src_node.recompute @@ -1222,3 +1223,6 @@ def copy_node_meta_info(src_node: Union[IRFwOperation, IRDataOperation], dest_no dest_node.op_context = src_node.op_context dest_node.module_stack = src_node.module_stack dest_node.device = src_node.device + dest_node.hook_meta = src_node.hook_meta + dest_node.pre_hook = src_node.pre_hook + dest_node.post_hook = src_node.post_hook diff --git a/nnscaler/graph/parser/__init__.py b/nnscaler/graph/parser/__init__.py index 1dea36e7..e7fa0900 100644 --- a/nnscaler/graph/parser/__init__.py +++ b/nnscaler/graph/parser/__init__.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from nnscaler.graph.parser.parser import FxModuleParser +from nnscaler.graph.parser.parser import FxModuleParser, parse_fx_module from nnscaler.graph.parser.converter import convert_model, to_fx_graph, to_ir_graph from nnscaler.graph.parser.register import register from nnscaler.graph.parser.external import * diff --git a/nnscaler/graph/parser/converter.py b/nnscaler/graph/parser/converter.py index a30dfa23..ae338b25 100644 --- a/nnscaler/graph/parser/converter.py +++ b/nnscaler/graph/parser/converter.py @@ -12,7 +12,7 @@ from nnscaler.graph import IRGraph from nnscaler.flags import CompileFlag -from nnscaler.graph.parser import FxModuleParser +from nnscaler.graph.parser import parse_fx_module from nnscaler.graph.tracer import concrete_trace from nnscaler.graph.tracer.wrap_utils import Location, is_autograd_apply, LeafWrapInfo from nnscaler.graph.tracer.torch_fx_patcher import side_effectful_inplace_ops @@ -30,8 +30,11 @@ class no_save_tensor_hook(saved_tensors_hooks): """skip saving tensors for backward since tracer only traces forward""" def __init__(self): def pack(x): - return None + return (x.shape, x.dtype, x.device) def unpack(x): + # in pytorch 2.4.0-, torch.compile will call backward when tracing graph + if torch.__version__ < (2, 4, 0): + return torch.empty(x[0], dtype=x[1], device=x[2]) raise RuntimeError("not expecting backward to be called on this tensor") super().__init__(pack, unpack) @@ -146,7 +149,7 @@ def to_ir_graph( _logger.info(f"constant folding {'enabled' if constant_folding else 'disabled'} to parse graph") with no_save_tensor_hook(): - inputs, nodes, outputs = FxModuleParser.parse( + inputs, nodes, outputs = parse_fx_module( traced_model, dummy_input, attr_savedir=attr_savedir, constant_folding=constant_folding, diff --git a/nnscaler/graph/parser/external/__init__.py b/nnscaler/graph/parser/external/__init__.py index 5c71d8f9..5a628d8f 100644 --- a/nnscaler/graph/parser/external/__init__.py +++ b/nnscaler/graph/parser/external/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .apex import * \ No newline at end of file +from .apex import * +from .einops import * diff --git a/nnscaler/graph/parser/external/apex.py b/nnscaler/graph/parser/external/apex.py index 94b22209..e7d321d1 100644 --- a/nnscaler/graph/parser/external/apex.py +++ b/nnscaler/graph/parser/external/apex.py @@ -83,6 +83,7 @@ def apex_fused_rms_norm_affine_anno(input, weight, normalized_shape, eps, *args, parser.register(apex_fused_layer_norm_affine_anno)(fused_layer_norm_affine) parser.register(apex_fused_rms_norm_anno)(fused_rms_norm) parser.register(apex_fused_rms_norm_affine_anno)(fused_rms_norm_affine) + _logger.info("apex ops registered successfully.") except: - _logger.warning('skip apex ops as it is not installed.') + _logger.debug('skip apex ops as it is not installed.') diff --git a/nnscaler/graph/parser/external/einops.py b/nnscaler/graph/parser/external/einops.py new file mode 100644 index 00000000..fb24f38e --- /dev/null +++ b/nnscaler/graph/parser/external/einops.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import logging + +import torch + + +_logger = logging.getLogger(__name__) + +try: + import einops + # trigger einops initialization + einops.rearrange(torch.arange(1), '(a b c) -> a b c', a=1, b=1, c=1) + + from nnscaler.graph.tracer.wrap_utils import default_never_wrap_function, LeafWrapInfo, Location + + default_never_wrap_function[einops.einops._prepare_transformation_recipe] = \ + LeafWrapInfo([Location(einops.einops, '_prepare_transformation_recipe')], False, None) + + # we comment out these two functions + # because it looks not necessary for now. + # and they also introduce some problems, + # i.e. dynamic shape will be lost even with `compute_config.constant_folding=False` + + # default_never_wrap_function[einops.einops._reconstruct_from_shape_uncached] = \ + # LeafWrapInfo([Location(einops.einops, '_reconstruct_from_shape_uncached')], False, None) + # default_never_wrap_function[einops.einops._reconstruct_from_shape] = \ + # LeafWrapInfo([Location(einops.einops, '_reconstruct_from_shape')], False, None) + +except ImportError as e: + _logger.debug("Einops is not installed") + pass diff --git a/nnscaler/graph/parser/mapping.py b/nnscaler/graph/parser/mapping.py index 55c2792f..56c2908f 100644 --- a/nnscaler/graph/parser/mapping.py +++ b/nnscaler/graph/parser/mapping.py @@ -55,6 +55,7 @@ def exist(signature: str) -> bool: kOpMap = { # __tnmtemplate('Dropout'): function.nnDropout, + __ttemplate('dot'): function.Dot, __fcntemplate('linear'): function.Linear, __ftemplate('dropout') : function.Dropout, __ttemplate('sum'): function.Sum, @@ -180,6 +181,7 @@ def exist(signature: str) -> bool: __ttemplate('rand_like'): function.RandLike, __ttemplate('randn'): function.Randn, __ttemplate('randn_like'): function.RandnLike, + __ttemplate('eye'): function.Eye, __ttemplate('clone'): function.Clone, '_operator.is_': function.Is, diff --git a/nnscaler/graph/parser/parser.py b/nnscaler/graph/parser/parser.py index 02c52611..4fa263b9 100644 --- a/nnscaler/graph/parser/parser.py +++ b/nnscaler/graph/parser/parser.py @@ -11,8 +11,9 @@ from nnscaler.graph.tracer.metadata import OpContext from nnscaler.ir.operator import IRFwOperation from nnscaler.ir.tensor import IRFullTensor -from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR +from nnscaler.ir.cten import IRObject, IRCell, IRTensor, IR, ValueTrack from nnscaler.graph.parser.frame import Frame +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.function.dimops import IRDimops @@ -38,14 +39,14 @@ class FxModuleParser: ATTR_CONTENT_FILE_FORMAT = '{stem}.{idx}' ATTR_MAP_FILE = 'dist_param_map.pt' - @staticmethod - def parse(module: torch.fx.GraphModule, + def __init__(self, + module: torch.fx.GraphModule, dummy_inputs: Dict[str, Any], attr_savedir='./', *, save_content: bool = True, constant_folding: bool = False - ) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + ): """Parse torch.fx module into cube IR The overall entry to parse a torch.fx graph module @@ -56,6 +57,24 @@ def parse(module: torch.fx.GraphModule, attr_savedir (str): the directory to save the attribute content save_content (bool): whether to save the content of the module constant_folding (bool): whether to parse the module with constant folding + """ + + self.module = module + + self.dummy_inputs = dummy_inputs + assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" + + self.attr_savedir = attr_savedir + self.save_content = save_content + self.constant_folding = constant_folding + + self.frame = Frame() + self.value_tracker = ValueTracker() + + def parse(self) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module Returns: inputs (List[IRObject]): the input IRObjects @@ -67,12 +86,10 @@ def parse(module: torch.fx.GraphModule, # (Those ops creators include user registered function, all functions returning tensors and more) # We will connect the real op outputs (saved in frame) to all ir op outputs and inputs later. - frame = Frame() - frame.push_var() + self.frame.push_var() # shape propagation - assert isinstance(dummy_inputs, dict), f"Expected dummy inputs to parse module, but got {dummy_inputs} of type {type(dummy_inputs)}" - output_nodes = [node for node in module.graph.nodes if node.op == 'output'] + output_nodes = [node for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(output_nodes) == 1, f"Expect only one output, but got {len(output_nodes)}" @@ -81,11 +98,11 @@ def parse(module: torch.fx.GraphModule, assert len(output_node.args) == 1 and len(output_node.kwargs) == 0 # create IRObjects and IRTensors - for node in module.graph.nodes: + for node in self.module.graph.nodes: if node.op == 'placeholder': - FxModuleParser.init_objects(node, module, frame, is_constant=False) + self._init_objects(node, is_constant=False) else: - FxModuleParser.init_objects(node, module, frame, is_constant=True) + self._init_objects(node, is_constant=True) # note the output node will be reset later by `parse_prim_output_node` # with the help of `parse_complex` @@ -98,76 +115,93 @@ def parse(module: torch.fx.GraphModule, # to make sure the IRGraph has the correct output number # see `IRGrpah.from_logic_graph` - val = frame.get_var(node.name) + val = self.frame.get_var(node.name) if node == output_node.args[0] \ and IR.is_object(val) and isinstance(val.value, tuple): tuple_val = tuple(IRObject(name=node.name, value=v, is_constant=val.is_constant) for v in val.value) - frame.set_var(node.name, tuple_val) + self.frame.set_var(node.name, tuple_val) # get graph inputs - placeholders = [n for n in module.graph.nodes if n.op == 'placeholder'] - inputs = [frame.get_var(n.name) for n in placeholders] + placeholders = [n for n in self.module.graph.nodes if n.op == 'placeholder'] + inputs = [self.frame.get_var(n.name) for n in placeholders] + self.value_tracker.track_values(inputs) # - if the graph inputs contain nested strcuture, # it should be wrapped into an IRObject for idx, placeholder in enumerate(placeholders): if not isinstance(inputs[idx], IRObject): - obj = IRObject(name=placeholder.name, value=inputs[idx], is_constant=False) + obj = IRObject(name=placeholder.target, value=inputs[idx], is_constant=False) + obj.value_track.mark_as_input() inputs[idx] = obj - frame.set_var(placeholder.name, obj) + self.value_tracker.track_values([obj]) + self.frame.set_var(placeholder.name, obj) # parse graph nodes all_ir_nodes = [] - for node in module.graph.nodes: - ir_nodes = FxModuleParser.parse_node(node, module, constant_folding, frame) + for node in self.module.graph.nodes: + ir_nodes = self._parse_node(node) all_ir_nodes += ir_nodes + self.value_tracker.complete_tracking(all_ir_nodes) + # get graph outputs - outputs = [frame.get_var(node.name) for node in module.graph.nodes if node.op == 'output'] + outputs = [self.frame.get_var(node.name) for node in self.module.graph.nodes if node.op == 'output'] # currently fx graph always has only one output # even if a tuple/list is returned, it is still just one output assert len(outputs) == 1, f"Expect only one output, but got {len(outputs)}" - if save_content: - attr_savedir = Path(attr_savedir) - frame.save_attr_content(attr_savedir / FxModuleParser.ATTR_CONTENT_FILE_STEM) - frame.save_attr_map(attr_savedir / FxModuleParser.ATTR_MAP_FILE) + if self.save_content: + attr_savedir = Path(self.attr_savedir) + self.frame.save_attr_content(attr_savedir / self.ATTR_CONTENT_FILE_STEM) + self.frame.save_attr_map(attr_savedir / self.ATTR_MAP_FILE) - frame.pop_var() + self.frame.pop_var() return inputs, all_ir_nodes, outputs - @staticmethod - def parse_node(node: torch.fx.Node, module, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Parse the node and return the IRFwOperation nodes """ if node.op == 'placeholder': return [] if node.op == 'output': - return FxModuleParser.parse_prim_output_node(node, module, frame) + return self._parse_prim_output_node(node) if node.op in ('call_function', 'call_method'): - return FxModuleParser.parse_prim_function_method(node, module, constant_folding, frame) + return self._parse_prim_function_method(node) if node.op == 'get_attr': - return FxModuleParser.parse_prim_get_attr_node(node, module, frame) + return self._parse_prim_get_attr_node(node) if node.op == 'call_module': - return FxModuleParser.parse_prim_module(node, module, frame) + return self._parse_prim_module(node) else: raise TypeError(f"Unknown node kind {node.op}") - @staticmethod - def init_objects(node: torch.fx.Node, module: torch.fx.GraphModule, - frame: Frame, is_constant: bool = True): + def _init_objects(self, node: torch.fx.Node, is_constant: bool = True): assert isinstance(node, torch.fx.Node) assert hasattr(node, 'meta') and 'tensor_meta' in node.meta, f"Node {node} should have tensor_meta" meta = node.meta['tensor_meta'] - val = IR.new(node.name, meta, + val = IR.new( + # node.target is necesssary for input + # its name will be used to align with model forward args when generating code. + node.target if node.op == 'placeholder' else node.name, + meta, tensor_types=(TensorMetadata,), - is_constant=is_constant + is_constant=is_constant, ) - frame.add_var(node.name, val) - @staticmethod - def parse_complex(val: Any, frame: Frame) -> Any: + if node.op == 'placeholder': + def mark_as_input(x: IRObject): + if isinstance(x, IRTensor): + # let's the value_track of tensor stay None(unknown) + # because we don't care about it. + for dt in x.dim_tracks: + dt.with_no_dep() + else: + x.value_track.mark_as_input() + IR.modify_objects(val, mark_as_input) + + self.frame.add_var(node.name, val) + + def _parse_complex(self, val: Any) -> Any: """parse complex fx.Node into IRObject The val is usually from a node's input or output, can be fx.Node nested @@ -183,28 +217,28 @@ def parse_complex(val: Any, frame: Frame) -> Any: # to support more nested types, we can refer to the implementation of # https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py if isinstance(val, tuple): - return tuple(FxModuleParser.parse_complex(t, frame) for t in val) + return tuple(self._parse_complex(t) for t in val) if isinstance(val, list): - return list(FxModuleParser.parse_complex(t, frame) for t in val) + return list(self._parse_complex(t) for t in val) if isinstance(val, dict): - return {key: FxModuleParser.parse_complex(val, frame) for key, val in val.items()} + return {key: self._parse_complex(val) for key, val in val.items()} # TODO: Currently slice/DICT_VALUES_TYPE/DICT_ITEMS_TYPE cases are never found. # We need to find some examples to test them. if isinstance(val, slice): - return slice(FxModuleParser.parse_complex(val.start, frame), - FxModuleParser.parse_complex(val.stop, frame), - FxModuleParser.parse_complex(val.step, frame)) + return slice(self._parse_complex(val.start), + self._parse_complex(val.stop), + self._parse_complex(val.step)) # because fx node cannot be a dict key, so skip DICT_KEYS_TYPE here if isinstance(val, DICT_VALUES_TYPE): - return tuple(FxModuleParser.parse_complex(x, frame) for x in val) + return tuple(self._parse_complex(x) for x in val) if isinstance(val, DICT_ITEMS_TYPE): - return tuple((i, FxModuleParser.parse_complex(x, frame)) for i, x in val) + return tuple((i, self._parse_complex(x)) for i, x in val) if isinstance(val, torch.fx.Node): - return frame.get_var(val.name) + return self.frame.get_var(val.name) return val - @staticmethod - def fetch_attr(mod: torch.fx.GraphModule, target: str): + @classmethod + def _fetch_attr(cls, mod: torch.fx.GraphModule, target: str): target_atoms = target.split('.') attr_itr = mod for i, atom in enumerate(target_atoms): @@ -213,23 +247,19 @@ def fetch_attr(mod: torch.fx.GraphModule, target: str): attr_itr = getattr(attr_itr, atom) return attr_itr - @staticmethod - def parse_prim_module(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: - prim_module = FxModuleParser.fetch_attr(module, node.target) + def _parse_prim_module(self, node: torch.fx.Node) -> List[IRFwOperation]: + prim_module = self._fetch_attr(self.module, node.target) if prim_module.__class__.__module__.startswith('torch.nn.modules'): raise RuntimeError(f'{prim_module.__class__.__module__} can not be parsed as leaf nodes') else: raise RuntimeError(f'unknown module: {prim_module.__class__.__module__}') - @staticmethod - def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule, constant_folding: bool, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_function_method(self, node: torch.fx.Node) -> List[IRFwOperation]: """ Convert `call_function`/`call_method` op to IRFwOperation. Args: node (torch.fx.Node): the node to be parsed - module (torch.fx.GraphModule): the module containing the node - constant_folding (bool): global setting of whether to fold the constant Returns: List[IRFwOperation]: the IRFwOperation nodes. @@ -238,10 +268,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule """ # get signature - fsig = FxModuleParser._get_qualified_name(node.target, node) + fsig = self._get_qualified_name(node.target, node) # get inputs - input_vals = FxModuleParser.parse_complex(list(node.args), frame) - kwargs = FxModuleParser.parse_complex(node.kwargs, frame) + input_vals = self._parse_complex(list(node.args)) + kwargs = self._parse_complex(node.kwargs) # use context constant_folding if set # Please note constant_folding only controls the output of the op @@ -249,6 +279,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # when we enter the code block with different constant folding setting # as a workaround, # you can use `nnscaler.runtime.function.fold_constant` to fold inputs if needed + constant_folding = self.constant_folding op_context: Optional[Dict[str, Any]] = node.meta.get('op_context') if op_context is not None and op_context.get(fields(OpContext).constant_folding) is not None: constant_folding = op_context[fields(OpContext).constant_folding] @@ -258,12 +289,12 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule else: # FIXME: handle cases for IRObject in kwargs # case1: unknown torch operator - if FxModuleParser._is_torch_autograd_op(node, frame, fsig): + if self._is_torch_autograd_op(node, fsig): _logger.warning(f'Find unknown pytorch operation: {fsig}') fname = fsig.split('.')[-1] if '.' in fsig else fsig ir_node = IRFwOperation(fname, fsig, input_vals, 1, **kwargs) # case2: custom autograd function - elif FxModuleParser._is_custom_autograd_op(node): + elif self._is_custom_autograd_op(node): # custom autograd function _logger.warning(f'Find unknown custom autograd operation: {fsig}. You should register it with nnscaler.register_op') ir_node = IRFwOperation(fsig, fsig, input_vals, 1, **kwargs) @@ -276,7 +307,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule 'You can register it as a customized function using nnscaler.register_op to remove this warning' _logger.warning(warning_msg) is_constant = False - output = frame.get_var(node.name) + output = self.frame.get_var(node.name) if not isinstance(output, IRObject): # avoid nested IRObject output = IRObject(name=node.name, value=output, is_constant=is_constant) @@ -292,10 +323,10 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # As node is deleted, we must set concrete value or IRTensor/IRObject into framework. # TODO: check the value saved in frame should equal to the value returned by the op - frame.set_var(node.name, ir_node) + self.frame.set_var(node.name, ir_node) return [] - FxModuleParser._set_node_meta(node, ir_node) + self._set_node_meta(node, ir_node) # step 1: align the node output with the value in frame @@ -307,11 +338,11 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # but its output is used in other nodes. # By removing from frame, # we can catch the case earlier - frame.del_val(node.name) + self.frame.del_val(node.name) # if the function has no output, just return return [ir_node] - vals = frame.get_var(node.name) + vals: Union[Any, IRObject, List[IRObject], IRTensor, List[IRTensor]] = self.frame.get_var(node.name) if len(ir_node.outputs()) == 1: vals = [vals] elif IR.is_object(vals): @@ -324,7 +355,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule if not isinstance(vals, (list, tuple)): raise RuntimeError(f'Expect list or tuple for multiple outputs, but got {type(vals)}') vals = type(vals)(IRObject(name=node.name, value=v, is_constant=is_constant) for v in vals) - frame.set_var(node.name, vals) + self.frame.set_var(node.name, vals) # verify the inferred shape are consistent with actual output if isinstance(ir_node, IRFwOperation): @@ -337,11 +368,43 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # 1. output tensors are not set in function.py # 2. IRObject output from some functions (registered functions/getattr) are not set # For above two cases, we need to set them with values from frame. + if isinstance(ir_node.output(i), IRTensor): + assert isinstance(vals[i], IRTensor), f'Expect tensor for output {i}, but got {type(vals[i])}' + assert ir_node.output(i).shape == vals[i].shape, f'Expect shape {ir_node.output(i).shape} for output {i}, but got {vals[i].shape}' + # We need to copy dim tracks + # As we will use frame version as node output, instead of the placeholder created in function.py + for dim in range(len(vals[i].shape)): + vals[i].dim_tracks[dim].merge(ir_node.output(i).dim_tracks[dim]) ir_node.set_output(i, vals[i]) + elif isinstance(ir_node.output(i), IRObject) and ir_node.output(i).is_value_missing(): + # output is IRObject with missing value + # we need to set it with the value from frame + assert not IR.contains_object(vals[i], lambda x: isinstance(x, IRTensor)), \ + f'Output {i} of node {node} is expected to be IRObject, but got tensor: {vals[i]}' + ir_node.output(i).value = IR.try_unwrap(vals[i]) + else: + # Currently we don't support missing-value IRObject in tuple/list/dict/... + # TODO: add support when needed + assert not IR.contains_object(ir_node.output(i), lambda x: not isinstance(x, IRTensor) and x.is_value_missing()), \ + f'Output {i} of node {node} contains missing value: {ir_node.output(i)}' + + # per-op value tracking via its annotation + # TODO: + # This may be not accurate because many ops in function.py are not properly annotated their value deps + # Two ways to improve it: + # 1. add value deps annotation for those ops in function.py + # 2. use global data flow analysis to track value deps + # a. add all nodes without folding + # b. use value_tracker.track_nodes to analyze value deps for all nodes + # c. remove nodes that can be folded. + # It is not easy because some op logic in function.py works differently + # when its inputs are constant or not. + # For now, we just use per-op value tracking for simplicity. + self.value_tracker.track_nodes([ir_node]) # update frame with ir output # Please note when there is only one output, we will unwrap it from `ir_node.outputs()` here - frame.set_var( + self.frame.set_var( node.name, type(vals)(ir_node.outputs()) if len(ir_node.outputs()) > 1 else ir_node.output(0) ) @@ -349,6 +412,7 @@ def parse_prim_function_method(node: torch.fx.Node, module: torch.fx.GraphModule # update the name of output tensors # Note assignment is not allowed in lambda # so we use a helper function to update the name + def _update_name(x: IRObject): x.name = node.name IR.modify_objects_inplace(ir_node.outputs(), _update_name) @@ -378,22 +442,28 @@ def _is_primitive_type(val): # use a white list instead of a black list return isinstance(val, (int, float, bool, type(None), str, type(Ellipsis))) - # Note when it is not IRObject as a whole, we will not fold it if constant_folding and ir_node.constant_foldable \ and len(ir_node.outputs()) == 1 \ - and isinstance(ir_node.output(0), IRObject) \ - and not isinstance(ir_node.output(0), IRTensor) \ and not contains_undefined_output \ and not ir_node.signature.startswith(nnscaler.runtime.function.__name__ + '.')\ - and ir_node.output(0).is_constant \ - and _is_primitive_type(ir_node.output(0).value): - frame.set_var(node.name, ir_node.output(0).value) + and not IR.contains_object(ir_node.output(0), lambda x: isinstance(x, IRTensor) or not x.is_constant) \ + and _is_primitive_type(cval := IR.try_unwrap(ir_node.output(0))): + # TODO: + # This will break the value tracking graph + # for example, if not folded: + # value1 -> op1 -> value2 -> op2 -> value3 -> op3 + # if op2 is folded, then op3 will not know the value1 dependency + # So the value tracking becomes: + # value1 -> op1 value3 -> op3 + # In many cases, op1 and op3 can be connected by other ops, + # But when this becomes a problem, we need to fix it by using global data flow analysis. + self.frame.set_var(node.name, cval) + self.value_tracker.untrack_node(ir_node) return [] else: return [ir_node] - @staticmethod - def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRFwOperation]: + def _parse_prim_get_attr_node(self, node: torch.fx.Node) -> List[IRFwOperation]: """ There are two types of get_attr, one is `FxNodeKind.PrimGetAttr` which is dealt with in this function. The other is `FxNodeKind.PrimCallFunction ` (i.e., ) @@ -403,74 +473,84 @@ def parse_prim_get_attr_node(node: torch.fx.Node, module: torch.fx.GraphModule, node.target is the attribute name of the object. """ ir_nodes = [] - concrete_value = FxModuleParser.fetch_attr(module, node.target) + concrete_value = self._fetch_attr(self.module, node.target) if isinstance(concrete_value, torch.Tensor): assert isinstance(concrete_value, torch.Tensor), \ f"GetAttrPrim: expect tensor but got {type(concrete_value)}" - exist_tensor = frame.get_attr_var(concrete_value) + exist_tensor = self.frame.get_attr_var(concrete_value) # the case that the parameter is the first time used by getattr if not exist_tensor: - tensor = frame.get_var(node.name) + tensor: IRFullTensor = self.frame.get_var(node.name) # set tensor name same with the name in original model tensor.name = node.target if tensor.requires_grad: tensor.as_param() else: - direct_module = module + direct_module = self.module full_qualified_name = node.target.split('.') for name in full_qualified_name[:-1]: # last one is the attribute name direct_module = getattr(direct_module, name) persistent = full_qualified_name[-1] not in direct_module._non_persistent_buffers_set tensor.as_buffer(persistent=persistent) - frame.add_attr(tensor, concrete_value, node.target) + + # Parameters and buffers have no dependency on other values + for dt in tensor.dim_tracks: + dt.is_constant = True + dt.with_no_dep() + + self.frame.add_attr(tensor, concrete_value, node.target) # the case that the parameter is consumed multiple times and registered previously else: - frame.set_var(node.name, exist_tensor) + self.frame.set_var(node.name, exist_tensor) else: assert isinstance(node.target, str), f"GetAttrPrim: expect `node.target` to be str but got {type(node.target)}" # in sub modules, the target is full qualified name (for example `embeddings.dropout.training`) if node.target.split('.')[-1] == 'training': # Let's just support `self.training` and ignore all other cases for now - output = IRObject(name=node.name, value=frame.get_var(node.name), is_constant=False) + if isinstance(output := self.frame.get_var(node.name), IRObject): + output.is_constant = False + else: + output = IRObject(name=node.name, value=output, is_constant=False) ir_node = IRPyFunc(SELF_GETATTR_SIG, ['training'], [output]) - FxModuleParser._set_node_meta(node, ir_node) - frame.set_var(node.name, output) + self._set_node_meta(node, ir_node) + self.frame.set_var(node.name, output) # never fold the IRPyFunc node ir_nodes.append(ir_node) else: - frame.set_var(node.name, concrete_value) + self.frame.set_var(node.name, concrete_value) + self.value_tracker.track_nodes(ir_nodes) return ir_nodes - @staticmethod - def parse_prim_output_node(node: torch.fx.Node, module: torch.fx.GraphModule, frame: Frame) -> List[IRCell]: + def _parse_prim_output_node(self, node: torch.fx.Node) -> List[IRCell]: assert len(node.args) == 1 and len(node.kwargs) == 0 - output = FxModuleParser.parse_complex(node.args[0], frame) - frame.set_var(node.name, output) + output = self._parse_complex(node.args[0]) + self.frame.set_var(node.name, output) return [] - @staticmethod - def _set_node_meta(node: torch.fx.Node, ir_node: Union[IRCell, Any]): + @classmethod + def _set_node_meta(cls, node: torch.fx.Node, ir_node: Union[IRCell, Any]): if not isinstance(ir_node, IRCell): return ir_node.op_context = node.meta.get('op_context') module_stack = node.meta.get('nn_module_stack') ir_node.module_stack = module_stack + ir_node.call_expr = node.meta.get('call_expr') comment = str(node.meta.get('frame_record', '')) if comment: ir_node.comment = comment - @staticmethod - def _get_qualified_name(node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: + @classmethod + def _get_qualified_name(cls, node_target: Union[str, Callable[..., Any]], node: torch.fx.Node = None) -> str: if isinstance(node_target, str): assert node is not None - return FxModuleParser._get_qualified_name_of_call_method(node_target, node) + return cls._get_qualified_name_of_call_method(node_target, node) else: - return FxModuleParser._get_qualified_name_of_call_function(node_target) + return cls._get_qualified_name_of_call_function(node_target) - @staticmethod - def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str: + @classmethod + def _get_qualified_name_of_call_function(cls, node_target: Callable[..., Any]) -> str: """ The target field of call_function node must be an callable object. """ @@ -480,12 +560,12 @@ def _get_qualified_name_of_call_function(node_target: Callable[..., Any]) -> str # TODO(yizhu1): find a general solution assert callable(node_target) name = node_target.__name__ - module = FxModuleParser._find_module_of_method(node_target) + module = cls._find_module_of_method(node_target) module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module return f'{module}.{name}' - @staticmethod - def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> str: + @classmethod + def _get_qualified_name_of_call_method(cls, node_target: str, node: torch.fx.Node) -> str: """ The target field of call_method node must be a string. """ @@ -513,8 +593,8 @@ def _get_qualified_name_of_call_method(node_target: str, node: torch.fx.Node) -> else: return f'{in_type.__module__}.{in_type.__name__}.{node_target}' - @staticmethod - def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + @classmethod + def _find_module_of_method(cls, orig_method: Callable[..., Any]) -> str: if getattr(orig_method, '__name__', None) == 'apply' and isinstance(getattr(orig_method, '__self__', None), Type) \ and issubclass(orig_method.__self__, torch.autograd.Function): # for torch.autograd.Function @@ -547,18 +627,49 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: return guess.__name__ raise RuntimeError(f'cannot find module for {orig_method}') - @staticmethod - def _is_torch_autograd_op(node: torch.fx.Node, frame: Frame, signature: str) -> bool: + def _is_torch_autograd_op(self, node: torch.fx.Node, signature: str) -> bool: """Check whether the node is of a pytorch autograd operation.""" # note: some python operations like torch.Tensor.size() doesn't return # an IRTensor, thus cannot be considered as a pytorch autograd operator. return signature.startswith('torch.') and \ - isinstance(frame.get_var(node.name), IRFullTensor) + isinstance(self.frame.get_var(node.name), IRFullTensor) - @staticmethod - def _is_custom_autograd_op(node: torch.fx.Node) -> bool: + @classmethod + def _is_custom_autograd_op(cls, node: torch.fx.Node) -> bool: node_target = node.target return callable(node_target) \ and getattr(node_target, '__name__', None) == 'apply' \ and isinstance(getattr(node_target, '__self__', None), Type) \ and issubclass(node_target.__self__, torch.autograd.Function) + + +def parse_fx_module( + module: torch.fx.GraphModule, + dummy_inputs: Dict[str, Any], + attr_savedir='./', + *, + save_content: bool = True, + constant_folding: bool = False +) -> Tuple[List[IRObject], List[IRFwOperation], List[IRObject]]: + """Parse torch.fx module into cube IR + + The overall entry to parse a torch.fx graph module + + Args: + module (torch.fx.GraphModule): the torch.fx module + dummy_inputs (Dict[str, Any]): the dummy inputs to run the module + attr_savedir (str): the directory to save the attribute content + constant_folding (bool): whether to parse the module with constant folding + + Returns: + inputs (List[IRObject]): the input IRObjects + all_ir_nodes (List[IRFwOperation]): the IRFwOperation nodes + outputs (List[IRObject]): the output IRObjects + """ + return FxModuleParser( + module, + dummy_inputs, + attr_savedir, + save_content=save_content, + constant_folding=constant_folding + ).parse() diff --git a/nnscaler/graph/parser/value_tracker.py b/nnscaler/graph/parser/value_tracker.py new file mode 100644 index 00000000..45a3cf0f --- /dev/null +++ b/nnscaler/graph/parser/value_tracker.py @@ -0,0 +1,303 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from typing import Any +from nnscaler.graph.function.dimops import IRDimops +from nnscaler.ir.cten import IR, IRObject, IRTensor, ValueTrack +from nnscaler.ir.operator import IRFwOperation + + +class ValueTracker: + """ + Example: + >>> vt = ValueTracker() + >>> vt.track_value(input1) + >>> vt.track_value(input2) + >>> ... + >>> vt.track_nodes([node1]) + >>> vt.track_nodes([node2]) + >>> vt.untrack_node(node2) # when node2 is folded + >>> vt.track_nodes([node3]) + >>> ... + >>> vt.complete_tracking([node1, node3, ...]) # pass all tracked nodes here + """ + def __init__(self): + # value_id -> ValueTrack + # Please note some ValueTracks may be merged together (from annotation) + # So the key can be different from the id of the ValueTrack + self._vtm: dict[int, ValueTrack] = {} + self._equiv_value_ids: dict[int, set[int]] = {} + # store removed value ids + # used to delay the removal of value tracks in deps + self._removed_value_ids: set[int] = set() + + def _add_track_value(self, value: ValueTrack): + if value.value_id not in self._vtm: + # always use the updated value track in self._vtm + self._vtm[value.value_id] = value + + if value.value_id not in self._equiv_value_ids: + self._equiv_value_ids[value.value_id] = {value.value_id} + + def track_values(self, objs: list[Any]) -> set[int]: + """ + Track the value tracks of the given objects. + Args: + objs (list[Any]): the objects to be tracked + Returns: + set[int]: the set of value ids tracked + """ + value_ids = set() + for obj in objs: + value_ids.update(self._track_value(obj)) + return value_ids + + def _track_value(self, value: Any): + for obj in IR.get_objects(value): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + self._add_track_value(dt) + yield dt.value_id + else: + assert isinstance(obj, IRObject) + self._add_track_value(obj.value_track) + yield obj.value_track.value_id + + def _update_track_value(self, obj: IRObject): + if isinstance(obj, IRTensor): + new_dim_tracks = [] + for dt in obj.dim_tracks: + new_dim_tracks.append(self._vtm[dt.value_id]) + obj.dim_tracks = new_dim_tracks + else: + assert isinstance(obj, IRObject) + obj.value_track = self._vtm[obj.value_track.value_id] + + def _update_constness(self, obj: IRObject): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + dt.is_constant = dt.is_constant and all(self._vtm[dep].is_constant for dep in dt.deps or []) + else: + assert isinstance(obj, IRObject) + obj.value_track.is_constant = obj.value_track.is_constant and all(self._vtm[dep].is_constant for dep in obj.value_track.deps or []) + + def track_nodes(self, nodes: list[IRFwOperation]): + """ + Track the value tracks of the input and output objects in the given nodes. + Here we assume the nodes are topologically sorted. + + Please note we only update the tracks of nodes in arguments. + For nodes not in arguments, their tracks are not updated. + + Args: + nodes (list[IRFwOperation]): the nodes to be tracked + """ + # collect all value tracks from nodes + if not nodes: + return + + # collect all involved value ids from nodes + node_value_ids = set() + for node in nodes: + for obj in node.iobjs(): + node_value_ids.update(self._track_value(obj)) + for obj in node.oobjs(): + node_value_ids.update(self._track_value(obj)) + + # collect extra value tracks from dimops + for node in nodes: + if isinstance(node, IRDimops): + self._track_dims(node) + + # merge equivalent value tracks together + done_value_ids = set() + for value_id in node_value_ids: + equiv_ids = self._equiv_value_ids[value_id] + + min_value_id = min(equiv_ids) + if min_value_id in done_value_ids: + continue + done_value_ids.add(min_value_id) + + # use the smallest id as the representative + rep_one = self._vtm[min_value_id] + for vid in equiv_ids: + if vid == min_value_id or self._vtm[vid] is rep_one: + continue + # TODO: how we merge dependencies? + # current we take union (Union may be too strict) + if rep_one.deps is None: + rep_one.deps = self._vtm[vid].deps + elif self._vtm[vid].deps is not None: + # deps can still have duplicates here + # because merging of the rest value tracks haven't been done yet + # NOTE: + # 1. this duplication is temporary, + # Duplicated value ids will be removed when we touch the same value track again + # in future track_nodes call. + # 2. duplication is not harmful for correctness + rep_one.deps = list( + set(rep_one.deps) + .union(self._vtm[vid].deps) + .difference(self._removed_value_ids) + ) + self._vtm[vid] = rep_one + + self._propagate_tracks(nodes) + + def untrack_node(self, node: IRFwOperation): + """ + Untrack the value tracks of output objects in the given node. + This function is used when we fold a node from the graph. + + Args: + node (IRFwOperation): the node to be untracked + """ + input_value_ids = set() + for obj in node.iobjs(): + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + input_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + input_value_ids.add(obj.value_track.value_id) + + for obj in node.oobjs(): + # we can only remove value tracks that are not used by inputs + if isinstance(obj, IRTensor): + for dt in obj.dim_tracks: + if dt.value_id not in input_value_ids: + self._removed_value_ids.add(dt.value_id) + else: + assert isinstance(obj, IRObject) + if obj.value_track.value_id not in input_value_ids: + self._removed_value_ids.add(obj.value_track.value_id) + + def complete_tracking(self, nodes: list[IRFwOperation]): + """ + Complete the tracking process. + Should be called after all nodes are tracked. + """ + # remove all removed value ids for vtm + # note we don't remove them from equivalence classes + for removed_id in self._removed_value_ids: + if self._vtm[removed_id].value_id == removed_id \ + and (new_equiv_cls := self._equiv_value_ids[removed_id].difference(self._removed_value_ids)): + # change the representative value id of this equivalence class + # NOTE: + # In current usage, code should not reach here. + # As we remove value tracks only for constant irobjects, + # and all equivalent value tracks should be removed together. + self._vtm[removed_id].value_id = min(new_equiv_cls) + self._vtm.pop(removed_id, None) + + # replace dependencies with their representative value tracks + # which can introduce some duplicates + # So we use `set` to further dedup dependencies + for vt in self._vtm.values(): + if vt.deps is not None: + vt.deps = list(set( + self._vtm[d].value_id for d in vt.deps + if d not in self._removed_value_ids + )) + + self._propagate_tracks(nodes) + + def _propagate_tracks(self, nodes: list[IRFwOperation]): + """ + Update value tracks and constantness information of the input and output objects + in the given nodes. + """ + # propagate the merged value tracks back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_track_value(obj) + for obj in node.oobjs(): + self._update_track_value(obj) + + # propagate the constantness information back to nodes + for node in nodes: + for obj in node.iobjs(): + self._update_constness(obj) + for obj in node.oobjs(): + self._update_constness(obj) + + def _track_dims(self, node: IRDimops): + """ + Track the dimension values of output tensors according to input tensors. + This function should be called after shape inference. + """ + # align the dim_ids of output with inputs + # not-hidden-dimension means the identifier is all for this dimension + # for example, in `l (2 h) m`, + # l and m are not-hidden-dimension identifiers, h is hidden-dimension identifier + # + # If the annotation is `l (2 h) m -> l h (m 2 h)` + # We will get the following relations (nhd->not-hidden-dimension, hd->hidden-dimension): + # 1. for `l`: `input.dim_tracks[0] is output.dim_tracks[0]` # both nhd, equality + # 2. for `m`: `input.dim_tracks[2].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # 3. for `h`: `input.dim_tracks[1].value_id in output.dim_tracks[2].deps` # one is hd, depencency + # `input.dim_tracks[1] in output.dim_tracks[1].deps` # one is hd, depencency + + # TODO: We can handle more complex cases in the future if needed. + # In current version, we don't handle the case like + # 1. `(2 h) -> (2 h)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]? (2 can be a runtime number, so we cannot be sure) + # 2. `(l m) -> (l m)`: input.dim_tracks[0] should be equal to output.dim_tracks[0]. + + # ivt => identifier_value_track_map + hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + non_hidden_ivt: dict[str, list[ValueTrack]] = defaultdict(list) + + for i, input_tensor in enumerate(node.inputs()): + if not isinstance(input_tensor, IRTensor) or node.ianno(i).ignore: + continue + + ianno = node.ianno(i) + for dim, dim_track in zip(ianno.dims, input_tensor.dim_tracks): + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + # not hidden dimension + non_hidden_ivt[identifiers[0]].append(dim_track) + else: + for iden in identifiers: + hidden_ivt[iden].append(dim_track) + + for iden, iden_infos in non_hidden_ivt.items(): + # merge all not-hidden-dimension infos together + first = iden_infos[0] + for info in iden_infos[1:]: + self._add_equiv_value(first.value_id, info.value_id) + + for i, output_tensor in enumerate(node.outputs()): + if not isinstance(output_tensor, IRTensor) or node.oanno(i).ignore: + continue + + oanno = node.oanno(i) + for dim, dim_track in zip(oanno.dims, output_tensor.dim_tracks): + # find the first identifier that is not a number + identifiers = [i for i in dim.identifiers if not str.isdecimal(i)] + if len(identifiers) == 1 and len(dim.identifiers) == 1: + ident = identifiers[0] + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + self._add_equiv_value(first.value_id, dim_track.value_id) + else: + # this identifier is used together with other identifiers + # so it is just a dependency. + dim_track.deps = dim_track.deps or [] + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + dim_track.deps = list(set(dim_track.deps)) # deduplicate + else: + dim_track.deps = dim_track.deps or [] + for ident in identifiers: + if ident in hidden_ivt: + dim_track.deps.extend(v.value_id for v in hidden_ivt[ident]) + if ident in non_hidden_ivt: + first = non_hidden_ivt[ident][0] + dim_track.deps.append(first.value_id) + + def _add_equiv_value(self, value_id, other_value_id): + self._equiv_value_ids[value_id].update(self._equiv_value_ids[other_value_id]) + for vid in self._equiv_value_ids[other_value_id]: + self._equiv_value_ids[vid] = self._equiv_value_ids[value_id] diff --git a/nnscaler/graph/tracer/concrete_tracer.py b/nnscaler/graph/tracer/concrete_tracer.py index 9ec92c27..d2471d36 100644 --- a/nnscaler/graph/tracer/concrete_tracer.py +++ b/nnscaler/graph/tracer/concrete_tracer.py @@ -16,6 +16,7 @@ from types import FunctionType, MethodDescriptorType, MethodType, MethodWrapperType, ModuleType from typing import Any, Dict, Optional, Set, Tuple, Type, List, Callable, Union, Literal from contextlib import contextmanager +import weakref import torch from torch._C import ScriptObject @@ -29,6 +30,8 @@ from torch.fx.proxy import TracerBase, Scope from torch.fx.operator_schemas import check_for_mutable_operation +from nnscaler.utils import transform_recursively + dict_keys_type = type(dict().keys()) dict_values_type = type(dict().values()) dict_items_type = type(dict().items()) @@ -89,6 +92,7 @@ def __init__(self, strategy, record_frames = False): self.scope = Scope("", None) self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + self.call_expr_stack = [] self.strategy = TRACE_STRATEGY[strategy](self) self.record_frames = record_frames self.patcher = FunctionPatcher() @@ -100,6 +104,16 @@ def __init__(self, strategy, record_frames = False): self.need_revert_functions = set() self.need_revert_wrapped_functions = set() + # Save functions decorated with functools.cache/lru_cache + # We need to clear up caches after tracing to avoid memory leak or tracing error. + # TODO: currently only functions/methods are tracked. + # Cached Properties (via @property @cache or @cached_property) are not tracked + # The reason is: + # 1. Cached properties is rare to cause problem as they have no arguments (no ConcrateProxy object will pass to it) + # 2. We need to patch all getattr (`a.b``) to support this scenario, which is too expensive + # Currently only function calls (`f(a,b)`) are patched and tracked. (See `operator_patcher`) + self.cached_function = weakref.WeakSet() + self.temp_call_origin = False def add_need_revert_function(self, func, wrapped_func): @@ -109,6 +123,49 @@ def add_need_revert_function(self, func, wrapped_func): def need_revert(self, func): return func in self.need_revert_functions or func in self.need_revert_wrapped_functions + @classmethod + def _is_cache_wrapped_function(cls, func): + return callable(func) \ + and hasattr(func, 'cache_clear') \ + and hasattr(func, 'cache_info') \ + and hasattr(func, 'cache_parameters') \ + and hasattr(func, '__wrapped__') \ + and callable(func.__wrapped__) + + def _track_cache_wrapped_function(self, func): + while func is not None: + if self._is_cache_wrapped_function(func): + self.cached_function.add(func) + break + func = getattr(func, '__wrapped__', None) + + @classmethod + def _is_torch_compile_function(cls, func): + return callable(func) \ + and hasattr(func, '__wrapped__') \ + and hasattr(func, '_torchdynamo_orig_callable') + + def _check_torch_compile_function(self, func): + outmost_func = func + while func is not None: + if self._is_torch_compile_function(func): + # If func is registered, run this func will be in a reverted context. + if not self.need_revert(outmost_func): + raise RuntimeError( + f"@torch.compile decorated function `{outmost_func.__module__}.{outmost_func.__qualname__}` is not registered. " + f"You must register it to avoid tracing failure." + ) + break + func = getattr(func, '__wrapped__', None) + + def on_function_call(self, func, expr): + self.call_expr_stack.append(expr) + self._track_cache_wrapped_function(func) + self._check_torch_compile_function(func) + + def on_function_call_end(self): + self.call_expr_stack.pop() + @contextmanager def do_temp_call_origin(self): temp_call_origin = self.temp_call_origin @@ -159,6 +216,15 @@ def create_node(self, kind : str, target : Target, else: node.meta['nn_module_stack'] = collections.OrderedDict() + if self.call_expr_stack: + last_call_expr = None + for item in reversed(self.call_expr_stack): + # if not found, leave last_call_expr as None + if item: + last_call_expr = item + break + node.meta['call_expr'] = last_call_expr + def unwrap_nested_proxy(proxy: ep.ConcreteProxy): return pytree_utils.tree_map_only(ep.ConcreteProxy, unwrap_nested_proxy, proxy.value) @@ -397,7 +463,6 @@ def proxy_placeholder(name: str): return self.create_proxy('placeholder', name, default_arg, {}) args.extend(proxy_placeholder(names) for names in arg_names) - if hasattr(co, 'co_kwonlyargcount') and ( co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF): # TODO: type annotations for *args and **kwargs @@ -407,6 +472,13 @@ def proxy_placeholder(name: str): more_args = proxy_placeholder(name) if co.co_flags & inspect.CO_VARKEYWORDS: name = '**' + next(names_iter) + if name not in concrete_args: + # auto pack the additional kwargs + kwargs_val = {} + for cc_name in concrete_args: + if cc_name not in arg_names and not cc_name.startswith('*'): + kwargs_val[cc_name] = concrete_args[cc_name] + concrete_args[name] = kwargs_val default_args[name] = {} kwargs = proxy_placeholder(name) @@ -470,7 +542,10 @@ def get_wrapped_leaves(self, leaf_functions: Dict[Callable, wrap_utils.LeafWrapI method_name=func.__name__, ) elif func.__name__ != func.__qualname__ and func.__qualname__ != 'boolean_dispatch..fn' \ - and not func.__qualname__.startswith('PyCapsule'): + and not func.__qualname__.startswith('PyCapsule') \ + and not func.__qualname__.startswith('pybind11_detail_function_'): + # this branch is for method/functions originally not defined in module level. + # in torch >= 2.9, we found pybind11_builtins are included in torch namespace. # method # in torch >= 2.2, we found two functions under torch._C has no __module__: # @@ -480,8 +555,11 @@ def get_wrapped_leaves(self, leaf_functions: Dict[Callable, wrap_utils.LeafWrapI path = sys.modules.get(func.__module__[1:], sys.modules[func.__module__]) else: path = sys.modules[func.__module__] - path = getattr(path, func.__qualname__.split('.')[0]) - locations = (*locations, wrap_utils.Location(path, func.__name__)) + try: + path = getattr(path, func.__qualname__.split('.')[0]) + locations = (*locations, wrap_utils.Location(path, func.__name__)) + except AttributeError: + _logger.warning(f'Can not get the class path of method {func} {func.__qualname__}!') if len(locations) == 0: _logger.warning(f'Can not find location of {func}, skip wrap it.') continue @@ -647,8 +725,14 @@ def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): def wrap_never_wrap_function(func, *args, **kwargs): if self.patcher.patch_mode: with self.patcher.revert(): + # unwrap all proxy in args/kwargs + args = transform_recursively(args, lambda x: x.value, target_types=ep.ConcreteProxy) + kwargs = transform_recursively(kwargs, lambda x: x.value, target_types=ep.ConcreteProxy) return func(*args, **kwargs) else: + # unwrap all proxy in args/kwargs + args = transform_recursively(args, lambda x: x.value, target_types=ep.ConcreteProxy) + kwargs = transform_recursively(kwargs, lambda x: x.value, target_types=ep.ConcreteProxy) return func(*args, **kwargs) try: @@ -692,6 +776,10 @@ def unwrap_nested_proxy(proxy: ep.ConcreteProxy): {}, type_expr=fn.__annotations__.get('return', None), node_result=node_result) finally: _retain_weight_consistency(self.root) + # clean up caches + for func in self.cached_function: + if func is not None: + func.cache_clear() return self.graph diff --git a/nnscaler/graph/tracer/metadata.py b/nnscaler/graph/tracer/metadata.py index f6f898a2..75f4de9a 100644 --- a/nnscaler/graph/tracer/metadata.py +++ b/nnscaler/graph/tracer/metadata.py @@ -8,6 +8,7 @@ from torch.fx.node import Node from . import pytree_utils +from nnscaler.utils import get_dynamic DICT_KEYS_TYPE = type({}.keys()) DICT_VALUES_TYPE= type({}.values()) @@ -95,6 +96,9 @@ class TensorMetadata(NamedTuple): is_quantized : bool qparams: Dict[str, Any] + # all dynamic dimensions in shape + dynamic_dims: set[int] + def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: """ @@ -134,7 +138,7 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] - return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + return TensorMetadata(shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams, get_dynamic(result)) def extract_metadata(results: Any, node: Node): diff --git a/nnscaler/graph/tracer/operator_patcher.py b/nnscaler/graph/tracer/operator_patcher.py index 235da0b5..783d9627 100644 --- a/nnscaler/graph/tracer/operator_patcher.py +++ b/nnscaler/graph/tracer/operator_patcher.py @@ -171,7 +171,11 @@ def visit_Call(self, node: ast.Call): self.modified = True return self.generic_visit(ast.Call( func=ast.Name(id=self.proxy_call_name, ctx=ast.Load()), - args=[node.func, *node.args], + args=[ + node.func, + ast.fix_missing_locations(ast.Constant(value=ast.unparse(node))), + *node.args + ], keywords=node.keywords, )) else: @@ -311,7 +315,7 @@ def patch_func_helper(self, func): # use func.__code__.co_filename to make the new function easily debuggable. compile(new_tree, func_inner.__code__.co_filename, 'exec'), { - self.proxy_call_name: OperatorPatcherContext.patch_run, + self.proxy_call_name: OperatorPatcherContext._patch_run, **func_inner.__globals__, **closure_dict, }, @@ -346,9 +350,19 @@ def __exit__(self, exc_type, exc_value, tb): return exc_type is None @staticmethod - def patch_run(func, *args, **kwargs): + def _patch_run(func, expr, *args, **kwargs): assert OperatorPatcherContext.ctx_tracer is not None assert OperatorPatcherContext.ctx_patcher is not None with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call(func, expr) new_func = OperatorPatcherContext.ctx_patcher.patch_func_or_module(func) - return new_func(*args, **kwargs) + + ret = new_func(*args, **kwargs) + + with wrap_utils.do_temp_call_origin(): + OperatorPatcherContext.ctx_tracer.on_function_call_end() + return ret + + @staticmethod + def patch_run(func, *args, **kwargs): + return OperatorPatcherContext._patch_run(func, '', *args, **kwargs) diff --git a/nnscaler/graph/tracer/torch_fx_patcher.py b/nnscaler/graph/tracer/torch_fx_patcher.py index 8affab89..af358275 100644 --- a/nnscaler/graph/tracer/torch_fx_patcher.py +++ b/nnscaler/graph/tracer/torch_fx_patcher.py @@ -191,7 +191,7 @@ def format_import_statement_new(name: str, obj: Any, importer) -> str: return TorchFXPatcher.format_import_statement_ori(name, obj, importer) @staticmethod - def is_impure_new(node: fx_node.Node): + def is_impure_new(node: fx_node.Node, impure_random: bool = True) -> bool: """ Returns whether this op is impure, i.e. if its op is a placeholder or output, or if a call_function or call_module which is impure. @@ -208,6 +208,39 @@ def is_impure_new(node: fx_node.Node): # Check if an impure function. if node.op == "call_function": + schema = getattr(node.target, "_schema", None) + if schema is not None and schema.is_mutable: + # impure since it mutates inputs + return True + + if impure_random: + if getattr(node.target, "_nondeterministic_seeded", False): + # impure since it mutates RNG state + return True + + # Handle Python random functions that don't have _nondeterministic_seeded + # but still affect global RNG state (issue #151524) + # These should be impure regardless of impure_random setting to maintain + # consistency between eager and compiled execution + _random_functions = { + torch.rand, + torch.randn, + torch.randint, + torch.randperm, + torch.rand_like, + torch.randn_like, + torch.randint_like, + torch.normal, + torch.poisson, + torch.bernoulli, + torch.multinomial, + } + + if node.target in _random_functions: + # All random operations are impure to ensure consistent behavior + # between eager and compiled execution, regardless of generator usage + return True + return node.target in _side_effectful_functions # NOTE by nnscaler: we assume all method end with "_" is inplace operation, diff --git a/nnscaler/ir/cten.py b/nnscaler/ir/cten.py index f359fec3..42a05eaf 100644 --- a/nnscaler/ir/cten.py +++ b/nnscaler/ir/cten.py @@ -18,18 +18,19 @@ from __future__ import annotations +from dataclasses import dataclass, field from functools import lru_cache -from typing import ClassVar, List, Tuple, Union, Optional, Any, Dict, Callable +from typing import ClassVar, Iterable, List, Set, Tuple, Type, Union, Optional, Any, Dict, Callable from collections import OrderedDict import copy import torch from nnscaler.ir.unique import IDGenerator from nnscaler.ir.dtype import DTypeInfo -from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE +from nnscaler.utils import _DICT_ITEMS_TYPE, _DICT_VALUES_TYPE, load_type, get_dynamic -NestedVarOrStatic = Any +NestedVarOrStatic = Union[Any, 'IRObject', List['IRObject'], 'IRTensor'] class IRCell: @@ -77,9 +78,30 @@ def __init__(self, self._comment: Optional[str] = None # the module stack that preserves the hierarchy information self._module_stack: Optional[OrderedDict[str, Any]] = None + # the original call expression + # Note: + # 1. some cells may not have call expression if the cell is not from function call (e.g., __getitem__) + # 2. call_expr can be inaccurate when function call happens + # inside pytorch official module (like in torch.nn namespace) forward, + # (e.g., F.linear inside nn.Linear), in this case, call_expr will be module call expression. + self._call_expr: Optional[str] = None # the operation context information self._op_context: Optional[Dict[str, Any]] = None + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + self._pre_hook: Optional[Callable[..., None]] = None + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + self._post_hook: Optional[Callable[..., None]] = None + self._hook_meta: Any = None + @property def cid(self) -> int: """ @@ -377,6 +399,22 @@ def comment(self, info: str): @property def module_stack(self) -> Optional[OrderedDict[str, Any]]: + """ + Get the module stack, which preserves the hierarchy information + of modules this cell belongs to. + For example, if this cell is from model.submodule.layers.0.block0.conv2d, + then the module stack will be: + OrderedDict([ + ('model.submodule', ), + ('model.submodule.layers.0.block0', ), + ('model.submodule.layers.0.block0.conv2d', ), + ]) + + Please note + 1. Root module (e.g., model) is not included in the stack. + 2. Only modules that have `.forward` function are included in the stack, + so in above example, `torch.nn.ModuleList` is not included. + """ return self._module_stack @module_stack.setter @@ -386,6 +424,97 @@ def module_stack(self, stack: OrderedDict[str, Any]): """ self._module_stack = stack + @property + def module_class_chain(self) -> list[type[torch.nn.Module]]: + """ + Get the module chains the IRCell belongs to. + If module stack is None or empty, return []. + """ + if not self._module_stack: + return [] + return list(self._module_stack.values()) + + @property + def fqn(self) -> str: + """ + Get the fully qualified module name the IRCell belongs to. + If module stack is None or empty, return ''. + """ + if not self._module_stack: + return '' + return list(self._module_stack.keys())[-1] + + def get_module_fqn( + self, module_class: Type[torch.nn.Module], + *, + include_subclass: bool = False + ) -> str: + """ + Get the first fully qualified module name for the given module class + in the module stack. If not found, return ''. + + Args: + module_class (Type[torch.nn.Module]): the module class to find + include_subclass (bool): whether to include subclass of the module_class + + Returns: + str: the fully qualified module name + """ + if not self._module_stack: + return '' + for fqn, mod_cls in self._module_stack.items(): + if mod_cls == module_class or ( + include_subclass and issubclass(mod_cls, module_class) + ): + return fqn + return '' + + @property + def call_expr(self) -> Optional[str]: + return self._call_expr + + @call_expr.setter + def call_expr(self, expr: Optional[str]): + self._call_expr = expr + + @property + def fn(self) -> Optional[Callable]: + """ + Get the function of this cell based on its signature. + Return None if the function cannot be loaded. (e.g. virtual ops like `self_getattr`) + + Returns: + Callable: the function object + """ + try: + return load_type(self.signature) + except Exception as e: + return None + + @property + def pre_hook(self) -> Optional[Callable[..., None]]: + return self._pre_hook + + @pre_hook.setter + def pre_hook(self, hook: Optional[Callable[..., None]]): + self._pre_hook = hook + + @property + def post_hook(self) -> Optional[Callable[..., None]]: + return self._post_hook + + @post_hook.setter + def post_hook(self, hook: Optional[Callable[..., None]]): + self._post_hook = hook + + @property + def hook_meta(self) -> Any: + return self._hook_meta + + @hook_meta.setter + def hook_meta(self, meta: Any): + self._hook_meta = meta + @property def op_context(self) -> Optional[Dict[str, Any]]: return self._op_context @@ -459,14 +588,127 @@ def modify_objects_of_complex(val: Any, modifier: Callable[['IRObject'], 'IRObje return val +@dataclass +class ValueTrack: + """ + Track the value of an IRObject or a dimension of IRTensor. + Currently only implemented for dimension via IRDimops annotation. + + Example: + `l (2 h) m -> l h (2 m)`: + Input Tensor Tracks (2/5 is external dependencies for illustration): + dim 0: ValueTrack(value_id=10, dependencies=[]) # l + dim 1: ValueTrack(value_id=20, dependencies=[]) # (2 h) + dim 2: ValueTrack(value_id=30, dependencies=[2, 5]) # m + Then we can infer the output Tensor Tracks: + Output Tensor Tracks: + dim 0: ValueTrack(value_id=10, dependencies=[]) # reuse input dim 0, since they are the same + dim 1: ValueTrack(value_id=40, dependencies=[20]) # it depends on input dim 1: (2 h) + dim 2: ValueTrack(value_id=50, dependencies=[30]) # it depends on input dim 2: m + """ + value_id: int = field(default_factory=IDGenerator().gen_value_id) + # By default, we consider the value is constant + # unless it is set to not constant + # via mark_dynamic or it is from input or explicitly set in function.py + is_constant: bool = True + # None: unknown dependencies + # []: no dependencies + deps: Optional[list[int]] = None + + def with_no_dep(self) -> 'ValueTrack': + """ + Initialize this ValueTrack with no dependencies. + """ + self.deps = [] + return self + + def add_dep(self, dep: Union[Any, 'ValueTrack', 'IRObject']) -> 'ValueTrack': + """ + Initialize or add a dependency to the ValueTrack. + If dep is not IRObject or ValueTrack, do nothing. + """ + if self.deps is None: + self.deps = [] + + if not isinstance(dep, (ValueTrack, IRObject)): + return self + + if isinstance(dep, IRTensor): + raise TypeError("Cannot directly add IRTensor as dependency.") + + dep: ValueTrack = dep.value_track if isinstance(dep, IRObject) else dep + dep_value_id = dep.value_id + if dep_value_id not in self.deps: + self.deps.append(dep_value_id) + self.is_constant = self.is_constant and dep.is_constant + + return self + + def merge(self, other: ValueTrack) -> 'ValueTrack': + """ + Merge another ValueTrack into this one. + The merged ValueTrack will have dependencies from both ValueTracks. + """ + if self.deps is None: + self.deps = other.deps + else: + self.deps.extend(other.deps or []) + + if self.deps is not None: + self.deps = list(set(self.deps)) + + self.is_constant = self.is_constant and other.is_constant + return self + + @classmethod + def new(cls, deps: Iterable[Union[Any, 'ValueTrack', 'IRObject']], is_constant: Optional[bool] = None) -> 'ValueTrack': + vt = cls() + if is_constant is not None: + vt.is_constant = is_constant + vt.deps = [] + for dep in deps: + vt.add_dep(dep) + return vt + + def mark_as_input(self) -> 'ValueTrack': + """ + Mark this ValueTrack as graph input, which should be not constant and have no dependencies. + """ + self.is_constant = False + self.deps = [] + return self + + +_missing_value = object() + class IRObject: """ IRObject serves as general data of IRGraph edge + + There are two special IRObject for lazy evaluation: + 1. IRObject.missing: a singleton object to represent missing object + It is used to tell parser that we don't know the real object yet. + The parser is supposed to create a new IRObject to replace it. + For example, all custom ops will have missing outputs.It relies on parser to set them. + 2. IRObject(..., value=missing_value, ...): an object with unknown value + It is used to tell parser that we don't know the real value yet. + The parser is supposed to set the value. + We have this because we want ops to pass out `value_track` even when the value is unknown. + For example, `Item()` op in `function.py` will create such object. """ # will be set after class definition missing: ClassVar['IRObject'] = None - - def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: Optional[None] = None, is_constant: bool = True): + missing_value: ClassVar[object] = _missing_value + + def __init__( + self, + name: Optional[str] = None, + tid: Optional[int] = None, + value: Any = _missing_value, + is_constant: Optional[bool] = None, + *, + value_track: Optional[ValueTrack] = None, + ) -> None: """ Args: name (str): object name @@ -479,13 +721,19 @@ def __init__(self, name: Optional[str] = None, tid: Optional[int] = None, value: 2. val is model input, or is the result of a non-torch operation on another not constant IRObject Please note is_constant flag is only used in parser, so after parser, you can totally ignore this flag. + We keep this flag in IRObject for backward compatibility. + If both is_constant and value_track are provided, + `value_track.is_constant` will be overrided by this flag. + value_track (ValueTrack): the value track info of this object """ self._id: int = tid if isinstance(tid, int) else IDGenerator().gen_tensor_id() self.name: str = name if name else 'obj' self._cell: Optional[IRCell] = None self._is_attr: bool = False - self._value: Optional[Any] = value - self._is_constant: bool = is_constant + self._value: Any = value + self._value_track: ValueTrack = value_track or ValueTrack() + if is_constant is not None: + self._value_track.is_constant = is_constant def __hash__(self) -> int: return self._id @@ -538,13 +786,30 @@ def value(self) -> Any: """Get example value""" return self._value + @value.setter + def value(self, val: Any): + self._value = val + + def is_value_missing(self) -> bool: + """Check if the value is missing""" + return self._value is IRObject.missing_value + + @property + def value_track(self) -> ValueTrack: + """Get value track info""" + return self._value_track + + @value_track.setter + def value_track(self, val: ValueTrack): + self._value_track = val + @property def is_constant(self) -> bool: - return self._is_constant + return self._value_track.is_constant @is_constant.setter def is_constant(self, val: bool): - self._is_constant = val + self._value_track.is_constant = val def __eq__(self, obj) -> bool: if not isinstance(obj, IRObject): @@ -555,7 +820,7 @@ def __copy__(self): """Copy this object but remove the cell information""" if self is IRObject.missing: # missing object is singleton return IRObject.missing - return IRObject(self.name, self._id, self._value, self._is_constant) + return IRObject(self.name, self._id, self._value, self.is_constant, value_track=self._value_track) def as_attr(self): """ @@ -651,7 +916,10 @@ def _inner(obj) -> Tuple[Any, bool]: new_ir_tensor._value = obj.value return new_ir_tensor, True else: - return IRObject(name, value=obj.value, is_constant=is_constant), False + return IRObject( + name, value=obj.value, + is_constant=is_constant, value_track=obj.value_track + ), False if isinstance(obj, tensor_types): if requires_grad is None: @@ -667,6 +935,10 @@ def _inner(obj) -> Tuple[Any, bool]: dtype=obj.dtype, requires_grad=rg, ) + + for dyn_idx in get_dynamic(obj): + tensor.dim_tracks[dyn_idx].is_constant = False + if tosub: tensor = tensor.tosub() tensor._value = obj # is required in SemanticModel.forward @@ -907,11 +1179,12 @@ class IRTensor(IRObject): You can get the original shape with `origin_shape` property. """ def __init__(self, shape=None, name='tensor', dtype=None, tid=None, *, - is_attr=False, is_grad=False, requires_grad=False, persistent=False + is_attr=False, is_grad=False, requires_grad=False, persistent=False, ): super().__init__(name, tid, is_constant=False) self._is_scalar_tensor: bool = True - self._shape: Tuple[int] = () + self._shape: Tuple[int, ...] = () + self._dim_tracks: Tuple[ValueTrack, ...] = () self._dtype: Optional[torch.dtype] = None # tensor gradient self._is_grad: bool = False @@ -946,7 +1219,9 @@ def _update( if shape is not None: self._is_scalar_tensor = not shape # will always convert scalar tensor to 1-d tensor - self._shape: Tuple[int] = (1,) if not shape else tuple(shape) + self._shape: Tuple[int, ...] = (1,) if not shape else tuple(shape) + # reset dim tracks + self._dim_tracks = tuple(ValueTrack() for _ in self._shape) if name is not None or self.name is None: self.name = name if dtype is not None: @@ -973,7 +1248,7 @@ def dtype(self) -> Optional[torch.dtype]: def is_param(self) -> bool: """! - Check if the tensor is parameter + Check if the tensor is parameter (with requires_grad = True). @return is_param boolean: True if is parameter. """ @@ -1039,12 +1314,55 @@ def origin_shape(self) -> Tuple[int]: return self.shape if not self.is_scalar_tensor() else () @property - def shape(self) -> Tuple[int]: + def shape(self) -> Tuple[int, ...]: # NOTE: here return a tuple but not a real torch.Size obj may have risk, here is an example: # (torch.Size + tuple -> torch.Size) will change to (tuple + tuple -> tuple), is ok. # (torch.Size + list -> torch.Size) will change to (tuple + list -> error), is wrong. return self._shape + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + """ + Get the track of each dimension + """ + return self._dim_tracks + + @dim_tracks.setter + def dim_tracks(self, val: Tuple[Optional[ValueTrack], ...]): + """ + Set the unique id of each dimension + """ + if not isinstance(val, (list, tuple)): + raise ValueError("dim_tracks must be a list or tuple") + if len(val) != len(self._shape): + raise ValueError("dim_tracks length must be equal to shape length") + # None means starting a new dim track + self._dim_tracks = tuple(v if v is not None else ValueTrack() for v in val) + + def set_dim_track(self, dim: int, track: ValueTrack): + """ + Set the track of a specific dimension + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + dim_tracks = list(self._dim_tracks) + dim_tracks[dim] = track + self._dim_tracks = tuple(dim_tracks) + + def dim_constant(self, dim: int) -> bool: + """ + Check if a dim is constant + """ + if dim < 0 or dim >= len(self._shape): + raise IndexError("dim out of range") + return self._dim_tracks[dim].is_constant + + def dims_constant(self) -> bool: + """ + Check if all dims are constant + """ + return all(track.is_constant for track in self._dim_tracks) + def nelement(self) -> int: """ Get total number of element in the tensor. diff --git a/nnscaler/ir/tensor.py b/nnscaler/ir/tensor.py index 6546720f..f3f2c9eb 100644 --- a/nnscaler/ir/tensor.py +++ b/nnscaler/ir/tensor.py @@ -27,10 +27,10 @@ 3) gradient of parameters """ -from typing import List, Optional, Union, Tuple, NewType, Dict, Any +from typing import List, Optional, Set, Union, Tuple, NewType, Dict, Any import torch -from nnscaler.ir.cten import IRTensor +from nnscaler.ir.cten import IRTensor, ValueTrack StartEnd = NewType('[start:end)', Tuple[int, int]) IdxChunk = NewType('(index, chunks)', Tuple[int, int]) @@ -260,14 +260,17 @@ class IRFullTensor(IRTensor): """ def __init__(self, shape=None, name='tensor', requires_grad=False, dtype=None, *, - is_attr=False, is_grad=False, persistent=False, is_loss=False + is_attr=False, is_grad=False, persistent=False, is_loss=False, ): self._is_loss: bool = False # record all created sub_tensors self._subtensors : Dict[(ValueMap, IndexMap), int] = dict() self._grad: Optional[IRFullTensor] = None - super().__init__(shape, name, dtype, requires_grad=requires_grad, is_attr=is_attr, is_grad=is_grad, persistent=persistent) + super().__init__( + shape, name, dtype, requires_grad=requires_grad, + is_attr=is_attr, is_grad=is_grad, persistent=persistent, + ) self._update( is_loss=is_loss, ) @@ -334,6 +337,7 @@ def like(self): self.origin_shape, self.name, self._requires_grad, self._dtype, is_loss=self._is_loss ) + tensor.dim_tracks = self.dim_tracks return tensor def like_grad(self): @@ -346,6 +350,7 @@ def like_grad(self): self.origin_shape, 'g' + self.name, requires_grad=False, dtype=self.dtype ).as_grad(self._is_attr) + grad.dim_tracks = self.dim_tracks return grad @property @@ -363,6 +368,7 @@ def grad(self, val: Optional[IRTensor]): assert self._requires_grad, f"Cannot assign {val} to no grad-required tensor" assert val.origin_shape == self.origin_shape assert val.is_attr() == self.is_attr() + val.dim_tracks = self.dim_tracks # TODO: we should check the grad-required here # it is very common in current code that we assign None to grad # so currently it is impossible to check the grad-required here @@ -507,6 +513,7 @@ def __init__(self, ftensor: IRFullTensor, del self._is_grad del self._requires_grad del self._persistent + del self._dim_tracks self.cell = None # the index from full_tensor @@ -556,7 +563,7 @@ def ndims(self) -> int: def as_attr(self): raise RuntimeError("as_attr is not allowed for SubTensor") - def splitdims(self) -> Tuple[int]: + def splitdims(self) -> Tuple[int, ...]: """! Get partitioned dimensions @@ -677,6 +684,10 @@ def dtype(self) -> Optional[torch.dtype]: """Tensor data type""" return self.parent.dtype + @property + def dim_tracks(self) -> Tuple[ValueTrack, ...]: + return self.parent.dim_tracks + @IRTensor.shape.setter def shape(self, val: Tuple[int]): # TODO: remove this function diff --git a/nnscaler/ir/unique.py b/nnscaler/ir/unique.py index dde3ceb2..72338ee5 100644 --- a/nnscaler/ir/unique.py +++ b/nnscaler/ir/unique.py @@ -5,14 +5,14 @@ class IDGenerator: """ Tensor / Operator manager. To guarantee that each IRTensor / IROperator id is unique and progressively increases. - + This class is designed in singleton pattern. """ class __IDGenerator: def __init__(self): - self._tensor_id = 0 self._cell_id = 0 + self._value_id = 0 instance = None @@ -31,13 +31,19 @@ def gen_cell_id(self): self.instance._cell_id += 1 return self.instance._cell_id + def gen_value_id(self): + self.instance._value_id += 1 + return self.instance._value_id + def get_states(self): - return (self._tensor_id, self._cell_id) - + return (self._tensor_id, self._cell_id, self._value_id) + def load_states(self, states: tuple): IDGenerator.instance._tensor_id = states[0] IDGenerator.instance._cell_id = states[1] + IDGenerator.instance._value_id = states[2] def clear(self): self.instance._tensor_id = 0 self.instance._cell_id = 0 + self.instance._value_id = 0 diff --git a/nnscaler/parallel.py b/nnscaler/parallel.py index b728269b..e8c2cb67 100644 --- a/nnscaler/parallel.py +++ b/nnscaler/parallel.py @@ -14,6 +14,7 @@ import logging import copy import os +from collections import OrderedDict, defaultdict import torch import torch.distributed @@ -40,15 +41,26 @@ from nnscaler.ir.tensor import IRFullTensor from nnscaler.ir.unique import IDGenerator -from nnscaler.runtime.adapter.reducer import Reducer +from nnscaler.runtime.adapter.reducer import Bucket, Reducer from nnscaler.runtime.device import DeviceGroup from nnscaler.runtime.gnorm import calcuate_gnorm, clip_grads -from nnscaler.runtime.module import AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState +from nnscaler.runtime.module import AttrMeta, Zero3AttrMeta, CubeModule, ParallelModule, OriginModuleMetadata, ExtraState, dedup_attrs from nnscaler.flags import CompileFlag, RuntimeFlag import nnscaler.policies as policies from nnscaler.program import disable_global_graph -from nnscaler.utils import get_member_by_name, setup_stride_broadcast_group, get_shared_params +from nnscaler.utils import ( + get_member_by_name, + load_type, + set_member_by_name, + setup_stride_broadcast_group, + get_shared_params, + OptStateDict, + copy_dynamic, + broadcast_files, + broadcast_mixed_data, + gather_mixed_data, +) logger = logging.getLogger(__name__) @@ -78,11 +90,17 @@ class ComputeConfig: # how to execute the functions during trace trace_strategy: str = 'cuda_run_cpu_offload' - use_zero: bool = False + # Only support 0/1/3 for now + # If you set use_zero to 2, ZeRO stage 3 will be used internally. + # 0: no zero + # 1: ZeRO stage 1 + # 2: ZeRO stage 3 + # 3: ZeRO stage 3 + use_zero: int = 0 zero_ngroups: int = 1 # whether to use reduce scatter for zero # Please note - # 1. this only works when `use_zero` is True and `zero_ngroups` is 1. + # 1. this only works when `use_zero` is not 0 and `zero_ngroups` is 1. # 2. In some cases, it can introduce parity issue. So use it with caution. zero_use_reduce_scatter: bool = False @@ -149,16 +167,38 @@ def __post_init__(self): raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be > 0") if self.runtime_ngpus % self.plan_ngpus != 0: raise ValueError(f"runtime_ngpus {self.runtime_ngpus} must be a multiple of plan_ngpus {self.plan_ngpus}") + + if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: + raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") + + # for backward compatibility, convert bool to int + super().__setattr__('use_zero', int(self.use_zero)) + if self.use_zero not in (0, 1, 2, 3): + raise ValueError(f"use_zero {self.use_zero} must be 0, 1, 2 or 3.") + if self.use_zero == 2: + logger.warning("use_zero=2 is not supported. ZeRO stage 3 will be used instead.") + super().__setattr__('use_zero', 3) + + num_scale_units = self.runtime_ngpus // self.plan_ngpus + if self.use_zero: + if num_scale_units % self.zero_ngroups != 0: + raise ValueError(f"zero_ngroups {self.zero_ngroups} must be a divisor of runtime_ngpus/plan_ngpus {num_scale_units}.") + # NOTE: + # we can't disable zero optimization when num_scale_units == zero_ngroups here + # because some ops are replicated inside a scale unit, + # and those ops can still utilize zero optimization. + # if num_scale_units == self.zero_ngroups: + # logger.warning(f"zero_ngroups {self.zero_ngroups} equals to runtime_ngpus/plan_ngpus {num_scale_units}. Zero optimization is disabled.") + # super().__setattr__('use_zero', 0) + if self.use_zero and self.zero_ngroups <= 0: raise ValueError(f"zero_ngroups {self.zero_ngroups} must be > 0") + if not self.use_zero and self.zero_ngroups != 1: logger.warning(f"use_zero is False, but zero_ngroups is {self.zero_ngroups}. Will set zero_ngroups to 1.") # have to use __setattr__ for frozen dataclass super().__setattr__('zero_ngroups', 1) - if self.reducer_bucket_cap_mb and self.reducer_bucket_cap_mb < 0: - raise ValueError(f"reducer_bucket_cap_mb {self.reducer_bucket_cap_mb} should not be negative.") - # TODO: Please note in current implementation of Bucket, # zero_use_reduce_scatter still works when zero_ngroups > 1 in sync mode # Let's hide this feature for now for consistency. @@ -215,7 +255,11 @@ def module_dedup_group_size(self) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.plan_ngpus + if self.use_zero > 1: + # for zero3 + return self.runtime_ngpus // self.zero_ngroups + else: + return self.plan_ngpus @property def optimizer_dedup_group_size(self) -> int: @@ -337,19 +381,29 @@ def _runtime_flags(**kwargs): return _flags(RuntimeFlag, **kwargs) -def _to_cpu(val: Any): - """Complex to CPU""" +def _to_cpu(val: Any, requires_grad: Optional[bool] = None) -> Any: + """ + Complex to CPU + Recursively move the input to CPU. + Args: + val (Any): the input value + requires_grad (Optional[bool]): whether the returned tensor requires grad. + If it is None, will keep the same as the input tensor. + """ if isinstance(val, tuple): - return tuple(_to_cpu(t) for t in val) + return tuple(_to_cpu(t, requires_grad) for t in val) if isinstance(val, list): - return list(_to_cpu(t) for t in val) + return list(_to_cpu(t, requires_grad) for t in val) if isinstance(val, dict): - return {_to_cpu(key):_to_cpu(val) for key, val in val.items()} + return {_to_cpu(key, requires_grad):_to_cpu(val, requires_grad) for key, val in val.items()} if isinstance(val, set): - return {_to_cpu(t) for t in val} + return {_to_cpu(t, requires_grad) for t in val} if isinstance(val, torch.Tensor): - requires_grad = val.is_floating_point() or val.is_complex() - return val.detach().clone().cpu().requires_grad_(requires_grad) + if requires_grad is None: + requires_grad = val.requires_grad + else: + requires_grad = requires_grad and (val.is_floating_point() or val.is_complex()) + return copy_dynamic(val, val.detach().clone().cpu().requires_grad_(requires_grad)) return val @@ -379,7 +433,7 @@ def _add_gen_savedir_to_syspath(gen_savedir: str) -> Path: gen_savedir = Path(gen_savedir).resolve() gen_savedir.mkdir(parents=True, exist_ok=True) if str(gen_savedir) not in sys.path: - sys.path.append(str(gen_savedir)) + sys.path.insert(0, str(gen_savedir)) return gen_savedir @@ -556,6 +610,10 @@ def _prepare_and_check_reusable( if reuse == ReuseType.MATCH or reuse == ReuseType.MOO: # check if the module is already generated expected_output_files = [outdir / _GENCODE_FILE_TEMPLATE.format(rank) for rank in range(compute_config.runtime_ngpus)] + expected_output_files.extend([ + outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) + for rank in range(compute_config.runtime_ngpus) + ]) expected_output_files.extend(trace_meta_files) expected_output_files.append(config_file) expected_output_files.append(outdir / _GRAPH_DUMP_FILE) @@ -636,7 +694,13 @@ def _gen_graph( raise ValueError(f"Default value type {type(v)} of forward args is not supported.") # generate fx graph - dummy_forward_args = _to_cpu(dummy_forward_args) + dummy_forward_args = _to_cpu( + dummy_forward_args, + # in end2end mode, we don't need gradients for inputs + # in normal mode, we assume all inputs require gradients + # so it can connect to other parts of the graph correctly + requires_grad=not end2end_mode + ) fx_graph = parser.to_fx_graph(module, dummy_forward_args) # generate ir logic graph @@ -653,51 +717,22 @@ def _gen_graph( node.target: forward_args_default.get(node.target, inspect.Parameter.empty) for node in fx_input_nodes } - ir_dummy_inputs = [] - for node in fx_input_nodes: - if node.target.startswith('*'): # *args or **kwargs - if node.target.strip('*') in dummy_forward_args: - raise ValueError(f"Input {node.target}: *args or **kwargs is not suppported") - ir_dummy_inputs.append(None) # always set None to *args/**kwargs - elif node.target in dummy_forward_args: - ir_dummy_inputs.append(dummy_forward_args[node.target]) - elif forward_args[node.target] is not inspect.Parameter.empty: - ir_dummy_inputs.append(forward_args[node.target]) - else: - raise ValueError(f"Input {node.target} not in dummy forward args, nor has default value.") - for i in range(len(ir_dummy_inputs)): - # note: we will always set tensor to require gradient, which may - # generate backward communications in adapter. However, as long as - # the data doesn't require gradient in real runtime, the backward - # communication will not be triggered. - ir_dummy_inputs[i] = IR.new( - fx_input_nodes[i].target, ir_dummy_inputs[i], - requires_grad=True, - tosub=True, - is_constant=False, - ) - # if the input is a complex type, we should wrap it with IRObject - if not isinstance(ir_dummy_inputs[i], IRObject): - ir_dummy_inputs[i] = IRObject(fx_input_nodes[i].target, value=ir_dummy_inputs[i], is_constant=False) - # generate complete ir graph - ir_dummy_outputs = graph(*ir_dummy_inputs) if end2end_mode: # in end2end mode, we must use dataloader as the first argument of forward # we assume the first argument of forward is the data sample (which is a requirement in our doc) graph.use_dataloader_input() # we require the first output is the loss - if isinstance(ir_dummy_outputs, (list, tuple)): - ir_loss = ir_dummy_outputs[0] - else: - ir_loss = ir_dummy_outputs + ir_loss = graph.output(0) if not isinstance(ir_loss, IRTensor) or ir_loss.shape != (1,): # internally scalar tensor will be reshaped to (1,) in IRGraph raise RuntimeError(f"Loss can only be scalar tensor but got {ir_loss.shape if isinstance(ir_loss, IRTensor) else ir_loss}") else: ir_loss = None + # we generate backward nodes and setup gradient tensors here + # forward nodes are done when we trace the model if not inference_only: graph.backward(ir_loss) else: @@ -839,12 +874,14 @@ def _gencode( sgener = ScheduleCodeGen(execplan, compute_config.runtime_ngpus) for rank in range(compute_config.runtime_ngpus): fname = outdir / _GENCODE_FILE_TEMPLATE.format(rank) + attr_meta_map_fname = outdir / ParallelModule.ATTR_META_FILE_TEMPLATE.format(rank) mgener.gen(rank, forward_args=forward_args, outfile=fname, attach=False, as_parallel_module=True, - end2end_mode=compute_config.use_end2end + end2end_mode=compute_config.use_end2end, + outfile_attr_meta_map=attr_meta_map_fname ) # generate temporal schedule code only for end2end module # because the code generated is wrong for non-end2end module. @@ -912,6 +949,7 @@ def parallelize( module_dtype: Optional[torch.dtype] = None, module_fn: Optional[Callable[[], torch.nn.Module]] = None, init_module_params: bool = True, + build_module_buckets: bool = True, broadcast_strategy: Union[str, BroadcastGenFilesStrategy] = 'none', ) -> Union[None, ParallelModule, Type[ParallelModule]]: """ @@ -982,6 +1020,12 @@ def __init__(self, init_params=True): Otherwise, they will be empty tensor. This parameter will be passed to the module constructor, so it is only used when module_or_module_class is a module object, and load_module is true. + build_module_buckets (bool): For parallel module, parameters that needs to synchronize will be grouped into buckets for more efficient communication. + If true, grouping process will be done in `__init__` + If false, you should do this by yourself. + This parameter will be passed to the module constructor, + so it is only used when module_or_module_class is a module object, and load_module is true. + Please leave it to true until you have a good reason to change it. module_dtype (Optional[torch.dtype]): the dtype of the module. Keep the module as it is if it is None. module_fn (Optional[Callable[[], torch.nn.Module]]): the function to create the module. Will use __init__ if it is None. broadcast_strategy (Union[str, BroadcastGenFilesStrategy]): the broadcast strategy for generated files. @@ -1008,7 +1052,11 @@ def __init__(self, init_params=True): if isinstance(pas_policy, str): if not pas_policy in _PREDEFINED_POLICIES: raise ValueError(f"Invalid pas_policy: {pas_policy}") - pas_policy = _PREDEFINED_POLICIES[pas_policy] + pas_policy = partial(policies.fn, policy=_PREDEFINED_POLICIES[pas_policy]) + else: + if not callable(pas_policy): + raise ValueError("pas_policy should be a callable or a predefined policy name") + pas_policy = partial(policies.fn, policy=pas_policy) is_module_class = inspect.isclass(module_or_module_class) module_class = module_or_module_class if is_module_class else module_or_module_class.__class__ @@ -1115,7 +1163,7 @@ def __init__(self, init_params=True): if is_module_class: return parallel_module_class else: - parallel_module = parallel_module_class(init_module_params) + parallel_module = parallel_module_class(init_module_params, build_module_buckets) parallel_module.train(module_or_module_class.training) # set training state to the same as original module return parallel_module @@ -1244,12 +1292,34 @@ def register_reducer_post_hook(self, fn: Callable[[Reducer, torch.Tensor], None] OptimizerT = TypeVar('OptimizerT', bound=torch.optim.Optimizer) +HybridOptimizerT = TypeVar('HybridOptimizer', bound=torch.optim.Optimizer) + + +def hybrid( + params: list[torch.nn.Parameter], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], + **kwargs, +) -> HybridOptimizerT: + """ + Stub for hybrid optimizer creation. + Signature of Hybrid optimizer constructor: + ``` + def __init__(self, params, param_clss, **kwargs): + ... + ``` + When you pass arguments to `build_optimizer` + You must pass `param_clss_fn`, + and `build_optimizer` will automatically pass `param_clss` to its constructor. + """ + ... +hybrid.is_hybrid = True # mark this function as hybrid optimizer factory def build_optimizer( module: torch.nn.Module, optimizer_fn: Union[Type[OptimizerT], Callable[..., OptimizerT]], compute_config: Optional[ComputeConfig] = None, + param_clss_fn: Optional[Callable[[str], Any]] = None, **kwargs, ) -> Union[OptimizerT, ParallelOptimizer]: """ @@ -1277,6 +1347,11 @@ def build_optimizer( compute_config (Optional[ComputeConfig]): The config will be used to generate communication reducer. If it is None, Default configuration will be used when creating reducer for non-parallel modules. + param_clss_fn (Optional[Callable[[str], Any]]): + A function that maps original full qualified parameter names to their class IDs. + If you are using a hybrid optimizer, + you must specify this function + and the return value of this function must be a tuple[int, int] of (optimizer_index, param_group_index). **kwargs: the kwargs for optimizer constructor Returns: @@ -1285,7 +1360,6 @@ def build_optimizer( and will be patched with the methods in ParallelModule class to support parallelized module. Please note the type annotation of the returned optimizer (`Union[OptimizerT, ParallelOptimizer]`) is just for intellisense. """ - if isinstance(module, CubeModule) and not isinstance(module, ParallelModule): raise RuntimeError("Old style CubeModule is not supported") @@ -1293,12 +1367,17 @@ def build_optimizer( if any(m != module and isinstance(m, ParallelModule) and m.compute_config.use_end2end for m in module.modules()): raise RuntimeError("End2End module cannot be nested in another module") + is_hybrid = getattr(optimizer_fn, 'is_hybrid', False) + if is_hybrid and param_clss_fn is None: + raise ValueError("param_clss_fn must be provided when using hybrid optimizer") + RuntimeFlag.skip_reducer = True RuntimeFlag.skip_zero_grad = False non_parallel_module_reducer = None non_parallel_modules = [m for m in module.modules() if not isinstance(m, ParallelModule)] parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + parallel_modules_prefix = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} if not parallel_modules: raise RuntimeError("No ParallelModule found in the module. Please make sure you have called parallelize() before build_optimizer().") @@ -1310,6 +1389,23 @@ def build_optimizer( non_parallel_parameters_dict[param] = None non_parallel_parameters = list(non_parallel_parameters_dict.keys()) + param_original_names = {} + for n, p in module.named_parameters(): + nparts = n.split('.') + module_prefix = '.'.join(nparts[:-1]) + if module_prefix in parallel_modules_prefix: + name_mapping = parallel_modules_prefix[module_prefix].get_full_map() + original_name = name_mapping[nparts[-1]].orig_name + param_original_names[p] = \ + f'{module_prefix}.{original_name}' if module_prefix else original_name + else: + param_original_names[p] = n + + if param_clss_fn: + param_clss = {p: param_clss_fn(n) for p, n in param_original_names.items()} + else: + param_clss = {} + # check if all ParallelModules have the same gpu_config compute_configs = [m.compute_config for m in parallel_modules] for i in range(1, len(compute_configs)): @@ -1331,7 +1427,9 @@ def build_optimizer( if compute_config: reducer_config = { 'async_op': compute_config.use_async_reducer, - 'zero': compute_config.use_zero, + # zero3 can't be used in non-parallel module reducer + # because we are unable to insert hooks to prefetch/postevict params + 'zero': 1 if compute_config.use_zero else 0, 'max_bucket_size_bytes': compute_config.max_bucket_size_bytes, 'zero_use_reduce_scatter': compute_config.zero_use_reduce_scatter, 'zero_ngroups': compute_config.zero_ngroups, @@ -1339,7 +1437,13 @@ def build_optimizer( non_parallel_module_reducer = Reducer(group, **reducer_config) for param in non_parallel_parameters: non_parallel_module_reducer.add_param(param) - non_parallel_module_reducer.build_buckets() + non_parallel_module_reducer.build_buckets(param_clss=param_clss) + + if param_clss_fn: + for pm in parallel_modules: + pm.build_buckets(param_clss=param_clss) + for reducer in pm.reducers: + param_clss.update(reducer.get_opt_params()) opt_module_locs: Dict[str, ModuleParameterLocation] = {} def _local_parameters(module: torch.nn.Module): @@ -1372,7 +1476,13 @@ def _local_parameters(module: torch.nn.Module): opt_module_locs[name].count += 1 yield param - optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) + if is_hybrid: + optimizer = optimizer_fn(_local_parameters(module), + param_clss, + **kwargs + ) + else: + optimizer: torch.optim.Optimizer = optimizer_fn(_local_parameters(module), **kwargs) optimizer._non_parallel_module_reducer = non_parallel_module_reducer optimizer._extra_state = OptimizerExtraState( rank=torch.distributed.get_rank(), @@ -1385,21 +1495,21 @@ def _local_parameters(module: torch.nn.Module): } ) - def _step_pre_hook(opt, *args, **kwargs): - opt.sync_shard_grad() - - def _step_post_hook(opt, *args, **kwargs): + orig_step = optimizer.step + def _patched_step(self, closure=None): + # Please note: + # when closure is used in optimizer.step() + # the backward is done in closure, + # and it is useless to sync grad because grad is still unavailable there + # so you must call sync_shard_grad() manually in this case. + if closure is None: + self.sync_shard_grad() + orig_step(closure=closure) for m in parallel_modules: m.gather_params() if non_parallel_module_reducer: non_parallel_module_reducer.gather_params() - - # Please note: - # register_step_pre_hook doesn't work expectly - # when closure is used in optimizer.step() - # in that case, you must call sync_shard_grad() manually - optimizer.register_step_pre_hook(_step_pre_hook) - optimizer.register_step_post_hook(_step_post_hook) + optimizer.step = types.MethodType(_patched_step, optimizer) orig_zero_grad = optimizer.zero_grad def _patched_zero_grad(self, set_to_none: bool = True): @@ -1574,6 +1684,13 @@ def _get_parallel_module_state_dict_info( return pm_extra_states, pm_state_dicts, non_pm_state_dict +def _is_supported_optimizer(name: str): + from nnscaler.runtime.hybrid_optimizer import HybridOptimizer + return ('adam' in name.lower()) \ + or ('muon' in name.lower()) \ + or name == HybridOptimizer.__name__ + + def _get_optimizer_state_dict_info( optimizer_state_dicts: List[Dict[str, Any]] ) -> Tuple[ @@ -1630,8 +1747,8 @@ def _get_optimizer_state_dict_info( ] = {} for opt_state_dict in optimizer_state_dicts: opt_extra_state = OptimizerExtraState(**opt_state_dict[ParallelModule.EXTRA_STATE_KEY]) - if 'adam' not in opt_extra_state.name.lower(): - raise ValueError("Only Adam-like optimizers are supported.") + if not _is_supported_optimizer(opt_extra_state.name): + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") opt_extra_states[opt_extra_state.rank] = opt_extra_state for module_prefix, loc in opt_extra_state.parallel_module_locs.items(): @@ -1675,11 +1792,15 @@ def merge_state_dicts( Please Note: We don't garantee the devices of tensors are the same in the merged state dict. - You can assume the device of the tensors in the merged state dict can be one of the following: + You can assume the device of the tensors in the merged state dict + can be 'cpu' or the device of the tensor in the original state dict. - 1. the current device when running this function - 2. the current cuda device when running this function - 3. the device of the tensor in the original state dict + Quick Explanation: + In current implementation, + For non-parallel modules, we directly take the tensor from input state dicts + For parallel modules, we will create new tensors from cpu, and copy/merge the tensors from input state dicts to it. + (this may be optimized later as we can avoid copying for replicated tensors.) + So in summary, the devices of the tensors in output state dicts can be either 'cpu' or the device in original state dict. When you load the state dict from file, you can just use `torch.load(..., map_location='...')` to unify the device of the tensors. @@ -1703,10 +1824,10 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] sorted_state_dicts =[None] * len(state_dicts) for state_dict in state_dicts: rank = _get_state_dict_rank(state_dict) + if rank >= len(state_dicts): + raise ValueError(f"Invalid rank {rank} in state_dicts with length {len(state_dicts)}.") if sorted_state_dicts[rank] is not None: raise ValueError(f"Duplicate rank {rank} in state_dicts.") - if rank >= len(state_dicts): - raise ValueError(f"Invalid rank {rank} in state_dicts.") sorted_state_dicts[rank] = state_dict return sorted_state_dicts @@ -1738,7 +1859,7 @@ def _sort_state_dicts(state_dicts: List[Dict[str, Any]]) -> List[Dict[str, Any]] module_prefix = '.'.join(k) opt_state_dicts_for_merge = None if opt_state_dicts is None else opt_state_dicts[module_prefix] - merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks) for e in extra_states] + merge_partial_states_zero_idx_maps = [(e.model_idx2opt_idx, e.opt_idx2ranks, e.zero, e.zero3_param_metadata) for e in extra_states] if not extra_states[0].compute_config.use_zero: # all ranks should have the same use_zero merge_partial_states_zero_idx_maps = None merged_state_dict, merged_opt_state_dict = ParallelModule.merge_state_dicts( @@ -1856,6 +1977,8 @@ def load_merged_state_dict( """ device = device or torch.cuda.current_device() + module.to(device) + # non ParallelModule parameters will be loaded here # there will be mismatched keys if the module is a ParallelModule or contains ParallelModule # so we need to ignore the mismatched keys @@ -1866,94 +1989,190 @@ def load_merged_state_dict( prefix = name + '.' if name else '' child_module.load_merged_state_dict(module_state_dict, prefix=prefix) - module.to(device) - if optimizer is not None and optimizer_state_dict is not None: - if 'adam' not in optimizer._extra_state.name.lower(): - raise ValueError("Only Adam-like optimizers are supported.") - - # handle non-paralleled module parameters - # make sure the order of the parameters - pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(optimizer._extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) - pm_modules: List[torch.nn.Module] = [] - pm_locs = list(pm_name_locs.values()) - for name in pm_name_locs: - m = get_member_by_name(module, name) - if not isinstance(m, ParallelModule): - raise ValueError(f"Module {name} is not a ParallelModule") - pm_modules.append(m) - - merged_cur = 0 # the current index of the merged state dict - pm_cur = 0 # the current index of the parallel module in pm_locs - new_states: Dict[int, Dict[str, Any]] = {} - new_cur = 0 # the current index of the new state dict - assert len(optimizer_state_dict['param_groups']) == 1 - effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) - while merged_cur < effective_state_len: - # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) - # The parameter list would look like: NNPNPPPN - # []: the current processing parameter - # <>: the current processing parallel module - if ( - pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module - or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters - ): - # non-parallel module - if merged_cur in optimizer_state_dict['state']: - new_states[new_cur] = optimizer_state_dict['state'][merged_cur] - merged_cur += 1 - new_cur += 1 - else: - # NNPN<[P]PP>N: the current parallel module - # parallel module - pm_param_count = len(pm_modules[pm_cur]._orign_module_metadata.origin_param_names) - # will map `pm_param_count` parameters in merge state dict - # to `pm_locs[pm_cur].count` in optimizer state. - cur_states = {} - for i in range(pm_param_count): - if merged_cur + i in optimizer_state_dict['state']: - cur_states[i] =optimizer_state_dict['state'][merged_cur + i] - pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) - for idx, value in pm_new_states.items(): - new_states[new_cur + idx] = value - new_cur += pm_locs[pm_cur].count - merged_cur += pm_param_count - pm_cur += 1 - - # move the new states to the device if needed - for idx, state in new_states.items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - new_states[idx][key] = value.to(device) - - new_optimizer_state_dict = {} - new_optimizer_state_dict['state'] = new_states - new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) - new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) + new_optimizer_state_dict = _trim_optimizer_merged_state_dict(module, optimizer._extra_state, optimizer_state_dict, device='cpu') optimizer.load_state_dict(new_optimizer_state_dict) +def _trim_optimizer_merged_state_dict( + module: torch.nn.Module, + opt_extra_state: OptimizerExtraState, + optimizer_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None +) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the states needed for the optimizer. + + Args: + module (torch.nn.Module): the module to be loaded + opt_extra_state (OptimizerExtraState): the extra state of the optimizer + optimizer_state_dict (Dict[str, Any]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the optimizer state dict. + + Returns: + Dict[str, Any]: the trimmed optimizer state dict + """ + if not _is_supported_optimizer(opt_extra_state.name): + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") + + device = device or torch.cuda.current_device() + + # handle non-paralleled module parameters + # make sure the order of the parameters + pm_name_locs: Dict[str, ModuleParameterLocation] = dict(sorted(opt_extra_state.parallel_module_locs.items(), key=lambda x: x[1].offset)) + pm_modules: List[ParallelModule] = [] + pm_locs = list(pm_name_locs.values()) + for name in pm_name_locs: + m = get_member_by_name(module, name) + if not isinstance(m, ParallelModule): + raise ValueError(f"Module {name} is not a ParallelModule") + pm_modules.append(m) + + merged_cur = 0 # the current index of the merged state dict + pm_cur = 0 # the current index of the parallel module in pm_locs + new_states: Dict[int, Dict[str, Any]] = {} + new_cur = 0 # the current index of the new state dict + assert len(optimizer_state_dict['param_groups']) == 1 + effective_state_len = len(optimizer_state_dict['param_groups'][0]['params']) + while merged_cur < effective_state_len: + # N: non-paralleled module parameters, P: paralleled module (will have multiple parameters) + # The parameter list would look like: NNPNPPPN + # []: the current processing parameter + # <>: the current processing parallel module + if ( + pm_cur >= len(pm_modules) # NNPNPPP[N]: the ending parameters, no current parallel module + or new_cur < pm_locs[pm_cur].offset # [N]N

NPPPN: other parameters + ): + # non-parallel module + if merged_cur in optimizer_state_dict['state']: + new_states[new_cur] = optimizer_state_dict['state'][merged_cur] + merged_cur += 1 + new_cur += 1 + else: + # NNPN<[P]PP>N: the current parallel module + # parallel module + pm_param_count = len(pm_modules[pm_cur].origin_module_metadata.origin_param_names) + # will map `pm_param_count` parameters in merge state dict + # to `pm_locs[pm_cur].count` in optimizer state. + cur_states = {} + for i in range(pm_param_count): + if merged_cur + i in optimizer_state_dict['state']: + cur_states[i] =optimizer_state_dict['state'][merged_cur + i] + pm_new_states = _opt_load_merged_state_dict(pm_modules[pm_cur], cur_states) + for idx, value in pm_new_states.items(): + new_states[new_cur + idx] = value + new_cur += pm_locs[pm_cur].count + merged_cur += pm_param_count + pm_cur += 1 + + # move the new states to the device if needed + for idx, state in new_states.items(): + for key, value in state.items(): + if isinstance(value, torch.Tensor): + new_states[idx][key] = value.to(device) + + new_optimizer_state_dict = {} + new_optimizer_state_dict['state'] = new_states + new_optimizer_state_dict['param_groups'] = copy.deepcopy(optimizer_state_dict['param_groups']) + new_optimizer_state_dict['param_groups'][0]['params'] = list(range(new_cur)) + + return new_optimizer_state_dict + + def _opt_load_merged_state_dict(module: ParallelModule, states: Dict[int, Dict[str, Any]]): + """ + Args: + module (ParallelModule): the parallel module + states (Dict[int, Dict[str, Any]]): the merged optimizer state dict for a parallel module + key: optimizer parameter index in the merged state dict + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ with torch.no_grad(): # orig_name -> state + # state: Dict[str, Any], e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys orig_param_dict: Dict[str, Dict[str, Any]] = {} cnt = 0 - origin_param_names = module._orign_module_metadata.origin_param_names + origin_param_names = module.origin_module_metadata.origin_param_names for name in origin_param_names: if cnt in states: # some parameters may not in the sates when it is not used or requires_grad is False in training orig_param_dict[name] = states[cnt] cnt = cnt + 1 - if module.compute_config.use_zero: + if module.compute_config.use_zero == 1: return _construct_optim_state_zero(module, orig_param_dict) + elif module.compute_config.use_zero > 1: + return _construct_optim_state_zero3(module, orig_param_dict) else: return _construct_optim_state_nonzero(module, orig_param_dict) +def _construct_optim_state_zero3( + module: ParallelModule, + orig_param_dict: Dict[str, Dict[str, Any]] +): + # state for each parameter in the parallel module + new_states = _construct_optim_state_nonzero(module, orig_param_dict) + param_state_map = {p: new_states[idx] for idx, p in enumerate(module.parameters())} + + state_dict, opt_param_idx = {}, 0 + opt_param = module.parameters_for_optimizer() + # first load the params' optimizer state for the reducers's flattened params + for reducer in module.reducers: + for bucket in reducer.buckets: + bucket: Bucket + # one bucket corresponds to one flattened param + assert len(opt_param[opt_param_idx].shape) == 1 + chunk_size = bucket._contiguous_params.shape[0] + opt_states = {} + offset = 0 + for param in bucket.params: + sliced_new_val = param_state_map[param] + param_numel = bucket.get_aligned_numel(param) + # init the optimizer state + if not opt_states: + for key in sliced_new_val.keys(): + if key == 'step': + opt_states[key] = sliced_new_val[key] + else: + opt_states[key] = torch.zeros( + [chunk_size], dtype=sliced_new_val[key].dtype, + device='cpu', requires_grad=False + ) + # copy the param's slices to the optimizer's chunk + for key in opt_states.keys(): + if key == 'step': + continue + opt_states[key][offset:offset+sliced_new_val[key].numel()] = sliced_new_val[key] + + offset += param_numel + state_dict[opt_param_idx] = opt_states + opt_param_idx += 1 + + # load the params' optimizer state that are not in reducers + reducer_pids = set() + for reducer in module.reducers: + reducer_pids.update(id(p) for p in reducer.params) + for param in module.parameters(): + if id(param) not in reducer_pids: + state_dict[opt_param_idx] = param_state_map[param] + opt_param_idx += 1 + + return state_dict + + def _construct_optim_state_zero( module: ParallelModule, orig_param_dict: Dict[str, Dict[str, Any]], ): + """ + Construct the optimizer state for a ParallelModule with ZeRO optimization. + Args: + module (ParallelModule): the parallel module + orig_param_dict (Dict[str, Dict[str, Any]]): the original parameter optimizer state + key: original parameter name + value: the state dict for each attribute, e.g. 'step', 'exp_avg', 'exp_avg_sq' are keys + """ dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta def _get_optimizer_state_of_param(param, param_ids, local_names): @@ -1999,7 +2218,7 @@ def _get_optimizer_state_of_param(param, param_ids, local_names): opt_state_keys.remove('step') for key in opt_state_keys: opt_states[key] = torch.zeros([chunk_size], dtype=sliced_new_val[key].dtype, - device=sliced_new_val[key].device, requires_grad=False) + device='cpu', requires_grad=False) # copy the param's slices to the optimizer's chunk for key in opt_state_keys: sliced_new_val[key] = sliced_new_val[key].view(-1) @@ -2070,9 +2289,12 @@ def _construct_optim_state_nonzero( dist_param_map = module.dist_param_map # name in parallel module (without tid suffix) -> name in origin module param_area_map = module.fullmap # str -> AttrMeta - new_states = {} + new_states: dict[int, dict[str, torch.Tensor]] = {} for index, (local_name, _) in enumerate(module.named_parameters()): - new_states[index] = _extract_new_state(local_name, orig_param_dict, dist_param_map, param_area_map) + new_states[index] = _extract_new_state( + local_name, orig_param_dict, dist_param_map, param_area_map, + module.get_zero3_attr_meta(local_name) + ) return new_states @@ -2082,7 +2304,8 @@ def _extract_new_state( orig_param_dict: Dict[str, Dict[str, Any]], dist_param_map: Dict[str, str], param_area_map: Dict[str, AttrMeta], -): + zero3_info: Optional[Zero3AttrMeta] = None +) -> Dict[str, torch.Tensor]: name = '_'.join(local_name.split('_')[:-1]) # remove the integer suffix assert name in dist_param_map attr_meta = param_area_map[local_name] @@ -2093,6 +2316,16 @@ def _extract_new_state( sliced_new_val[key] = new_val[key] else: sliced_new_val[key] = new_val[key][attr_meta.slicers] / attr_meta.val_chunks + if zero3_info is not None: + sliced_new_val[key] = sliced_new_val[key].view(-1)[zero3_info.start:zero3_info.end] + if sliced_new_val[key].numel() < zero3_info.chunk_size: + # padding if needed + sliced_new_val[key] = torch.nn.functional.pad( + sliced_new_val[key].cpu(), + (0, zero3_info.chunk_size - sliced_new_val[key].numel()), + mode='constant', + value=0.0 + ) return sliced_new_val @@ -2147,61 +2380,95 @@ def _broadcast_gen_files( return curr_rank = torch.distributed.get_rank() - ranks = list(range(0, world_size, local_world_size)) - group = DeviceGroup().get_group(ranks) - - # use the first rank of each node to broadcast - if curr_rank % local_world_size == 0: - _, outdir = _prepare_namespace(gen_savedir, module_class, instance_name) - files: List[str] = [] - # send file list - if curr_rank == 0: - for file in outdir.glob('*'): - if file.is_file() and ( - broadcast_strategy == BroadcastGenFilesStrategy.ALL or - ( - broadcast_strategy == BroadcastGenFilesStrategy.NO_WEIGHTS - and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) - ) or - ( - # broadcast code files and compute config file - # please note the compute config file can be updated - # even when the graph is reused. - broadcast_strategy == BroadcastGenFilesStrategy.CODE - and (file.suffix == '.py' or file.name == ParallelModule.COMPUTE_CONFIG_FILE) - ) - ): - files.append(file.name) - sent_obj = [files] + + # use all ranks of each node to broadcast + _, outdir = _prepare_namespace(gen_savedir, module_class, instance_name) + files: List[str] = [] + # send file list + if curr_rank == 0: + for file in outdir.glob('*'): + if file.is_file() and ( + broadcast_strategy == BroadcastGenFilesStrategy.ALL or + ( + broadcast_strategy == BroadcastGenFilesStrategy.NO_WEIGHTS + and not file.name.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM) + ) or + ( + # broadcast code files and compute config file + # please note the compute config file can be updated + # even when the graph is reused. + broadcast_strategy == BroadcastGenFilesStrategy.CODE + and (file.suffix == '.py' or file.name == ParallelModule.COMPUTE_CONFIG_FILE) + ) + ): + files.append(file.name) + sent_obj = [files] + else: + sent_obj = [None] + torch.distributed.broadcast_object_list( + sent_obj, + src=0, + ) + # get file list + if curr_rank != 0: + files = sent_obj[0] + + logger.info(f'File list broadcasted ({len(files)} in total).') + + grouped_files = [[]] # 0th groups for small files (attribute content files excluded) + for fname in files: + if not fname.startswith(FxModuleParser.ATTR_CONTENT_FILE_STEM): + grouped_files[0].append(outdir / fname) else: - sent_obj = [None] - torch.distributed.broadcast_object_list( - sent_obj, - src=0, - group=group, - ) - # get file list - if curr_rank != 0: - files = sent_obj[0] - - logging.info(f'File list broadcasted ({len(files)} in total).') - # send file content one by one - for fname in files: - if curr_rank == 0: - with open(outdir / fname, 'rb') as f: - data = [f.read()] - else: - data = [None] - torch.distributed.broadcast_object_list(data, src=0, group=group) - if curr_rank != 0: - with open(outdir / fname, 'wb') as f: - f.write(data[0]) - logging.info(f'File {fname} broadcasted.') + grouped_files.append([outdir / fname]) + + broadcast_files(grouped_files) # wait for all nodes to finish torch.distributed.barrier() +def _collect_dedup_info(parallel_modules: Dict[str, ParallelModule]) -> Tuple[ + Dict[int, Dict[str, Dict[str, AttrMeta]]], + Dict[str, int], + Dict[int, Dict[str, Dict[str, AttrMeta]]] +]: + """ + A helper function that computes the deduplicated attribute information from all ranks. + Note that this function may be removed in the future and dedup information are computed + directly at the compilation stage. + + Returns: + A tuple containing: + - rank2deduped_fullmap: a mapping from rank id to deduplicated attribute information + - dedup_group_size: the size of the deduplication group for each parallel module + - global_fullmaps: a mapping from rank id to full attribute information + """ + dedup_group_size = {} + for prefix, parallel_module in parallel_modules.items(): + dedup_group_size[prefix] = parallel_module.module_dedup_group_size + + world_size = torch.distributed.get_world_size() + global_fullmaps: Dict[ + int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module + AttrMeta]] + ] = {} + for rank in range(world_size): + global_fullmaps[rank] = {} + for prefix, m in parallel_modules.items(): + global_fullmaps[rank][prefix] = m.get_attr_meta_map(rank) + # `dedup_attrs` is a deterministic algorithm, so it produces same results across different ranks + rank2deduped_fullmap = dedup_attrs(global_fullmaps) + + for prefix, group_size in dedup_group_size.items(): + for rank in range(group_size, world_size): + assert len(rank2deduped_fullmap[rank].get(prefix, {})) == 0, f'Rank {rank} has non-empty deduped_fullmap: {rank2deduped_fullmap[rank]}' + + return rank2deduped_fullmap, dedup_group_size, global_fullmaps + + @torch.no_grad() def deduped_state_dict( module: torch.nn.Module, @@ -2224,6 +2491,9 @@ def deduped_state_dict( module_state_dict, opt_state_dict = None, None parallel_modules = {prefix: m for prefix, m in module.named_modules() if isinstance(m, ParallelModule)} + rank2deduped_fullmap, _, _ = _collect_dedup_info(parallel_modules) + cur_deduped_fullmap = rank2deduped_fullmap[cur_rank] + # The reason we use `Module.state_dict` on the whole to get the complete state dict # instead of call `Module.state_dict` on each submodule # is to make sure the hooks to state_dict are called. @@ -2231,11 +2501,18 @@ def deduped_state_dict( for key in list(module_state_dict.keys()): if key.endswith(ParallelModule.EXTRA_STATE_KEY): # never remove extra state continue - prefix = '.'.join(key.split('.')[:-1]) # remove the last part of the key - dedup_group_size = parallel_modules[prefix].module_dedup_group_size \ - if prefix in parallel_modules else 1 - # only keep the first `dedup_group_size` ranks' state - if cur_rank >= dedup_group_size: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + if prefix in parallel_modules: + if parallel_modules[prefix].compute_config.use_zero > 1: + # for zero3, we don't use advanced deduplication. + # TODO: handle zero3 case + if cur_rank >= parallel_modules[prefix].module_dedup_group_size: + module_state_dict.pop(key, None) + elif prefix not in cur_deduped_fullmap or split_names[-1] not in cur_deduped_fullmap[prefix]: + module_state_dict.pop(key, None) + # since replicated non-parallel modules, we only keep weights on rank 0 + elif cur_rank >= 1: module_state_dict.pop(key, None) if optimizer is not None: @@ -2267,7 +2544,7 @@ def load_deduped_state_dict( module: torch.nn.Module, module_state_dict: Dict[str, Any], optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, - optimizer_state_dict: Optional[Dict[str, Any]] = None, + optimizer_state_dict: Optional[OptStateDict] = None, *, device: Union[str, torch.device] = None ) -> None: @@ -2285,25 +2562,112 @@ def load_deduped_state_dict( None """ device = device or torch.cuda.current_device() + cur_rank = torch.distributed.get_rank() - # only load partial state for all ranks except rank 0 - module.load_state_dict(module_state_dict, strict=False) module.to(device) + + # step 1: load deduped state dict at each rank + missing_keys, unexpected_keys = module.load_state_dict(module_state_dict, strict=False) torch.distributed.barrier() + logger.debug(f'At rank {cur_rank}, state_dict keys: {module_state_dict.keys()}.') + logger.debug(f'At rank {cur_rank}, missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}.') + + # step 2: broadcast deduped weights inside 1st scale unit for non-zero3 parallel modules + # for zero3 modules, the weights are already complete after step 1 + # TODO: refine zero3 modules support + no_zero3_pms = { + prefix: m + for prefix, m in module.named_modules() + if isinstance(m, ParallelModule) and m.compute_config.use_zero <= 1 + } + if no_zero3_pms: + rank2deduped_fullmap, dedup_group_size, _ = _collect_dedup_info(no_zero3_pms) + logger.debug(f'At rank {cur_rank}, dedup_group_size: {dedup_group_size}, rank2deduped_fullmap: {rank2deduped_fullmap}.') + + # collect dedup info from attr meta maps + # Key: (prefix, local_name) + # Value: list[rank]: a list of ranks that have the local_name + local_name2ranks: Dict[tuple[str, str], list[int]] = {} + + for prefix, m in no_zero3_pms.items(): + for rank in range(dedup_group_size[prefix]): + for local_name, _ in m.get_attr_meta_map(rank).items(): + key = (prefix, local_name) + if key not in local_name2ranks: + local_name2ranks[key] = [] + local_name2ranks[key].append(rank) + + # create process groups for broadcasting + for key, ranks in local_name2ranks.items(): + if len(ranks) <= 1: + continue + # should have sorted. + ranks.sort() + logger.debug(f'At rank {cur_rank}, create groups for ranks: {ranks}.') + DeviceGroup().get_group(ranks) - # broadcast weights - broadcast_weights(module) + torch.distributed.barrier() - if optimizer is not None: - if 'adam' not in optimizer._extra_state.name.lower(): - raise ValueError("Only Adam-like optimizers are supported.") - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + # broadcast weights in parallel modules inside dedup group (most time it is the 1st scale unit) + # Implementation of `deduped_state_dict` can guarantee that the first rank in each rank group always has the weights + for key_name, ranks in local_name2ranks.items(): + if len(ranks) <= 1: + continue + prefix, local_name = key_name + if cur_rank in ranks: + key = f'{prefix}.{local_name}' if prefix else local_name + broadcast_group = DeviceGroup().get_group(ranks) + assert prefix in no_zero3_pms, f'Prefix {prefix} not found in parallel_modules: {list(no_zero3_pms.keys())}.' + pm = no_zero3_pms[prefix] + assert hasattr(pm, local_name), f'Local name {local_name} not found in {pm}.' + # the shared tensor will always store in the smallest rank in the dedup group + if cur_rank == ranks[0]: + broadcast_tensor = getattr(pm, local_name) + logger.info(f'Broadcast: {key} from {cur_rank}.') + else: + existing_tensor = None + logger.info(f'At rank {cur_rank}, try to load: {key} from rank {ranks[0]}.') + attr = getattr(pm, local_name) - for idx, state in optimizer_state_dict['state'].items(): - for key, value in state.items(): - if isinstance(value, torch.Tensor): - optimizer_state_dict['state'][idx][key] = value.to(device) + broadcast_tensor = attr.data + if key in missing_keys: + missing_keys.remove(key) + else: + # the tensor is already loaded, we need to check if they are equal + # it should not come here if _collect_dedup_info is strict + existing_tensor = broadcast_tensor.cpu() + + logger.debug(f'At rank {cur_rank}, broadcast from {ranks[0]} to {ranks} for `{key}`.') + torch.distributed.broadcast(broadcast_tensor, src=ranks[0], group=broadcast_group) + + if cur_rank != ranks[0]: + # it should not come here if _collect_dedup_info is strict + # anyway, we add an assertion here to make sure + if existing_tensor is not None: + assert torch.equal(existing_tensor, broadcast_tensor.cpu()), \ + f'At rank {cur_rank}, the attribute {key} is already loaded, ' \ + f'but not equal to the broadcasted tensor from rank {ranks[0]}.' + + torch.distributed.barrier() + + for key in missing_keys: + split_names = key.split('.') + prefix = '.'.join(split_names[:-1]) # remove the last part of the key + assert prefix not in no_zero3_pms or cur_rank >= dedup_group_size[prefix], f'At rank {cur_rank}, the missing key {key} should be in non-parallel modules.' + + # At this point + # - All parallel modules in first scale unit should be complete. + # - Non-parallel modules in rank0 should be complete. The rest ranks will get the weights via broadcast_weights. + torch.distributed.barrier() + + # step 3: + # - broadcast non-parallel module weights from 0th rank to other ranks + # - broadcast parallel modules weights from 1st scale unit to other units + broadcast_weights(module) + + if optimizer is not None and optimizer_state_dict is not None: + if not _is_supported_optimizer(optimizer._extra_state.name): + raise ValueError("Only Adam-like or Muon-like optimizers are supported.") # get the locations of non-parallel module parameters # by removing the parallel module locations @@ -2325,12 +2689,13 @@ def load_deduped_state_dict( for bg in opt_broadcast_groups: _broadcast_opt_state(optimizer_state_dict, *bg) + optimizer.load_state_dict(optimizer_state_dict) - torch.distributed.barrier() + torch.distributed.barrier() -def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_group_size: int): +def _broadcast_opt_state(optimizer_state_dict: OptStateDict, state_indexes: List[int], dedup_group_size: int): if not state_indexes: return @@ -2338,7 +2703,7 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g broadcast_group = setup_stride_broadcast_group(dedup_group_size) src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting optimizer states to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list if rank == src_rank: @@ -2359,25 +2724,37 @@ def _broadcast_opt_state(optimizer_state_dict, state_indexes: List[int], dedup_g key: torch.zeros(value[0], dtype=value[1], device=torch.cuda.current_device()) for key, value in v.items() } + else: + for idx in state_indexes: + for key, value in optimizer_state_dict['state'][idx].items(): + optimizer_state_dict['state'][idx][key] = optimizer_state_dict['state'][idx][key].cuda() # broadcast step # step is too small, so we can just broadcast all of them all together # some adam/adamw optimizers may not have step in their state dict # so we need to check if 'step' is in the state dict - if 'step' in optimizer_state_dict['state'][state_indexes[0]]: + step_state_indexes = [k for k in state_indexes if 'step' in optimizer_state_dict['state'][k]] + if step_state_indexes: + assert all( + optimizer_state_dict['state'][k]['step'].dtype == + optimizer_state_dict['state'][step_state_indexes[0]]['step'].dtype and + optimizer_state_dict['state'][k]['step'].shape == + optimizer_state_dict['state'][step_state_indexes[0]]['step'].shape + for k in step_state_indexes + ) if rank == src_rank: step_stack = torch.stack( - [optimizer_state_dict['state'][k]['step'] for k in state_indexes] + [optimizer_state_dict['state'][k]['step'] for k in step_state_indexes] ) else: step_stack = torch.zeros( - len(state_indexes), - dtype=optimizer_state_dict['state'][state_indexes[0]]['step'].dtype, + len(step_state_indexes), + dtype=optimizer_state_dict['state'][step_state_indexes[0]]['step'].dtype, device=torch.cuda.current_device() ) torch.distributed.broadcast(step_stack, src=src_rank, group=curr_parallel_group) if rank != src_rank: - for k, v in zip(state_indexes, step_stack): + for k, v in zip(step_state_indexes, step_stack): optimizer_state_dict['state'][k]['step'].copy_(v) # broadcast other states @@ -2424,7 +2801,7 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): broadcast_group = setup_stride_broadcast_group(stride_size) rank = torch.distributed.get_rank() src_rank, curr_parallel_group, curr_parallel_group_ranks = broadcast_group.src_rank, broadcast_group.group, broadcast_group.ranks - logging.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') + logger.info(f'Rank-{rank} is broadcasting weights of {module.__class__.__name__} to ranks {curr_parallel_group_ranks}, broadcast source: {src_rank}...') if isinstance(module, ParallelModule): if not _broadcast_single_value(src_rank, curr_parallel_group, module.non_presistent_buffers_inited): @@ -2432,15 +2809,15 @@ def _broadcast_weights(module: torch.nn.Module, stride_size: int): # we have a special optimization for ParallelModule params = module.parameters_for_broadcast() if isinstance(module, ParallelModule) else list(module.parameters(False)) - logging.info(f'Inplace broadcasting {len(params)} parameters...') + logger.info(f'Inplace broadcasting {len(params)} parameters...') for i, param in enumerate(params): torch.distributed.broadcast(param.data, src=src_rank, group=curr_parallel_group) - logging.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') + logger.info(f'Inplace broadcasted {i+1}/{len(params)} parameters') # NOTE: may batch buffers for efficient broadcast, # current implementation is the most memory efficient way. buffers = list(module.buffers(False)) - logging.info(f'Inplace broadcasting {len(buffers)} buffers...') + logger.info(f'Inplace broadcasting {len(buffers)} buffers...') for buffer in buffers: torch.distributed.broadcast(buffer.data, src=src_rank, group=curr_parallel_group) @@ -2475,11 +2852,10 @@ def load_sharded_state_dict( """ device = device or torch.cuda.current_device() - module.load_state_dict(module_state_dict) module.to(device) - if optimizer: - if optimizer_state_dict is None: - raise ValueError("optimizer_state_dict should be provided when optimizer is not None.") + + module.load_state_dict(module_state_dict) + if optimizer and optimizer_state_dict: optimizer.load_state_dict(optimizer_state_dict) @@ -2513,3 +2889,451 @@ def sync_grad_when(cond: bool): cond (bool): whether to synchronize gradients. """ return _runtime_flags(skip_reducer=not cond) + + +def _construct_parallel_module_stub(metadata): + pmodules = {prefix: ParallelModule._unpack(minfo) for prefix, minfo in metadata.items()} + + # whole parallel module + if len(pmodules) == 1 and list(pmodules.keys())[0] == '': + module = pmodules[''] + else: + module = torch.nn.Module() + for prefix, pmodule in pmodules.items(): + set_member_by_name(module, prefix, pmodule) + + # mock `named_modules` to list parallel modules in stub module + def named_modules( + memo=None, + prefix: str = "", + remove_duplicate: bool = True, + ): + assert memo is None and prefix == '' and remove_duplicate is True, \ + "Only support default arguments" + return pmodules.items() + + module.named_modules = named_modules + + return module + + +def _trim_module_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Dict[str, Any], + *, + device: Union[str, torch.device] = None, +): + device = device or torch.cuda.current_device() + + parallel_modules = {module_path: m for module_path, m in module.named_modules() if isinstance(m, ParallelModule)} + + trimmed_state_dict = {} + # collect non-parallel module parameters + for key, tensor in module_state_dict.items(): + parts = key.split('.') + if not any('.'.join(parts[:i]) in parallel_modules for i in range(0, len(parts))): + trimmed_state_dict[key] = tensor.to(device) + + for module_path, pmodule in parallel_modules.items(): + prefix = module_path + '.' if module_path else '' + trimmed_state_dict.update( + pmodule.trim_merged_state_dict( + module_state_dict, prefix=prefix, + device=device + ) + ) + return trimmed_state_dict + + +def _send_trimmed_module_state_dict( + trimmed_state_dict: Dict[str, torch.Tensor], + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed state dict to the specified destination rank. + + Args: + trimmed_state_dict (Dict[str, torch.Tensor]): the trimmed state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed state dict to rank + # one tensor each time + keys = list(trimmed_state_dict.keys()) + shape_dtypes = [(tensor.shape, tensor.dtype) for tensor in trimmed_state_dict.values()] + torch.distributed.send_object_list([keys, shape_dtypes], group=group, dst=dst_rank) + for key in keys: + tensor = trimmed_state_dict[key] + # NOTE: send is broken if the tensor is not contiguous + torch.distributed.send(tensor.cuda().contiguous(), group=group, dst=dst_rank) + + +def _receive_trimmed_module_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, +): + """ + Receive the trimmed state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed state dict from rank + # one at a time + keys_shape_dtypes=[None, None] + torch.distributed.recv_object_list(keys_shape_dtypes, group=group, src=src_rank) + keys: list[str] = keys_shape_dtypes[0] + shape_dtypes: list[tuple[torch.Size, torch.dtype]] = keys_shape_dtypes[1] + + trimmed_state_dict = {} + for key, shape_dtype in zip(keys, shape_dtypes): + tensor = torch.zeros(shape_dtype[0], dtype=shape_dtype[1], device='cuda') + torch.distributed.recv(tensor, group=group, src=src_rank) + trimmed_state_dict[key] = tensor.to(device) + return trimmed_state_dict + + +def _send_trimmed_opt_state_dict( + trimmed_opt_state_dict: OptStateDict, + group: torch.distributed.ProcessGroup, + dst_rank: int, +): + """ + Send the trimmed optimizer state dict to the specified destination rank. + + Args: + trimmed_opt_state_dict (OptStateDict): the trimmed optimizer state dict to send. + dst_rank (int): the destination rank to send the state dict to. + """ + # send trimmed optimizer state dict to rank + # one tensor each time + + # broadcast param groups and state keys/shapes/dtypes via broadcast_object_list + state_info = {} + state_keys = list(trimmed_opt_state_dict['state'].keys()) + param_group = trimmed_opt_state_dict['param_groups'] + for idx in state_keys: + state_info[idx] = {key: (value.shape, value.dtype) for key, value in trimmed_opt_state_dict['state'][idx].items()} + sent = [state_keys, state_info, param_group] + torch.distributed.send_object_list(sent, group=group, dst=dst_rank) + + # broadcast step in stack + step_state_keys = [k for k in state_keys if 'step' in trimmed_opt_state_dict['state'][k]] + if step_state_keys: + assert all( + trimmed_opt_state_dict['state'][k]['step'].dtype == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype and + trimmed_opt_state_dict['state'][k]['step'].shape == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].shape + for k in step_state_keys + ) + step_stack = torch.stack( + [trimmed_opt_state_dict['state'][k]['step'] for k in step_state_keys] + ) + torch.distributed.send(step_stack.cuda(), group=group, dst=dst_rank) + + # broadcast other states + # TODO: can be slow? + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key] + torch.distributed.send(value.data.cuda(), group=group, dst=dst_rank) + + +def _receive_trimmed_opt_state_dict( + src_rank: int, + group: torch.distributed.ProcessGroup, + device: Union[str, torch.device] = None, + ) -> OptStateDict: + """ + Receive the trimmed optimizer state dict from the specified source rank. + + Args: + src_rank (int): the source rank to receive the state dict from. + """ + device = device or torch.cuda.current_device() + + # receive trimmed optimizer state dict from rank + # one at a time + state_dict_info = [None, None, None] + torch.distributed.recv_object_list(state_dict_info, group=group, src=src_rank) + state_keys: list[str] = state_dict_info[0] + state_info: list[tuple[torch.Size, torch.dtype]] = state_dict_info[1] + param_group = state_dict_info[2] + + trimmed_opt_state_dict = { + 'state': {}, + 'param_groups': param_group + } + for key in state_keys: + trimmed_opt_state_dict['state'][key] = { + k: torch.zeros(v[0], dtype=v[1], device=device) + for k, v in state_info[key].items() + } + + # receive steps + step_state_keys = [k for k in state_keys if 'step' in trimmed_opt_state_dict['state'][k]] + if step_state_keys: + assert all( + trimmed_opt_state_dict['state'][k]['step'].dtype == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype and + trimmed_opt_state_dict['state'][k]['step'].shape == + trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].shape + for k in step_state_keys + ) + step_stack = torch.zeros( + len(step_state_keys), + dtype=trimmed_opt_state_dict['state'][step_state_keys[0]]['step'].dtype, + device='cuda' + ) + torch.distributed.recv(step_stack, group=group, src=src_rank) + for k, v in zip(step_state_keys, step_stack): + trimmed_opt_state_dict['state'][k]['step'].copy_(v) + + # receive other states + for k in state_keys: + keys = sorted(trimmed_opt_state_dict['state'][k].keys()) + if 'step' in keys: + keys.remove('step') # we have done step in previous. + for key in keys: + value = trimmed_opt_state_dict['state'][k][key].cuda() + torch.distributed.recv(value, group=group, src=src_rank) + trimmed_opt_state_dict['state'][k][key] = value.to(device) + + return trimmed_opt_state_dict + + +def trimmed_broadcast_merged_state_dict( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + """ + trim merged state dict and broadcast to each rank. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + the trimmed state dicts for the module and optimizer + """ + device = device or torch.cuda.current_device() + world_size = torch.distributed.get_world_size() + dst_ranks = dst_ranks or list(range(world_size)) + cur_rank = torch.distributed.get_rank() + + if cur_rank not in dst_ranks or src_rank not in dst_ranks: + raise ValueError( + f"Invalid rank configuration. Both current rank ({cur_rank}) and source rank ({src_rank}) " + f"must be in the destination ranks {dst_ranks}." + ) + + pg = DeviceGroup().get_group(dst_ranks) + + if cur_rank == src_rank: + if optimizer_state_dict and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + else: + if optimizer_state_dict or module_state_dict: + raise ValueError("Only the source rank can provide the merged state dicts.") + + rank_metadata = ( + {module_path: m._pack() for module_path, m in module.named_modules() if isinstance(m, ParallelModule)}, + optimizer._extra_state if optimizer else None, + ) + + rank_metadatas = [None] * len(dst_ranks) if cur_rank == src_rank else None + torch.distributed.gather_object(rank_metadata, rank_metadatas, group=pg, dst=src_rank) + + if cur_rank == src_rank: + will_load_opt_state = [optimizer_state_dict is not None] + else: + will_load_opt_state = [None] + torch.distributed.broadcast_object_list(will_load_opt_state, group=pg, src=src_rank) + will_load_opt_state = will_load_opt_state[0] + if will_load_opt_state and not optimizer: + raise ValueError("Optimizer must be provided when loading optimizer state dict.") + + ret = None + + if cur_rank == src_rank: + pmodule_stubs = {rank : _construct_parallel_module_stub(r[0]) for rank, r in zip(dst_ranks, rank_metadatas)} + opt_extra_states = {rank : r[1] for rank, r in zip(dst_ranks, rank_metadatas)} + for rank in dst_ranks: + if rank != cur_rank: + logger.info(f'At rank {src_rank}: Trimming module state dict for rank {rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[rank], + module_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed module state dict for rank {rank}') + _send_trimmed_module_state_dict(trimmed_module_state_dict, dst_rank=rank, group=pg) + del trimmed_module_state_dict + + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for rank {rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[rank], + opt_extra_states[rank], + optimizer_state_dict, + device=device, + ) + logger.info(f'At rank {src_rank}: Sending trimmed optimizer state dict for rank {rank}') + _send_trimmed_opt_state_dict(trimmed_opt_state_dict, dst_rank=rank, group=pg) + del trimmed_opt_state_dict + + torch.distributed.barrier(group=pg) + + # load for self after state dict for all other ranks are sent + # this can lower gpu memory peak + logger.info(f'At rank {src_rank}: Trimming module state dict for self rank {cur_rank}') + trimmed_module_state_dict = _trim_module_merged_state_dict( + pmodule_stubs[cur_rank], + module_state_dict, + device=device, + ) + if will_load_opt_state: + logger.info(f'At rank {src_rank}: Trimming optimizer state dict for self rank {cur_rank}') + trimmed_opt_state_dict = _trim_optimizer_merged_state_dict( + pmodule_stubs[cur_rank], + opt_extra_states[cur_rank], + optimizer_state_dict, + device=device, + ) + else: + trimmed_opt_state_dict = None + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + else: + for rank in dst_ranks: + if rank == cur_rank: + # receive state dict from src_rank + logger.info(f'At rank {cur_rank}: Receiving trimmed module state dict from rank {src_rank}') + trimmed_module_state_dict = _receive_trimmed_module_state_dict(src_rank, group=pg) + + if will_load_opt_state: + logger.info(f'At rank {cur_rank}: Receiving trimmed optimizer state dict from rank {src_rank}') + trimmed_opt_state_dict = _receive_trimmed_opt_state_dict(src_rank, group=pg) + else: + trimmed_opt_state_dict = None + + ret = (trimmed_module_state_dict, trimmed_opt_state_dict) + + torch.distributed.barrier(group=pg) + + assert ret is not None + # make it a sharded state dict. + for module_path, m in module.named_modules(): + prefix = module_path + '.' if module_path else '' + if isinstance(m, ParallelModule): + m._add_extra_state(ret[0], prefix) + return ret + + +def load_merged_state_dict_from_rank( + module: torch.nn.Module, + module_state_dict: Optional[Dict[str, Any]] = None, + optimizer: Optional[Union[torch.optim.Optimizer, ParallelOptimizer]] = None, + optimizer_state_dict: Optional[Dict[str, Any]] = None, + *, + src_rank: int = 0, + dst_ranks: Optional[list[int]] = None, + device: Union[str, torch.device] = None, +): + """ + load the merged state dict from rank. + + Only src_rank will load merged state dict to memory (for saving memory), + and dst_ranks will receive the sharded state dict from src_rank via communication. + + Args: + module (torch.nn.Module): the module to be loaded + module_state_dict (Dict[str, Any]): the merged model state dict + optimizer (Optional[torch.optim.Optimizer]): the optimizer to be loaded + optimizer_state_dict (Optional[Dict[str, Any]]): the merged optimizer state dict + device (Union[str, torch.device]): the device to put the module and optimizer state dicts. + Use torch.cuda.current_device() if it is None. + src_rank (int): the source rank to load the merged state dict from. + dst_ranks (Optional[list[int]]): the destination ranks to load the merged state dict to. + + Returns: + None + """ + device = device or torch.cuda.current_device() + module.to(device) + trimmed_module_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + module, + module_state_dict, + optimizer, + optimizer_state_dict, + device='cpu', + src_rank=src_rank, + dst_ranks=dst_ranks, + ) + module.load_state_dict(trimmed_module_state_dict) + if trimmed_opt_state_dict: + optimizer.load_state_dict(trimmed_opt_state_dict) + + +@torch.no_grad() +def gather_full_model_state_dict( + module: torch.nn.Module, +) -> Dict[str, Any]: + """ + Gather model state dicts from all ranks, + And merge them into a single merged model state dict in all ranks. + + Args: + module (torch.nn.Module): the module to gather state dicts from + + Returns: + Dict[str, Any]: the merged model state dict + """ + + rank = torch.distributed.get_rank() + parallel_modules = [m for m in module.modules() if isinstance(m, ParallelModule)] + if not parallel_modules: + raise ValueError("No ParallelModule found in the module.") + parallel_module = parallel_modules[0] + compute_config = parallel_module.compute_config + num_involved_ranks = compute_config.module_dedup_group_size + involved_group = DeviceGroup().get_group(list(range(num_involved_ranks))) + + logger.info(f'Gathering full model state dict from ranks {list(range(num_involved_ranks))}') + + if rank < num_involved_ranks: + local_state_dict, _ = deduped_state_dict(module, optimizer=None) + logger.info(f'Rank {rank}: gathering state dict') + state_dicts = gather_mixed_data(local_state_dict, src_rank=0, group=involved_group, device='cpu') + if rank == 0: + logger.info(f'Rank {rank}: merging gathered state dicts') + merge_state_dict = merge_state_dicts(state_dicts) + else: + merge_state_dict = None + else: + merge_state_dict = None + + logger.info(f'Rank {rank}: Broadcasting merged state dict to all ranks') + merge_state_dict = broadcast_mixed_data(merge_state_dict, src_rank=0, device='cpu') + logger.info(f'Rank {rank}: Finished gathering full model state dict') + torch.distributed.barrier() + return merge_state_dict diff --git a/nnscaler/policies.py b/nnscaler/policies.py index f1db5858..2768e0a6 100644 --- a/nnscaler/policies.py +++ b/nnscaler/policies.py @@ -18,8 +18,10 @@ IRDataOperation is recommended to be replicated to all devices. """ +import ast +from dataclasses import dataclass, field import logging -from typing import List, Optional, TYPE_CHECKING +from typing import Any, List, Literal, Optional, TYPE_CHECKING, Callable, Iterable, Union import random import torch @@ -30,13 +32,17 @@ from nnscaler.graph import IRGraph from nnscaler.graph.function.anchor import IRGraphAnchor from nnscaler.graph.function.dimops import IRDimops +from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.segment import IRSegment -from nnscaler.ir.operator import IRDataOperation, IRFwOperation +from nnscaler.ir.operator import IRBpOperation, IRDataOperation, IRFwOperation from nnscaler.ir import IRCell, IRSubTensor, IRFullTensor +from nnscaler.ir.cten import IR +from nnscaler.runtime.function import identity, multiref +from nnscaler.utils import load_type if TYPE_CHECKING: - from nnscaler.parallel import ComputeConfig + from nnscaler.parallel import ComputeConfig, ParallelModule _logger = logging.getLogger(__name__) @@ -94,6 +100,8 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random tensor parallelism inside a scale unit, and dp across scale units """ ngpus = cfg.plan_ngpus + pas_cfg = cfg.pas_config + enable_random_replicated = pas_cfg.get('enable_random_replicated', False) # get the current random state state = random.getstate() @@ -108,7 +116,7 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): continue if isinstance(node, IRDimops): configs = node.transform_space() - if len(configs) == 0: + if len(configs) == 0 or (enable_random_replicated and random.random() < 0.5): _replica(graph, node, devs) else: configs = sorted(configs, reverse=True, @@ -116,6 +124,8 @@ def pas_tp(graph: IRGraph, cfg: 'ComputeConfig'): random.shuffle(configs) for (idx, dim) in configs: if node.input(idx).shape[dim] % len(devs) != 0: continue + # only partition when all input tensors are constant on this dim + if not node.input(idx).dim_tracks[dim].is_constant: continue if node.algorithm('dim').satisfy(idx=idx, dim=dim, num=len(devs)): _tp(graph, node, devs, idx, dim) break @@ -219,6 +229,8 @@ def pas_hybrid(graph: IRGraph, cfg: 'ComputeConfig'): def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: + from nnscaler.autodist.util import get_default_profile_path + pas_cfg = cfg.pas_config update_freq = pas_cfg.get('update_freq', 1) @@ -266,18 +278,24 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: use_memory_efficient_bf16 = pas_cfg.get('use_memory_efficient_bf16', False) use_fp16 = pas_cfg.get('use_fp16', use_memory_efficient_fp16) use_bf16 = pas_cfg.get('use_bf16', use_memory_efficient_bf16) + profile_dir = pas_cfg.get('profile_dir', None) + if profile_dir is None: + profile_dir = get_default_profile_path() re_profile = pas_cfg.get('re_profile', False) verbose = pas_cfg.get('verbose', False) load_plan_path = pas_cfg.get('load_plan_path', None) save_plan_path = pas_cfg.get('save_plan_path', None) partition_constraints_path = pas_cfg.get('partition_constraints_path', '') recompute_modules = pas_cfg.get('recompute_modules', '') + recompute_ratio = pas_cfg.get('recompute_ratio', 1.0) pipeline_pivots = pas_cfg.get('pipeline_pivots', '') max_pipeline_bubble_ratio = pas_cfg.get('max_pipeline_bubble_ratio', 0.2) max_pipeline_unbalance_ratio = pas_cfg.get('max_pipeline_unbalance_ratio', 0.5) use_apex_fused_adam_v2 = pas_cfg.get('use_apex_fused_adam_v2', False) parallel_profile = pas_cfg.get('parallel_profile', True) transient_mem_coef = pas_cfg.get('transient_mem_coef', 2) + disable_shared_param_constraint = pas_cfg.get('disable_shared_param_constraint', False) + solver = pas_cfg.get('solver', 'dp') task_name = f'{task_name}_{cfg.plan_ngpus}gpus_{update_freq}update_freq' if memory_constraint == -1: @@ -332,8 +350,10 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: opt_transient_coef=opt_transient_coef, verbose=verbose, re_profile=re_profile, + profile_dir=profile_dir, world_size=cfg.runtime_ngpus, recompute_modules=recompute_modules, + recompute_ratio=recompute_ratio, zero_stage=zero_stage, zero_ngroups=zero_ngroups, load_plan_path=load_plan_path, @@ -344,6 +364,503 @@ def pas_autodist(graph: IRGraph, cfg: 'ComputeConfig') -> IRGraph: max_pipeline_unbalance_ratio=max_pipeline_unbalance_ratio, parallel_profile=parallel_profile, transient_mem_coef=transient_mem_coef, + disable_shared_param_constraint=disable_shared_param_constraint, + solver=solver, ) return parallelize_graph(graph, autodist_cfg) + + +@dataclass(unsafe_hash=True, frozen=True) +class OpPartition: + """ + OpPartition represents a partition plan for an operator dimension. + """ + input: int + dim: int + + +@dataclass +class OpPlan: + """ + OpPlan represents the distributed plan for an operator. + """ + op: IRFwOperation + recompute_id: int = -1 # -1 means no recompute + stage_id: int = -1 # pipeline stage id, -1 means following the previous op's stage + + # user defined meta data for hooks + # which will be passed to the pre_hook and post_hook functions + # Note: Only types that can be safely `repr`-ed can be used here. (e.g., str, int, float, tuple, list, dict) + hook_meta: Any = None + + # function to be called before the op is executed + # which will be inserted in the runtime code before the op call. + # op's inputs will be passed to the hook. + # The signature will be like + # def pre_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: + pre_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any]], None]] = None + + # function to be called after the op is executed + # which will be inserted in the runtime code after the op call. + # op's inputs and outputs will be passed to the hook. + # the signature will be like + # def post_hook(module: ParallelModule, meta: Any, inputs: Tuple[Any, ...], kwargs: Dict[str, Any], output: Any) -> None: + post_hook: Optional[Callable[['ParallelModule', Any, tuple[Any, ...], dict[str, Any], Any], None]] = None + + # OpPartition: user specified partition plan + # You only need to specify one partition plan here. + # For example, torch.matmul has annotation of `m k+, k+ n -> m n`, + # If you want to partition the matmul on the k dimension, + # you can set OpPartition(input=0, dim=1) or OpPartition(input=1, dim=0). + # They are equivalent. + # None: replicated + # 'auto': auto partition based on the input tensor partition info + # 1. if any of the input tensors is value partitioned, we replicate the op + # TODO: is it too strict? + # 2. if any of the input tensors is partitioned on a dim, + # we will try to partition the op on the same dim first, + # if the partition is invalid, we replicate the op + # 3. if all the input tensor is replicated, we replicate the op + partition: OpPartition | None | Literal['auto'] = None # partition plan + # for future extension + # don't use it now. + partitions: List[OpPartition | None] = field(default_factory=list) # multiple partition plans + + def __post_init__(self): + if self.partition is not None and len(self.partitions) > 0: + raise ValueError("Only one of partition and partitions can be set") + + if len(self.partitions) > 1: + raise NotImplementedError("Multiple partitions are not supported yet") + + if len(self.partitions) == 1: + self.partition = self.partitions[0] + self.partitions = [] + + +def get_layer_index(fqn: str) -> int: + """ + Extract the layer index from full qualified name. + If there are multiple integers in the name, raise ValueError. + """ + nums = [int(s) for s in fqn.split(".") if s.isdigit()] + if len(nums) != 1: + raise ValueError(f"Name {fqn} should only contain one integer") + return nums[0] + + +def get_called_self_module_name(node_call_expr: str) -> str: + """ + Get the called module name from the node's call expr by ast. + For example: + self.up_proj(x) -> up_proj + self.act_fn(self.gate_proj(x)) -> act_fn + self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) -> down_proj + torch.tanh(x) -> '' # because it's not called from self + self.up_proj(x).transpose() -> '' # because it's an attribute call + + Other cases return empty string. + + NOTE: regex is not easy to make it work + + """ + + if not node_call_expr: + return '' + call_expr: ast.Call = ast.parse(node_call_expr, mode='eval').body # type: ignore + if isinstance(call_expr, ast.Call): # self.up_proj(x) + if isinstance(call_expr.func, ast.Attribute): # self.up_proj + if isinstance(call_expr.func.value, ast.Name) and call_expr.func.value.id == 'self': + return call_expr.func.attr # up_proj + return '' + + +def get_pas_ops(graph: IRGraph) -> List[IRFwOperation]: + """ + Get all operators in the graph that can set operator plan. + When we write a policy, only ops returned from this function need to be considered. + + Args: + graph: the input IRGraph + + Returns: + List[IRFwOperation]: list of IRFwOperation nodes + """ + return graph.select(ntype=IRFwOperation) + + +def fn( + graph: IRGraph, cfg: 'ComputeConfig', + policy: Union[ + Callable[[IRGraph, 'ComputeConfig'], IRGraph], + Callable[[IRGraph, 'ComputeConfig'], Iterable[OpPlan]], + ] +) -> IRGraph: + """ + General policy function based on user-defined policy. + The user-defined policy can either return the final IRGraph, or + return a list of OpPlan to describe the distributed plan for each operator. + + To write a new-style policy, the most important part is to locate the operator node in the graph. + Here are some tips: + 1. use `node.name` to get the operator name. + 2. use `node.fn` to get the operator function. + 3. use `node.module_stack` to get the module stack info. + 4. use `node.module_class_chain` to get the module class chain. + 5. use `node.call_expr` to get the call expression string. And you can user `ast.parse` to parse it. + 6. use `get_layer_index` to get the layer index in a torch.nn.ModuleList. + 7. use `get_called_self_module_name` to get the called self module name from the call expression. + 8. use `node.inputs()` the get the input tensors of the operator. + We can further check whether the input tensor is a parameter by `tensor.is_param`, + or get the full name of the parameter by `tensor.name`, etc. + 9. insert anchors in code with `nnscaler.anchor` to help locate the operator (intrusive way). + + A good way to locate the operator will be like: + 1. Locate the module first by module_class_chain (`target_module in node.module_class_chain`) + 2. If the module are used multiple times (e.g., in ModuleList), + locate further by layer index (`get_layer_index`) or `node.fqn`. + 3. Once the module is located, + we can further locate the operator by + `node.name`,`node.call_expr`, `node.fn`, `node.inputs()` (especially the `is_param`/`name` of input) + or other properties. + + Args: + graph: the input IRGraph + cfg: the compute config + policy: the user-defined policy function. It can either return the final IRGraph, + or return an iterable of OpPlan for each operator. + + Returns: + the distributed IRGraph + """ + result = policy(graph, cfg) + if isinstance(result, IRGraph): # traditional policy + return result + + op_plans = {r.op: r for r in result} + ngpus: int = cfg.plan_ngpus + + recompute_groups: dict[int, list[IRFwOperation]] = {} + recompute_last_id: int = -1 + recompute_group_stages: dict[int, int] = {} + + pp_stages: list[list[IRFwOperation]] = [[]] + pp_cur_stage_id = 0 + + # key: IRFullTensor + # value: + # key: stage_id + # value: set of OpPartition in this stage + tensor_splits: dict[IRFullTensor, dict[int, set[OpPartition]]] = {} + # store the last split info for each tensor to help handle auto partition + # None: replicated + # 'value': value partitioned + # int: the partitioned dim + output_tensor_last_split: dict[IRFullTensor, int | None | Literal['value']] = {} + + fw_nodes = dict.fromkeys(graph.select(ntype=IRFwOperation)) + + for node in fw_nodes: + if node not in op_plans: + op_plans[node] = OpPlan(op=node) # default: no partition, stage 0, no recompute + + node.hook_meta = op_plans[node].hook_meta + node.pre_hook = op_plans[node].pre_hook + node.post_hook = op_plans[node].post_hook + + op_plan = op_plans[node] + + # set pipeline stage id if not set + if op_plan.stage_id == -1: + op_plan.stage_id = pp_cur_stage_id + + # currently we only support partition for IRDimops + if not isinstance(op_plan.op, IRDimops): + if op_plan.partition == 'auto': + op_plan.partition = None + if op_plan.partition is not None: + raise ValueError("Only IRDimops can be partitioned.") + + # list of partitions for the op + # [] means no partition(replicated) + op_partitions = [op_plan.partition] if op_plan.partition is not None else [] + + if op_partitions == ['auto']: + # auto partition based on input tensor partition info + op_partitions = [] # reset to collect partitions + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + last_partition_dim = output_tensor_last_split.get(ftensor, None) + if last_partition_dim == 'value': + # value partitioned input, replicate the op + op_partitions = [] + break + elif last_partition_dim is not None: + op_partitions.append(OpPartition(input=idx, dim=last_partition_dim)) + + # final partition plan for the op + # key: input idx, value: partitioned dim + op_partition_map: dict[int, int] = {} + if op_partitions: + # we partition the op based on the first partition plan + # and then check the rest partitions are satisfied or not + op_first_partition = op_partitions[0] + partitioned_nodes = op_plan.op.algorithm('dim')\ + .instantiate(idx=op_first_partition.input, dim=op_first_partition.dim, num=ngpus) + subnode = partitioned_nodes[0] # first subnode carries all necessary partition info + + # collect input partition info + # key: input idx, value: partitioned dim + result_partitions: dict[int, int] = {} + for idx, input in enumerate(subnode.inputs()): + if not isinstance(input, IRSubTensor): + continue + split_dims = input.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one input" + if split_dims: + result_partitions[idx] = split_dims[0] + + # check the rest partitions + # Note if we only have one partition plan, the check is skipped, we can always partition it + # In fact, if `auto` is not specified, we always have at most one partition plan + for op_partition in op_partitions[1:]: + if op_partition.input not in result_partitions or \ + result_partitions[op_partition.input] != op_partition.dim: + _logger.warning( + f"Operator {op_plan.op} cannot be partitioned as specified: {op_partition}" + f", replicate it instead." + ) + op_partitions = [] + op_partition_map = {} + break + else: + # all partitions are satisfied + # then we can update input/output partition info + + # make sure the first item in op_partition_map is the first partition plan + op_partition_map[op_first_partition.input] = op_first_partition.dim + op_partition_map.update(result_partitions) + + for output in subnode.outputs(): + if not isinstance(output, IRSubTensor): + continue + ftensor = output.parent + if output.valmap != (0, 1): + output_tensor_last_split[ftensor] = 'value' + else: + split_dims = output.splitdims() + assert len(split_dims) <= 1, "Internal Error: multiple splitdims in one output" + if split_dims: + output_tensor_last_split[ftensor] = split_dims[0] + + if op_plan.partition == 'auto': + if not op_partition_map: + op_plan.partition = None + else: + # use the first partition plan, + # which is consistent with the logic above + first_input_idx = list(op_partition_map.keys())[0] + op_plan.partition = OpPartition( + input=first_input_idx, + dim=op_partition_map[first_input_idx] + ) + + # update tensor_splits for input tensors + for idx, input in enumerate(op_plan.op.inputs()): + if not isinstance(input, IRSubTensor): + continue + ftensor = input.parent + if ftensor not in tensor_splits: + tensor_splits[ftensor] = {} + if idx not in op_partition_map: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add(None) + else: + tensor_splits[ftensor].setdefault(op_plan.stage_id, set()).add( + OpPartition(input=idx, dim=op_partition_map[idx])) + + if op_plan.recompute_id != -1: + if op_plan.recompute_id in recompute_group_stages: + if recompute_group_stages[op_plan.recompute_id] != op_plan.stage_id: + raise ValueError("All ops in a recompute group must be in the same stage") + else: + recompute_group_stages[op_plan.recompute_id] = op_plan.stage_id + + if op_plan.recompute_id != recompute_last_id and op_plan.recompute_id in recompute_groups: + raise ValueError("Nodes in a recompute group must be continuous.") + + recompute_groups.setdefault(op_plan.recompute_id, []).append(op_plan.op) + + recompute_last_id = op_plan.recompute_id + + # update pipeline stages + if op_plan.stage_id == pp_cur_stage_id: + pp_stages[pp_cur_stage_id].append(op_plan.op) + elif op_plan.stage_id == pp_cur_stage_id + 1: + pp_cur_stage_id += 1 + pp_stages.append([op_plan.op]) + else: + raise ValueError("Pipeline stage ids must be continuous integers starting from 0") + + if len(op_plans) != len(fw_nodes): + assert len(op_plans) > len(fw_nodes) + for op_plan in op_plans.values(): + if op_plan.op not in fw_nodes: + raise ValueError(f"OpPlan contains operator {op_plan.op} not in the graph or not a forward operator") + + pp_segs = [graph] + nstages = len(pp_stages) + pp_enabled = nstages > 1 + # not all schedulers support pp_size < nstages + pp_size = cfg.pas_config.get('pipeline_size', nstages) + nmicros = cfg.pas_config.get('pipeline_nmicros', None) + scheduler = cfg.pas_config.get('pipeline_scheduler', '1f1b') + tp_size = ngpus // pp_size + + if pp_enabled: + if not cfg.use_end2end: + raise ValueError("Pipeline parallelism requires use_end2end to be True") + if pp_size <= 1: + raise ValueError("pipeline_size must be greater than 1 when pipeline is enabled") + if not nmicros: + raise ValueError("nmicros must be set when pipeline is enabled") + if nstages % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for nstages {nstages}') + if ngpus % pp_size != 0: + raise ValueError(f'invalid pipeline_size {pp_size} for ngpus {ngpus}') + else: + if pp_size != 1: + raise ValueError("pipeline_size must be 1 when pipeline is disabled") + + # set recompute groups + for group in recompute_groups.values(): + if len(group) <= 1: + continue + graph.recompute(group) + + # add multiref for shared parameters across stages + # note that we have constrained that shared parameters cannot be partitioned in SPMDSolver, other input tensors + # belonging to the same operator can be partitioned. For example, in some LLMs, the embedding matrix is shared + # with the output layer. In this case, the batch dim / seq dim of the activation tensor can be partitioned. + for ftensor, stage_info in tensor_splits.items(): + if not ftensor.is_param(): + continue + splits = set(k.dim if k is not None else None for v in stage_info.values() for k in v) + find_replicated = None in splits + splits = list(splits) + # For safety, we will add multiref when detecting shared param are all replicated for pipeline parallelism. + # The reason is that stages may have different number of devices, it is hard to synchronize gradients directly + # by inserting reducers although weights are all REPLICAED. + if len(splits) > 1 or (pp_enabled and find_replicated): + _logger.info(f'add multiref for shared param {ftensor}') + graph.multiref(ftensor, comment='shared param') + + # set pipeline stages + if pp_enabled: + graph.staging([s[0] for s in pp_stages]) + pp_segs: list[IRSegment] = graph.select(ntype=IRSegment, flatten=False) + + for stage_id, stage in enumerate(pp_segs): + for node in stage.select(ntype=IRFwOperation): + if node in fw_nodes: + continue + if node.fn == multiref: # skip multiref nodes + continue + assert node.fn == identity, "Internal Error: non-identity node added in staging" + # force identity nodes to be replicated + # these nodes are usually added for data transfer between stages in graph.staging + # TODO: is it possible to have TP here? + op_plans[node] = OpPlan(op=node, stage_id=stage_id, partition=None) + + # add multiref to an activation tensor when the states of the tensor and its grad are different + # among consumers and current segment's outputs + for ftensor, stage_info in tensor_splits.items(): + # Parameter are already handled above + if ftensor.is_grad() or ftensor.is_param(): + continue + + # check if this tensor is in the output of each stage + is_seg_output: dict[int, bool] = {} + for idx, stage in enumerate(pp_segs): + is_seg_output[idx] = IR.contains_object( + stage.outputs(), + lambda x: isinstance(x, IRSubTensor) and x.parent == ftensor + ) + + for idx, splits in stage_info.items(): + stage = pp_segs[idx] + split_list = list(splits) + if len(split_list) > 1 or ( + is_seg_output[idx] and split_list[0] is not None # treat segment output as a consumer + ): + _logger.debug(f'add multiref for {ftensor} in stage {stage}') + stage.multiref(ftensor, comment='activation') + + # stage-wise tensor parallelism + curr_devices = list(range(ngpus)) + for op_plan in op_plans.values(): + idx = op_plan.stage_id % pp_size + devs = curr_devices[idx * tp_size: (idx + 1)* tp_size] + if op_plan.partition is not None: + _tp(graph, op_plan.op, devs, idx=op_plan.partition.input, dim=op_plan.partition.dim) + else: + _replica(graph, op_plan.op, devs) + + # replicate dataloader + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, devs=list(range(ngpus))) + + if pp_enabled: + cfg.apply_pipeline_scheduler(graph, nstages, nmicros, scheduler) + + return graph + + +def pas_fsdp(graph, cfg: 'ComputeConfig'): + """ + A simple FSDP policy: + 1. all operators are replicated + 2. user specified modules with `cfg.pas_config.recompute_modules` are recomputed + 3. shard policy is configured in cfg.use_zero and cfg.zero_ngroups + 4. CPU offload is not supported + """ + if cfg.plan_ngpus != 1: + raise ValueError("FSDP policy only supports 1 plan GPU") + if not cfg.use_zero: + raise ValueError("FSDP policy requires use_zero to be 1/3") + # use 'recomputes' instead of 'recompute_modules' + # to avoid confliction with autodist config + recompute_modules = cfg.pas_config.get('recomputes', '') + # parse recompute_modules + # user can also provide a list of Module classes. + if isinstance(recompute_modules, str): + recompute_modules = recompute_modules.strip() + if not recompute_modules: + recompute_modules = [] + else: + recompute_modules = [m.strip() for m in recompute_modules.split(',')] + + if recompute_modules: + recompute_modules = [load_type(rm) for rm in recompute_modules] + else: + recompute_modules = [] + + cur_recompute_id = -1 + cur_recompute_module_fqn = None + for node in get_pas_ops(graph): + recompute_module: torch.nn.Module + for rm in recompute_modules: + if rm in node.module_class_chain: + recompute_module = rm + break + else: + cur_recompute_module_fqn = None + continue + + mod_fqn = node.get_module_fqn(recompute_module) + if cur_recompute_module_fqn is None or cur_recompute_module_fqn != mod_fqn: + cur_recompute_id += 1 + cur_recompute_module_fqn = mod_fqn + yield OpPlan(node, recompute_id=cur_recompute_id) diff --git a/nnscaler/profiler/database.py b/nnscaler/profiler/database.py index 27c7f54d..23ba4cd0 100644 --- a/nnscaler/profiler/database.py +++ b/nnscaler/profiler/database.py @@ -260,7 +260,7 @@ def unpack_hook(x): # warmup warmup_cnt = 0 tic = time.perf_counter() - while time.perf_counter() - tic < warmup_sec: + while time.perf_counter() - tic < warmup_sec and warmup_cnt < prof_times: run_step(func, tensors, train_kwargs, backward=require_backward) torch.cuda.synchronize() warmup_cnt += 1 @@ -511,7 +511,10 @@ def load_ops(self, folder: str): if filename.endswith('.json'): with open(os.path.join(folder, filename)) as f: signature = filename[:-len('.json')] - loaded_json = json.load(f) + try: + loaded_json = json.load(f) + except json.JSONDecodeError: + raise RuntimeError(f'fail to load profiling data from {filename}, please check the file content') self._data[signature] = {key: ProfiledMetrics(**value) for key, value in loaded_json.items()} def __repr__(self) -> str: diff --git a/nnscaler/runtime/__init__.py b/nnscaler/runtime/__init__.py index d0171757..46be9e99 100644 --- a/nnscaler/runtime/__init__.py +++ b/nnscaler/runtime/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from nnscaler.runtime import _patch_torch from nnscaler.runtime import executor from nnscaler.runtime import device from nnscaler.runtime import adapter diff --git a/nnscaler/runtime/_patch_torch.py b/nnscaler/runtime/_patch_torch.py new file mode 100644 index 00000000..53ab7438 --- /dev/null +++ b/nnscaler/runtime/_patch_torch.py @@ -0,0 +1,104 @@ +# The following code is copied from torch.distributed.distributed_c10d in PyTorch 2.4.0 +# For copyright, see pytorch/LICENSE +# https://github.com/pytorch/pytorch/blob/main/LICENSE + + +import torch +import torch.distributed + + +if torch.__version__ < (2, 4, 0): + # send_object_list and recv_object_list only available in PyTorch 2.4.0+ + + import torch.distributed.distributed_c10d as dist_c10d + + + if torch.__version__ < (2, 3, 0): + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device) + else: + def _object_to_tensor(obj, device, group): + return dist_c10d._object_to_tensor(obj, device, group) + + + if torch.__version__ < (2, 3, 0): + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size) + else: + def _tensor_to_object(tensor, size, group): + return dist_c10d._tensor_to_object(tensor, size, group) + + + def send_object_list(object_list, dst, group=None, device=None): + if torch.distributed.get_rank() == dst: + raise ValueError( + "Invalid destination rank: destination rank should not be the same as " + "the rank of the current process." + ) + + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("send_object_list") + return + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # sent to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + # Serialize object_list elements to tensors on src rank. + tensor_list, size_list = zip(*[_object_to_tensor(obj, current_device, group) for obj in object_list]) + object_sizes_tensor = torch.cat(size_list) + + # Send object sizes + torch.distributed.send(object_sizes_tensor, dst=dst, group=group) + + # Concatenate and send serialized object tensors + # Note: torch.cat will do an extra memory copy to the current device, if the tensor_list + # has only one element, we can skip the copy. + if len(tensor_list) == 1: # type: ignore[possibly-undefined] + object_tensor = tensor_list[0] + else: + object_tensor = torch.cat(tensor_list) + + torch.distributed.send(object_tensor, dst=dst, group=group) + + + def recv_object_list(object_list, src=None, group=None, device=None): + if dist_c10d._rank_not_in_group(group): + dist_c10d._warn_not_in_group("recv_object_list") + return -1 + + # Current device selection. + # To preserve backwards compatibility, ``device`` is default to ``None`` + # in which case we run current logic of device selection, i.e. + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the + # case it is not ``None`` we move the size and object tensors to be + # received to this device. + current_device = device or torch.device("cuda", torch.cuda.current_device()) + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long, device=current_device) + + # Receive object sizes + rank_sizes = torch.distributed.recv(object_sizes_tensor, src=src, group=group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type] + dtype=torch.uint8, + device=current_device + ) + + rank_objects = torch.distributed.recv(object_tensor, src=src, group=group) + assert rank_sizes == rank_objects, "Mismatch in return ranks for object sizes and objects." + # Deserialize objects using their stored sizes. + offset = 0 + for i, obj_size in enumerate(object_sizes_tensor): + obj_view = object_tensor[offset : offset + obj_size] + obj_view = obj_view.type(torch.uint8) + offset += obj_size + object_list[i] = _tensor_to_object(obj_view, obj_size, group) + return rank_objects + + torch.distributed.send_object_list = send_object_list + torch.distributed.recv_object_list = recv_object_list diff --git a/nnscaler/runtime/adapter/reducer.py b/nnscaler/runtime/adapter/reducer.py index be83fdc3..19c3c8b0 100644 --- a/nnscaler/runtime/adapter/reducer.py +++ b/nnscaler/runtime/adapter/reducer.py @@ -3,6 +3,7 @@ from typing import List, Dict, Tuple, Any, Callable, Optional, Set, Sequence from functools import partial +from dataclasses import dataclass import math import logging import torch @@ -11,12 +12,13 @@ from nnscaler.runtime.device import DeviceGroup from nnscaler.profiler.timer import CudaTimer from nnscaler.flags import RuntimeFlag +from nnscaler.utils import unchecked_fields _logger = logging.getLogger(__name__) # According to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#device-memory-accesses -# Any address of a variable residing in global memory or returned by one of the memory allocation +# Any address of a variable residing in global memory or returned by one of the memory allocation # routines from the driver or runtime API is always aligned to at least 256 bytes. # But in our practice, we found that 16 bytes alignment is enough, it can be modified if unaligned access is detected. ALIGNED_BYTES = 16 @@ -59,15 +61,31 @@ def _get_reduce_op(reduce_op: str) -> torch.distributed.ReduceOp: raise KeyError(f"Unsupported reduce op {reduce_op}. Supported reduce op: {supported}") +@dataclass +class _Z3ParamInfo: + shape: torch.Size # original shape of the parameter + start: int + end: int + param_buffer_start: int = -1 + param_buffer_end: int = -1 + + def numel(self) -> int: + return self.end - self.start + + def numel_with_padding(self) -> int: + return self.param_buffer_end - self.param_buffer_start + + class Bucket: - def __init__(self, params: List[torch.nn.Parameter], + def __init__(self, reducer: 'Reducer', params: List[torch.nn.Parameter], param_buffer: torch.Tensor, grad_buffer: torch.Tensor, reduce_op: torch.distributed.ReduceOp, - group: torch.distributed.ProcessGroup, async_op: bool, zero: bool, + group: torch.distributed.ProcessGroup, async_op: bool, zero: int, zero_subgroup: torch.distributed.ProcessGroup = None, zero_crossgroup: torch.distributed.ProcessGroup = None, zero_use_reduce_scatter: bool = False, align_size: int = ALIGNED_BYTES, + param_cls: Any = None, ): """ Create a communication unit for parameter allreduce. @@ -82,14 +100,17 @@ def __init__(self, params: List[torch.nn.Parameter], reduce_op (torch.distributed.ReduceOp): the reduce op used by collectives group (torch.distributed.ProcessGroup): communication group async_op (bool): whether to use asynchronous operation - zero (bool): whether to use zero optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_subgroup (torch.distributed.ProcessGroup): the subgroup for zero optimization the current rank belongs to zero_crossgroup (torch.distributed.ProcessGroup): the communication group for cross zero group allreduce when reduce scatter is enabled zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter + param_cls (Any): the class of the parameters """ self._params: List[torch.nn.Parameter] = params + self._param_cls: Any = param_cls self._pofset: Dict[torch.nn.Parameter, int] = {} self._reduce_op = reduce_op self._group = group @@ -99,7 +120,7 @@ def __init__(self, params: List[torch.nn.Parameter], self._hooks: List[Tuple[Any, RemovableHandle]] = [] self._async: bool = async_op - self._zero: bool = zero + self._zero: int = zero self._zero_use_reduce_scatter = zero_use_reduce_scatter self._contiguous_params = param_buffer self._contiguous_grads = grad_buffer @@ -123,6 +144,9 @@ def __init__(self, params: List[torch.nn.Parameter], self._pre_hooks: List[Callable] = [] self._post_hooks: List[Callable] = [] + self._z3 = self._zero > 1 + self._reducer = reducer + # only async will enable contiguous gradient self.build() self.register_hooks() @@ -137,11 +161,21 @@ def params(self) -> List[torch.nn.Parameter]: """Parameter list""" return self._params + @property + def param_cls(self) -> Any: + """Class of the parameters in the bucket""" + return self._param_cls + @property def zero(self) -> bool: """Whether enable zero for this bucket""" return self._zero + @property + def zero3(self) -> bool: + """Whether enable zero3 for this bucket""" + return self._z3 + def get_aligned_numel(self, param) -> int: """ Get the aligned number of elements for a parameter @@ -168,6 +202,22 @@ def _group_reduce_scatter(self): partial_tensor, self._contiguous_grads, op=self._reduce_op, group=self._zero_subgroup) + def _get_opt_param_data(self): + if not self._zero or self._zero > 1: + # when zero3 is used, the parameters are already sharded in reducer + opt = self._contiguous_params + else: + assert self._zero == 1 + rank = torch.distributed.get_rank(group=self._zero_subgroup) + assert len(self._contiguous_params) % self._zgroup_sz == 0 + # Note: + # There may be paddings both in the middle and at the end of the contiguous buffer + # When there are paddings in the middle or end of the contiguous buffer, + # the calculation of gnorm is not affected as long as the paddings are all 0. + # So for now, it looks harmless. + opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] + return opt + def build(self): """ Build offset for each parameter @@ -179,18 +229,7 @@ def build(self): ofst += _aligned_nelement(param.nelement(), param.element_size(), self._align_size) # build parameter for optimizer (shared storage). # Its gradient will be updated everytime calling `self.sync_grads()` - if not self._zero: - opt = self._contiguous_params - else: - rank = torch.distributed.get_rank(group=self._zero_subgroup) - assert len(self._contiguous_params) % self._zgroup_sz == 0 - # Note: - # There may be paddings both in the middle and at the end of the contiguous buffer - # When there are paddings in the middle or end of the contiguous buffer, - # the calculation of gnorm is not affected as long as the paddings are all 0. - # So for now, it looks harmless. - opt = self._contiguous_params.chunk(self._zgroup_sz)[rank] - self._param_for_optimizer = torch.nn.Parameter(opt) + self._param_for_optimizer = torch.nn.Parameter(self._get_opt_param_data()) def register_hooks(self): """ @@ -205,13 +244,43 @@ def register_hooks(self): """ @torch.no_grad() - def post_grad_hook(param: torch.nn.Parameter, *unused): + def post_grad_hook(param: torch.nn.Parameter, *unused): # pragma: no cover # stream = DeviceGroup().get_stream('reducer') ofst = self._pofset[param] + rank = torch.distributed.get_rank() # TODO: need to handle sparse gradients in torch.nn.Embedding - self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + if self._z3: + z3_info = self._reducer.get_z3_info(param) + grad = param.grad.data.view(-1) + padded_numel = z3_info.numel_with_padding() * self._zgroup_sz + if grad.numel() < padded_numel: + # add padding + grad = torch.nn.functional.pad( + grad, + (0, padded_numel - grad.numel()), + mode='constant', + value=0.0, + ) + output = torch.zeros(z3_info.numel_with_padding(), device=grad.device, dtype=grad.dtype) + torch.distributed.reduce_scatter_tensor( + output, + grad, + op=self._reduce_op, + group=self._zero_subgroup + ) + # accumulate the param grad in zero3 way + self._contiguous_grads[ofst:ofst+z3_info.numel()]\ + .add_(output[0:z3_info.end-z3_info.start]) + else: + self._contiguous_grads[ofst:ofst+param.numel()].add_(param.grad.data.view(-1)) + param.grad = None + if self._z3: + # in most cases, it is not necessary to post-evict here, + # let's add it for safety + self._reducer.postevict_param(param) + if RuntimeFlag.skip_reducer: return self._async_param_cnt += 1 @@ -225,7 +294,9 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): # apply pre hooks self._apply_pre_hooks() # communication - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: + # when zero3 is used, the parameters and gradients are already sharded in reducer + # so only allreduce is needed if self._zgroup_sz == self._wsz: rank = torch.distributed.get_rank(group=self._group) shards = list(self._contiguous_grads.chunk(self._wsz, dim=0)) @@ -236,9 +307,13 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): group=self._group, async_op=True) else: assert False, "group zero + reducescatter is not supported in async mode, " \ - "because the two steps (allreduce, reducescatter) use " \ - "two communication groups, which may induce deadlock." + "because the two steps (allreduce, reducescatter) use " \ + "two communication groups, which may induce deadlock." self._group_reduce_scatter() + elif self._zero > 1: + self._async_handle = torch.distributed.all_reduce( + self._contiguous_grads, op=self._reduce_op, + group=self._zero_crossgroup, async_op=True) else: self._async_handle = torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, @@ -247,13 +322,24 @@ def post_grad_hook(param: torch.nn.Parameter, *unused): for param in self._params: # same trick with FSDP and Megatron # reference: https://github.com/pytorch/pytorch/blob/v1.13.1/torch/distributed/fsdp/fully_sharded_data_parallel.py#L3177-L3188 - param_tmp = param.expand_as(param) + if self._z3: + old_param_data = param.data + # here we need the full parameter to build the computation graph + # let's create a temporary parameter with full shape to fake it. + param.data = torch.empty(self._reducer.get_z3_info(param).shape, dtype=param.dtype, device=param.device) + param_tmp = param.expand_as(param) + param.data = old_param_data + else: + param_tmp = param.expand_as(param) + # gets its AccumulateGrad object grad_acc = param_tmp.grad_fn.next_functions[0][0] hook = grad_acc.register_hook(partial(post_grad_hook, param)) # grad_acc must keep, otherwise the hook won't take effect self._hooks.append((grad_acc, hook)) + torch.cuda.empty_cache() + def sync_grads(self): """ Wait until allreduce finished (async), or perform allreduce (sync). @@ -274,8 +360,14 @@ def sync_grads(self): # apply pre-hooks self._apply_pre_hooks() # synchrnoize gradients - if self._zero and self._zero_use_reduce_scatter: + if self._zero == 1 and self._zero_use_reduce_scatter: self._group_reduce_scatter() + elif self._zero > 1: + torch.distributed.all_reduce( + self._contiguous_grads, + op=self._reduce_op, + group=self._zero_crossgroup + ) else: torch.distributed.all_reduce( self._contiguous_grads, op=self._reduce_op, group=self._group) @@ -284,10 +376,16 @@ def sync_grads(self): for param in self._params: assert param.grad is None pofst = self._pofset[param] + if self._z3: + z3_info = self._reducer.get_z3_info(param) + # the param should have been evicted + assert z3_info.numel_with_padding() == param.numel() and len(param.shape) == 1, \ + f"internal error: zero3 param size mismatch, " \ + f"expect {[z3_info.numel_with_padding()]} got {param.shape}" param.grad = self._contiguous_grads[pofst:pofst+param.numel()].view(param.size()) # setup gradient for optimizer parameters - if self._zero: + if self._zero == 1: rank = torch.distributed.get_rank(group=self._zero_subgroup) grad = self._contiguous_grads.chunk(self._zgroup_sz, dim=0)[rank] self._param_for_optimizer.grad = grad @@ -301,7 +399,7 @@ def gather_params(self): """ All-gather parameters """ - assert self._zero, "gathering paramters is only for zero optimization." + assert self._zero == 1, "gathering paramters is only for zero1 optimization." rank = torch.distributed.get_rank(group=self._zero_subgroup) CudaTimer().start(field_name='comm', predefined=True) src_tensor = self._contiguous_params.chunk(self._zgroup_sz, dim=0)[rank] @@ -363,6 +461,81 @@ def reset(self): self._async_param_cnt = 0 self._async_handle = None + def sleep(self): + """ + release reference to contiguous buffer in reducer + """ + cpu = torch.device('cpu') + self._param_for_optimizer.data = self._param_for_optimizer.data.to(cpu) + # set none to release memory + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self, param_buffer, grad_buffer): + """ + re-attach to the contiguous buffer and re-build hooks + """ + self._contiguous_params = param_buffer + self._contiguous_grads = grad_buffer + self._param_for_optimizer.data = self._get_opt_param_data() + + # TODO(yizhu1): seems moving attributes to cpu will make hooks invalid. + # The reason is that torch's autograd will reset the AccumulateGrad object if the data is set: + # https://github.com/pytorch/pytorch/blob/38a492d40d7ebb2856cb120df337c6cdac244528/torch/csrc/autograd/variable.cpp#L473 + # To make the resuming process safe, re-register them here. + self._hooks = [] + self.register_hooks() + + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + + fields = unchecked_fields(self) + state[fields._params] = [param_map[p] for p in self._params] + state[fields._pofset] = {param_map[p]: ofst for p, ofst in self._pofset.items()} + state[fields._param_for_optimizer] = torch.nn.Parameter(torch.empty_like(self._param_for_optimizer, device='meta')) + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._async_handle, None) + state.pop(fields._async_param_cnt, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove hooks + state.pop(fields._hooks, None) + state.pop(fields._pre_hooks, None) + state.pop(fields._post_hooks, None) + + # remove reducer reference + state.pop(fields._reducer, None) + + return state + + @classmethod + def _unpack(cls, state: dict, reducer: 'Reducer'): + """ + Return a fake bucket that carries the same information. + """ + bucket = object.__new__(cls) + bucket.__dict__.update(state) + bucket._reducer = reducer + + for param in bucket._params: + assert param.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._contiguous_grads.device.type == 'meta' + assert bucket._param_for_optimizer.device.type == 'meta' + + return bucket + class Reducer: # the default bucket cap for async reducer in megabytes @@ -370,11 +543,13 @@ class Reducer: # https://github.com/pytorch/pytorch/blob/4fd16dd8aa259cd75c9a6d2ddcd8171cd1ee8e28/torch/nn/parallel/distributed.py#L548 _DEFAULT_BUCKET_CAP_MB = 25 # 25MB, the same as pytorch - def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None, - reduce_op: str = 'sum', async_op: bool = False, - zero: bool = False, zero_ngroups: int = 1, - zero_use_reduce_scatter: bool = False, - align_size: int = ALIGNED_BYTES + def __init__(self, ranks: List[int], + *, + max_bucket_size_bytes: Optional[int] = None, + reduce_op: str = 'sum', async_op: bool = False, + zero: int = 0, zero_ngroups: int = 1, + zero_use_reduce_scatter: bool = False, + align_size: int = ALIGNED_BYTES, ): """ Create a reducer applied on a set of weights for weight reduction @@ -389,12 +564,15 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None Default is `None` reduce_op (str): reduce operation, can be 'sum', 'avg', 'max' or 'min' (default 'sum') async_op (bool): whether to overlap with backward computation (default False) - zero (bool): whether to apply ZeRO optimization on gradients + zero (int): whether to use zero optimization on gradients, currently only 0/1/3 are supported + zero=2 will be treated as zero=3 zero_ngroups (int): number of ZeRO subgroups in the original ZeRO group zero_use_reduce_scatter (bool): whether to use reduce scatter for zero optimization align_size (int): the alignment size in bytes for each parameter """ + # the parameters with same class will be consecutive in the list. self._params: List[torch.nn.Parameter] = list() + self._param_clss: Dict[torch.nn.Parameter, Any] = dict() # the class of each parameter, used for sorting self._param_ids: Set[int] = set() self._numel: int = 0 self._ranks = ranks @@ -409,7 +587,7 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None # buckets stands for a transission unit self._buckets: List[Bucket] = list() self._async: bool = async_op - self._zero: bool = zero + self._zero: int = int(zero) self._zero_use_reduce_scatter = zero_use_reduce_scatter self._align_size: int = align_size if self._align_size % ALIGNED_BYTES != 0: @@ -419,6 +597,13 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None self._contiguous_params: torch.Tensor = None self._contiguous_grads: torch.Tensor = None + # record following variables for params offload + # items in the bucket is params list + self.seq_buckets: List[List[torch.nn.Parameter]] = [] + # bucket start and stop pos in buffer + self.starts, self.stops = [], [] + self.buffer_length: int = 0 + # build the subgroup of zero the current rank belongs to. # When zero_ngroups is larger than 1, the number of ranks # will be divided by zero_ngroups into sub rank groups, @@ -454,9 +639,18 @@ def __init__(self, ranks: List[int], max_bucket_size_bytes: Optional[int] = None else: assert zero_ngroups == 1, f"ZeRO number of groups must be 1, but got {zero_ngroups}" self._zero_subgroup = self._group - self._zero_crossgroup = None + # trivial crossgroup for single rank + self._zero_crossgroup = DeviceGroup().get_group([torch.distributed.get_rank()]) + self._zero_ngroups = zero_ngroups + self._z3_size = torch.distributed.get_world_size(group=self._zero_subgroup) + if self._z3_size == 1: + self._zero = 0 # disable zero when only one rank in subgroup + self._z3 = self._zero > 1 + self._z3_rank = torch.distributed.get_rank(group=self._zero_subgroup) + self._z3_params_info: dict[torch.nn.Parameter, _Z3ParamInfo] = dict() + @property def zero_ngroups(self) -> int: return self._zero_ngroups @@ -479,6 +673,11 @@ def zero(self) -> bool: """Whether to apply zero optimization on gradients""" return self._zero + @property + def zero3(self) -> bool: + """Whether to apply ZeRO3""" + return self._zero > 1 + @property def buckets(self) -> Tuple[Bucket, ...]: return tuple(self._buckets) @@ -506,15 +705,75 @@ def add_param(self, param: torch.nn.Parameter): self._param_ids.add(param.data.data_ptr()) self._numel += param.numel() - def build_buckets(self): + def _allocate_buffers(self): + # gradient buffer + self._contiguous_grads: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + # parameter buffer + self._contiguous_params: torch.Tensor = torch.zeros( + (self.buffer_length,), dtype=self._params[0].dtype, + device=torch.cuda.current_device(), requires_grad=False) + + def _bind_params(self): + for params, start, stop in zip(self.seq_buckets, self.starts, self.stops): + # replace underlying parameter content using shared storage from parameter + ofst = start + for param in params: + with torch.no_grad(): + self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) + if self._z3: + param.data = self._contiguous_params[ofst:ofst+param.numel()] + self._z3_params_info[param].param_buffer_start = ofst + self._z3_params_info[param].param_buffer_end = ofst + param.numel() + else: + param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) + aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) + ofst += aligned_nelements + + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): """ Build buckets the reducer. - The parameters in each bucket have consistent data types, + The parameters in each bucket have consistent data types and classes, and each bucket contains at least one parameter. If the bucket contains more than 2 parameters, than the total size is samller than the max_bucket_size_bytes. """ + self._param_clss = {} + if param_clss: + # only keep parameters that are in self._params + self._param_clss = {p: param_clss[p] for p in self._params} + # sort parameters by their class + # which can help bucket building + self._params.sort(key=lambda p: self._param_clss[p]) + + # step 0: param split for zero3 + if self._z3: + for param in self._params: + if not param.requires_grad: + continue + + chunk_size = (param.numel() + self._z3_size - 1) // self._z3_size + start = self._z3_rank * chunk_size + end = min(start + chunk_size, param.numel()) + self._z3_params_info[param] = _Z3ParamInfo(shape=param.shape, start=start, end=end) + # clone the data so original param can be released + # this padding is required + # to make sure all ranks in the zero subgroup have the same bucket layout. + if end - start < chunk_size: + padding = chunk_size - (end - start) + param.data = torch.nn.functional.pad( + param.data.view(-1)[start:end], + (0, padding), + mode='constant', + value=0.0, + ) + else: + param.data = param.data.view(-1)[start:end].clone() + + torch.cuda.empty_cache() + # step 1: build bucket for overlapping gradient synchronization # self._numel * 8 + 1 here is to make sure # the bucket size is larger than the total size of all parameters @@ -526,9 +785,9 @@ def build_buckets(self): # (used in pytorch, with a couple percentage improvement) bucket_size = self._numel * 8 + 1 if not self._bucket_size else self._bucket_size - # items in the bucket is params list - seq_buckets: List[List[torch.nn.Parameter]] = [] + seq_buckets_cls: List[Any] = [] last_bucket_size = None + last_bucket_cls = None assert len(set(p.dtype for p in self._params)) == 1, ( "All parameters in the reducer should have the same data type" @@ -540,53 +799,45 @@ def build_buckets(self): # It will go the `else` branch # and finish the current bucket and start a new bucket. # This new bucket will be sealed in the next iteration - if len(seq_buckets) == 0: - seq_buckets.append([param]) + if len(self.seq_buckets) == 0: + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size - elif last_bucket_size + cur_byte_size <= bucket_size: - seq_buckets[-1].append(param) + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) + elif last_bucket_size + cur_byte_size <= bucket_size \ + and last_bucket_cls == self._param_clss.get(param, None): + self.seq_buckets[-1].append(param) last_bucket_size += cur_byte_size else: - seq_buckets.append([param]) + self.seq_buckets.append([param]) last_bucket_size = cur_byte_size + last_bucket_cls = self._param_clss.get(param, None) + seq_buckets_cls.append(last_bucket_cls) # step 2: build meta data for the offset of each bucket # the start of each bucket will be padded to the next multiple of `len(self.ranks)` - buffer_length: int = 0 - starts, stops = [], [] - for params in seq_buckets: - starts.append(buffer_length) + for params in self.seq_buckets: + self.starts.append(self.buffer_length) numel = sum(_aligned_nelement(p.nelement(), p.element_size(), self._align_size) for p in params) # this pad is for zero, which needs numels in each Bucket can be divided by the number of ranks in this group * _align_size # so that each chunck during zero can be divided by _align_size align_nelements = self._align_size // params[0].element_size() * len(self._ranks) padding = (align_nelements - numel % align_nelements) % len(self._ranks) - buffer_length += numel + padding - stops.append(buffer_length) + self.buffer_length += numel + padding + self.stops.append(self.buffer_length) - # step3: allocate memory - # gradient buffer - self._contiguous_grads: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) - # parameter buffer - self._contiguous_params: torch.Tensor = torch.zeros( - (buffer_length,), dtype=self._params[0].dtype, - device=torch.cuda.current_device(), requires_grad=False) + # step 3: allocate memory + self._allocate_buffers() + + # step 4: bind parameters + self._bind_params() - # step 4: build buckets + # step 5: build buckets buckets: List[Bucket] = [] - for params, start, stop in zip(seq_buckets, starts, stops): - # replace underlying parameter content using shared storage from parameter - ofst = start - for param in params: - with torch.no_grad(): - self._contiguous_params[ofst:ofst+param.numel()].copy_(param.data.view(-1)) - param.data = self._contiguous_params[ofst:ofst+param.numel()].view(param.size()) - aligned_nelements = _aligned_nelement(param.nelement(), param.element_size(), self._align_size) - ofst += aligned_nelements + for params, param_cls, start, stop in zip(self.seq_buckets, seq_buckets_cls, self.starts, self.stops): # initialize buckets bucket = Bucket( + self, params, self._contiguous_params[start:stop], self._contiguous_grads[start:stop], @@ -598,6 +849,7 @@ def build_buckets(self): self._zero_crossgroup, self._zero_use_reduce_scatter, self._align_size, + param_cls=param_cls, ) buckets.append(bucket) torch.cuda.empty_cache() @@ -617,12 +869,58 @@ def sync_grads(self): for bucket in self._buckets: bucket.sync_grads() + def get_z3_info(self, param: torch.nn.Parameter) -> _Z3ParamInfo: + """ + Get zero3 param info + if the param is not in zero3, return None + """ + return self._z3_params_info.get(param, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """Prefetch parameter before forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + + info = self._z3_params_info[param] + if param.shape == info.shape: + # no need to gather + return + + full_data = torch.zeros(info.numel_with_padding() * self._z3_size, dtype=param.dtype, + device=torch.cuda.current_device()) + torch.distributed.all_gather_into_tensor( + full_data, + param.data, + group=self._zero_subgroup + ) + param.data = full_data[0:math.prod(info.shape)].view(info.shape).contiguous() + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """Release parameter after forward and backward. + + This is required when zero3 is used. + """ + if not self._z3: + raise RuntimeError("postevict_param is only for zero3 optimization.") + if param not in self._z3_params_info: + raise ValueError(f"parameter {param} not found in zero3 params info.") + info = self._z3_params_info[param] + param.data = self._contiguous_params[info.param_buffer_start:info.param_buffer_end] + def gather_params(self): """Gather parameters with Zero optimizations after `optimizer.step()`. This is required when zero optimization is turned on. """ if not self._zero: return + if self._z3: return # in zero3 mode, no need to gather params for bucket in self._buckets: bucket.gather_params() @@ -652,9 +950,23 @@ def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: Returns: List[torch.nn.Parameter]: parameters for optimizer """ - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self) -> dict[torch.nn.Parameter, Any]: + """ + Get parameters and their classes for optimizers + Please note for ZeRO optimization, + the returned parameters are not the same as the original parameters, + and can have paddings (with value 0.0) both at the end and in the middle of paramters data. + + the calculation of gnorm is not affected as paddings are all 0. + + Returns: + List[torch.nn.Parameter]: parameters for optimizer + """ + params = {} for bucket in self._buckets: - params.append(bucket._param_for_optimizer) + params[bucket._param_for_optimizer] = bucket.param_cls return params def broadcast_params(self): @@ -723,3 +1035,82 @@ def clear_post_hooks(self): """Clear all post hooks.""" for bucket in self._buckets: bucket.clear_post_hooks() + + def sleep(self): + """ + release contiguous buffers on the device to save memory + """ + for bucket in self._buckets: + bucket.sleep() + + self._contiguous_params = None + self._contiguous_grads = None + + def wake_up(self): + """ + reallocate contiguous buffers and related objects + """ + self._allocate_buffers() + self._bind_params() + + for start, stop, bucket in zip(self.starts, self.stops, self._buckets): + bucket.wake_up( + self._contiguous_params[start:stop], + self._contiguous_grads[start:stop], + ) + + def _pack( + self, + param_map: dict[torch.nn.Parameter, torch.nn.Parameter], + ): + """ + Get the information of the bucket. + """ + state = self.__dict__.copy() + fields = unchecked_fields(self) + + state[fields._params] = [param_map[p] for p in self._params] + state[fields._z3_params_info] = {param_map[p]: info for p, info in self._z3_params_info.items()} + state[fields._param_clss] = {param_map[p]: param_cls for p, param_cls in self._param_clss.items()} + state[fields._contiguous_params] = torch.empty_like(self._contiguous_params, device='meta') + state[fields._contiguous_grads] = torch.empty_like(self._contiguous_grads, device='meta') + + state[fields._buckets] = [ + bucket._pack(param_map) + for bucket in self._buckets + ] + + # remove torch handles + state.pop(fields._group, None) + state.pop(fields._zero_subgroup, None) + state.pop(fields._zero_crossgroup, None) + + # remove unuseful information + state.pop(fields._param_ids, None) + state.pop(fields.seq_buckets, None) + + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Return a fake bucket that carries the same information. + """ + reducer = object.__new__(cls) + fields = unchecked_fields(reducer) + + buckets = state.pop(fields._buckets) + reducer._buckets = [ + Bucket._unpack(bucket, reducer) for bucket in buckets + ] + reducer.__dict__.update(state) + for param in reducer._params: + assert param.device.type == 'meta' + + for param in reducer._param_clss.keys(): + assert param.device.type == 'meta' + + assert reducer._contiguous_grads.device.type == 'meta' + assert reducer._contiguous_params.device.type == 'meta' + + return reducer diff --git a/nnscaler/runtime/f16_optimizer.py b/nnscaler/runtime/f16_optimizer.py index 414361a5..bee85b89 100644 --- a/nnscaler/runtime/f16_optimizer.py +++ b/nnscaler/runtime/f16_optimizer.py @@ -4,11 +4,12 @@ # CREDITS: This implementation is inspired by Fairseq https://github.com/facebookresearch/fairseq/blob/main/fairseq/optim/fp16_optimizer.py import logging -from typing import Optional, TYPE_CHECKING +import types +from typing import TYPE_CHECKING import torch -from nnscaler.cli.train_hook import TrainHook +from nnscaler.runtime.hybrid_optimizer import ScaleDelayedOptimizerMixin if TYPE_CHECKING: from nnscaler.cli.trainer import Trainer @@ -16,7 +17,7 @@ logger = logging.getLogger(__name__) -class MixedPrecisionF16OptimizerMixin(TrainHook): +class MixedPrecisionF16OptimizerMixin(ScaleDelayedOptimizerMixin): """ A mixin class for mixed precision optimizer. Support both FP16 and BF16 parameters. @@ -31,7 +32,10 @@ class MixedPrecisionF16OptimizerMixin(TrainHook): def __init__(self, *args, **kwargs): # forward __init__ call to the next class in mro(method resolution order) super().__init__(*args, **kwargs) - self._multiply_factor = 1.0 + # This flag is used to indicate whether fp32_params are loaded from checkpoint. + # If not, we will sync from fp16 params to fp32 params in after_load_checkpoint. + # If the model is trained from scratch, this flag will be None. + self._fp32_params_loaded = None def after_setup(self, trainer: 'Trainer') -> None: """ @@ -44,17 +48,25 @@ def after_setup(self, trainer: 'Trainer') -> None: Assumption: `clip_gnorm` is called immediately after `scale_grads` in training loop. """ - trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm - trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm - trainer.optimizer._scale_grads = trainer.optimizer.scale_grads - trainer.optimizer.scale_grads = self.overrided_scale_grads + if trainer.optimizer is self: + # don't override when using HybridOptimizer + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + trainer.optimizer.scale_grads = self.overrided_scale_grads + + # step method is overrided below to apply the scaling factor @classmethod - def build_fp32_params(cls, params): + def build_fp32_params(cls, params: list[torch.nn.Parameter]) -> list[torch.nn.Parameter]: # create FP32 copy of parameters and grads fp32_params = [] for p in params: - p32 = torch.nn.Parameter(p.data.float()) + if p.data.dtype != torch.float32: + p32 = torch.nn.Parameter(p.data.float()) + else: + # make sure the storage is not shared with original parameter + p32 = torch.nn.Parameter(p.data.clone()) p32.grad = torch.zeros_like(p32.data) fp32_params.append(p32) return fp32_params @@ -70,18 +82,22 @@ def step(self, closure=None): def zero_grad(self, set_to_none: bool = True): """ Clears the gradients of all optimized parameters. - Will ignore `set_to_none` and always set fp16 grads to None, and fp32 grads to zero. + Will ignore `set_to_none` and always set fp16 grads and fp32 grads to None. """ for p in self.f16_params: p.grad = None for p32 in self.fp32_params: - if p32.grad is not None: - p32.grad.zero_() + p32.grad = None def state_dict(self): """Return the optimizer's state dict.""" state_dict = super().state_dict() + # called from hybrid optimizer before call `.step` (to get the param_groups of the wrapped optimizer) + # In this case, state_dict['state'] is empty. + if not state_dict['state']: + return state_dict + # move fp32_params to the same level with 'exp_avg' and 'exp_avg_sq' # we do this to handle the merge of sharded checkpoint in nnscaler assert 'state' in state_dict, f'state not found in state_dict: {state_dict.keys()}' @@ -111,12 +127,15 @@ def load_state_dict(self, state_dict): param.data = state_dict['state'][i]['fp32_params'].data.to(device) # pop to avoid store a redundant copy in the wrapped optimizer state_dict['state'][i].pop('fp32_params') + else: + logger.warning('fp32_params not found in state_dict, will sync from fp16 params to fp32 params') + self._sync_fp16_params_to_fp32() - if len(self.param_groups) != 1: - raise RuntimeError('only support one param group') - self.param_groups[0]['params'] = self.fp32_params + if len(self.param_groups) != 1: + raise RuntimeError('only support one param group') super().load_state_dict(state_dict) + self._fp32_params_loaded = True def _sync_f16_grads_to_fp32(self): # copy FP16 grads to FP32 @@ -148,39 +167,46 @@ def _sync_fp16_params_to_fp32(self): continue p32.data.copy_(p.data) + def on_load_checkpoint(self, trainer, checkpoint) -> None: + self._fp32_params_loaded = False + logger.info('Set _fp32_params_loaded to False in on_load_checkpoint hook') + def after_load_checkpoint(self, trainer, checkpoint) -> None: - if 'nnscaler' not in checkpoint: - # this checkpoint is not created by nnscaler. + if not self._fp32_params_loaded: + logger.info('fp32_params not loaded, will sync from fp16 params to fp32 params') self._sync_fp16_params_to_fp32() + self._fp32_params_loaded = True - def overrided_scale_grads(self, scale: float): - """ - Scale the gradients by a factor. - Will override the original scale_grads method in ParallelOptimizer. - """ - self._multiply_factor *= scale + def _unfold_params(self, params) -> tuple[list[torch.nn.Parameter], dict]: + params = list(params) + if not params: + raise ValueError("optimizer got an empty parameter list") - def overrided_clip_gnorm(self, max_norm: Optional[float] = None) -> float: - """ - Will override the original clip_gnorm method in ParallelOptimizer. - """ - # self._clip_gnorm() is ParallelOptimizer.clip_gnorm - grad_norm = self._multiply_factor * self._clip_gnorm() - if max_norm is not None and max_norm > 0.0: - clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) - self._multiply_factor *= clip_coef - return grad_norm + if isinstance(params[0], dict): + if len(params) > 1: + raise ValueError("MixedPrecisionF16OptimizerMixin only supports one param group") + unfolded_params = list(params[0]['params']) + unfolded_kwargs = {k: v for k, v in params[0].items() if k != 'params'} + else: + if not all(isinstance(p, torch.nn.Parameter) for p in params): + raise ValueError("optimizer params should be either a list of Parameters or a dict with 'params' key") + unfolded_params = params + unfolded_kwargs = {} + + return unfolded_params, unfolded_kwargs class MixedPrecisionAdam(MixedPrecisionF16OptimizerMixin, torch.optim.Adam): def __init__(self, params, **kwargs): - self.f16_params = list(params) + self.f16_params, unfolded_kwargs = self._unfold_params(params) self.fp32_params = self.build_fp32_params(self.f16_params) + kwargs = {**unfolded_kwargs, **kwargs} super().__init__(self.fp32_params, **kwargs) class MixedPrecisionAdamW(MixedPrecisionF16OptimizerMixin, torch.optim.AdamW): def __init__(self, params, **kwargs): - self.f16_params = list(params) + self.f16_params, unfolded_kwargs = self._unfold_params(params) self.fp32_params = self.build_fp32_params(self.f16_params) + kwargs = {**unfolded_kwargs, **kwargs} super().__init__(self.fp32_params, **kwargs) diff --git a/nnscaler/runtime/function/function.py b/nnscaler/runtime/function/function.py index 51ef947c..8bda380a 100644 --- a/nnscaler/runtime/function/function.py +++ b/nnscaler/runtime/function/function.py @@ -8,7 +8,7 @@ """ from contextlib import contextmanager -from typing import Optional, List, Tuple, Union, Any +from typing import Callable, Optional, List, Tuple, Union, Any import torch import torch.nn.functional as TorchF import operator @@ -81,11 +81,24 @@ def fold_constant(a: Any) -> Any: return a -def multiref(tensor: torch.Tensor, times: int) -> Tuple[torch.Tensor]: +def multiref(tensor: torch.Tensor, times: int, *, clone_level: int = 0) -> Union[torch.Tensor, Tuple[torch.Tensor]]: """ identity forward. Create multiple same tensor. + Args: + tensor (torch.Tensor): input tensor + times (int): number of same tensor to create + clone_level (int): 0: no clone, 1: clone once for all, 2: clone each time + Returns: + Union[torch.Tensor, Tuple[torch.Tensor]]: + if times==1, return tensor; else return tuple of tensors """ - return tensor if times == 1 else tuple([tensor] * times) + if clone_level == 0: + return tensor if times == 1 else tuple([tensor] * times) + elif clone_level == 1: + cloned_tensor = tensor.clone() + return cloned_tensor if times == 1 else tuple([cloned_tensor] * times) + else: # clone_level == 2 + return tensor.clone() if times == 1 else tuple([tensor.clone() for _ in range(times)]) def to(tensor: torch.Tensor, dtype_or_device: Union[torch.device, torch.dtype]) -> torch.Tensor: @@ -312,6 +325,10 @@ def linspace(start: Union[int, torch.Tensor], end: Union[int, torch.Tensor], device=torch.cuda.current_device()) +def eye(n: int, m: Optional[int]=None, requires_grad=False, dtype: torch.dtype=torch.float32) -> torch.Tensor: + return torch.eye(n, m=m, dtype=dtype, device=torch.cuda.current_device(), requires_grad=requires_grad) + + def index_select(input: torch.Tensor, index: torch.Tensor, dim: int) -> torch.Tensor: return torch.index_select(input, dim, index) @@ -366,4 +383,24 @@ def print_time(content: str): rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1 if torch.cuda.is_available(): torch.cuda.synchronize() - print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") \ No newline at end of file + print(f"line timer: {rank} - {datetime.datetime.now()} - {content}") + + +class _BackwardHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, backward_hook: Callable[[], None]): + ctx.save_for_backward() + ctx.backward_hook = backward_hook + return x + + @staticmethod + def backward(ctx, grad_output): + ctx.backward_hook() + return grad_output, None + + +def insert_backward_hook(x: torch.Tensor, backward_hook: Optional[Callable[[], None]]) -> torch.Tensor: + if backward_hook is None: + # no need to add hook + return x + return _BackwardHook.apply(x, backward_hook) diff --git a/nnscaler/runtime/gnorm.py b/nnscaler/runtime/gnorm.py index eb6a3e5b..36d46c38 100644 --- a/nnscaler/runtime/gnorm.py +++ b/nnscaler/runtime/gnorm.py @@ -40,7 +40,7 @@ class TidReplicaInfo: def _calc_grad_shape(slicers_list): - # caculate the shape of each full parameters/grads + # calculate the shape of each full parameters/grads tid2shape = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -50,7 +50,7 @@ def _calc_grad_shape(slicers_list): # slicer: (start, end, step) if slicer.stop > tid2shape[tid][i]: tid2shape[tid][i] = slicer.stop - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} for rank_slicers in slicers_list: for tid, slicers in rank_slicers.items(): @@ -117,7 +117,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int Returns: tid2nreplicas: dict, tid -> TidReplicaInfo """ - # caculate the number of replicas of each model parameter + # calculate the number of replicas of each model parameter tid2nreplicas = {} tid2ranksset = defaultdict(set) for tid2ranks in tid2ranks_list: @@ -135,7 +135,7 @@ def _calc_grad_replicas(tid2ranks_list: List[Dict[int, Tuple[int]]]) -> Dict[int return tid2nreplicas -def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, List[torch.nn.Parameter]]: +def prepare_for_grad_clip(cube_model: 'CubeModule', use_zero: int) -> Dict[int, List[torch.nn.Parameter]]: params_info_for_gnorm = cube_model.parameters_for_calc_gnorm() tid2ranks = {} tid2info_list_seq = {} @@ -174,7 +174,7 @@ def prepare_for_grad_clip(cube_model: 'CubeModule', is_zero: bool) -> Dict[int, # multiplied by the number of ZeRO groups. Multiplying the number of pure replicated is easy # to understand. Multiplying the number of ZeRO groups is because the gradients of each ZeRO group # are full model gradients, so the number of ZeRO groups is the number of gradient replicas of the full model. - if not is_zero: + if not use_zero: nreplicas = replicated_info.nranks else: nreplicas = replicated_info.nreplicated * params_info.zero_ngroups @@ -241,7 +241,8 @@ def grad_exists(p): elif len(grads) == 1: total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) else: - if multi_tensor_l2norm_available: + dtypes = set([g.dtype for g in grads]) + if multi_tensor_l2norm_available and len(dtypes) == 1: total_norm = _multi_tensor_total_norm(grads).to(device) else: # torch.nn.utils.clip_grad_norm_ way to calculate the norm diff --git a/nnscaler/runtime/hybrid_optimizer.py b/nnscaler/runtime/hybrid_optimizer.py new file mode 100644 index 00000000..be8a5c85 --- /dev/null +++ b/nnscaler/runtime/hybrid_optimizer.py @@ -0,0 +1,421 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections import defaultdict +from dataclasses import dataclass, field +import types +from typing import Any, Callable, Iterable, Type, Union, TYPE_CHECKING, Optional + +import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LRScheduler +from torch.utils.hooks import RemovableHandle + +from nnscaler.cli.arg_parser import deserialize_dataclass +from nnscaler.cli.train_hook import TrainHookHost, TrainHook +from nnscaler.utils import fn_field, OptStateDict + +if TYPE_CHECKING: + from nnscaler.cli.trainer import Trainer + + +@dataclass +class HybridSubOptParamGroupConfig: + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridSubOptConfig: + type: Union[Type[Optimizer], Callable[..., Optimizer]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + param_groups: list[HybridSubOptParamGroupConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.type: + raise ValueError("Optimizer type must be specified in HybridSubOptConfig") + + +@dataclass +class HybridOptConfig: + optimizers: list[HybridSubOptConfig] = field(default_factory=list) + + def __post_init__(self): + if not self.optimizers: + raise ValueError("At least one optimizer must be specified in HybridOptConfig") + + +class HybridRemovableHandle: + def __init__(self, removable_handles: list[RemovableHandle]): + self.removable_handles = removable_handles + + def remove(self): + for removable_handle in self.removable_handles: + removable_handle.remove() + + def __enter__(self) -> "HybridRemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() + + +class ScaleDelayedOptimizerMixin(TrainHook): + """ + A mixin class to add scale-delayed optimization support to an optimizer. + This mixin overrides the `scale_grads`, `clip_gnorm`, and `step` methods + of the optimizer to delay the scaling of gradients until the `step` method is called. + """ + def __init__(self, *args, **kwargs): + # forward __init__ call to the next class in mro(method resolution order) + super().__init__(*args, **kwargs) + self._multiply_factor = 1.0 + + def after_setup(self, trainer: 'Trainer') -> None: + if trainer.optimizer is self: + # do nothing if we are in the hybrid optimizer, + # who is responsible for overriding these methods. + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer.clip_gnorm = self.overrided_clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + trainer.optimizer.scale_grads = self.overrided_scale_grads + + # we need to override the step method to apply the scaling factor + # hybrid optimizer will also call `step` of child optimizers, + self._step = self.step + self.step = self.override_step + + def overrided_scale_grads(self, scale: float): + """ + Scale the gradients by a factor. + Will override the original scale_grads method in ParallelOptimizer. + """ + self._multiply_factor *= scale + + def overrided_clip_gnorm(self, max_norm: Optional[float] = None) -> float: + """ + Will override the original clip_gnorm method in ParallelOptimizer. + """ + # self._clip_gnorm() is ParallelOptimizer.clip_gnorm + grad_norm = self._multiply_factor * self._clip_gnorm() + if max_norm is not None and max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) + self._multiply_factor *= clip_coef + return grad_norm + + def override_step(self, closure=None): + """ + Performs a single optimization step. + """ + # apply the accumulated multiply factor to grads + if self._multiply_factor != 1.0: + for pg_idx in range(len(self.param_groups)): + for p in self.param_groups[pg_idx]['params']: + if p.grad is not None: + p.grad.mul_(self._multiply_factor) + self._multiply_factor = 1.0 + # can't use super() here because we need to support applying this mixin to existing optimizers + self._step(closure) + + @classmethod + def apply_mixin(cls, obj: Any) -> Any: + """Apply this mixin to an existing object.""" + obj._multiply_factor = 1.0 + # bind the new methods + obj.after_setup = types.MethodType(cls.after_setup, obj) + obj.overrided_scale_grads = types.MethodType(cls.overrided_scale_grads, obj) + obj.overrided_clip_gnorm = types.MethodType(cls.overrided_clip_gnorm, obj) + obj.override_step = types.MethodType(cls.override_step, obj) + + return obj + + +class HybridOptimizer(torch.optim.Optimizer, TrainHookHost, TrainHook): + """ + A hybrid optimizer that combines multiple optimizers/multiple param groups + into a single optimizer. + + Please note HybridOptimizer doesn't call super().__init__(), + So it is actually a duck type for optimizer. + """ + + # Identifier for hybrid optimizer + is_hybrid = True + + def __init__( + self, + params: Iterable[torch.nn.Parameter], + param_clss: dict[torch.nn.Parameter, tuple[int, int]], + config: Union[HybridOptConfig, dict[str, Any]] + ): + """ + Initialize the hybrid optimizer. + + Args: + params (Iterable[torch.nn.Parameter]): The parameters to optimize. + param_clss (dict[torch.nn.Parameter, tuple[int, int]]): The parameter classes for each parameter. + config (Union[HybridOptConfig, dict[str, Any]]): The configuration for the hybrid optimizer. + """ + params = list(params) + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridOptConfig) + self.config = config + + self.optimizers = [] + classified_params = defaultdict(list) + # map from (optimizer_idx, pg_idx, param_pg_idx) to param global param index + param_loc = {} + + for idx, param in enumerate(params): + param_cls = param_clss[param] + assert param_cls[0] < len(self.config.optimizers) + classified_params[param_cls].append(param) + + loc = *param_cls, len(classified_params[param_cls]) - 1 + param_loc[loc] = idx + + # sort with key i.e. (optimizer idx, param group idx) + classified_params = dict(sorted(classified_params.items())) + + quick_param_groups = {param_cls: {"params": params} for param_cls, params in classified_params.items()} + opt_param_groups = defaultdict(dict) + for param_cls, group in quick_param_groups.items(): + opt_param_groups[param_cls[0]][param_cls[1]] = group + + for idx, opt_config in enumerate(config.optimizers): + param_groups = opt_param_groups[idx] + if len(param_groups) > 1: + if len(param_groups) != len(opt_config.param_groups): + raise ValueError(f"Expected {len(opt_config.param_groups)} param groups, got {len(param_groups)}") + # param group indices must be consecutive. + if max(param_groups.keys()) != len(opt_config.param_groups) - 1: + raise ValueError(f"Param group indices must be consecutive. We have {len(opt_config.param_groups)} groups, got max group id {max(param_groups.keys())}") + for param_group_idx, param_group in param_groups.items(): + param_group.update(opt_config.param_groups[param_group_idx].options) + else: + if len(opt_config.param_groups) > 1: + raise ValueError(f"Expected at most 1 param group, got {len(opt_config.param_groups)}") + if opt_config.param_groups: + param_groups[0].update(opt_config.param_groups[0].options) + optimizer = opt_config.type(param_groups.values(), **opt_config.options) + self.optimizers.append(optimizer) + + # map from param global index to (optimizer_idx, param_idx) + self._param_map: dict[int, tuple[int, int]] = {} + # map from (optimizer_idx, param_idx) to param global idx + self._reverse_param_map: dict[tuple[int, int], int] = {} + for opt_idx, optimizer in enumerate(self.optimizers): + state_dict: OptStateDict = optimizer.state_dict() + for pg_idx, pg in enumerate(state_dict['param_groups']): + for param_idx_in_pg, param_idx in enumerate(pg['params']): + # param_idx_in_pg is the index in this param group + # param_idx is the index in this optimizer + global_idx = param_loc[(opt_idx, pg_idx, param_idx_in_pg)] + self._param_map[global_idx] = (opt_idx, param_idx) + self._reverse_param_map[(opt_idx, param_idx)] = global_idx + + # Don't call base init + # So HybridOptimizer is a duck optimizer + # super().__init__(params, {}) + + # simulated param groups + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + + # to support scale-delayed optimizers like mixed-precision f16 optimizer + self._has_scale_delayed = any(isinstance(opt, ScaleDelayedOptimizerMixin) for opt in self.optimizers) + + def after_setup(self, trainer: 'Trainer') -> None: + if not self._has_scale_delayed: + return + + assert trainer.optimizer is self, "HybridOptimizer should not be nested inside another optimizer" + trainer.optimizer._clip_gnorm = trainer.optimizer.clip_gnorm + trainer.optimizer._scale_grads = trainer.optimizer.scale_grads + + # if any one of the optimizers is scale-delayed, + # we must apply the mixin to make sure all optimizers are scale-delayed + # this is the only way to calculate gnorm correctly. + for opt in self.optimizers: + if not isinstance(opt, ScaleDelayedOptimizerMixin): + ScaleDelayedOptimizerMixin.apply_mixin(opt) + # after_setup of non-scale-delayed optimizers can't be called automatically by Trainer + # we need to call it here manually + # For consistency, let's call all optimizers' after_setup manually here (including scale-delayed ones) + opt.after_setup(trainer) + # disable after_setup for sub optimizers + # as we have already handled it here + opt.after_setup = lambda *args, **kwargs: None + + def overrided_scale_grads(self, scale: float) -> None: + for optimizer in self.optimizers: + optimizer.overrided_scale_grads(scale) + + self.scale_grads = types.MethodType(overrided_scale_grads, self) + + def override_clip_gnorm(self, max_norm: Optional[float] = None) -> float: + # self._clip_gnorm() is ParallelOptimizer.clip_gnorm + # all optimizers have the same `multiply_factor` + grad_norm = self.optimizers[0]._multiply_factor * self._clip_gnorm() + if max_norm is not None and max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clamp(max=1.0) + # will update all optimizers' multiply_factor + self.scale_grads(clip_coef) + return grad_norm + + self.clip_gnorm = types.MethodType(override_clip_gnorm, self) + + def _get_hook_objects(self): + return self.optimizers + + def step(self, closure=None): + """ + Perform a single optimization step. + """ + assert closure is None, "Closure is not supported in HybridOptimizer" + for optimizer in self.optimizers: + optimizer.step(closure) + + def zero_grad(self, set_to_none: bool = False): + """ + Zero the gradients of all optimizers. + """ + for optimizer in self.optimizers: + optimizer.zero_grad(set_to_none=set_to_none) + + def __repr__(self) -> str: + format_string = self.__class__.__name__ + " [\n" + format_string += ",\n".join(f"{repr(opt)}" for opt in self.optimizers) + format_string += "\n]" + return format_string + + def register_step_pre_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_pre_hook(hook) for opt in self.optimizers]) + + def register_step_post_hook(self, hook) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_step_post_hook(hook) for opt in self.optimizers]) + + def register_state_dict_pre_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_state_dict_post_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def state_dict(self): + state_dicts: list[OptStateDict] = [opt.state_dict() for opt in self.optimizers] + merged_state_dict: OptStateDict = {'state': {}, 'param_groups': [{'children': {}}]} + + for opt_idx, sd in enumerate(state_dicts): + for param_idx, s in sd['state'].items(): + merged_state_dict['state'][self._reverse_param_map[(opt_idx, param_idx)]] = s + merged_state_dict['param_groups'][0]['children'][opt_idx] = sd['param_groups'] + + merged_state_dict['param_groups'][0]['params'] = list(range(len(self._param_map))) + merged_state_dict['param_groups'][0]['param_map'] = self._param_map + merged_state_dict['param_groups'][0]['reverse_param_map'] = self._reverse_param_map + merged_state_dict['state'] = dict(sorted(merged_state_dict['state'].items())) + + return merged_state_dict + + def register_load_state_dict_pre_hook( + self, + hook, + prepend: bool = False, + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_pre_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def register_load_state_dict_post_hook( + self, hook, prepend: bool = False + ) -> HybridRemovableHandle: + return HybridRemovableHandle([opt.register_load_state_dict_post_hook(hook, prepend=prepend) for opt in self.optimizers]) + + def load_state_dict(self, state_dict) -> None: + child_state_dicts = [{'state': {}, 'param_groups': []} for _ in self.optimizers] + + for idx, sd in enumerate(child_state_dicts): + # copy param groups from state dict + sd['param_groups'] = state_dict['param_groups'][0]['children'][idx] + if len(sd['param_groups']) != len(self.optimizers[idx].param_groups): + raise ValueError(f"Number of param groups mismatch. Expected {len(self.optimizers[idx].param_groups)} got {len(sd['param_groups'])}") + # param groups can be changed (for example, the compute config is changed) + # state_dict for HybridOptimizer is already well organized, + # here we will carefully dispatch parameters to each optimizer. + current_state_dict = self.optimizers[idx].state_dict() + for pg, current_pg in zip(sd['param_groups'], current_state_dict['param_groups']): + pg['params'] = current_pg['params'][:] # make a copy + + for param_idx, param_state in state_dict['state'].items(): + opt_idx, param_state_idx = self._param_map[param_idx] + child_state_dicts[opt_idx]['state'][param_state_idx] = param_state + + for child_state_dict, opt in zip(child_state_dicts, self.optimizers): + opt.load_state_dict(child_state_dict) + + # after loading from state dict, the param_groups of optimizers are reassigned + # (instead of updated inplace), so we need to gather them again (as we have done + # in the constructor). + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups.extend(optimizer.param_groups) + + def add_param_group(self, param_group: dict[str, Any]) -> None: + # no-op to avoid creating new parameter groups + # all parameter groups are managed by the individual optimizers + pass + + +@dataclass +class HybridSubLRSchedulerConfig: + type: Union[Type[LRScheduler], Callable[..., LRScheduler]] = fn_field(default=None) + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class HybridLRSchedulerConfig: + schedulers: list[HybridSubLRSchedulerConfig] = field(default_factory=list) + + +class HybridLRScheduler(LRScheduler, TrainHookHost): + """ + A hybrid learning rate scheduler that combines multiple schedulers. + + Please note HybridLRScheduler doesn't call super().__init__(), + So it is actually a duck type for scheduler. + """ + + def __init__( + self, + optimizer: HybridOptimizer, + config: Union[HybridLRSchedulerConfig, dict[str, Any]], + last_epoch: int = -1, + ): + assert isinstance(optimizer, HybridOptimizer), "Optimizer must be an instance of HybridOptimizer" + if isinstance(config, dict): + config = deserialize_dataclass(config, HybridLRSchedulerConfig) + + if len(config.schedulers) == 1: + self.schedulers = [config.schedulers[0].type(optimizer, **config.schedulers[0].options)] + elif len(config.schedulers) == len(optimizer.optimizers): + self.schedulers = [sub_config.type(opt, **sub_config.options) for sub_config, opt in zip(config.schedulers, optimizer.optimizers)] + else: + raise ValueError(f"Expected {len(optimizer.optimizers)} or 1 schedulers, got {len(config.schedulers)}") + + def _get_hook_objects(self): + return self.schedulers + + def step(self, epoch=None): + for scheduler in self.schedulers: + scheduler.step(epoch) + + def state_dict(self): + return {idx: scheduler.state_dict() for idx, scheduler in enumerate(self.schedulers)} + + def load_state_dict(self, state_dict): + for idx, sd in state_dict.items(): + self.schedulers[idx].load_state_dict(sd) diff --git a/nnscaler/runtime/module.py b/nnscaler/runtime/module.py index 0e26d483..9f57f835 100644 --- a/nnscaler/runtime/module.py +++ b/nnscaler/runtime/module.py @@ -1,16 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union +import functools +import pickle +from typing import Callable, List, Set, Dict, Tuple, Optional, TYPE_CHECKING, Any, Union, ClassVar +from typing_extensions import Self import logging import os import sys +import gc +import warnings from pathlib import Path from dataclasses import dataclass, asdict from collections import defaultdict import torch import torch.distributed as dist +from torch import device +from torch.autograd.graph import saved_tensors_hooks from nnscaler.graph.parser import FxModuleParser @@ -19,10 +26,11 @@ from nnscaler.runtime.executor import Executor from nnscaler.runtime.gnorm import ParamsInfo from nnscaler.runtime.utils import microbatches +from nnscaler.runtime.function import insert_backward_hook from nnscaler import __version__ as runtime_version from nnscaler.flags import CompileFlag -from nnscaler.utils import accum_mode +from nnscaler.utils import accum_mode, classproperty, unchecked_fields if TYPE_CHECKING: from nnscaler.parallel import ComputeConfig @@ -46,24 +54,57 @@ class AttrMeta: # the number of the partitioned values, usually 1 # (i.e., no partition on value -> no need to sum up) val_chunks: int + # data type of the full tensor and sub tensor + dtype: torch.dtype + # shape of the sub tensor + # it should be the shape of full_tensor[slicers] + sub_shape: Tuple[int, ...] -def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, AttrMeta]]) -> Dict[int, Dict[str, AttrMeta]]: +@dataclass +class Zero3AttrMeta: + """ + Used for loading merged state dict + """ + # original name in the module + orig_name: str + # name in the module + attr_name: str + # start index of the sub tensor + start: int + # end index of the sub tensor + end: int + # chunk size of the sub tensor, can be bigger than end - start due to padding + chunk_size: int + + +def dedup_attrs(rank2attr_area_map: Dict[int, Dict[str, Dict[str, AttrMeta]]]) -> Dict[int, Dict[str, Dict[str, AttrMeta]]]: ''' Deduplicate the attributes according to `rank2attr_area_map`. - For each `slicers` of a full tensor with the name `orig_name`, we only store its first appearance - in the `rank2attr_area_map`. + For each `slicers` of a full tensor identified by its full qualified name, we only store its first appearance + in the `rank2attr_area_map`. In nnscaler, this dedup process leads to: + - If an attribute is not within the first scale unit, it will be deduplicated. + - If an attribute is shared by different operators, it will be deduplicated. + - If an attribute is replicated across several devices, we only save it at the devices with the smallest rank. + - If an attribute is partitioned across several devices, all these sub tensors will be saved. + - Note that nnscaler supports partition an operator across multiple dimensions, attributes in the operator may + be saved at a subset of related devices. + - Pipeline parallelism is supported since it is composed of different segments in nnscaler, which are different + parallel modules with their own attribute maps at runtime. In addition, we will check - the shape of the full tensor is consistent across different ranks - the slicers of the full tensor are not intersected with each other - the slicers of the full tensor can cover the full tensor - The input and output attribute area map's key is the local attribute name. Args: - rank2attr_area_map (Dict[int, Dict[str, AttrMeta]]): the mapping from rank to the attribute area map + rank2attr_area_map ( + Dict[int, # rank id + Dict[str, # submodule prefix + Dict[str, # attribute name in parallel module (not original name) + AttrMeta]]]): fullmap information for all parallel modules in all ranks. Returns: - Dict[int, Dict[str, AttrMeta]]: the deduplicated attribute area map + Dict[int, Dict[str, Dict[str, AttrMeta]]]: the deduplicated fullmap info, the structure is the same as the input. ''' # assume ranks in rank2attr_area_map are in increasing order ranks = list(rank2attr_area_map.keys()) @@ -87,26 +128,32 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, return True ret = dict() - for rank, attr_area_map in rank2attr_area_map.items(): - dedup_attr_area_map = dict() - for attr, attr_meta in attr_area_map.items(): - assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' - if attr_meta.orig_name not in orig_name2shape: - orig_name2shape[attr_meta.orig_name] = attr_meta.shape - else: - assert orig_name2shape[attr_meta.orig_name] == attr_meta.shape, \ - f'unmatched shape {orig_name2shape[attr_meta.orig_name]} vs {attr_meta.shape}' - if need_save(attr_meta.slicers, orig_name2slice_info[attr_meta.orig_name]): - orig_name2slice_info[attr_meta.orig_name].append(attr_meta.slicers) - dedup_attr_area_map[attr] = attr_meta - ret[rank] = dedup_attr_area_map + for rank, module_fullmaps in rank2attr_area_map.items(): + dedup_module_fullmaps = dict() + for module_name, attr_area_map in module_fullmaps.items(): + dedup_attr_area_map = dict() + for attr, attr_meta in attr_area_map.items(): + assert attr_meta.val_chunks == 1, 'not support partitioning on value dimension' + # use module_name.orig_name as the unique identifier for full tensor + full_tensor_name = f"{module_name}.{attr_meta.orig_name}" + if full_tensor_name not in orig_name2shape: + orig_name2shape[full_tensor_name] = attr_meta.shape + else: + assert orig_name2shape[full_tensor_name] == attr_meta.shape, \ + f'unmatched shape {orig_name2shape[full_tensor_name]} vs {attr_meta.shape}' + if need_save(attr_meta.slicers, orig_name2slice_info[full_tensor_name]): + orig_name2slice_info[full_tensor_name].append(attr_meta.slicers) + dedup_attr_area_map[attr] = attr_meta + if dedup_attr_area_map: # only add non-empty maps + dedup_module_fullmaps[module_name] = dedup_attr_area_map + ret[rank] = dedup_module_fullmaps # since we # - skip saving when there are identical weights # - assert the slicers are disjoint # we can use the sum of the sub-slicers to verify the full tensor is covered - for orig_name, slicerss in orig_name2slice_info.items(): - shape = orig_name2shape[orig_name] + for full_tensor_name, slicerss in orig_name2slice_info.items(): + shape = orig_name2shape[full_tensor_name] full_size = 1 for s in shape: full_size *= s @@ -116,7 +163,7 @@ def need_save(slicers: Tuple[slice, ...], saved_slicers_list: List[Tuple[slice, for s in slicers: size *= s.stop - s.start covered_size += size - assert full_size == covered_size, f'uncovered size for {orig_name} with shape {shape}, slicerss {slicerss}' + assert full_size == covered_size, f'uncovered size for {full_tensor_name} with shape {shape}, slicerss {slicerss}' return ret @@ -192,14 +239,28 @@ def zero_grad(self): def parameters_for_optimizer(self) -> List[torch.nn.Parameter]: """Get parameter list for optimizer""" - params = [] + return list(self.get_opt_params().keys()) + + def get_opt_params(self, prefix='', classify_param_cls_fn: Callable[[str], Any]=None) -> dict[torch.nn.Parameter, Any]: + """ + Get all parameters and their classifications. Parameters in reducers come first. + + Args: + prefix (str): The prefix of this module, + which will be used to generate full names of parameters and further classify them. + classify_param_cls_fn (Callable[[str], Any], optional): A function to classify parameters by name. + + Returns: + dict[torch.nn.Parameter, Any]: A dictionary mapping parameters to their classifications. + """ + params = {} reducer_pids = set() for reducer in self._reducers: - params += reducer.parameters_for_optimizer() + params.update(reducer.get_opt_params()) reducer_pids.update(id(p) for p in reducer.params) - for param in self.parameters(): + for name, param in self.named_parameters(prefix): if id(param) not in reducer_pids: - params.append(param) + params[param] = classify_param_cls_fn(name) if classify_param_cls_fn else None # print(f'> get out parameters: {sum(p.numel() for p in params)}') return params @@ -277,7 +338,8 @@ def add_full_map(self, attr: str, tid: int, is_param: bool, orig_name: str, shap val_chunks int: the number of value chunks. """ assert hasattr(self, attr), f"{attr} is not in the module" - meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks) + attr_tensor: torch.Tensor = getattr(self, attr) + meta = AttrMeta(tid, is_param, orig_name, shape, slicers, val_chunks, attr_tensor.dtype, tuple(attr_tensor.shape)) self._fullmap[attr] = meta # TODO: remove this function, use the property instead @@ -330,7 +392,7 @@ def get_checkpoint(self, optimizer: torch.optim.Optimizer = None): # backward compatibility # in old version, dist_param_map is not loaded in constructor # so we will try to load it from file on the fly. - dist_param_map = getattr(self, '_dist_param_map', None) + dist_param_map = getattr(self, 'dist_param_map', None) if not dist_param_map: module_file = Path(sys.modules[self.__module__].__file__) # load from the same directory as the module file @@ -372,9 +434,11 @@ def _safe_tensor_equal(cls, tensor1: Any, tensor2: Any): @staticmethod def merge_model_state_dicts( state_dicts: List[Dict], - fullmaps: List[Dict[str, AttrMeta]] + fullmaps: List[Dict[str, AttrMeta]], + zero_idx_maps: Optional[List[Dict]] = None ): """Merge model states from multiple shard into a single-model state. + Here we assume the order of state_dicts and fullmaps are aligned, and is the same as the rank order. Note: Users only need to provide as fewer local model states as necessary to @@ -383,6 +447,7 @@ def merge_model_state_dicts( Args: state_dicts (List[Dict[str, torch.Tensor]]): per-rank local model state dict from model.state_dict() fullmaps (List[Dict[str, AttrMeta]]): per-rank fullmap + zero_idx_maps (Optional[List[Dict]]): zero information for the model, `None` if zero is not enabled Returns: full_state_dicts (List[Dict[str, torch.Tensor]]): Full model state dict @@ -396,8 +461,14 @@ def merge_model_state_dicts( # Here we expand slice to (start, step, stop) tuple, # because before python 3.12, slice object is not hashable state_dict_merge_track: Dict[str, Set[Tuple[Tuple[Any, Any, Any], ...]]] = {} + # the fill progress of zero3 parameters + # key: param name + # value: Dict[ tuple(start, step, stop) , filled chunk] + # used to track how many elements have been filled for each zero3 parameter + zero3_current_filled: Dict[str, Dict[Tuple[Tuple[int, int, int], ...], List[Tuple[int, int]]]] = {} + zero3_param_metadatas = [info[-1] for info in zero_idx_maps] if zero_idx_maps is not None else [None] * len(state_dicts) # gather param/buffer full tensor - for rank, (model_state_dict, local_fullmap) in enumerate(zip(state_dicts, fullmaps)): + for rank, (model_state_dict, local_fullmap, zero3_param_metadata) in enumerate(zip(state_dicts, fullmaps, zero3_param_metadatas)): for local_name, meta in local_fullmap.items(): if local_name not in model_state_dict: # the parameter may not in model_state_dict (deduped with optimization) @@ -408,20 +479,53 @@ def merge_model_state_dicts( partial_tensor = model_state_dict[local_name] if meta.orig_name not in full_model_state_dict: full_model_state_dict[meta.orig_name] = torch.empty( - meta.shape, dtype=partial_tensor.dtype) + meta.shape, dtype=partial_tensor.dtype, device='cpu') state_dict_merge_track[meta.orig_name] = set() # assign partial tensor if meta.val_chunks > 1: raise NotImplementedError("Not support of partitioning parameter / buffer at value dimension") state_dict_merge_track_id = tuple((i.start, i.step, i.stop) for i in meta.slicers) - if state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: - if not CubeModule._safe_tensor_equal(full_model_state_dict[meta.orig_name][meta.slicers], partial_tensor): - raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + dest_tensor = full_model_state_dict[meta.orig_name][meta.slicers] + if dest_tensor.shape == partial_tensor.shape and state_dict_merge_track_id in state_dict_merge_track[meta.orig_name]: + if not CubeModule._safe_tensor_equal(dest_tensor, partial_tensor): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") _logger.debug(f'rank {rank}: skip merging duplicated model state for param {meta.orig_name} with slicers {meta.slicers}') else: state_dict_merge_track[meta.orig_name].add(state_dict_merge_track_id) - full_model_state_dict[meta.orig_name][meta.slicers] = partial_tensor + if dest_tensor.shape == partial_tensor.shape: + dest_tensor.copy_(partial_tensor) + else: + # we assume zero3 is on when dest_tensor.shape != partial_tensor.shape + if len(partial_tensor.shape) != 1: + raise ValueError("Invalid tensor as a ZeRO3 parameter, expected a 1D tensor.") + curr_filled = zero3_current_filled.setdefault(meta.orig_name, {}).setdefault(state_dict_merge_track_id, []) + curr_z3_info = zero3_param_metadata[local_name] + curr_start, curr_end = curr_z3_info['start'], curr_z3_info['end'] + fill_len = curr_end - curr_start + if (curr_start, curr_end) in curr_filled: + # already filled, let's check consistency + if not CubeModule._safe_tensor_equal(dest_tensor.view(-1)[curr_start: curr_end], partial_tensor[0:fill_len]): + raise ValueError(f"Conflict in merging {meta.orig_name} from rank {rank}") + else: + old_shape = dest_tensor.shape + dest_tensor = dest_tensor.reshape(-1) + dest_tensor[curr_start: curr_end] = partial_tensor[0: fill_len] + full_model_state_dict[meta.orig_name][meta.slicers] = dest_tensor.view(old_shape) + zero3_current_filled[meta.orig_name][state_dict_merge_track_id].append((curr_start, curr_end)) + + if zero3_current_filled: + # verify all zero3 parameters are fully filled + for param_name, slicers2filled in zero3_current_filled.items(): + for slicers, filled_chunks in slicers2filled.items(): + full_size = 1 + for s in slicers: + full_size *= s[-1] - s[0] + covered_size = 0 + for start, end in filled_chunks: + covered_size += end - start + if full_size != covered_size: + raise ValueError(f'Uncovered ZeRO3 parameter {param_name} with slicers {slicers}, full size {full_size}, covered size {covered_size}') return full_model_state_dict @staticmethod @@ -504,7 +608,7 @@ def merge_state_dicts( # help understand the whole logic. In other words, the real plan_ngpus is <= len(model_state_dicts). plan_ngpus = len(model_state_dicts) # gather model states - full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: plan_ngpus]) + full_model_state_dict = CubeModule.merge_model_state_dicts(model_state_dicts, fullmaps[0: plan_ngpus], zero_idx_maps) _logger.info('finish merge model states') if optim_state_dicts is None: return full_model_state_dict, None @@ -542,7 +646,7 @@ def _check_state_size(opt_state_keys, bucket_state): return all(bucket_state[key].shape == bucket_state[opt_state_keys[0]].shape for key in opt_state_keys) - def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): + def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size, zero_version): assert bucket_size % len(bucket_states) == 0 opt_state_keys = list(bucket_states[0].keys()) if 'step' in bucket_states[0]: @@ -551,36 +655,72 @@ def _retrieve_param_opt_state(bucket_states, pstart, pend, pshape, bucket_size): # NOTE: only support adam for now assert 'exp_avg' in opt_state_keys assert 'exp_avg_sq' in opt_state_keys - chunk_size = bucket_size // len(bucket_states) - start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size - end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + opt_states, opt_states_1d = {}, {} for key in opt_state_keys: opt_states[key] = torch.zeros(pshape, dtype=bucket_states[0][key].dtype, - device=bucket_states[0][key].device, requires_grad=False) + device='cpu', requires_grad=False) opt_states_1d[key] = opt_states[key].view(-1) - if start_rank_id == end_rank_id: - for key in opt_state_keys: - opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] - else: - offset = chunk_size-start_offset - for key in opt_state_keys: - opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] - for i in range(start_rank_id+1, end_rank_id): + if zero_version == 1: + chunk_size = bucket_size // len(bucket_states) + start_rank_id, start_offset = pstart // chunk_size, pstart % chunk_size + end_rank_id, end_offset = pend // chunk_size, pend % chunk_size + if start_rank_id == end_rank_id: for key in opt_state_keys: - opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] - offset += chunk_size - if end_offset: # skip if end_offset == 0, because it is a no-op + opt_states_1d[key][:] = bucket_states[start_rank_id][key][start_offset:end_offset] + else: + offset = chunk_size-start_offset for key in opt_state_keys: - opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + opt_states_1d[key][:offset] = bucket_states[start_rank_id][key][start_offset:] + for i in range(start_rank_id+1, end_rank_id): + for key in opt_state_keys: + opt_states_1d[key][offset:offset+chunk_size] = bucket_states[i][key][:] + offset += chunk_size + if end_offset: # skip if end_offset == 0, because it is a no-op + for key in opt_state_keys: + opt_states_1d[key][offset:] = bucket_states[end_rank_id][key][:end_offset] + else: # zero_version == 3 + assert zero_version > 1, f'unsupported zero version {zero_version}' + for key in opt_state_keys: + fill_start = 0 + fill_len = pend - pstart + param_numel = opt_states_1d[key].numel() + for bstate in bucket_states: + if fill_start >= param_numel: + # from current implementation, code never goes here + # because we have used model_idx2opt_idx to filter out unnecessary ranks + # but let's keep the logic here for safety + fill_start = fill_start % param_numel + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # check consistency for the already filled part + if not CubeModule._safe_tensor_equal( + opt_states_1d[key][fill_start: fill_start + fill_len], + bstate[key][pstart: pstart+fill_len] + ): + raise ValueError(f"Conflict in merging optimizer state for param with shape {pshape}") + else: + if fill_start + fill_len > param_numel: + fill_len = param_numel - fill_start + # remove padding part + opt_states_1d[key][fill_start: fill_start + fill_len] = bstate[key][pstart: pstart+fill_len] + fill_start += fill_len if 'step' in bucket_states[0]: - opt_states['step'] = bucket_states[0]['step'] + # make sure all steps are different tensors (with same value) + opt_states['step'] = bucket_states[0]['step'].cpu().clone() return opt_states - def _merge_opt_zero(worker_idx, param_idx): - model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + def _merge_opt_zero(param_shape, worker_idx, param_idx): + if len(zero_idx_maps[worker_idx]) == 4: + model_idx2opt_idx, opt_idx2ranks, zero_version, _ = zero_idx_maps[worker_idx] + elif len(zero_idx_maps[worker_idx]) == 3: # backward compatibility + model_idx2opt_idx, opt_idx2ranks, zero_version = zero_idx_maps[worker_idx] + else: # backward compatibility + assert len(zero_idx_maps[worker_idx]) == 2 + model_idx2opt_idx, opt_idx2ranks = zero_idx_maps[worker_idx] + zero_version = 1 # default to ZeRO-1 opt_idx = model_idx2opt_idx[param_idx] if isinstance(opt_idx, int): # the param without reducer @@ -589,14 +729,19 @@ def _merge_opt_zero(worker_idx, param_idx): else: # the param in reducer bucket opt_idx, pstart, pend, pshape = opt_idx + if zero_version == 1: + assert param_shape == pshape, f'param shape {param_shape} vs pshape {pshape}' ranks, bucket_size = opt_idx2ranks[opt_idx] + # parameters in reducer come first, so we can directly use opt_idx to index. bucket_states = [optim_state_dicts[rank]['state'][opt_idx] for rank in ranks] return _retrieve_param_opt_state( bucket_states, pstart, pend, - pshape, - bucket_size) + param_shape, + bucket_size, + zero_version + ) # full_index: param IDs in the full optimizer state for full_index, param_name in enumerate(origin_parameter_names): @@ -639,7 +784,7 @@ def _merge_opt_zero(worker_idx, param_idx): # As ZeRO is applied, the optimizer state of this parameter (a shard) # may not be stored locally in its optimizer state. # _merge_opt_zero is for recovering the optimizer state corresponding to this parameter shard. - states: Dict[str, torch.Tensor] = _merge_opt_zero(work_idx, local_index) + states: Dict[str, torch.Tensor] = _merge_opt_zero(meta.sub_shape, work_idx, local_index) zero_done_track.add(track_id) else: _logger.debug(f'rank {work_idx}: skip merging duplicated optimizer state for param {full_index} with slicers {meta.slicers}') @@ -653,7 +798,7 @@ def _merge_opt_zero(worker_idx, param_idx): if not CubeModule._safe_tensor_equal(full_states[full_index][state_name], value): raise ValueError(f"Conflict in merging {param_name}.{state_name} from rank {work_idx}") else: - full_states[full_index][state_name] = value + full_states[full_index][state_name] = value.cpu() continue # for non-tensor states @@ -668,7 +813,7 @@ def _merge_opt_zero(worker_idx, param_idx): else: # create optimizer state tensor if state_name not in full_states[full_index]: - full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype) + full_states[full_index][state_name] = torch.empty(meta.shape, dtype=value.dtype, device='cpu') if track_id in state_merge_track: if not CubeModule._safe_tensor_equal(full_states[full_index][state_name][meta.slicers], value): @@ -718,6 +863,100 @@ def merge_checkpoints(filename_prefix='dist_checkpoint'): 'optim_state_dict': merged_optimizer_state_dict }, filename_prefix + '.full.ckpt') + def sleep(self): + """ + Move attributes (buffer and param) to cpu and release contiguous buffer in reducers. Different from + nn.Module's cpu() method, references to attributes are unchanged. + """ + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + for reducer in self._reducers: + reducer.zero_grad() + + # we want attribute references are unchanged, so super().cpu() is not used here + cpu = torch.device('cpu') + for buffer in self.buffers(): + buffer.data = buffer.data.to(cpu) + + for param in self.parameters(): + param.data = param.data.to(cpu) + + for reducer in self._reducers: + reducer.sleep() + + gc.collect() + torch.cuda.empty_cache() + return self + + def wake_up(self, device: Optional[Union[int, device]] = None) -> Self: + """ + Move attributes (buffer and param) back to gpu and reallocate memories in reducers. It is a reverse + operation of `self.sleep()`. + """ + gpu = torch.cuda.current_device() + if device is not None: + if isinstance(device, int): + index = device + elif isinstance(device, torch.device): + index = device.index + else: + raise RuntimeError(f'unexpected device type {type(device)}') + assert gpu == index, f'nnscaler module does not support cross gpu transport, expect {gpu} but got {index}' + + for name, param in self.named_parameters(): + assert param.grad is None, f'expect {name} with shape {param.shape} has no grad' + + # we want attribute references are unchanged, so super().gpu() is not used here + for buffer in self.buffers(): + buffer.data = buffer.data.to(gpu) + + for param in self.parameters(): + param.data = param.data.to(gpu) + + for reducer in self._reducers: + reducer.wake_up() + + gc.collect() + torch.cuda.empty_cache() + return self + + def to(self, *args, **kwargs): + """ + Override nn.Module's to function, currently we only allow transfer data from host and device + + Args: + device (:class:`torch.device`): the desired device of the parameters + and buffers in this module + dtype (:class:`torch.dtype`): the desired floating point or complex dtype of + the parameters and buffers in this module + tensor (torch.Tensor): Tensor whose dtype and device are the desired + dtype and device for all parameters and buffers in this module + memory_format (:class:`torch.memory_format`): the desired memory + format for 4D parameters and buffers in this module (keyword + only argument) + + Returns: + Module: self + """ + device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( + *args, **kwargs + ) + if dtype is not None: + raise ValueError(f'nnscaler does not support passing dtype {dtype} to to()') + if convert_to_format is not None: + raise ValueError(f'nnscaler does not support passing convert_to_format {convert_to_format} to to()') + if non_blocking is not None: + warnings.warn(f'nnscaler moves tensors in a blocking approach currently') + + # after _parse_to `device` must in type of torch.device + if device.type == 'cpu': + return self.cpu() + elif device.type == 'cuda': + return self.cuda(device) + else: + raise ValueError(f'unsupported device type {device}') + @dataclass class OriginModuleMetadata: @@ -733,6 +972,12 @@ class ZeroMetadata: model_idx2opt_idx: Optional[Dict] = None # a mapping from optimizer_index to the related bucket information (sub_ranks, bucket_size) opt_idx2ranks: Optional[Dict] = None + # the level of zero optimization + # 0: no zero optimization + # 1: zero1 + # > 1: zero3 + zero: int = 0 + zero3_param_metadata: Optional[Dict[str, Dict]] = None @dataclass @@ -765,10 +1010,27 @@ class ParallelModule(CubeModule): COMPUTE_CONFIG_FILE = 'compute_config.pt' ORIGIN_MODULE_METADATA_FILE = 'origin_module_metadata.pt' EXTRA_STATE_KEY = 'CUBE_EXTRA_STATE' + ATTR_META_FILE_PREFIX = 'attr_meta' + ATTR_META_FILE_TEMPLATE = ATTR_META_FILE_PREFIX + '{}.pkl' # 'attr_meta{}.pkl' + # the rank of the module, will be assigned in the generated subclasses rank: int + # the world size to run this module, will be assigned in the generated subclasses + world_size: int # the runtime version of the module when it is generated, will be assigned in the generated subclasses runtime_version: str + # mapping from the name of local attribute tensor + # to its corresponding fulltensor meta for all ranks. + # it is a list of dictionaries mapping from attribute names to their metadata + # and it is a replacement of `CubeModule.fullmap` + attr_meta_maps: list[dict[str, AttrMeta]] + # the directory of the module located + module_dir: Path + # The map is a dict mapping from the new parameter name (without tid suffix) in parallel module + # to the parameter name in original module. + dist_param_map: dict[str, str] + compute_config: 'ComputeConfig' + origin_module_metadata: OriginModuleMetadata def __init__(self): if self.__class__ == ParallelModule: # not init via super().__init__() @@ -790,6 +1052,42 @@ def __init__(self): self._nreplicas2localparams: Optional[Dict[int, List[torch.nn.Parameter]]] = None # track whether all the parames (especially the non-persistent buffers) have been initialized self._non_presistent_buffers_inited = False + # track the params that have been prefetched in backward + # this is only used for zero3 + # The reason is the eviction of prefetched params in backward + # relies on the input.requires_grad flag to be True + # If all the inputs do not require grad, + # the eviction logic will not be triggered + # In that case, we will delay the eviction until next backward hook. + self._backward_prefetched_params: dict[torch.nn.Parameter, int] = {} + # the params that have been prefetched in forward + self._forward_prefetched_params: set[torch.nn.Parameter] = set() + + def __init_subclass__(cls, skip_init=False, **kwargs): + # special case when we just fake a ParallelModule class + # In this case, you should also use object.__new__ instead of __init__ + if skip_init: + return + + from nnscaler.parallel import ComputeConfig + + super().__init_subclass__(**kwargs) + cls.attr_meta_maps = [] + cls.module_dir = Path(sys.modules[cls.__module__].__file__).parent + + for rank in range(cls.world_size): + attr_map_file = cls.module_dir / cls.ATTR_META_FILE_TEMPLATE.format(rank) + with open(attr_map_file, 'rb') as f: + attr_meta_map = pickle.load(f) + attr_meta_map = {attr: AttrMeta(**meta) for attr, meta in attr_meta_map.items()} + cls.attr_meta_maps.append(attr_meta_map) + + cls.dist_param_map = torch.load(cls.module_dir / FxModuleParser.ATTR_MAP_FILE, weights_only=False) + cls.compute_config = ComputeConfig.safe_load_from_file( + cls.module_dir / cls.COMPUTE_CONFIG_FILE, + return_none_on_error=False + ) + cls.origin_module_metadata = torch.load(cls.module_dir / cls.ORIGIN_MODULE_METADATA_FILE, weights_only=False) @property def non_presistent_buffers_inited(self): @@ -812,12 +1110,14 @@ def _warn_uninitialized_non_persistent_buffers(self, raise_error = False): else: _logger.warning(_non_persistent_buffers_load_warning) - def _post_init(self, init_params=True): + def _post_init(self, init_params=True, build_buckets=True): """ This is post init function to further initialize the model. Should be called by subclass's __init__(). Args: init_params (bool): whether to load model init parameters. Default True. + build_buckets (bool): whether to build buckets for the model. Default True. + If it is False, you must manually call `build_buckets()` later before use this module. """ # Here we check the rank to load the module file name # Current we don't check rank when we are not in distributed mode @@ -825,28 +1125,14 @@ def _post_init(self, init_params=True): # TODO: re-enable this check # if dist.is_initialized() and self.rank != dist.get_rank(): # raise RuntimeError(f"The rank to load this module file name is expected to be {self._rank}, but got {dist.get_rank()}") - from nnscaler.parallel import ComputeConfig self._non_presistent_buffers_inited = init_params or not self._non_persistent_buffers_set module_file = Path(sys.modules[self.__module__].__file__) - self.module_dir = module_file.parent if init_params: self.load_attr_content(str(module_file.with_name(f"{FxModuleParser.ATTR_CONTENT_FILE_STEM}"))) self._warn_uninitialized_non_persistent_buffers() - self._dist_param_map = torch.load(module_file.with_name(f"{FxModuleParser.ATTR_MAP_FILE}"), weights_only=False) - self._compute_config: ComputeConfig = ComputeConfig.safe_load_from_file( - module_file.with_name(f"{self.COMPUTE_CONFIG_FILE}"), - return_none_on_error=False - ) - self._orign_module_metadata: OriginModuleMetadata = torch.load(module_file.with_name(f"{self.ORIGIN_MODULE_METADATA_FILE}"), weights_only=False) - - for reducer in self.reducers: - reducer.build_buckets() - - self._zero_metadata = self._get_zero_metadata() - # add state_dict hook to save extra state # Please note extra_state is only used for merging, not for loading # so we can safely remove it in load_state_dict pre hook @@ -854,11 +1140,174 @@ def _post_init(self, init_params=True): # add load_state_dict pre hook to pop extra state to prevent warning self._register_load_state_dict_pre_hook(ParallelModule._pre_load_state_dict_hook, with_module=True) + if build_buckets: + self.build_buckets() + + def build_buckets(self, param_clss: Optional[dict[torch.nn.Parameter, Any]]=None): + """ + Build buckets for the model reducers. + + You should call this method exactly once before using this module. + Typically this will be called when building optimizer when multiple optimizers/param groups are used. + And we will put parameters with different optimizer or different param groups into different buckets. + + Currently we have done an optimization to make sure this is only called once even for hybrid optimizers + by + 1. setting `build_buckets=False` when calling constructor in `nnscaler.parallelize`. + 2. manually calling `build_buckets()` later in `nnscaler.build_optimizer` + """ + # needs all parameters to be in cuda memory before building buckets + self.cuda() + self._param_reducer_map: dict[torch.nn.Parameter, int] = {} + model_params = {p: n for n, p in self.named_parameters()} + # key: attr name of the parameter + # value: Zero3AttrMeta + self._zero3_param_metadata: dict[str, Zero3AttrMeta] = {} + for idx, reducer in enumerate(self.reducers): + reducer.build_buckets(param_clss) + for param in reducer.params: + self._param_reducer_map[param] = idx + attr_name = model_params[param] + param_attr = self._fullmap[attr_name] + zero3_info = reducer.get_z3_info(param) + self._zero3_param_metadata[attr_name] = Zero3AttrMeta( + attr_name=attr_name, + orig_name=param_attr.orig_name, + start = zero3_info.start, + end = zero3_info.end, + chunk_size=zero3_info.numel_with_padding(), + ) if zero3_info is not None else None + + self._zero_metadata = self._get_zero_metadata() + + def get_zero3_attr_meta(self, attr_name: str) -> Optional[Zero3AttrMeta]: + """ + Get the Zero3AttrMeta for the given attribute name. + + Args: + attr_name (str): the attribute name of the parameter + Returns: + Optional[Zero3AttrMeta]: the Zero3AttrMeta for the given attribute name + """ + return self._zero3_param_metadata.get(attr_name, None) + + @torch.no_grad() + def prefetch_param(self, param: torch.nn.Parameter): + """ + Gather the full parameter tensor for FSDP. + + Args: + param (torch.nn.Parameter): the local parameter to gather + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.prefetch_param(param) + self._forward_prefetched_params.add(param) + + @torch.no_grad() + def postevict_param(self, param: torch.nn.Parameter): + """ + Release the full parameter tensor for zero3. + + Args: + param (torch.nn.Parameter): the local parameter + """ + reducer = self._reducers[self._param_reducer_map[param]] + reducer.postevict_param(param) + self._forward_prefetched_params.discard(param) + + def _backward_evict_leftover_params(self, order: int): + for p in [p for p, o in self._backward_prefetched_params.items() if o > order]: + self.postevict_param(p) + self._backward_prefetched_params.pop(p, None) + + def backward_postevict_param(self, input: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an input tensor to register the backward hook. + """ + if not input.requires_grad: + # if input does not require grad, we cannot register backward hook on it + return input + + @torch.no_grad() + def _postevict_param(param): # pragma: no cover + self.postevict_param(param) + self._backward_prefetched_params.pop(param, None) + self._backward_evict_leftover_params(order) + + return insert_backward_hook(input, functools.partial(_postevict_param, param)) + + def backward_prefetch_param(self, activation: torch.Tensor, param: torch.nn.Parameter, order: int): + """ + Here we need an activation tensor to register the backward hook. + """ + if not activation.requires_grad: + # if activation does not require grad, we cannot register backward hook on it + return activation + + @torch.no_grad() + def _prefetch_param(param): # pragma: no cover + self.prefetch_param(param) + self._backward_prefetched_params[param] = order + self._backward_evict_leftover_params(order) + + return insert_backward_hook(activation, functools.partial(_prefetch_param, param)) + + def save_params_hooks(self) -> saved_tensors_hooks: + """ + A hook to save tensors during forward pass. + This is used to avoid parameters being saved for activation checkpointing. + + Returns: + saved_tensors_hooks: the saved tensors hooks + """ + def pack(x: torch.Tensor): + for param in self._forward_prefetched_params: + if x.untyped_storage() == param.untyped_storage(): + return (param, x.shape, x.stride(), x.storage_offset()) + return x + + def unpack(x): + if isinstance(x, tuple) and len(x) == 4: + return torch.as_strided(x[0], x[1], x[2], x[3]) + return x + + return saved_tensors_hooks(pack, unpack) + + @classmethod + def get_attr_meta_map(cls, rank=None): + """ + Get the attribute meta map for the given rank. + If rank is None, return the attribute map for the current rank. + + This function is preferred over accessing `CubeModule.fullmap` in most cases, + since it doesn't need to instantiate the module. + """ + if rank is None: + rank = cls.rank + if rank < 0 or rank >= cls.world_size: + raise ValueError(f"Rank {rank} is out of range [0, {cls.world_size})") + return cls.attr_meta_maps[rank] + def forward(self, *args, **kwargs): self._warn_uninitialized_non_persistent_buffers(raise_error=True) if self.training: self._sync_grad_required = True # mark sync_grad() can be called again - return self._forward_impl(*args, **kwargs) + # all prefetched params should have been evicted + # please note the param can be evicted in Reducer, + # which is not tracked in self._backward_prefetched_params + # so we just check the shape to make sure the param is evicted + for param in self._backward_prefetched_params.keys(): + old_shape = param.shape + self.postevict_param(param) + assert param.shape == old_shape, \ + f'Param {param} is not properly evicted in backward' + self._backward_prefetched_params.clear() + + ret = self._forward_impl(*args, **kwargs) + + assert not self._forward_prefetched_params, \ + f'All forward prefetched params should have been evicted in forward' + return ret def _forward_impl(self, *args, **kwargs): """ @@ -1015,19 +1464,6 @@ def infer_step(self, samples: List[Any]) -> List[Any]: outputs.append(output) return outputs - @property - def dist_param_map(self) -> Dict[str, str]: - """ - Get the parameter map of the model. - The map is a dict mapping from the new parameter name (without tid suffix) in parallel module - to the parameter name in original module. - """ - return self._dist_param_map - - @property - def compute_config(self) -> 'ComputeConfig': - return self._compute_config - def clip_gnorm(self, max_norm: Optional[float] = None) -> Tuple[torch.Tensor, List[torch.Tensor]]: """ Calculate the gradient norm and clip gradients. @@ -1116,6 +1552,8 @@ def _get_zero_metadata(self) -> ZeroMetadata: return ZeroMetadata( model_idx2opt_idx=model_idx2opt_idx, opt_idx2ranks=opt_idx2ranks, + zero=self.compute_config.use_zero, + zero3_param_metadata=self._zero3_param_metadata, ) def _get_zero_subranks(self, reducer: Reducer) -> Tuple[int, List[int]]: @@ -1150,11 +1588,11 @@ def _add_extra_state(self, state_dict, prefix) -> None: state_dict[f'{prefix}{self.EXTRA_STATE_KEY}'] = asdict( ExtraState( rank=self.rank, - compute_config=self._compute_config, - dist_param_map=self._dist_param_map, + compute_config=self.compute_config, + dist_param_map=self.dist_param_map, param_area_map=self._fullmap, cube_param_names=[name for name, _ in self.named_parameters()], - **asdict(self._orign_module_metadata), + **asdict(self.origin_module_metadata), **asdict(self._zero_metadata), ) ) @@ -1190,19 +1628,19 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, if strict: missing_keys.extend(new_missing_keys) - @property - def module_dedup_group_size(self) -> int: + @classproperty + def module_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the model state dict, which is `plan_ngpus`. """ - return self.compute_config.module_dedup_group_size + return cls.compute_config.module_dedup_group_size - @property - def optimizer_dedup_group_size(self) -> int: + @classproperty + def optimizer_dedup_group_size(cls) -> int: """ Get the size of the deduplication group of the optimizer state dict. """ - return self.compute_config.optimizer_dedup_group_size + return cls.compute_config.optimizer_dedup_group_size def _list_fullmodel_files(self) -> List[Path]: legacy_fullmodel_path = self.module_dir / FxModuleParser.ATTR_CONTENT_FILE_STEM @@ -1237,18 +1675,58 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s Raises: RuntimeError: if strict=True and there are missing keys. """ - - dist2param = self.dist_param_map - orig_param_names = list(dist2param.values()) # param names in original module (without prefix) non_persistent_buffers = self.get_non_persistent_buffers() with torch.no_grad(): # avoid checking the non-persistent buffers attr_names = set([attr for attr in self._fullmap.keys() if attr not in non_persistent_buffers]) - origname_tid_map = {meta.orig_name: meta.tid for meta in self._fullmap.values()} + for prefix_attr, content in self.trim_merged_state_dict(state_dict, prefix, device='cpu').items(): + attr = prefix_attr[len(prefix):] + tensor: torch.Tensor = getattr(self, attr) + tensor.copy_(content) + attr_names.remove(attr) + + missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] + if len(attr_names) != 0: + erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' + if strict: + raise RuntimeError(erro_msg) + else: + _logger.warning(erro_msg) + + self._warn_uninitialized_non_persistent_buffers() + return missing_keys + + def trim_merged_state_dict( + self, + state_dict: Dict[str, Any], + prefix: str = '', + *, + device=None, + ) -> Dict[str, Any]: + """ + Trim the merged state dict to only keep the parameters needed for the module. + Please note we don't check missing/unexpected keys. + + Args: + state_dict (Dict[str, Any]): the merged state dict + prefix (str): the prefix of the model state dict in the merged state dict + + Returns: + Dict[str, Any]: the trimmed state dict + """ + device = device or torch.cuda.current_device() + trimmed_state_dict = {} + + dist2param = self.dist_param_map + orig_param_names = list(dist2param.values()) # param names in original module (without prefix) + attr_meta_map = self.get_attr_meta_map(self.rank) + with torch.no_grad(): + # avoid checking the non-persistent buffers + origname_tid_map = {meta.orig_name: meta.tid for meta in attr_meta_map.values()} tid_info = defaultdict(list) - for attr, meta in self._fullmap.items(): + for attr, meta in attr_meta_map.items(): tid_info[meta.tid].append((attr, meta.slicers, meta.val_chunks)) # multiple params may share the same tid for orig_param_name in orig_param_names: @@ -1261,20 +1739,82 @@ def load_merged_state_dict(self, state_dict: Dict[str, Any], prefix: str = '', s param_value = state_dict[orig_param_name_with_prefix] tid = origname_tid_map[orig_param_name] for attr, slicer, nchunks in tid_info[tid]: - tensor: torch.Tensor = getattr(self, attr) - content = param_value[slicer] + content: torch.Tensor = param_value[slicer] if nchunks != 1: content = content / nchunks - tensor.copy_(content) - attr_names.remove(attr) + if self.compute_config.use_zero <= 1 or self._zero3_param_metadata.get(attr, None) is None: + trimmed_state_dict[prefix + attr] = content.to(device) + else: + z3_info = self._zero3_param_metadata[attr] + start, end, chunk_size = z3_info.start, z3_info.end, z3_info.chunk_size + if end - start < chunk_size: + # need padding + padding = chunk_size - (end - start) + trimmed_state_dict[prefix + attr] = torch.nn.functional.pad( + content.view(-1)[start:end].to(device), + (0, padding), + mode='constant', + value=0.0, + ) + else: + trimmed_state_dict[prefix + attr] = content.reshape(-1)[start:end].to(device) - missing_keys = [prefix + self._fullmap[attr].orig_name for attr in attr_names] - if len(attr_names) != 0: - erro_msg = f'Missing key(s) in state_dict: {missing_keys}.' - if strict: - raise RuntimeError(erro_msg) - else: - _logger.warning(erro_msg) + return trimmed_state_dict - self._warn_uninitialized_non_persistent_buffers() - return missing_keys + def _pack( + self, + ): + """ + Get a packed information of the ParallelModule, so it can be sent to other ranks. + """ + param_map: dict[torch.nn.Parameter, torch.nn.Parameter] = {} + for p in self.parameters(): + param_map[p] = torch.nn.Parameter( + torch.empty_like(p, device='meta')) if p is not None else None + for b in self.buffers(): + param_map[b] = torch.empty_like( + b, device='meta') if b is not None else None + state = {} + fields = unchecked_fields(self) + state[fields._parameters] = {n: param_map[p] for n, p in self._parameters.items()} + state[fields._buffers] = {n: param_map[b] for n, b in self._buffers.items()} + state[fields._reducers] = [reducer._pack(param_map) for reducer in self._reducers] + state[fields._zero_metadata] = self._zero_metadata + state[fields._fullmap] = self._fullmap + state[fields._param_reducer_map] = { + param_map[p]: rid for p, rid in self._param_reducer_map.items() + } + state[fields._zero3_param_metadata] = self._zero3_param_metadata + + for cv in ParallelModule.__annotations__: + state[cv] = getattr(self, cv) + return state + + @classmethod + def _unpack(cls, state: dict): + """ + Unpack the information and return a fake ParallelModule that carries the same information. + """ + class GenModelX(ParallelModule, skip_init=True): + pass + pm = object.__new__(GenModelX) + fields = unchecked_fields(pm) + object.__setattr__(pm, fields._parameters, state[fields._parameters]) + object.__setattr__(pm, fields._buffers, state[fields._buffers]) + object.__setattr__(pm, fields._reducers, [Reducer._unpack(reducer) for reducer in state[fields._reducers]]) + object.__setattr__(pm, fields._zero_metadata, state[fields._zero_metadata]) + object.__setattr__(pm, fields._fullmap, state[fields._fullmap]) + object.__setattr__(pm, fields._param_reducer_map, state[fields._param_reducer_map]) + object.__setattr__(pm, fields._zero3_param_metadata, state[fields._zero3_param_metadata]) + + def named_parameters( + prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ): + assert prefix == "" and recurse is True, "Only support default arguments" + return pm._parameters.items() + + pm.named_parameters = named_parameters + + for cv in ParallelModule.__annotations__: + setattr(GenModelX, cv, state[cv]) + return pm diff --git a/nnscaler/runtime/serialization.py b/nnscaler/runtime/serialization.py new file mode 100644 index 00000000..ff492c26 --- /dev/null +++ b/nnscaler/runtime/serialization.py @@ -0,0 +1,252 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, TypedDict +import pickle +import base64 +import copy + +import torch +from safetensors.torch import save_file +from safetensors import safe_open + +from nnscaler.utils import transform_recursively, check_recursively +from nnscaler.version import __version__ + + +class MetadataDict(TypedDict): + obj: str + nnscaler: str + + +class _Index: + def __init__(self, index: int): + self.index = index + + def __repr__(self): + return f"_Index({self.index})" + + +def save(obj: Any, f, *, format="safetensors") -> None: + """ + Saves an object containing tensors into a safetensors file. + Args: + obj (`Any`): + The object you want to save. It can be a nested structure containing + tensors, lists, tuples, and dictionaries. + f: + The file-like object or filename where to save the safetensors file. + format (`str`, *optional*, defaults to `"safetensors"`): + The format to save the object. Currently `"safetensors"` and `"pt"` is supported. + """ + if format == 'pt': + torch.save(obj, f) + return + + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + index = 0 + + # all tensors to be saved + tensors = {} + # detect shared tensors + # because safetensors does not support shared tensors, we need to + # save shared tensors only once and replace other occurrences + # TODO: Currently we only detect shared tensors that are exactly the same + # (i.e., share the same data_ptr and shape and stride). + # We may improve it in the future if needed. + # key: (tensor.data_ptr(), tensor.shape, tensor.stride()), value: _Index + tensor_ids: dict[tuple[int, tuple[int, ...], tuple[int, ...]], _Index] = {} + def transform_fn(o: Any) -> Any: + nonlocal index + if isinstance(o, torch.Tensor): + key = (o.data_ptr(), o.shape, o.stride()) + if key in tensor_ids: + idx = tensor_ids[key] + else: + idx = _Index(index) + tensor_ids[key] = idx + tensors[f'{index}'] = o + index += 1 + return idx + return o + metadata = transform_recursively(obj, transform_fn, target_types=(torch.Tensor,)) + save_file(tensors, f, metadata={ + 'obj': base64.b64encode(pickle.dumps(metadata)).decode('utf-8'), + 'nnscaler': __version__ + }) + + +class _LazyContainer: + """ + Mock class for dictionary, list, and tuple that loads tensors lazily from safetensors file. + """ + def __init__(self, data: dict | tuple | list, tensors: safe_open): + self.data = data + self.tensors = tensors + + def __getitem__(self, key): + return self._v(self.data[key]) + + def __setitem__(self, key, value): + raise NotImplementedError("Lazy containers are read-only.") + + def __delitem__(self, key): + raise NotImplementedError("Lazy containers are read-only.") + + def pop(self, key, default=None): + raise NotImplementedError("Lazy containers are read-only.") + + def __len__(self): + return len(self.data) + + def __contains__(self, item): + return self.data.__contains__(item) + + def get(self, key, default=None): + return self._v(self.data.get(key, default)) + + def keys(self): + return self.data.keys() + + def values(self): + return map(self._v, self.data.values()) + + def items(self): + return ((k, self._v(v)) for k, v in self.data.items()) + + def _v(self, v): + return _wrap_value(v, self.tensors) + + def load_all(self): + def _load(v): + if isinstance(v, _Index): + return self.tensors.get_tensor(f'{v.index}') + return v + return transform_recursively(self.data, _load, target_types=(_Index,)) + + def __copy__(self): + return copy.copy(self.load_all()) + + def __deepcopy__(self, memo): + return copy.deepcopy(self.load_all(), memo) + + def __iter__(self): + return iter(self.data) + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.data)})" + + +class _LazyList(_LazyContainer, list): + pass + + +class _LazyDict(_LazyContainer, dict): + pass + + +class _LazyTuple(_LazyContainer, tuple): + # tuple is immutable, so we need to override __new__ + def __new__(cls, *args, **kwargs): + return tuple.__new__(cls, ()) + + +def _wrap_value(v: Any, tensors: safe_open) -> Any: + if isinstance(v, _Index): + return tensors.get_tensor(f'{v.index}') + if not check_recursively(v, lambda k: isinstance(k, _Index)): + return v + if isinstance(v, dict): + return _LazyDict(v, tensors) + if isinstance(v, list): + return _LazyList(v, tensors) + if isinstance(v, tuple): + return _LazyTuple(v, tensors) + # should not reach here + return v + + +class LazyLoader: + def __init__(self, filename, device="cpu"): + self.filename = filename + self.device = device + self.tensor_loader = safe_open(self.filename, framework="pt", device=self.device) + self.tensors = None + self.data = None + + def __enter__(self): + self.tensors = self.tensor_loader.__enter__() + metadata: MetadataDict = self.tensors.metadata() + metadata_obj_b64 = metadata['obj'] + self.data = pickle.loads(base64.b64decode(metadata_obj_b64.encode('utf-8'))) + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self.tensor_loader.__exit__(_exc_type, _exc_value, _traceback) + + def get_lazy_data(self) -> _LazyContainer | Any: + if self.tensors is None: + raise RuntimeError("LazyLoader context is not entered.") + return _wrap_value(self.data, self.tensors) + + +def load(f, *, device="cpu", format="safetensors", lazy=False) -> LazyLoader | Any: + """ + Loads an object containing tensors from a safetensors file lazily. + Args: + f: The file-like object or filename from which to load the safetensors file. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + lazy (`bool`, *optional*, defaults to `False`): + If set to `False`, loads all tensors into memory eagerly. + Returns: + (`LazyLoader` | `Any`): + The lazy loader object that can be used to access the data. + If `lazy` is set to `False`, returns the loaded object with all tensors + loaded into memory. + """ + if format == 'pt': + return torch.load(f, map_location=device, weights_only=False) + if format != 'safetensors': + raise ValueError(f"Unsupported format: {format}") + + if not lazy: + with LazyLoader(f, device=device) as loader: + data = loader.get_lazy_data() + if isinstance(data, _LazyContainer): + return data.load_all() + else: + # pure data without any tensors + return data + return LazyLoader(f, device=device) + + +def convert(src: str, dst: str, *, src_format="safetensors", dst_format="pt", device="cpu") -> None: + """ + Converts a serialized file from one format to another. + Args: + src (`str`): + The source filename. + dst (`str`): + The destination filename. + src_format (`str`, *optional*, defaults to `"safetensors"`): + The format of the source file. Currently `"safetensors"` and `"pt"` is supported. + dst_format (`str`, *optional*, defaults to `"pt"`): + The format of the destination file. Currently `"safetensors"` and `"pt"` is supported. + device (`str`, *optional*, defaults to `"cpu"`): + The device where the tensors will be loaded. + + Returns: + (`None`): + This function does not return anything. + """ + if src_format == dst_format: + raise ValueError("Source and destination formats are the same.") + + save( + load(src, device=device, format=src_format, lazy=False), + dst, + format=dst_format + ) diff --git a/nnscaler/runtime/utils.py b/nnscaler/runtime/utils.py index b15748ea..e6cb1bd9 100644 --- a/nnscaler/runtime/utils.py +++ b/nnscaler/runtime/utils.py @@ -5,6 +5,7 @@ from typing import Any, List import logging +import heapq _logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ class MicroBatchDataLoader: """ MicroBatchDataLoader is used for scenarios of gradient accumulation, where a training iteration will have multiple data samples and perform - multiple forward and backward on each sample (i.e., each refers to + multiple forward and backward on each sample (i.e., each refers to as a micro-batch). To support more flexible training patterns, e.g., pipeline parallelism, @@ -25,7 +26,7 @@ class MicroBatchDataLoader: ```python # compilation phase dataloader = MicroBatchDataLoader([(input1,),]) # only need one micro-batch - + @nnscaler.compile(model, dataloader, ...) def train_iter(model, dataloader): input1 = next(dataloader) @@ -36,9 +37,9 @@ def train_iter(model, dataloader): ... # runtime phase - + for mini_batch_samples in iter(dataloader): - # mini_batch_samples are sample list for + # mini_batch_samples are sample list for # all micro-batches in one iteration. dl = MicroBatchDataLoader(mini_batch_samples) loss =train_iter(model, dl) @@ -68,7 +69,7 @@ def __init__(self, samples: List[Any], cycle: bool = False): def __iter__(self): self._idx = 0 return self - + def __next__(self): if self._idx == self.nmicros: raise StopIteration @@ -77,10 +78,10 @@ def __next__(self): if self.cycle: self._idx = self._idx % self.nmicros return batch - + def __len__(self): return self.nmicros - + def get_micro_batch(self, idx: int): idx = idx % self.nmicros if self.cycle else idx return self.samples[idx] @@ -104,3 +105,114 @@ def microbatches(samples: List[Any], cycle: bool = False) -> MicroBatchDataLoade MicroBatchDataLoader: a micro-batch data loader. """ return MicroBatchDataLoader(samples, cycle=cycle) + + +def split_array_min_max(nums: list[int], g: int, *, keep_order: bool = True) -> tuple[list[list[int]], list[list[int]]]: + """ + Split the array nums into g continuous subarrays such that the maximum sum + of the subarrays is minimized. + + Args: + nums (list[int]): The input array of integers. + g (int): The number of groups to split the array into. + keep_order (bool): Whether to keep the order of elements in the subarrays. + If True, the order of elements in the original array is preserved + in the subarrays. If False, the order can be changed. + Returns: + tuple[list[list[int]], list[list[int]]]: + A tuple containing a list of g subarrays and their corresponding indices. + """ + if g <= 0 or g > len(nums): + raise ValueError("g must be in the range [1, len(nums)]") + + if not keep_order: + return _split_array_min_max_out_of_order(nums, g) + + def _check(limit): + count = 1 + count_sum = nums[0] + for x in nums[1:]: + if count_sum + x > limit: + count += 1 + count_sum = x + else: + count_sum += x + return count <= g + + # 1. Binary search to find the "minimum maximum sum" (Target Limit) + left = max(nums) + right = sum(nums) + target_limit = right + + while left <= right: + mid = (left + right) // 2 + if _check(mid): + target_limit = mid + right = mid - 1 + else: + left = mid + 1 + + # 2. Reconstruct the result based on the calculated target_limit + # Note: A special greedy strategy is needed here to ensure exactly g groups + # A simple greedy approach may result in fewer than g groups (although the maximum sum meets the condition, the number of groups is insufficient) + + result = [[nums[0]]] + result_idx = [[0]] + current_sum = nums[0] + + # We process in forward order, or forcefully reserve enough elements for the remaining groups during forward processing + # Here we use forward iteration with a "remaining quota" check + for i, x in enumerate(nums[1:], start=1): + # Remaining groups needed + groups_needed = g - len(result) + # Remaining elements not yet processed + elements_left = len(nums) - i + if elements_left == groups_needed: + # Each element must form a separate group + result.append([x]) + result_idx.append([i]) + current_sum = x + continue + + if current_sum + x > target_limit: + result.append([x]) + result_idx.append([i]) + current_sum = x + else: + result[-1].append(x) + result_idx[-1].append(i) + current_sum += x + + return result, result_idx + + +def _split_array_min_max_out_of_order(nums: list[int], g: int) -> tuple[list[list[int]], list[list[int]]]: + """ + Split the array nums into g subarrays (order of elements can be changed) + This problem (multi-way number partitioning) is NP-hard. We use a greedy approximation algorithm here. + """ + # 1. Sort numbers in descending order + nums_with_indices = list((nun, i) for i, nun in enumerate(nums)) + sorted_nums = sorted(nums_with_indices, reverse=True) + + # 2. Initialize heap + heap = [(0, i) for i in range(g)] + + # groups to save results + groups = [[] for _ in range(g)] + group_idx = [[] for _ in range(g)] + + # 3. greedy assignment + for num, idx in sorted_nums: + # Pop the bucket with the smallest current sum + current_sum, gidx = heapq.heappop(heap) + + # Add the number to this bucket + groups[gidx].append(num) + group_idx[gidx].append(idx) + + # Update the sum of this bucket and push it back to the heap + new_sum = current_sum + num + heapq.heappush(heap, (new_sum, gidx)) + + return groups, group_idx diff --git a/nnscaler/utils.py b/nnscaler/utils.py index 310c6649..f4d68cb1 100644 --- a/nnscaler/utils.py +++ b/nnscaler/utils.py @@ -4,24 +4,29 @@ import builtins import importlib from contextlib import contextmanager -from functools import wraps +from functools import wraps, cache from typing import ( Generator, Optional, Tuple, Callable, Dict, List, Set, Any, - Iterable, Type, Union, Protocol, ClassVar, cast, TypeVar + Iterable, Type, TypedDict, Union, Protocol, ClassVar, cast, TypeVar ) import logging from pathlib import Path import sys from collections import defaultdict -from dataclasses import dataclass +from dataclasses import dataclass, field import inspect import os +import warnings +from concurrent.futures import ThreadPoolExecutor +import itertools +import numpy as np import nnscaler from nnscaler.flags import RuntimeFlag, CompileFlag import torch + _logger = logging.getLogger(__name__) @@ -112,6 +117,27 @@ def get_member_by_name(model: torch.nn.Module, name: str) -> Any: return model_attr +def set_member_by_name(model: Any, name: str, value: Any) -> None: + """ + Set the member of the model by its full name. + """ + if not name: + raise ValueError("Name cannot be empty") + class _ValueHolder: + """ + A value holder. + In python you can't call `setattr` on object, but you can call it on its subclasses. + """ + pass + sliced_names = name.split(".") + model_attr = model + for sliced_name in sliced_names[:-1]: + if not hasattr(model_attr, sliced_name): + setattr(model_attr, sliced_name, _ValueHolder()) + model_attr = getattr(model_attr, sliced_name) + setattr(model_attr, sliced_names[-1], value) + + def get_shared_params(model: torch.nn.Module) -> List[List[str]]: paramid2name = defaultdict(set) for name in model.state_dict().keys(): @@ -211,65 +237,197 @@ def wrapped_fn(*args, **kwargs): _DICT_ITEMS_TYPE = type({}.items()) _DICT_KEYS_TYPE = type({}.keys()) _DICT_VALUES_TYPE = type({}.values()) +TRANSFORM_SUPPORTED_COLLECTION_TYPES = (tuple, list, dict, set, slice, _DICT_ITEMS_TYPE, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE) -def transform_recursively(data: Any, fn: Callable[[Any], Any], +def _transform_recursively(data: Any, fn: Callable[[Any], Any], target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], collection_types = (tuple, list, dict), skip_dict_keys = True -) -> Any: - """ - Transform the data with the given function, will recursively apply the function to the nested data. - Args: - data: the data to be transformed. - fn: the function to apply. - target_types: the target types to apply the function. - collection_types: the collection types to apply the function to the nested data. - skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). - _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. - """ +) -> tuple[bool, Any]: + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES if isinstance(data, collection_types): if isinstance(data, tuple): - return tuple(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = tuple(_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data) + changed = any(c for c, _ in result) + if changed: + return True, tuple(v for _, v in result) + else: + return False, data if isinstance(data, list): - return list(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, [v for _, v in result] + else: + return False, data if isinstance(data, set): - return set(transform_recursively(t, fn, target_types, collection_types) for t in data) + result = [_transform_recursively(t, fn, target_types, collection_types, skip_dict_keys) for t in data] + changed = any(c for c, _ in result) + if changed: + return True, {v for _, v in result} + else: + return False, data if isinstance(data, dict): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) + if skip_dict_keys: + keys = {k: (False, k) for k in data.keys()} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data.keys() + } + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for k, v in data.items() - } + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + } + else: + return False, data if isinstance(data, _DICT_ITEMS_TYPE): - return { - k if skip_dict_keys else transform_recursively(k, fn, target_types, collection_types): - transform_recursively(v, fn, target_types, collection_types) + if skip_dict_keys: + keys = {k: (False, k) for k, _ in data} + else: + keys = { + k: _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k, _ in data + } + + changed = any(c for c, _ in keys.values()) + result = { + k: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for k, v in data - }.items() + } + changed = changed or any(c for c, _ in result.values()) + if changed: + return True, { + keys[k][1]: v for k, (_, v) in result.items() + }.items() + else: + return False, data if isinstance(data, _DICT_KEYS_TYPE): - return { - transform_recursively(k, fn, target_types, collection_types): i - for i, k in enumerate(data) - }.keys() + result = [ + _transform_recursively(k, fn, target_types, collection_types, skip_dict_keys) + for k in data + ] + changed = any(c for c, _ in result) + if changed: + return True, { + v: i for i, (_, v) in enumerate(result) + }.keys() + else: + return False, data if isinstance(data, _DICT_VALUES_TYPE): - return { - i: transform_recursively(v, fn, target_types, collection_types) + result = { + i: _transform_recursively(v, fn, target_types, collection_types, skip_dict_keys) for i, v in enumerate(data) - }.values() + } + changed = any(c for c, _ in result.values()) + if changed: + return True, { + i: v for i, (_, v) in result.items() + }.values() + else: + return False, data if isinstance(data, slice): - return slice( - transform_recursively(data.start, fn, target_types, collection_types), - transform_recursively(data.stop, fn, target_types, collection_types), - transform_recursively(data.step, fn, target_types, collection_types) + result = ( + _transform_recursively(data.start, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.stop, fn, target_types, collection_types, skip_dict_keys), + _transform_recursively(data.step, fn, target_types, collection_types, skip_dict_keys), ) + if any(c for c, _ in result): + return True, slice( + result[0][1], + result[1][1], + result[2][1] + ) + else: + return False, data raise ValueError(f"Unsupported collection type: {type(data)}") elif isinstance(target_types, (tuple, list)) or inspect.isclass(target_types): if isinstance(data, target_types): - return fn(data) + return True, fn(data) elif callable(target_types): # not a class, but callable. treat as a check function. if target_types(data): - return fn(data) - return data + return True, fn(data) + return False, data + + +def transform_recursively(data: Any, fn: Callable[[Any], Any], + target_types: Union[Callable[[Any], bool], Type, Tuple[Type, ...]], + collection_types = (tuple, list, dict), skip_dict_keys = True +) -> Any: + """ + Transform the data with the given function, will recursively apply the function to the nested data. + Currently supported collection types is SUPPORTED_COLLECTION_TYPES. + Args: + data: the data to be transformed. + fn: the function to apply. + target_types: the target types to apply the function. + collection_types: the collection types to apply the function to the nested data. + Will handle all supported types if None. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + """ + _, result = _transform_recursively(data, fn, target_types, collection_types, skip_dict_keys) + return result + + +def check_recursively(data, fn: Callable[[Any], bool], + collection_types = (tuple, list, dict), + skip_dict_keys = True +) -> bool: + """ + Check the data with the given function, will recursively apply the function to the nested data. + Args: + data: the data to be checked. + fn: the function to check. + collection_types: the collection types to apply the function to the nested data. + skip_dict_keys: whether to skip the dict keys (for types dict, _DICT_ITEMS_TYPE). + _DICT_KEYS_TYPE is not skipped, if you want to skip it, just remove it from the collection_types. + + """ + if collection_types is None: + collection_types = TRANSFORM_SUPPORTED_COLLECTION_TYPES + + if isinstance(data, collection_types): + if isinstance(data, (list, tuple, set, _DICT_KEYS_TYPE, _DICT_VALUES_TYPE)): + return any(check_recursively(t, fn, collection_types) for t in data) + if isinstance(data, dict): + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for v in data.values() + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data.items() + ) + if isinstance(data, _DICT_ITEMS_TYPE): + if skip_dict_keys: + return any( + check_recursively(v, fn, collection_types) + for _, v in data + ) + else: + return any( + check_recursively(k, fn, collection_types) or check_recursively(v, fn, collection_types) + for k, v in data + ) + if isinstance(data, slice): + return any(( + check_recursively(data.start, fn, collection_types), + check_recursively(data.stop, fn, collection_types), + check_recursively(data.step, fn, collection_types) + )) + raise ValueError(f"Unsupported collection type: {type(data)}") + + return fn(data) def is_running_distributed() -> bool: @@ -325,6 +483,21 @@ def fields(model: TDataClass, /) -> TDataClass: return cast(TDataClass, _GetFields(model)) +class _UncheckedFields: + def __getattr__(self, item: str) -> Any: + return item + + +TUncheckedClass = TypeVar("TAnyClass") +def unchecked_fields(_: TUncheckedClass, /) -> TUncheckedClass: + """ + This function is used to get the field names(in str) of any object without checking + This is a workaround for the lack of `__name__` of member. + """ + return cast(TUncheckedClass, _UncheckedFields()) + + +@cache def load_type(type_name: str): """ Load function/class from its full qualified name @@ -457,3 +630,373 @@ def steps(nsteps: int): RuntimeFlag.skip_reducer = (not (step == nsteps - 1)) yield step RuntimeFlag.skip_zero_grad, RuntimeFlag.skip_reducer = old + + +class AdamOptState(TypedDict): + step: torch.Tensor + exp_avg: torch.Tensor + exp_avg_sq: torch.Tensor + + +class OptStateParamGroup(TypedDict): + params: list[int] + lr: float + + +class OptStateDict(TypedDict): + state: dict[int, AdamOptState | dict[str, Union[Any, torch.Tensor]]] + param_groups: list[OptStateParamGroup | dict[str, Union[Any, torch.Tensor]]] + + +def fn_field(**kwargs): + metadata = kwargs.pop('metadata', {}) + metadata['deserialize'] = lambda t: None if t is None else load_type(t) + return field(**kwargs, metadata=metadata) + + +TENSOR_DYNAMIC_DIMS_FIELD_NAME = '_nnscaler_dynamic_dims' +# for nnscaler custom class (TensorMetadata) +NNSCALER_DYNAMIC_DIMS_NAME = 'dynamic_dims' + + +def mark_dynamic(tensor: torch.Tensor, dims: int | list[int] | tuple[int]) -> torch.Tensor: + """ + Mark the dim of a tensor as dynamic, which means it can be changed in the future. + This is the same with `torch._dynamo.mark_dynamic` + """ + dims = [dims] if isinstance(dims, int) else dims + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set(dims) if dims else set()) + return tensor + + +def copy_dynamic(src: torch.Tensor, tensor: torch.Tensor) -> torch.Tensor: + """ + Copy the dynamic dims from src to tensor, and return the tensor. + """ + if hasattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME): + setattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, getattr(src, TENSOR_DYNAMIC_DIMS_FIELD_NAME)) + return tensor + + +def get_dynamic(tensor: Any) -> set[int]: + """ + Get the dynamic dims of a tensor. + It also works when tensor is not an instance of torch.Tensor + """ + if isinstance(tensor, torch.Tensor): + return getattr(tensor, TENSOR_DYNAMIC_DIMS_FIELD_NAME, set()) + else: + return getattr(tensor, NNSCALER_DYNAMIC_DIMS_NAME, set()) + + +@contextmanager +def suppress_warnings(message): + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', message=message) + yield + + +def broadcast_files( + file_groups: List[List[Union[str, Path]]], + *, + max_workers: int = 8, +): + """Broadcast files from src to all other nodes. Files are grouped into file_groups, + and each group of files are broadcasted together to get better performance. + + Args: + files (List[List[str | Path]]): List of file groups to be broadcasted. + Note that the file names should be the same across all ranks. + """ + from nnscaler.runtime.device import DeviceGroup + + # filter out empty file groups + file_groups = [ + fg for fg in file_groups if fg + ] + + curr_rank = torch.distributed.get_rank() + local_world_size = DeviceGroup().local_world_size + world_size = torch.distributed.get_world_size() + local_rank = curr_rank % local_world_size + + # create groups, make sure all groups are created correctly + for i in range(local_world_size): + group_ranks = list(range(i, world_size, local_world_size)) + DeviceGroup().get_group(group_ranks) + + # collect file sizes and broadcast + if curr_rank == 0: + file_group_sizes: List[List[int]] = [ + [os.path.getsize(file) for file in files] for files in file_groups + ] + exchange_objects = [file_group_sizes] + else: + exchange_objects = [None] + + torch.distributed.broadcast_object_list(exchange_objects, src=0) + file_group_sizes = exchange_objects[0] + + # sort file_groups by size descending to improve overlapping + file_groups_sizes_pairs = list(zip(file_groups, file_group_sizes)) + file_groups_sizes_pairs.sort(key=lambda x: sum(x[1]), reverse=True) + file_groups = [pair[0] for pair in file_groups_sizes_pairs] + file_group_sizes = [pair[1] for pair in file_groups_sizes_pairs] + + def _write_file(file: Union[str, Path], buffer, start, size): + _logger.info(f'Rank {curr_rank}: Writing file {file} of size {size} bytes.') + # have better performance than open + write + buffer[start: start + size].numpy().tofile(file) + + def _read_file(file, buffer, start, size): + _logger.info(f'Rank {curr_rank}: Reading file {file} of size {size} bytes.') + # slightly faster than open + read + buffer[start: start + size] = torch.from_numpy(np.fromfile(file, dtype=np.uint8)) + + def _write_files(buffer, files, file_sizes): + buffer = buffer.cpu() + file_starts = itertools.accumulate([0] + file_sizes[:-1]) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + executor.map( + lambda args: _write_file(args[0], buffer, args[1], args[2]), + zip(files, file_starts, file_sizes) + ) + + def _send_file_group(src, files, file_sizes): + total_size = sum(file_sizes) + + ranks = list(range(src, world_size, local_world_size)) + group = DeviceGroup().get_group(ranks) + file_buffer = torch.empty(total_size, dtype=torch.uint8, device='cpu').pin_memory() + + if curr_rank < local_world_size: + file_starts = itertools.accumulate([0] + file_sizes[:-1]) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + executor.map( + lambda args: _read_file(args[0], file_buffer, args[1], args[2]), + zip(files, file_starts, file_sizes) + ) + broadcast_tensor = file_buffer.cuda() + else: + broadcast_tensor = torch.empty(total_size, dtype=torch.uint8, device='cuda') + + torch.distributed.broadcast(broadcast_tensor, src=src, group=group) + + if curr_rank >= local_world_size: + file_buffer.copy_(broadcast_tensor) + _write_files(file_buffer, files, file_sizes) + + # we split the file groups among local ranks + # each local rank sends its assigned file groups (in round robin fashion) + for i in range(local_rank, len(file_groups), local_world_size): + _send_file_group(local_rank, file_groups[i], file_group_sizes[i]) + + +class _TensorIndex: + def __init__(self, index: int): + self.index = index + + def __repr__(self): + return f"_TensorIndex({self.index})" + + +def extract_tensors(data: Dict[str, Any]) -> Tuple[Dict[str, Any], List[torch.Tensor]]: + """ + Extract tensors from a collection, and return the skeleton (by replacing tensors with _TensorIndex) and the list of tensors. + Args: + data (Dict[str, Any]): The collection to be extracted. + Returns: + Tuple[Dict[str, Any], List[torch.Tensor]]: The skeleton and the list of tensors. + """ + tensors = [] + + # used to deduplicate tensors + # TODO: Consider more robust way to identify tensors + # key: (tensor.data_ptr(), tensor.shape, tensor.stride()), value: _Index + tensor_ids: dict[tuple[int, tuple[int, ...], tuple[int, ...]], _TensorIndex] = {} + def transform_fn(o: torch.Tensor) -> Any: + key = (o.data_ptr(), o.shape, o.stride()) + if key in tensor_ids: + idx = tensor_ids[key] + else: + idx = _TensorIndex(len(tensors)) + tensor_ids[key] = idx + tensors.append(o) + return idx + skeleton = transform_recursively(data, transform_fn, target_types=(torch.Tensor,)) + + return skeleton, tensors + + +def refill_tensors(skeleton: Dict[str, Any], tensors: List[torch.Tensor]) -> Dict[str, Any]: + """ + Refill tensors into the skeleton, and return the data. + This is the inverse operation of `extract_tensors`. + + Args: + skeleton (Dict[str, Any]): The skeleton to be refilled. + tensors (List[torch.Tensor]): The list of tensors to be refilled. + Returns: + Dict[str, Any]: The data. + """ + def transform_fn(o: _TensorIndex) -> Any: + return tensors[o.index] + state_dict = transform_recursively(skeleton, transform_fn, target_types=_TensorIndex) + return state_dict + + +def broadcast_mixed_data( + data: Optional[dict] = None, + *, + src_rank: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, + device: Optional[Union[str, torch.device]] = None, +): + """ + Broadcast the data (containing tensors) from src_rank to all other ranks. + + Args: + data (Optional[dict]): The data to be broadcasted. + for non-src ranks, this must be None. + src_rank (int): The source rank to broadcast from. Default: 0. + group (torch.distributed.ProcessGroup, optional): The process group to use for broadcasting. + If None, the default process group will be used. Default: None. + device (str or torch.device, optional): The device to use for receiving tensors on non-src ranks. + If None, the current cuda device will be used. Default: None. + + Returns: + dict: The broadcasted data. + For src_rank, it is the same as the input data. + For non-src ranks, it is the broadcasted data. the device of tensors will be cuda. + """ + device = device or torch.cuda.current_device() + if isinstance(device, str): + # need to compare device later, so convert to torch.device + device = torch.device(device) + rank = torch.distributed.get_rank(group=group) + + # share the structure and tensor shapes + if rank == src_rank: + if data is None: + raise ValueError("data must not be None in src_rank") + skeleton, tensors = extract_tensors(data) + meta_tensors = [t.to('meta') for t in tensors] + sent = [(skeleton, meta_tensors)] + else: + if data is not None: + raise ValueError("data must be None in non-src ranks") + skeleton, tensors, meta_tensors = None, None, None + sent = [None] + + torch.distributed.broadcast_object_list(sent, src=src_rank, group=group) + skeleton, meta_tensors = sent[0] + if rank != src_rank: + tensors = [None] * len(meta_tensors) + + # broadcast tensor data + for i in range(len(tensors)): + if rank != src_rank: + tensor = torch.empty_like(meta_tensors[i], device='cuda') + else: + # make sure tensors are in cuda + tensor = tensors[i].cuda() + + torch.distributed.broadcast(tensor, src=src_rank, group=group) + + if rank != src_rank: + tensors[i] = tensor.to(device, non_blocking=True) + else: + # try to reuse the existing tensors if device matches + if tensor.device == device: + tensors[i] = tensor + else: + tensors[i] = tensors[i].to(device, non_blocking=True) + + # refill tensors + return refill_tensors(skeleton, tensors) + + +def gather_mixed_data( + data: dict, + *, + src_rank: int = 0, + group: Optional[torch.distributed.ProcessGroup] = None, + device: Optional[Union[str, torch.device]] = None, +): + """ + Gather the data (containing tensors) from all ranks to src_rank. + + Args: + data (dict): The data to be gathered. + src_rank (int): The source rank to gather to. Default: 0. + group (torch.distributed.ProcessGroup, optional): The process group to use for gathering. + If None, the default process group will be used. Default: None. + device (str or torch.device, optional): The device to use for receiving tensors on src_rank. + If None, the current cuda device will be used. Default: None. + If you want to save memory, you can set it to 'cpu' to move tensors to cpu after receiving. + Returns: + dict: The gathered data. + For src_rank, it is the gathered data from all ranks. + For non-src ranks, it is None. + """ + device = torch.cuda.current_device() if device is None else device + + rank = torch.distributed.get_rank(group=group) + world_size = torch.distributed.get_world_size(group=group) + result = [None] * world_size + result[rank] = data + + skeleton, tensors = extract_tensors(data) + sent = (skeleton, [t.to('meta') for t in tensors]) + + # Gather metadata from all ranks + gathered_sent = [None for _ in range(world_size)] + torch.distributed.all_gather_object(gathered_sent, sent, group=group) + + def _send_recv_tensors( + sender: int, + skel: Dict[str, Any], + tensors: list[torch.Tensor] + ) -> Dict[str, Any]: + if rank == src_rank: + assert all(tensor.device.type == 'meta' for tensor in tensors), \ + "Tensors should be on meta device on rank 0." + if rank != src_rank: + assert all(tensor.device.type != 'meta' for tensor in tensors), \ + f"Tensors should not be on meta device on rank {rank}." + + if rank == src_rank: + cuda_tensors = [torch.empty_like(tensor, device='cuda') for tensor in tensors] + else: + cuda_tensors = [tensor.cuda() for tensor in tensors] + + for i in range(len(tensors)): + if rank == src_rank: + torch.distributed.recv(cuda_tensors[i], group_src=sender, group=group) + else: + torch.distributed.send(cuda_tensors[i], group_dst=src_rank, group=group) + + if rank == src_rank: + tensors = [tensor.to(device, non_blocking=True) for tensor in cuda_tensors] + return transform_recursively( + skel, + lambda idx: tensors[idx.index], + target_types=_TensorIndex, + ) + else: + return None # only rank 0 needs the recovered state dict + + # TODO: It may have performance issue if the number of ranks is large + for i in range(0, world_size): + if i == src_rank: + continue + if rank == src_rank: + result[i] = _send_recv_tensors(i, gathered_sent[i][0], gathered_sent[i][1]) + elif rank == i: + _send_recv_tensors(rank, skeleton, tensors) + torch.distributed.barrier(group=group) + + if rank == src_rank: + return result + else: + return None diff --git a/pipelines/nightly-build.yaml b/pipelines/nightly-build.yaml new file mode 100644 index 00000000..da792cf6 --- /dev/null +++ b/pipelines/nightly-build.yaml @@ -0,0 +1,25 @@ +trigger: +- main + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/nightly + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py --nightly + python -m build + displayName: build wheel + +- script: | + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + python -m twine upload -r nightly --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload nightly wheel diff --git a/pipelines/release.yaml b/pipelines/release.yaml new file mode 100644 index 00000000..b7a06433 --- /dev/null +++ b/pipelines/release.yaml @@ -0,0 +1,39 @@ +# depends on two variables: +# +# - version +# must be set on devops website for each run +# the value should be something like "0.1" or "v0.1a1" (w/ or w/o leading v) +# +# - test_pypi_token +# secret, should never expire +# to view it or to update it, check onenote accounts/pypi page (test.pypi token) + +trigger: none +pr: none + +pool: + vmImage: ubuntu-latest + +steps: +- task: TwineAuthenticate@1 + inputs: + artifactFeed: SuperScaler/release + +- script: | + python -m pip install --upgrade build twine + displayName: prepare environment + +- script: | + python pipelines/scripts/update_version.py $(version) + python -m build + number_of_wheels=`ls dist/*.whl | wc -l` + test $number_of_wheels -eq 1 + displayName: build wheel + +- script: | + python -m twine upload -r release --config-file $(PYPIRC_PATH) dist/*.whl + displayName: upload to artifact + +- script: | + python -m twine upload -r testpypi -p $(test_pypi_token) dist/*.whl + displayName: upload to testpypi diff --git a/pipelines/scripts/update_version.py b/pipelines/scripts/update_version.py new file mode 100644 index 00000000..98b4f289 --- /dev/null +++ b/pipelines/scripts/update_version.py @@ -0,0 +1,71 @@ +""" +Update "nnscaler/version.py" before building the wheel. + +Usage 1: + + python update_version.py --nightly + +Update version.py to "X.Y.dev{TIMESTAMP}+{GIT_COMMIT}". + +Usage 2: + + python update_version.py 1.2 + python update_version.py v1.2b3 + +Update version.py to the specified version (normalized, leading "v" removed). +It will verify that the release part matches the old version. +""" + +import argparse +from datetime import datetime +from pathlib import Path +import subprocess + +from packaging.version import Version + +project_dir = Path(__file__).parents[2] + +def main(): + parser = argparse.ArgumentParser(add_help=False) + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument('--nightly', action='store_true') + group.add_argument('version', nargs='?') + args = parser.parse_args() + + version_file = Path(project_dir, 'nnscaler/version.py') + file_content = version_file.read_text() + version_str = file_content.split('=')[-1].strip()[1:-1] # "version = 'x'" -> "x" + repo_version = Version(version_str) + + if args.nightly: + timestamp = datetime.now().strftime('%y%m%d%H%M') + + r = subprocess.run( + 'git rev-parse --short HEAD'.split(), + stdout=subprocess.PIPE, + cwd=project_dir, + text=True, + ) + if r.returncode != 0: + print('[error] failed to get git commit hash') + exit(1) + commit = r.stdout.strip() + + new_version_str = f'{repo_version.base_version}.dev{timestamp}+{commit}' + + else: + arg_version = Version(args.version) + + if repo_version.release != arg_version.release: + print('[error] version not match') + print(f' repo: {version_str} -> {repo_version}') + print(f' arg: {args.version} -> {arg_version}') + exit(1) + + new_version_str = str(arg_version) # normalize + + file_content = file_content.replace(version_str, new_version_str) + version_file.write_text(file_content) + +if __name__ == '__main__': + main() diff --git a/requirements-dev.txt b/requirements-dev.txt index 7d181749..df05299d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,14 +8,15 @@ pytest pytest-cov pytest-mock scikit-learn -lightning +lightning==2.5.1.post0 sphinx sphinxcontrib-napoleon tabulate tox -tox-conda +tox-uv yapf wandb tensorboard mosaicml-streaming cppimport +einops diff --git a/requirements.txt b/requirements.txt index 6ae52a83..0d337d5b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,6 @@ psutil pulp pybind11<3.0.0 pyyaml -torch>=2.0,<=2.6 +torch>=2.0,<=2.8 tqdm +safetensors diff --git a/tests/autodist/test_dp_solver.py b/tests/autodist/test_dp_solver.py index 846e1c05..c6b172a4 100644 --- a/tests/autodist/test_dp_solver.py +++ b/tests/autodist/test_dp_solver.py @@ -37,6 +37,7 @@ def test_dp_solver(): # the optimal plan is each operator's first partition assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem(): solver = dp_solver.DPSolver(True, 100, 1) solver.add_interval(0, 4) @@ -73,6 +74,7 @@ def test_dp_solver_mem(): assert best.path == [(0, 0), (1, 0), (2, 0), (3, 0), (4, 0)] assert best.memory == 71 + def test_dp_solver_build_in_edges(): # mock following code # dropout_rate = self.attention_dropout if self.training else 0.0 @@ -102,6 +104,7 @@ def test_dp_solver_build_in_edges(): best = ans[0] assert best.path == [(0, 0), (1, 0), (2, 0)] + def test_dp_solver_mem_bound(): solver = dp_solver.DPSolver(True, 10, 1) solver.add_interval(0, 2) @@ -119,3 +122,26 @@ def test_dp_solver_mem_bound(): ans = solver.get_results(0, 2) assert len(ans) == 0 + + +def test_dp_solver_output(): + solver = dp_solver.DPSolver(True, 1024, 1) + solver.add_interval(0, 2) + + solver.add_node(0, 0, [0], [], 2, False, False, False) + solver.add_partition(0, 0, 10, 16, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(0, 1, 5, 8, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(1, 1, [0, 1], [], 2, False, False, False) + solver.add_partition(1, 0, 4, 6, 0, 0, 0, 0, 0, [[]]) + solver.add_partition(1, 1, 2, 3, 0, 0, 0, 0, 1, [[]]) + + solver.add_node(2, 2, [2], [0], 1, False, False, False) + solver.add_partition(2, 0, 0, 0, 0, 0, 0, 0, 0, [[0, 0]]) + + solver.solve() + ans = solver.get_results(0, 2) + best = ans[0] + assert best.all_time == 7 + assert best.path == [(0, 1), (1, 1), (2, 0)] + assert best.memory == 11 diff --git a/tests/cli/common.py b/tests/cli/common.py index f5be1c9d..02f8b197 100644 --- a/tests/cli/common.py +++ b/tests/cli/common.py @@ -1,6 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# CausalSelfAttention is copied from https://github.com/karpathy/nanoGPT/blob/master/model.py +# with minor modifications. +# See the original license in the file https://github.com/karpathy/nanoGPT/blob/master/LICENSE + from pathlib import Path import torch from torch import nn @@ -9,11 +13,88 @@ from streaming import MDSWriter, StreamingDataset, StreamingDataLoader +import nnscaler from nnscaler.cli.trainer_args import TrainerArgs from tests.parallel_module.test_end2end import MLP from tests.utils import init_random as init_random_fn + +class CausalSelfAttention(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=True) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd, bias=True) + # regularization + self.attn_dropout = nn.Dropout(dropout) + self.resid_dropout = nn.Dropout(dropout) + self.n_head = n_head + self.n_embd = n_embd + self.dropout = dropout + + def forward(self, x): + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) + y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + + +class SimpleTransformerModel(nn.Module): + def __init__(self, n_embd: int, n_head: int, dropout: float, nlayers: int, vocab_size: int): + super().__init__() + + self.layers = nn.ModuleList([]) + self.lm_head = nn.Linear(n_embd, vocab_size, bias=False) + for _ in range(nlayers): + self.layers.append(CausalSelfAttention(n_embd, n_head, dropout)) + + def forward(self, data): + x = data['input'] + target = data['target'] + for layer in self.layers: + x = layer(x) + logits = self.lm_head(x) + loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1), ignore_index=-1) + return loss + + +def csa_forward_args_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + + return { + 'x': torch.randn(1, seq_len, trainer_args.model.args['n_embd']), + } + + +def post_csa_forward_args_gen_fn(trainer_args: TrainerArgs, args): + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + nnscaler.mark_dynamic(args['x'], dynamic_dims) + return args + + +def transformer_dummy_sample_gen_fn(trainer_args: TrainerArgs): + seq_len = 128 # dynamicness is controlled by trainer_args.vars['dynamic_dims'] + dynamic_dims = trainer_args.get_resolved_var('dynamic_dims', default=[]) + return { + 'input': nnscaler.mark_dynamic(torch.randn(1, seq_len, trainer_args.model.args['n_embd']), dynamic_dims), + 'target': nnscaler.mark_dynamic(torch.randint(0, trainer_args.model.args['vocab_size'], (1, seq_len)), dynamic_dims), + } + + class MixModuleMLP(nn.Module): def __init__(self, dim: int, nlayers: int, init_random: bool = True): super().__init__() diff --git a/tests/cli/test_arg_parser.py b/tests/cli/test_arg_parser.py index 427c8f26..dffb0813 100644 --- a/tests/cli/test_arg_parser.py +++ b/tests/cli/test_arg_parser.py @@ -209,6 +209,28 @@ class A: assert y.p == {'value': 'auto'} +def test_merge_dict(): + a = { + 'compute_config': { + 'plan_ngpus': 1 + }, + 'optimizer': { + 'type': 'torch.nn.Adam', + 'args': { + 'lr': 0.001 + } + } + } + merge_args(a, ['--optimizer', { + 'type': 'torch.nn.AdamW', + 'args': { + 'hello': 'haha' + } + }]) + assert a['optimizer']['args']['lr'] == 0.001 + assert a['optimizer']['args']['hello'] == 'haha' + + def test_merge_list(): @dataclass class A: diff --git a/tests/cli/test_hooks.py b/tests/cli/test_hooks.py new file mode 100644 index 00000000..102a48bc --- /dev/null +++ b/tests/cli/test_hooks.py @@ -0,0 +1,30 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Any, List + +from nnscaler.cli.train_hook import TrainHook, TrainHookHost + + +class A(TrainHook): + pass + +class B(TrainHook): + pass + +class C(TrainHook, TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [A(), B(), self] + + +class D(TrainHookHost): + def _get_hook_objects(self) -> List[Any]: + return [self, A(), C()] + +def test_hook(): + hooks = D().get_hooks() + assert len(hooks) == 4 + assert isinstance(hooks[0], A) + assert isinstance(hooks[1], C) + assert isinstance(hooks[2], A) + assert isinstance(hooks[3], B) diff --git a/tests/cli/test_serialization.py b/tests/cli/test_serialization.py new file mode 100644 index 00000000..e4120cfd --- /dev/null +++ b/tests/cli/test_serialization.py @@ -0,0 +1,251 @@ +import pytest +import torch +from pathlib import Path + +from nnscaler.cli.serialization import ( + convert_format, SerializationRunner, register_serialization_runner, + Checkpointer +) +from nnscaler.cli.trainer import Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import assert_equal + + +def test_runner(tmp_path): + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def run_load(self, load_func, f, *, device='cpu'): + model_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], f) + save_func(state_dict['optimizer'], str(f) + '.opt') + + register_serialization_runner(SplitSerializationRunner) + + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + d = torch.randn((3, 3), device='cpu') + tensors = { + "model": { + "embedding": a, + "attention": b, + }, + "optimizer": { + "state": { + 0: { + "exp_avg": c, + "exp_avg_sq": d, + } + } + } + } + checkpointer = Checkpointer() + checkpointer.save(tensors, tmp_path / "model.ckpt") + checkpointer.flush() + + convert_format( + src=str(tmp_path / "model.ckpt"), + dst=str(tmp_path / "model_split.ckpt"), + dst_serializer='split', + ) + + assert Path(tmp_path / "model_split.ckpt").exists() + assert Path(tmp_path / "model_split.ckpt.opt").exists() + tensor3 = Checkpointer(serializer='split').load(tmp_path / "model_split.ckpt") + assert_equal(tensors, tensor3) + + checkpointer2 = Checkpointer(serializer=':split') + tensor2 = checkpointer2.load(tmp_path / "model.ckpt") + assert_equal(tensors, tensor2) + + checkpointer2.save(tensor2, tmp_path / "model_split2.ckpt") + checkpointer2.flush() + assert Path(tmp_path / "model_split2.ckpt").exists() + assert Path(tmp_path / "model_split2.ckpt.opt").exists() + + tensor4 = Checkpointer(serializer='split').load(tmp_path / "model_split2.ckpt") + assert_equal(tensors, tensor4) + + +def trainer_split_serializer_worker(tmp_path, symblink): + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' + use_zero = 1 + format = 'safetensors' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) + + + class SplitSerializationRunner(SerializationRunner): + name: str = 'split' + + def __init__(self, mark=''): + self.mark = mark + + def run_load(self, load_func, f, *, device='cpu'): + other_state_dict = load_func(f, device=device) + opt_state_dict = load_func(str(f) + '.opt', device=device) + model_state_dict = load_func(str(f) + '.model', device=device) + return { + 'model': model_state_dict, + 'optimizer': opt_state_dict, + **other_state_dict + } + + def run_save(self, save_func, state_dict, f): + save_func(state_dict['model'], str(f) + '.model') + save_func(state_dict['optimizer'], str(f) + '.opt') + other_state_dict = {k: v for k, v in state_dict.items() if k not in ['model', 'optimizer']} + other_state_dict['mark'] = self.mark + save_func(other_state_dict, f) + + register_serialization_runner(SplitSerializationRunner) + + # train 4 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer.name', 'split', + '--checkpoint.serializer.args.mark', 'hello', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + torch.distributed.barrier() + + ckpt_files = list_ckpt_files(ckpt_savedir) + assert len(ckpt_files)/4 == min(10, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last + + for f in ckpt_files: + assert trainer.checkpointer.load(f)['mark'] == 'hello' + assert Path(str(f) + '.opt').exists() + assert Path(str(f) + '.model').exists() + + torch.distributed.barrier() + # train 4 epcho two times (resume from last) + ckpt0_savedir = save_dir / 'ckpt0' + # first two epochs + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '2', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.format', format, + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + torch.distributed.barrier() + # create merged checkpoint + ckpt1_savedir = save_dir / 'ckpt1' + ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' + if trainer.rank == 0: + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name, serializer='split') + assert Path(str(ckpt1_savedir / merged_file_name) + '.opt').exists() + assert Path(str(ckpt1_savedir / merged_file_name) + '.model').exists() + + torch.distributed.barrier() + # continue with the last two epochs (resume for sharded/deduped checkpoint) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt0_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + torch.distributed.barrier() + + # continue with the last two epochs (resume for merged) + trainer = Trainer([ + '-f', config_path, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--compute_config.use_zero', str(use_zero), + '--checkpoint.save_type', 'deduped', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), + '--checkpoint.keep_last_n_checkpoints', '10', + '--checkpoint.serializer', 'split', + '--checkpoint.symlink_best_and_last', str(symblink), + ]) + trainer.run() + + torch.distributed.barrier() + + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} + for i in range(4): + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + assert_equal(x['lr_scheduler'], y['lr_scheduler']) + assert_equal(x['model'], z['model']) + assert_equal(x['optimizer'], z['optimizer']) + assert_equal(x['lr_scheduler'], z['lr_scheduler']) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('symblink', [True, False]) +def test_trainer_split_serializer(tmp_path, symblink): + launch_torchrun(4, trainer_split_serializer_worker, tmp_path, symblink) diff --git a/tests/cli/test_trainer.py b/tests/cli/test_trainer.py index 2dc1d252..4e3ffea5 100644 --- a/tests/cli/test_trainer.py +++ b/tests/cli/test_trainer.py @@ -2,13 +2,19 @@ # Licensed under the MIT License. from pathlib import Path +import re import shutil +from typing import Any +from mock import PropertyMock import torch import pytest import torch.distributed +from unittest.mock import patch from nnscaler import merge_state_dicts +from nnscaler.cli.serialization import Checkpointer +import nnscaler from nnscaler.cli.trainer import Trainer, logger from nnscaler.cli.trainer_args import AggregatedOutputs, TrainerArgs from tests.parallel_module.common import assert_equal, assert_close @@ -117,7 +123,12 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): optimizer_type = 'nnscaler.runtime.f16_optimizer.MixedPrecisionAdam' \ if bf16 == 'Mixed' \ else 'torch.optim.Adam' - use_zero = save_type == 'sharded' + use_zero = 1 if save_type == 'sharded' else 0 + format = 'safetensors' if parallel_type % 2 else 'pt' + rev_format = 'pt' if format == 'safetensors' else 'safetensors' + + def list_ckpt_files(dir): + return set(dir.glob('**/*.ckpt')) | set(dir.glob('**/*.safetensors')) if parallel_type == 0: additional_args = [] @@ -147,7 +158,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -164,7 +175,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False' if use_zero else 'True', + '--model.parallel_modules.0.compute_config.use_zero', str(use_zero), '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -191,10 +202,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt_files = set(ckpt_savedir.glob('**/*.ckpt')) + ckpt_files = list_ckpt_files(ckpt_savedir) assert len(ckpt_files)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -215,10 +227,11 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', format, *additional_args, ]) trainer.run() - ckpt0_files0 = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0 = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} assert len(ckpt0_files0)/4 == min(30, trainer.total_train_steps_per_epoch * 2) + 2 # 2 for best/last # resume from last without update max_epochs @@ -236,18 +249,20 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', '--checkpoint.keep_last_n_checkpoints', '30', + '--checkpoint.format', rev_format, *additional_args, ]) trainer.run() - ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in ckpt0_savedir.glob('**/*.ckpt')} + ckpt0_files0_x = {f: f.stat().st_mtime_ns for f in list_ckpt_files(ckpt0_savedir)} # nothing should be updated in this case. assert ckpt0_files0 == ckpt0_files0_x # create merged checkpoint ckpt1_savedir = save_dir / 'ckpt1' ckpt1_savedir.mkdir(parents=True, exist_ok=True) + merged_file_name = f'merged{Checkpointer.NAME_MAP[format]}' if trainer.rank == 0: - Trainer.merge_checkpoint(list((ckpt0_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(trainer.checkpointer.list_checkpoints(ckpt0_savedir / 'last'), ckpt1_savedir / merged_file_name) torch.distributed.barrier() # continue with the last two epochs (resume for sharded/deduped checkpoint) @@ -264,6 +279,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt0_savedir), '--checkpoint.resume_from', 'last', + '--checkpoint.format', rev_format, '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -276,7 +292,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -293,7 +309,8 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): '--compute_config.use_zero', str(use_zero), '--checkpoint.save_type', save_type, '--checkpoint.save_dir', str(ckpt1_savedir), - '--checkpoint.resume_from', str(ckpt1_savedir / 'merged.pt'), + '--checkpoint.format', rev_format, + '--checkpoint.resume_from', str(ckpt1_savedir / merged_file_name), '--checkpoint.keep_last_n_checkpoints', '30', *additional_args, ]) @@ -306,7 +323,7 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): for f, s in left_files.items(): # make sure the old checkpoints are not overwritten assert ckpt0_files0[f] == s - ckpt0_files1 = set(ckpt0_savedir.glob('**/*.ckpt')) + ckpt0_files1 = list_ckpt_files(ckpt0_savedir) assert len(ckpt0_files1)/4 == min(30, trainer.total_train_steps_per_epoch * 4) + 2 # 2 for best/last torch.distributed.barrier() @@ -314,9 +331,9 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): if torch.distributed.get_rank() == 0: assert {f.parent.name for f in ckpt_files} == {f.parent.name for f in ckpt0_files1} for i in range(4): - x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) - y = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) - z = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + x = trainer.checkpointer.load_for_rank(ckpt_savedir / 'last', i) + y = trainer.checkpointer.load_for_rank(ckpt0_savedir / 'last', i) + z = trainer.checkpointer.load_for_rank(ckpt1_savedir / 'last', i) assert_equal(x['model'], y['model']) assert_equal(x['optimizer'], y['optimizer']) assert_equal(x['lr_scheduler'], y['lr_scheduler']) @@ -324,12 +341,13 @@ def trainer_resume_worker(save_dir, save_type, bf16, parallel_type=0): assert_equal(x['optimizer'], z['optimizer']) assert_equal(x['lr_scheduler'], z['lr_scheduler']) + suffix = Checkpointer.NAME_MAP[format] if save_type == 'deduped': - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size > (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size > (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size > (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size > (ckpt_savedir / f'last/3{suffix}').stat().st_size else: - assert (ckpt_savedir / 'last/0.ckpt').stat().st_size == (ckpt_savedir / 'last/2.ckpt').stat().st_size - assert (ckpt_savedir / 'last/1.ckpt').stat().st_size == (ckpt_savedir / 'last/3.ckpt').stat().st_size + assert (ckpt_savedir / f'last/0{suffix}').stat().st_size == (ckpt_savedir / f'last/2{suffix}').stat().st_size + assert (ckpt_savedir / f'last/1{suffix}').stat().st_size == (ckpt_savedir / f'last/3{suffix}').stat().st_size @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @@ -593,7 +611,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): gen_savedir = save_dir / 'gen' ckpt_savedir = save_dir / 'ckpt' optimizer_type = 'torch.optim.Adam' - use_zero = False if zero_ngroups is None else True + use_zero = 0 if zero_ngroups is None else 1 zero_ngroups = '1' if zero_ngroups is None else zero_ngroups trainer = Trainer([ @@ -614,7 +632,7 @@ def trainer_grad_sync_check(save_dir, use_bf16, zero_ngroups, runtime_ngpus): torch.distributed.barrier() -def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): +def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False, hybrid_opt=False, use_zero=0): save_dir = Path(save_dir) config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) gen_savedir = save_dir / 'gen' @@ -651,7 +669,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -668,7 +686,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--model.parallel_modules.0.type', 'tests.cli.common.MixModuleMLP', '--model.parallel_modules.0.args.dim', '16', '--model.parallel_modules.0.args.nlayers', '16', - '--model.parallel_modules.0.compute_config.use_zero', 'False', + '--model.parallel_modules.0.compute_config.use_zero', '0', '--model.parallel_modules.0.compute_config.constant_folding', 'False', '--model.parallel_modules.0.pas_policy', 'tp', '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.forward_args_gen_fn', @@ -680,6 +698,47 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): else: raise ValueError(f'parallel_type {parallel_type} is not supported') + + def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'mlp0.' in param_name: + return 0, 0 + elif 'mlp1.' in param_name: + return 0, 1 + else: + return 1, 0 + + optimizer_config = { + 'type': 'nnscaler.HybridOptimizer', + 'param_clss_fn': param_clss_fn, + 'args': { + 'config': { + 'optimizers':[ + { + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01, + }, + 'param_groups': [ + {}, + {} + ], + },{ + 'type': torch.optim.Adam, + 'options': { + 'lr': 0.01 + } + } + ] + } + } + } + + if hybrid_opt: + additional_args.extend(['--optimizer!', '--optimizer', optimizer_config]) + # train 4 epcho in one time trainer = Trainer([ '-f', config_path, @@ -687,7 +746,7 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): '--max_epochs', '2', '--enable_progress_bar', 'false', '--gen_savedir', str(gen_savedir), - '--compute_config.use_zero', 'False', + '--compute_config.use_zero', str(use_zero), '--compute_config.plan_ngpus', '1', '--compute_config.runtime_ngpus', '2', '--compute_config.use_async_reducer', str(async_reducer), @@ -713,33 +772,49 @@ def trainer_correctness_worker(save_dir, parallel_type=0, async_reducer=False): torch.distributed.barrier() -def trainer_correctness_worker_aggregate(tmp_path): +def trainer_correctness_worker_aggregate(tmp_path, use_zero): for parallel_type in range(5): for async_reducer in [False, True]: - print(f'parallel_type={parallel_type}, async_reducer={async_reducer}') - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - trainer_correctness_worker(save_dir, parallel_type, async_reducer) + for hybrid_opt in [True, False]: + print(f'parallel_type={parallel_type}, async_reducer={async_reducer}, hybrid_opt={hybrid_opt}') + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + trainer_correctness_worker(save_dir, parallel_type, async_reducer, hybrid_opt, use_zero) @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_trainer_correctness(tmp_path): - launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path) +@pytest.mark.parametrize('use_zero', [0, 1, 3]) +def test_trainer_correctness(tmp_path, use_zero): + launch_torchrun(2, trainer_correctness_worker_aggregate, tmp_path, use_zero) merged_ckpts = {} for parallel_type in range(5): for async_reducer in [False, True]: - save_dir = tmp_path/f'{parallel_type}-{async_reducer}' - merged_ckpts[(parallel_type, async_reducer)] = torch.load(save_dir/'merged.pt') + for hybrid_opt in [True, False]: + save_dir = tmp_path/f'{parallel_type}-{async_reducer}-{hybrid_opt}' + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)] = torch.load(save_dir/'merged.pt') + + if use_zero == 3: + assert_fn = assert_close + else: + assert_fn = assert_equal for parallel_type in range(5): for async_reducer in [False, True]: - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['model'], - merged_ckpts[(0, False)]['model'] - ) - assert_equal( - merged_ckpts[(parallel_type, async_reducer)]['optimizer'], - merged_ckpts[(0, False)]['optimizer'] - ) + for hybrid_opt in [True, False]: + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['model'], + merged_ckpts[(0, False, False)]['model'] + ) + if not hybrid_opt: + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer'], + merged_ckpts[(0, False, False)]['optimizer'] + ) + else: + # param_groups are different when using hybrid optimizer. + assert_fn( + merged_ckpts[(parallel_type, async_reducer, hybrid_opt)]['optimizer']['state'], + merged_ckpts[(0, False, False)]['optimizer']['state'] + ) def tracing_from_weights_worker(tmp_path): @@ -936,19 +1011,45 @@ def trainer_resumable_dataloader(save_dir): torch.distributed.barrier() # resume for merged - trainer = Trainer([ - '-f', config_path_streaming, - '--precision', 'bf16', - '--optimizer.type', optimizer_type, - '--enable_progress_bar', 'false', - '--gen_savedir', str(gen_savedir), - '--checkpoint.save_type', save_type, - '--checkpoint.save_dir', str(ckpt2_savedir), - '--checkpoint.resume_from', str(ckpt2_savedir / 'merged.pt'), - '--checkpoint.keep_last_n_checkpoints', '30', - ]) - trainer.run() - assert trainer.dataloader_resumed + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', False, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' not in log.getvalue() # no warning about dataloader states + + torch.distributed.barrier() + + + ckpt2_1_savedir = save_dir / 'ckpt2_1' + ckpt2_1_savedir.mkdir(parents=True, exist_ok=True) + # resume for merged + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt2_1_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt2_savedir / 'merged.pt'), + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states torch.distributed.barrier() @@ -981,20 +1082,44 @@ def trainer_resumable_dataloader(save_dir): '--checkpoint.save_dir', str(ckpt4_savedir), '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', False, '--checkpoint.keep_last_n_checkpoints', '30', ]) trainer.run() assert trainer.dataloader_resumed assert 'Broadcasting merged checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + # resume from auto-merged with save_memory + ckpt5_savedir = save_dir / 'ckpt5' + with catch_log(logger) as log: + trainer = Trainer([ + '-f', config_path_streaming, + '--precision', 'bf16', + '--optimizer.type', optimizer_type, + '--enable_progress_bar', 'false', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_type', save_type, + '--checkpoint.save_dir', str(ckpt5_savedir), + '--checkpoint.resume_from.checkpoint', str(ckpt1_savedir / '0002-0035'), + '--checkpoint.resume_from.with_merged', True, + '--checkpoint.resume_from.save_memory', True, + '--checkpoint.keep_last_n_checkpoints', '30', + ]) + trainer.run() + assert trainer.dataloader_resumed + assert 'Broadcasting trimmed checkpoint to all ranks.' in log.getvalue() # no warning about dataloader states + + if torch.distributed.get_rank() == 0: for i in range(4): g = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) z = torch.load(ckpt2_savedir / 'last' / f'{i}.ckpt', weights_only=False) + z_1 = torch.load(ckpt2_1_savedir / 'last' / f'{i}.ckpt', weights_only=False) w = torch.load(ckpt3_savedir / 'last' / f'{i}.ckpt', weights_only=False) v = torch.load(ckpt4_savedir / 'last' / f'{i}.ckpt', weights_only=False) + u = torch.load(ckpt5_savedir / 'last' / f'{i}.ckpt', weights_only=False) assert 'dataloader' not in g assert 'dataloader' in x for key in ['model', 'optimizer', 'lr_scheduler', 'dataloader']: @@ -1002,6 +1127,8 @@ def trainer_resumable_dataloader(save_dir): assert_equal(x[key], z[key]) assert_equal(x[key], w[key]) assert_equal(x[key], v[key]) + assert_equal(x[key], u[key]) + assert_equal(x[key], z_1[key]) if key != 'dataloader': assert_equal(g[key], x[key]) @@ -1009,3 +1136,396 @@ def trainer_resumable_dataloader(save_dir): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_trainer_resumable_dataloader(tmp_path): launch_torchrun(4, trainer_resumable_dataloader, tmp_path) + + +@replace_all_device_with('cpu') +def test_trainer_dynamic_worker(tmp_path): + + def check_match(code_dir: Path, should_exist: bool): + gencode_files = list(code_dir.glob('**/*.py')) + assert set(f.name for f in gencode_files) == set(['gencode0.py', 'gencode1.py', 'gencode2.py', 'gencode3.py']) + for gencode_file in gencode_files: + filecontent = gencode_file.read_text() + matches = re.findall(r'B, T, C = x\.size\(\)', filecontent) + if should_exist: + assert matches + else: + assert not matches + + shutil.rmtree(code_dir) + + save_dir = Path(tmp_path) + config_path = str(Path(__file__).with_name('trainer_args_csa.yaml').resolve()) + gen_savedir = save_dir / 'gen' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + gen_savedir = save_dir / 'gen0' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) + + # mixed compile + gen_savedir = save_dir / 'gen1' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[1]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=True) + + # mixed compile + gen_savedir = save_dir / 'gen2' + # compile only + trainer = Trainer([ + '-f', config_path, + '--vars.dynamic_dims', '[]', + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--compute_config.plan_ngpus', '2', + '--compute_config.runtime_ngpus', '4', + '--checkpoint.no_save', 'true', + '--run_mode', 'compile', + '--broadcast_strategy', 'none', + + '--model.parallel_modules.0.type', 'tests.cli.common.CausalSelfAttention', + '--model.parallel_modules.0.args.n_embd', '$(model.args.n_embd)', + '--model.parallel_modules.0.args.n_head', '$(model.args.n_head)', + '--model.parallel_modules.0.args.dropout', '$(model.args.dropout)', + '--model.parallel_modules.0.forward_args_gen_fn', 'tests.cli.common.csa_forward_args_gen_fn', + '--model.parallel_modules.0.forward_args_post_process_fn', 'tests.cli.common.post_csa_forward_args_gen_fn', + ]) + trainer.run() + check_match(gen_savedir, should_exist=False) + + +def trainer_zero3(dim, save_dir, plan_ngpus, runtime_ngpus): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + + zero_ngroups = runtime_ngpus // plan_ngpus // 2 + if zero_ngroups < 1: + zero_ngroups = 1 + policy = 'dp' if plan_ngpus == 1 else 'tp' + + gen3_savedir = save_dir / 'gen3' + ckpt3_savedir = save_dir / 'ckpt3' + # train 1 epcho in one time with zero3 + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '5', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + # load from sharded + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + # load from deduped + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '15', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', 'last', + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + torch.distributed.barrier() + + # load from merged (from deduped) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + + with ( + patch('nnscaler.ComputeConfig.module_dedup_group_size', new_callable=PropertyMock) as mock_dgs, + patch('nnscaler.ComputeConfig.optimizer_dedup_group_size', new_callable=PropertyMock) as mock_dgs2 + ): + # to mock the case where we have duplicated data in merging + mock_dgs.return_value = runtime_ngpus + mock_dgs2.return_value = runtime_ngpus + + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict2 = torch.load(ckpt3_savedir / 'merged2.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero3_merged_state_dict, zero3_merged_state_dict2) + + torch.distributed.barrier() + + # load from merged (from sharded) + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen3_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '3', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt3_savedir), + '--checkpoint.resume_from', str(ckpt3_savedir / 'merged.pt'), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'deduped', + ]) + trainer.run() + + torch.distributed.barrier() + + gen1_savedir = save_dir / 'gen1' + ckpt1_savedir = save_dir / 'ckpt1' + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '25', + '--vars.dim', f'{dim}', + '--gen_savedir', str(gen1_savedir), + '--compute_config.plan_ngpus', f'{plan_ngpus}', + '--compute_config.runtime_ngpus', f'{runtime_ngpus}', + '--compute_config.zero_ngroups', f'{zero_ngroups}', + '--compute_config.use_zero', '1', + '--compute_config.pas_config.enable_random_replicated', 'True', + '--checkpoint.save_dir', str(ckpt1_savedir), + '--pas_policy', f'{policy}', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + Trainer.merge_checkpoint(list((ckpt1_savedir / 'last').glob('*.ckpt')), ckpt1_savedir / 'merged.pt') + zero1_merged_state_dict = torch.load(ckpt1_savedir / 'merged.pt') + Trainer.merge_checkpoint(list((ckpt3_savedir / 'last').glob('*.ckpt')), ckpt3_savedir / 'merged.pt') + zero3_merged_state_dict = torch.load(ckpt3_savedir / 'merged.pt') + assert_equal(zero1_merged_state_dict['model'], zero3_merged_state_dict['model']) + assert_equal(zero1_merged_state_dict['optimizer'], zero3_merged_state_dict['optimizer']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_trainer_zero3(tmp_path): + launch_torchrun(2, trainer_zero3, 16, tmp_path, 1, 2) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_tp(tmp_path): + launch_torchrun(4, trainer_zero3, 16, tmp_path, 2, 4) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_trainer_zero3_ngroup(tmp_path): + # dim that needs padding + launch_torchrun(4, trainer_zero3, 13, tmp_path, 1, 4) + + +def trainer_checkpointer_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + from nnscaler.cli import register_format + + load_triggered = False + + class TestFormat: + name: str = 'test_format' + suffix: str = '.testpt' + + @classmethod + def save(cls, obj: Any, f: Path) -> None: + obj['test'] = True + return torch.save(obj, f) + + @classmethod + def load(cls, f: str | Path, *, device='cpu') -> Any: + x = torch.load(f, map_location=device, weights_only=False) + assert x['test'] is True + nonlocal load_triggered + load_triggered = True + return x + + register_format(TestFormat) + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + files0 = list(ckpt_savedir.glob('**/*.testpt')) + assert files0, 'No checkpoint files saved with custom format.' + + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.format', 'test_format', + '--checkpoint.resume_from', 'last', + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + assert load_triggered, 'Custom load function not triggered when resuming.' + + files1 = list(ckpt_savedir.glob('**/*.testpt')) + assert len(files1) > len(files0), 'Checkpoint files not updated after resuming.' + assert all(f in files1 for f in files0), 'Some checkpoint files missing after resuming.' + assert files1, 'No checkpoint files saved with custom format.' + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_checkpointer(tmp_path): + launch_torchrun(1, trainer_checkpointer_worker, tmp_path) + + +def trainer_pas_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + from nnscaler.policies import pas_dp + from nnscaler.cli import TrainerArgs + called = False + + def custom_pas(graph, cfg): + nonlocal called + called = True + return pas_dp(graph, cfg) + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + args.pas_policy = custom_pas + # train 1 epcho in one time + trainer = Trainer(train_args=args) + trainer.run() + + assert called, 'Custom PAS policy not called.' + + gen_savedir = save_dir / 'gen2' + # train 1 epcho in one time + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--pas-policy', 'nnscaler.policies.pas_dp', # use full qualified name of pas policy + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--compute_config.plan_ngpus', '1', + '--compute_config.runtime_ngpus', '1', + ]) + trainer.run() + + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') +def test_custom_pas(tmp_path): + launch_torchrun(1, trainer_pas_worker, tmp_path) diff --git a/tests/cli/test_trainer2.py b/tests/cli/test_trainer2.py new file mode 100644 index 00000000..34007486 --- /dev/null +++ b/tests/cli/test_trainer2.py @@ -0,0 +1,134 @@ +from pathlib import Path +import pytest +import torch +from torch.utils.data import Dataset + +from nnscaler.cli import TrainerArgs, Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode + + +class NanoGptDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randint(0, 151936, (1, 4096), dtype=torch.int64) for _ in indices] + + def __len__(self): + return 10000 + + +def gen_args(trainer_args: 'TrainerArgs'): + src_token = torch.randint(0, 151936, (1, 4096), dtype=torch.int64) + ret = dict( + input_ids=src_token, # torch.Size([1, 4096]) torch.int64 + ) + return ret + + +class WrappedSubModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.embedding = torch.nn.Embedding(151936, 1536) + + def forward(self, input_ids): + x = self.embedding(input_ids) + return x + + +class WrapperModel(torch.nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + self.model = WrappedSubModel() + + def forward(self, src_tokens): + # the logic is from task.train_step + logits = self.model( + src_tokens + ) + return torch.sum(logits) + + +def trainer_mixed_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_mixed_bf16.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # should reach here without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_mixed_bf16_model(tmp_path): + launch_torchrun(2, trainer_mixed_worker, tmp_path) + + +class SharedWeightsDataset(Dataset): + def __init__(self, *args, **kwargs): + pass + + def __getitems__(self, indices): + return [torch.randn(4, 4) for _ in indices] + + def __len__(self): + return 10000 + + +class SharedWeightsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.linear2 = torch.nn.Linear(4, 4, bias=False) + self.linear2.weight = self.linear.weight # share weight + + def forward(self, x): + y = x * 2 + z = x + 2 + r = self.linear2(y) + r = r + self.linear(z) + return torch.sum(r) + + +def trainer_zero3_shared_weights_worker(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('trainer_args_shared_weights.yaml').resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + args = TrainerArgs.from_cli([ + '-f', config_path, + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer = Trainer(train_args=args) + trainer.run() + # weight sharing multiref should have clone_level=1 in gencode + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'linear_weight_\d+, linear_weight_\d+ = nnscaler.runtime.function.multiref\(self.linear_weight_\d+, times=2, clone_level=1\)' + ) + # non-weight tensor multiref should not have clone_level + assert _gencode_contains( + gen_savedir, + SharedWeightsModule, + torch.distributed.get_rank(), + r'x_\d+, x_\d+ = nnscaler.runtime.function.multiref\(x_\d+, times=2\)' + ) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_zero3_shared_weights(tmp_path): + launch_torchrun(4, trainer_zero3_shared_weights_worker, tmp_path) diff --git a/tests/cli/test_trainer_muon.py b/tests/cli/test_trainer_muon.py new file mode 100644 index 00000000..dbb4af3b --- /dev/null +++ b/tests/cli/test_trainer_muon.py @@ -0,0 +1,114 @@ +from pathlib import Path +import shutil + +import torch + +import pytest + +from nnscaler.cli.trainer import Trainer +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import assert_equal + + +try: + from torch.optim import Muon +except ImportError: + pytest.skip("Muon not available", allow_module_level=True) + + + +def trainer_muon_worker(save_dir, config_file): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name(config_file).resolve()) + gen_savedir = save_dir / 'gen' + ckpt_savedir = save_dir / 'ckpt' + + # train first epoch + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '1', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # create merged checkpoint + if trainer.rank == 0: + Trainer.merge_checkpoint(list((ckpt_savedir / 'last').glob('*.ckpt')), ckpt_savedir / 'merged.pt') + + torch.distributed.barrier() + + # train 2nd epoch, resume from merged checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '2', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', str(ckpt_savedir / 'merged.pt'), + ]) + trainer.run() + + torch.distributed.barrier() + + # train 3rd epoch, resume from deduped checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '3', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + '--checkpoint.save_type', 'sharded', + ]) + trainer.run() + + torch.distributed.barrier() + + # train 4th epoch, resume from sharded checkpoint + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt_savedir), + '--checkpoint.resume_from', 'last', + ]) + trainer.run() + + torch.distributed.barrier() + + ckpt1_savedir = save_dir / 'ckpt1' + # train 4 epoch without resuming + trainer = Trainer([ + '-f', config_path, + '--max_epochs', '4', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + ]) + trainer.run() + + torch.distributed.barrier() + + if trainer.rank == 0: + for i in range(2): + x = torch.load(ckpt_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + for key in ['model', 'optimizer']: + assert_equal(x[key], y[key]) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('config_file', ['trainer_args_muon.yaml', 'trainer_args_muon_hybrid.yaml']) +def test_trainer_muon_resume_correctness(tmp_path, config_file): + launch_torchrun(2, trainer_muon_worker, tmp_path, config_file) + + +def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + else: + return 1, 0 diff --git a/tests/cli/trainer_args.yaml b/tests/cli/trainer_args.yaml index db05b3e6..5006b87a 100644 --- a/tests/cli/trainer_args.yaml +++ b/tests/cli/trainer_args.yaml @@ -1,6 +1,7 @@ vars: dim: 16 drop_last: true + compute_config: plan_ngpus: 4 runtime_ngpus: 100 diff --git a/tests/cli/trainer_args_csa.yaml b/tests/cli/trainer_args_csa.yaml new file mode 100644 index 00000000..1a18c6e3 --- /dev/null +++ b/tests/cli/trainer_args_csa.yaml @@ -0,0 +1,53 @@ +vars: + dynamic_dims: [1] + dim: 16 + drop_last: true +compute_config: + plan_ngpus: 4 + runtime_ngpus: 100 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +dummy_sample_gen_fn: tests.cli.common.transformer_dummy_sample_gen_fn + +model: + type: tests.cli.common.SimpleTransformerModel + args: + n_embd: 1024 + n_head: 8 + dropout: 0.001 + nlayers: 2 + vocab_size: 10000 + +optimizer: + type: torch.optim.Adam + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/cli/trainer_args_mixed_bf16.yaml b/tests/cli/trainer_args_mixed_bf16.yaml new file mode 100644 index 00000000..3ba80cbb --- /dev/null +++ b/tests/cli/trainer_args_mixed_bf16.yaml @@ -0,0 +1,36 @@ +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.WrapperModel + + parallel_modules: + - type: tests.cli.test_trainer2.WrappedSubModel + forward_args_gen_fn: tests.cli.test_trainer2.gen_args + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.NanoGptDataset diff --git a/tests/cli/trainer_args_muon.yaml b/tests/cli/trainer_args_muon.yaml new file mode 100644 index 00000000..ae1ef334 --- /dev/null +++ b/tests/cli/trainer_args_muon.yaml @@ -0,0 +1,51 @@ +vars: + dim: 16 + drop_last: true + +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: true + use_zero: 0 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +precision: fp32 +enable_progress_bar: false + +model: + type: tests.cli.common.MLP + args: + dim: $(vars.dim) + nlayers: 16 + +optimizer: + type: torch.optim.Muon + args: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 5 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/cli/trainer_args_muon_hybrid.yaml b/tests/cli/trainer_args_muon_hybrid.yaml new file mode 100644 index 00000000..8e2f8987 --- /dev/null +++ b/tests/cli/trainer_args_muon_hybrid.yaml @@ -0,0 +1,60 @@ +vars: + dim: 16 + drop_last: true + +compute_config: + plan_ngpus: 1 + runtime_ngpus: 2 + constant_folding: true + use_zero: 0 + use_end2end: true + +run_mode: run +pas_policy: dp +micro_batch_size: 2 +global_batch_size: 8 +max_epochs: 4 +max_train_steps: 100 +seed: 0 +precision: bf16 +enable_progress_bar: false + +model: + type: tests.cli.common.MLP + args: + dim: $(vars.dim) + nlayers: 16 + + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.cli.test_trainer_muon.param_clss_fn + args: + config: + optimizers: + - type: nnscaler.runtime.f16_optimizer.MixedPrecisionAdamW + options: + lr: 0.01 + - type: torch.optim.Muon + options: + lr: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: $(vars.dim) + size: 100 + val_args: + dim: $(vars.dim) + size: 10 + +dataloader: + train_args: + drop_last: $(vars.drop_last) + val_args: + drop_last: $(vars.drop_last) + +checkpoint: + keep_last_n_checkpoints: 5 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/cli/trainer_args_shared_weights.yaml b/tests/cli/trainer_args_shared_weights.yaml new file mode 100644 index 00000000..2dbd6f7a --- /dev/null +++ b/tests/cli/trainer_args_shared_weights.yaml @@ -0,0 +1,32 @@ +compute_config: + plan_ngpus: 2 + runtime_ngpus: 4 + constant_folding: false + use_zero: 3 + use_end2end: true + +run_mode: run +pas_policy: tp +micro_batch_size: 1 +grad_accumulation_steps: 4 +max_train_steps: 10 +enable_progress_bar: false +log_progress_every_n_train_steps: 10 +precision: bf16 +seed: 1 + +model: + type: tests.cli.test_trainer2.SharedWeightsModule + +optimizer: + type: torch.optim.AdamW + args: + betas: (0.9, 0.95) + eps: 1e-08 + weight_decay: 0.1 + lr: 0.0001 + fused: true + clip_gnorm: 2.0 + +dataset: + type: tests.cli.test_trainer2.SharedWeightsDataset diff --git a/tests/conftest.py b/tests/conftest.py index b581d41c..126b05f6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,3 +29,12 @@ def clean_generated_files(): f.unlink() for f in basedir.glob('gencode*.py'): f.unlink() + + +def pytest_collection_modifyitems(session, config, items): + def policy_first(item): + # it is very easy to break policy related tests, so run them first + if item.fspath.basename == 'test_policies.py': + return 0 + return 1 + items.sort(key=policy_first) diff --git a/tests/customized_ops/__init__.py b/tests/customized_ops/__init__.py new file mode 100644 index 00000000..78e3db5e --- /dev/null +++ b/tests/customized_ops/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Ring Attention test module""" \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/__init__.py b/tests/customized_ops/ring_attn/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/customized_ops/ring_attn/configs.py b/tests/customized_ops/ring_attn/configs.py new file mode 100644 index 00000000..ebc7182c --- /dev/null +++ b/tests/customized_ops/ring_attn/configs.py @@ -0,0 +1,268 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Configuration file for ring attention tests. +This file contains predefined test configurations for both correctness and performance testing. +""" + +from dataclasses import dataclass +from typing import List, Tuple, Optional + + +@dataclass +class RingAttnConfig: + """Configuration for ring attention test cases""" + batch_size: int + num_heads: int + head_dim: int + max_seqlen: int + dtype: str = "bf16" + name: str = "" + num_kv_heads: Optional[int] = None # For GQA/MQA support + causal: bool = True # Most attention patterns are causal + window_size: Tuple[int, int] = (-1, -1) # Sliding window attention (-1, -1) means no window + + def __post_init__(self): + # Set num_kv_heads to num_heads if not specified (standard MHA) + if self.num_kv_heads is None: + self.num_kv_heads = self.num_heads + + if not self.name: + gqa_suffix = f"_gqa{self.num_kv_heads}" if self.num_kv_heads != self.num_heads else "" + causal_suffix = "" if self.causal else "_noncausal" + window_suffix = f"_w{self.window_size[0]}-{self.window_size[1]}" if self.window_size != (-1, -1) else "" + self.name = f"b{self.batch_size}_h{self.num_heads}_d{self.head_dim}_s{self.max_seqlen}_{self.dtype}{gqa_suffix}{causal_suffix}{window_suffix}" + + # Generate cu_seqlens for variable length sequences + # Create sequences with different lengths for more realistic testing + seq_lens = [ + self.max_seqlen // 8, # Short sequence + self.max_seqlen // 4, # Medium sequence + self.max_seqlen // 2, # Long sequence + self.max_seqlen - self.max_seqlen // 8 - self.max_seqlen // 4 - self.max_seqlen // 2 # Remaining + ] + self.cu_seqlens = [0] + for seq_len in seq_lens: + self.cu_seqlens.append(self.cu_seqlens[-1] + seq_len) + + @property + def total_tokens(self) -> int: + """Total number of tokens across all sequences""" + return self.cu_seqlens[-1] + + @property + def is_gqa(self) -> bool: + """Check if this is a GQA (Grouped Query Attention) configuration""" + return self.num_kv_heads < self.num_heads + + @property + def is_mqa(self) -> bool: + """Check if this is an MQA (Multi-Query Attention) configuration""" + return self.num_kv_heads == 1 + + @property + def num_groups(self) -> int: + """Number of query heads per KV head (group size)""" + return self.num_heads // self.num_kv_heads + + +# Small test cases for quick correctness validation +SMALL_CONFIGS = { + "tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "tiny", causal=True), + "small": RingAttnConfig(4, 12, 128, 4096, "bf16", "small", causal=True), + "small_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "small_fp16", causal=False), # One non-causal config + "small_window": RingAttnConfig(4, 12, 128, 4096, "bf16", "small_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Medium test cases for standard testing +MEDIUM_CONFIGS = { + "medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium", causal=True), + "medium_large_head": RingAttnConfig(4, 12, 256, 8192, "bf16", "medium_large_head", causal=False), # One non-causal config + "medium_many_heads": RingAttnConfig(4, 32, 128, 8192, "bf16", "medium_many_heads", causal=True), + "medium_fp16": RingAttnConfig(4, 24, 128, 8192, "fp16", "medium_fp16", causal=True), + "medium_window": RingAttnConfig(4, 24, 128, 8192, "bf16", "medium_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Large test cases for performance benchmarking +LARGE_CONFIGS = { + "large": RingAttnConfig(4, 32, 128, 16384, "bf16", "large", causal=True), + "large_seq": RingAttnConfig(4, 24, 128, 32768, "bf16", "large_seq", causal=True), + "large_head": RingAttnConfig(4, 24, 256, 16384, "bf16", "large_head", causal=False), # One non-causal config + "xlarge": RingAttnConfig(8, 32, 128, 32768, "bf16", "xlarge", causal=True), + "large_window": RingAttnConfig(4, 32, 128, 16384, "bf16", "large_window", causal=True, window_size=(512, 0)), # Sliding window +} + +# Realistic model configurations (kept minimal, most covered by medium/large configs) +MODEL_CONFIGS = { +} + +# GQA (Grouped Query Attention) configurations based on Qwen models +GQA_CONFIGS = { + # Qwen3-235B-A22B: 64 heads, 4 kv_heads, 128 head_dim + "qwen3_235b_a22b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_235b_a22b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-30B-A3B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_30b_a3b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=64, + max_seqlen=16384, + dtype="bf16", + name="qwen3_30b_a3b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-4B: 32 heads, 4 kv_heads, 80 head_dim + "qwen3_4b": RingAttnConfig( + batch_size=4, + num_heads=32, + head_dim=80, + max_seqlen=16384, + dtype="bf16", + name="qwen3_4b", + num_kv_heads=4, + causal=True + ), + + # Qwen3-32B: 64 heads, 8 kv_heads, 128 head_dim + "qwen3_32b": RingAttnConfig( + batch_size=2, + num_heads=64, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_32b", + num_kv_heads=8, + causal=True + ), + + # Qwen3-14B: 40 heads, 8 kv_heads, 128 head_dim + "qwen3_14b": RingAttnConfig( + batch_size=4, + num_heads=40, + head_dim=128, + max_seqlen=16384, + dtype="bf16", + name="qwen3_14b", + num_kv_heads=8, + causal=True + ), +} + +# MQA is already covered by medium/large configs, so removed duplicate MQA_CONFIGS + +# Zigzag attention configurations (only supports causal=True and window_size=(-1, -1)) +ZIGZAG_CONFIGS = { + "zigzag_tiny": RingAttnConfig(2, 8, 64, 1024, "bf16", "zigzag_tiny", causal=True, window_size=(-1, -1)), + "zigzag_small": RingAttnConfig(4, 12, 128, 4096, "bf16", "zigzag_small", causal=True, window_size=(-1, -1)), + "zigzag_medium": RingAttnConfig(4, 24, 128, 8192, "bf16", "zigzag_medium", causal=True, window_size=(-1, -1)), + "zigzag_large": RingAttnConfig(4, 32, 128, 16384, "bf16", "zigzag_large", causal=True, window_size=(-1, -1)), + "zigzag_fp16": RingAttnConfig(4, 12, 128, 4096, "fp16", "zigzag_fp16", causal=True, window_size=(-1, -1)), + "zigzag_gqa": RingAttnConfig(4, 32, 128, 8192, "bf16", "zigzag_gqa", num_kv_heads=8, causal=True, window_size=(-1, -1)), +} + +# All configurations combined +ALL_CONFIGS = { + **SMALL_CONFIGS, + **MEDIUM_CONFIGS, + **LARGE_CONFIGS, + **MODEL_CONFIGS, + **GQA_CONFIGS, + **ZIGZAG_CONFIGS, +} + +# Default configurations for different test types +DEFAULT_CORRECTNESS_CONFIGS = ["tiny", "small", "medium"] +DEFAULT_PERFORMANCE_CONFIGS = ["medium", "large"] +DEFAULT_MULTI_GPU_CONFIGS = ["small", "medium"] +DEFAULT_GQA_CONFIGS = ["qwen3_4b", "qwen3_14b", "qwen3_32b"] +DEFAULT_ZIGZAG_CONFIGS = ["zigzag_tiny", "zigzag_small", "zigzag_medium"] + + +def get_config(name: str) -> RingAttnConfig: + """Get a configuration by name""" + if name in ALL_CONFIGS: + return ALL_CONFIGS[name] + else: + raise ValueError(f"Unknown configuration: {name}. Available: {list(ALL_CONFIGS.keys())}") + + +def list_configs(category: str = "all") -> List[str]: + """List available configurations by category""" + if category == "all": + return list(ALL_CONFIGS.keys()) + elif category == "small": + return list(SMALL_CONFIGS.keys()) + elif category == "medium": + return list(MEDIUM_CONFIGS.keys()) + elif category == "large": + return list(LARGE_CONFIGS.keys()) + elif category == "model": + return list(MODEL_CONFIGS.keys()) + elif category == "gqa": + return list(GQA_CONFIGS.keys()) + elif category == "zigzag": + return list(ZIGZAG_CONFIGS.keys()) + elif category == "correctness": + return DEFAULT_CORRECTNESS_CONFIGS + elif category == "performance": + return DEFAULT_PERFORMANCE_CONFIGS + elif category == "multi_gpu": + return DEFAULT_MULTI_GPU_CONFIGS + elif category == "gqa_default": + return DEFAULT_GQA_CONFIGS + elif category == "zigzag_default": + return DEFAULT_ZIGZAG_CONFIGS + else: + raise ValueError(f"Unknown category: {category}") + + +def get_configs_by_category(category: str) -> dict: + """Get all configurations in a category""" + config_names = list_configs(category) + return {name: get_config(name) for name in config_names} + + +def get_gqa_configs() -> dict: + """Get all GQA (Grouped Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_gqa and not config.is_mqa} + + +def get_mqa_configs() -> dict: + """Get all MQA (Multi-Query Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if config.is_mqa} + + +def get_mha_configs() -> dict: + """Get all MHA (Multi-Head Attention) configurations""" + return {name: config for name, config in ALL_CONFIGS.items() if not config.is_gqa} + + +def get_zigzag_configs() -> dict: + """Get all Zigzag attention configurations""" + return ZIGZAG_CONFIGS + + +def filter_configs_by_attention_type(attention_type: str) -> dict: + """Filter configurations by attention type: 'mha', 'gqa', 'mqa', or 'zigzag'""" + if attention_type.lower() == "mha": + return get_mha_configs() + elif attention_type.lower() == "gqa": + return get_gqa_configs() + elif attention_type.lower() == "mqa": + return get_mqa_configs() # Will return empty dict since no dedicated MQA configs + elif attention_type.lower() == "zigzag": + return get_zigzag_configs() + else: + raise ValueError(f"Unknown attention type: {attention_type}. Supported: 'mha', 'gqa', 'mqa', 'zigzag'") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_runner.py b/tests/customized_ops/ring_attn/ring_attn_runner.py new file mode 100644 index 00000000..405c61b1 --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_runner.py @@ -0,0 +1,118 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Test Runner Script + +This script runs ring attention correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +from typing import Tuple +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_func + + +class TestModule(torch.nn.Module): + """Test module for ring attention""" + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + result = wrap_ring_attn_func( + q, k, v, + causal=self.causal, + window_size=self.window_size + ) + return result + + +class RingAttnRunner(RingAttnRunnerBase): + """Runner for ring attention tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn.wrap_ring_attn_func' + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 + + @property + def function_name(self) -> str: + return 'wrap_ring_attn_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare regular inputs with shape [batch_size, seq_len, num_heads, head_dim]""" + q = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + k = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + v = torch.clamp(torch.randn( + config.batch_size, + config.max_seqlen, + config.num_kv_heads, + config.head_dim, + device=device, + dtype=torch_dtype + ), min=-1, max=1) + + return {'q': q, 'k': k, 'v': v} + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Run single GPU version (this should call flash_attn internally when no process_group) + single_out = wrap_ring_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, + window_size=config.window_size + ) + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + } + + +def ring_attn_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def run_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnRunner() + runner.main(**kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py new file mode 100644 index 00000000..2ca40313 --- /dev/null +++ b/tests/customized_ops/ring_attn/ring_attn_varlen_runner.py @@ -0,0 +1,107 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Test Runner + +This script runs ring attention variable length correctness tests in a distributed environment. +It compares the outputs of single-GPU and multi-GPU ring attention to ensure correctness. +""" + +import sys +from typing import Tuple +import torch + +from runner_base import RingAttnRunnerBase +from nnscaler.customized_ops.ring_attention import wrap_ring_attn_varlen_func + + +class TestModule(torch.nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super(TestModule, self).__init__() + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v, cu_seqlens_q, cu_seqlens_k): + out = wrap_ring_attn_varlen_func( + q, k, v, cu_seqlens_q, cu_seqlens_k, None, + causal=self.causal, + window_size=self.window_size + ) + return out + + +class RingAttnVarlenRunner(RingAttnRunnerBase): + """Runner for ring attention variable length tests""" + + @property + def function_signature(self) -> str: + return 'nnscaler.customized_ops.ring_attention.ring_attn_varlen.wrap_ring_attn_varlen_func' + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 0 + + @property + def function_name(self) -> str: + return 'ring_attn_varlen_func' + + def create_test_module(self, config) -> torch.nn.Module: + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare variable length inputs with cu_seqlens""" + cu_seqlens_tensor = torch.tensor(config.cu_seqlens, dtype=torch.int32, device=device) + total_seqlen = config.cu_seqlens[-1] + + # Create inputs with total sequence length (don't set requires_grad here, base class handles it) + q = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(total_seqlen, config.num_heads, config.head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + + return { + 'q': q, + 'k': k, + 'v': v, + 'cu_seqlens_q': cu_seqlens_tensor, + 'cu_seqlens_k': cu_seqlens_tensor + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + single_out = wrap_ring_attn_varlen_func( + inputs['q'], inputs['k'], inputs['v'], + inputs['cu_seqlens_q'], inputs['cu_seqlens_k'], None, + causal=config.causal, + window_size=config.window_size + ) + single_out.retain_grad() + return single_out, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs): + """Get dummy forward arguments for model parallelization""" + return { + "q": inputs["q"], + "k": inputs["k"], + "v": inputs["v"], + 'cu_seqlens_q': inputs['cu_seqlens_q'], + 'cu_seqlens_k': inputs['cu_seqlens_k'] + } + + +def ring_attn_varlen_test(dtype="bf16", config_name="tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = RingAttnVarlenRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def run_ring_attn_correctness_test(**kwargs): + """Legacy function for backward compatibility""" + runner = RingAttnVarlenRunner() + runner.run_correctness_test(**kwargs) + + +if __name__ == "__main__": + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + runner = RingAttnVarlenRunner() + runner.main(**kwargs) diff --git a/tests/customized_ops/ring_attn/runner_base.py b/tests/customized_ops/ring_attn/runner_base.py new file mode 100644 index 00000000..e7d9a32f --- /dev/null +++ b/tests/customized_ops/ring_attn/runner_base.py @@ -0,0 +1,301 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base runner framework for ring attention correctness tests. +This module provides common functionality for both ring_attn and ring_attn_varlen test runners. +""" + +import os +import sys +import logging +from abc import ABC, abstractmethod +from typing import Dict, Any, Tuple, Union + +import torch +import torch.distributed as dist +import nnscaler +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType + +from nnscaler.customized_ops.ring_attention.core.utils import set_seed, log +from configs import get_config + + +class RingAttnRunnerBase(ABC): + """Base class for ring attention test runners""" + + @property + @abstractmethod + def function_signature(self) -> str: + """Return the function signature to look for in the graph""" + pass + + @property + @abstractmethod + def partition_position(self) -> Tuple[int, int]: + """Return the partition position (idx, dim)""" + pass + + @property + @abstractmethod + def function_name(self) -> str: + """Return the function name for partitioning""" + pass + + @abstractmethod + def create_test_module(self, config) -> torch.nn.Module: + """Create the test module with the appropriate configuration""" + pass + + @abstractmethod + def prepare_inputs(self, config, device, torch_dtype): + """Prepare input tensors based on the configuration and attention type""" + pass + + @abstractmethod + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + pass + + @abstractmethod + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + pass + + def create_policy(self) -> callable: + """Create partitioning policy for the specific attention type""" + def policy(graph: IRGraph, resource: ComputeConfig) -> IRGraph: + ngpus = resource.plan_ngpus + partitioned = False + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == self.function_signature: + print(f'\nPartitioned node: {node}\n') + idx, dim = self.partition_position + sub_nodes = graph.partition(node, node.algorithm('dim'), idx=idx, dim=dim, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + if not partitioned: + signatures = [node.signature for node in graph.select(ntype=IRFwOperation)] + raise RuntimeError(f"Failed to find the target function '{self.function_signature}' in {signatures}") + return graph + return policy + + def initialize_distributed(self): + """Initialize distributed environment""" + # Check CUDA availability first + if not torch.cuda.is_available(): + print("ERROR: CUDA is not available") + sys.exit(1) + + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + + # Check if we have enough GPUs + available_gpus = torch.cuda.device_count() + if available_gpus < world_size: + print(f"ERROR: Test requires {world_size} GPUs, but only {available_gpus} available") + sys.exit(1) + + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + try: + torch.cuda.set_device(device) + except Exception as e: + print(f"ERROR: Failed to set CUDA device {device}: {e}") + sys.exit(1) + + print(f"[INFO] world_size:{world_size}, rank:{rank}, available_gpus:{available_gpus}") + + try: + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + except Exception as e: + print(f"ERROR: Failed to initialize process group: {e}") + sys.exit(1) + + # Initialize nnscaler + nnscaler.init() + return world_size, rank + + def get_tolerances(self, dtype: str, num_heads: int, num_kv_heads: int) -> Dict[str, float]: + """Get tolerance values based on data type""" + if dtype == "bf16": + if num_heads == num_kv_heads: + return dict(atol=2.5e-2, rtol=2.5e-2) + else: + return dict(atol=3.5e-2, rtol=3.5e-2) + elif dtype == "fp16": + return dict(atol=5e-3, rtol=5e-3) + else: + return dict(atol=2.5e-2, rtol=2.5e-2) + + def print_debug_info(self, single_out, para_out, single_grads, para_grads, rank_id): + """Print debug information when correctness test fails""" + if rank_id == 0: + print("โœ— Correctness test FAILED!") + # Print detailed error information + log("single out", single_out, rank0_only=True) + log("multi out", para_out, rank0_only=True) + log("out diff", single_out - para_out, rank0_only=True) + + for i, (single_grad, para_grad, name) in enumerate(zip(single_grads, para_grads, ['q', 'k', 'v'])): + log(f"single d{name}", single_grad, rank0_only=True) + log(f"multi d{name}", para_grad, rank0_only=True) + log(f"d{name} diff", single_grad - para_grad, rank0_only=True) + + def print_success_info(self, rank_id, config_name=None): + """Print success information""" + if rank_id == 0: + config_suffix = f" for config '{config_name}'" if config_name else "" + print(f"โœ“ Correctness test PASSED{config_suffix}!") + + def run_correctness_test(self, config_name: str, dtype: str = "bf16", **kwargs): + """Run correctness test with the specific attention implementation""" + # Initialize distributed + world_size, rank = self.initialize_distributed() + rank_id = torch.distributed.get_rank() + + # Get configuration + config = get_config(config_name) + torch_dtype = torch.bfloat16 if dtype == "bf16" else torch.float16 + + if rank_id == 0: + print(f"Testing {self.function_name} correctness") + print(f"Configuration: {config.name}") + print(f" Batch size: {config.batch_size}") + print(f" Sequence length: {config.max_seqlen}") + print(f" Num heads: {config.num_heads}") + print(f" KV heads: {config.num_kv_heads}") + print(f" Head dim: {config.head_dim}") + print(f" Data type: {dtype}") + print(f" World size: {world_size}") + print("=" * 60) + + # Set seed for reproducibility + set_seed(42 + rank_id) + device = torch.device(f"cuda:{rank_id}") + + # Prepare inputs (implementation-specific) + inputs = self.prepare_inputs(config, device, torch_dtype) + + # Broadcast inputs to ensure consistency across ranks + for tensor in inputs.values(): + if isinstance(tensor, torch.Tensor): + dist.broadcast(tensor, src=0) + dist.barrier() + + # Setup models + model = self.create_test_module(config) + + # Create parallel model + dummy_args = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + dummy_args[k] = v.detach().clone().requires_grad_() + else: + dummy_args[k] = v.detach().clone() + else: + dummy_args[k] = v + + parallel_model = parallelize( + model, + dummy_forward_args=self.get_dummy_forward_args(dummy_args), + pas_policy=self.create_policy(), + compute_config=ComputeConfig(world_size, world_size), + reuse=ReuseType.OVERRIDE + ) + parallel_model = parallel_model.cuda() + parallel_model.train() + + # Run correctness test + print("Running correctness test..." if rank_id == 0 else "", end="") + + # Single mode for reference + single_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + single_inputs[k] = v.detach().clone().requires_grad_() + else: + single_inputs[k] = v.detach().clone() + else: + single_inputs[k] = v + + single_out, single_grad_tensors = self.run_single_gpu_reference(single_inputs, config) + + # Create gradient for backward pass + dout = torch.clamp(torch.randn_like(single_out, device=device, dtype=torch_dtype), min=-1, max=1) + # Ensure dout is consistent across all ranks + dist.broadcast(dout, src=0) + single_out.backward(dout) + + # Extract single gradients + single_grads = [tensor.grad for tensor in single_grad_tensors] + + # Parallel mode for correctness + para_inputs = {} + for k, v in inputs.items(): + if isinstance(v, torch.Tensor): + if v.is_floating_point(): + para_inputs[k] = v.detach().clone().requires_grad_() + else: + para_inputs[k] = v.detach().clone() + else: + para_inputs[k] = v + + para_out = parallel_model(**para_inputs) + para_out.backward(dout) + parallel_model.sync_grad() + + # Extract gradients for q, k, v tensors + para_grads = [para_inputs[k].grad for k in ['q', 'k', 'v']] + + print(" Done!" if rank_id == 0 else "") + + # Check correctness with tolerances + tols = self.get_tolerances(dtype, config.num_heads, config.num_kv_heads) + + # Verify outputs and gradients + try: + torch.testing.assert_close(single_out, para_out, **tols) + for single_grad, para_grad in zip(single_grads, para_grads): + torch.testing.assert_close(single_grad, para_grad, **tols) + + self.print_success_info(rank_id, config_name) + + except AssertionError as e: + self.print_debug_info(single_out, para_out, single_grads, para_grads, rank_id) + raise e + + dist.destroy_process_group() + + def main(self, **kwargs): + """Main entry point for the test runner""" + # Filter out torch.distributed.launch arguments + filtered_kwargs = {} + for k, v in kwargs.items(): + if k.startswith('--'): + # Remove leading '--' from argument names + k = k[2:].replace('-', '_') + if k not in ['local_rank', 'local-rank']: # Filter out torch.distributed.launch args + filtered_kwargs[k] = v + + # Convert string arguments back to appropriate types + for numeric_arg in ['batch_size', 'num_heads', 'head_dim', 'max_seqlen']: + if numeric_arg in filtered_kwargs and filtered_kwargs[numeric_arg] is not None: + filtered_kwargs[numeric_arg] = int(filtered_kwargs[numeric_arg]) + + for float_arg in ['rtol', 'atol']: + if float_arg in filtered_kwargs and filtered_kwargs[float_arg] is not None: + filtered_kwargs[float_arg] = float(filtered_kwargs[float_arg]) + + self.run_correctness_test(**filtered_kwargs) \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_base.py b/tests/customized_ops/ring_attn/test_base.py new file mode 100644 index 00000000..44870792 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_base.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Base test framework for ring attention tests. +This module provides common functionality for both ring_attn and ring_attn_varlen tests. +""" + +import os +import sys +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Tuple +from functools import partial + +import pytest +import torch + +from .configs import ( + DEFAULT_CORRECTNESS_CONFIGS, + DEFAULT_MULTI_GPU_CONFIGS, + DEFAULT_GQA_CONFIGS, + get_config, + list_configs +) + +from ...launch_torchrun import torchrun + + +class RingAttnTestBase(ABC): + """Base class for ring attention tests""" + + @property + @abstractmethod + def runner_script_name(self) -> str: + """Return the name of the runner script (e.g., 'run_correctness.py')""" + pass + + @property + @abstractmethod + def test_name_prefix(self) -> str: + """Return the prefix for test names (e.g., 'ring_attn' or 'ring_attn_varlen')""" + pass + + @property + @abstractmethod + def test_function_name(self) -> str: + """Return the name of the test function to import (e.g., 'zigzag_attn_test')""" + pass + + def _check_gpu_availability(self, required_gpus: int): + """Check if enough GPUs are available and skip test if not""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available") + + available_gpus = torch.cuda.device_count() + if available_gpus < required_gpus: + pytest.skip(f"Test requires {required_gpus} GPUs, but only {available_gpus} available") + + def _get_project_root(self): + """Get the absolute path to nnscaler root directory""" + current_dir = os.path.dirname(__file__) # tests/customized_ops/ring_attn/ + return os.path.abspath(os.path.join(current_dir, "../../../")) + + def get_bash_arguments(self, num_gpus_per_node: int, **kwargs) -> List[str]: + """Generate command line arguments for running the test script + + Deprecated: This method is kept for backward compatibility. + The new implementation uses launch_torchrun directly. + """ + args = [ + "python3", + "-m", + "torch.distributed.launch", + "--nproc-per-node=" + str(num_gpus_per_node), + ] + + project_root = self._get_project_root() + script_path = os.path.join( + project_root, "tests", "customized_ops", "ring_attn", + self.runner_script_name + ) + args.append(script_path) + + for k, v in kwargs.items(): + args.append(f"{k}={v}") + return args + + def _get_test_function(self): + """Get the test function for this test""" + # Add the script directory to sys.path to allow imports + project_root = self._get_project_root() + script_dir = os.path.join(project_root, "tests", "customized_ops", "ring_attn") + if script_dir not in sys.path: + sys.path.insert(0, script_dir) + + # Import the module and get the test function + module_name = self.runner_script_name.replace('.py', '') + module = __import__(module_name) + + if hasattr(module, self.test_function_name): + return getattr(module, self.test_function_name) + else: + raise ImportError(f"Could not find function '{self.test_function_name}' in {module_name}") + + def run_test_subprocess(self, num_gpus: int, **kwargs): + """Run test using torchrun with the configured test function""" + # Check GPU availability before running test + self._check_gpu_availability(num_gpus) + + # Get the test function and use torchrun to execute it + test_function = self._get_test_function() + + # Extract common parameters + dtype = kwargs.get('dtype', 'bf16') + config_name = kwargs.get('config_name', 'tiny') + + # Use partial with positional arguments like test_gnorm.py + return partial(torchrun, num_gpus, test_function, dtype, config_name)() + + # Common test methods that can be used by both ring_attn and ring_attn_varlen + + def run_correctness_basic(self, dtype: str, config_name: str): + """Test correctness with different configurations""" + num_gpus = 2 # Default to 2 GPUs for correctness tests + config = get_config(config_name) + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_multi_gpu_scaling(self, num_gpus: int, config_name: str): + """Test with different numbers of GPUs""" + self.run_test_subprocess( + num_gpus=num_gpus, + dtype="bf16", + config_name=config_name, + ) + + def run_comprehensive_configs(self, dtype: str): + """Test all available configurations (comprehensive test)""" + num_gpus = 2 + + # Test a selection of configurations + test_configs = ["tiny", "small", "medium"] + + for config_name in test_configs: + config = get_config(config_name) + # Skip very large configs for comprehensive test + if config.max_seqlen > 16384: + continue + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_gqa_correctness(self, dtype: str, config_name: str): + """Test GQA correctness with Qwen model configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a GQA config + assert config.is_gqa, f"Configuration {config_name} should be GQA" + assert config.num_kv_heads < config.num_heads, f"Configuration {config_name} should have fewer KV heads" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + def run_sliding_window(self, dtype: str, config_name: str): + """Test with sliding window configurations""" + num_gpus = 2 + config = get_config(config_name) + + # Ensure it's actually a sliding window config + assert config.window_size != (-1, -1), f"Configuration {config_name} should have sliding window" + + self.run_test_subprocess( + num_gpus=num_gpus, + dtype=dtype, + config_name=config_name, + ) + + +def create_parametrized_tests(test_class: RingAttnTestBase): + """ + Factory function to create parametrized test methods for a test class. + This reduces code duplication between ring_attn and ring_attn_varlen tests. + """ + + # Correctness tests with different dtypes and configs + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + @pytest.mark.parametrize("config_name", DEFAULT_CORRECTNESS_CONFIGS) + def test_correctness(dtype, config_name): + """Test correctness with different configurations""" + instance = test_class() + instance.run_correctness_basic(dtype, config_name) + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + @pytest.mark.parametrize("config_name", DEFAULT_MULTI_GPU_CONFIGS) + def test_multi_gpu(num_gpus, config_name): + """Test with different numbers of GPUs""" + instance = test_class() + instance.run_multi_gpu_scaling(num_gpus, config_name) + + # Comprehensive tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_all_configs(dtype): + """Test all available configurations (comprehensive test)""" + instance = test_class() + instance.run_comprehensive_configs(dtype) + + # GQA tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", DEFAULT_GQA_CONFIGS) + def test_gqa_correctness(dtype, config_name): + """Test GQA correctness with Qwen model configurations""" + instance = test_class() + instance.run_gqa_correctness(dtype, config_name) + + # Sliding window tests + @pytest.mark.parametrize("dtype", ["bf16"]) + @pytest.mark.parametrize("config_name", ["small_window", "medium_window"]) + def test_sliding_window(dtype, config_name): + """Test with sliding window configurations""" + instance = test_class() + instance.run_sliding_window(dtype, config_name) + + return { + f'test_{test_class().test_name_prefix}_correctness': test_correctness, + f'test_{test_class().test_name_prefix}_multi_gpu': test_multi_gpu, + f'test_{test_class().test_name_prefix}_all_configs': test_all_configs, + f'test_{test_class().test_name_prefix}_gqa_correctness': test_gqa_correctness, + f'test_{test_class().test_name_prefix}_sliding_window': test_sliding_window, + } \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn.py b/tests/customized_ops/ring_attn/test_ring_attn.py new file mode 100644 index 00000000..e1378101 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn.py @@ -0,0 +1,60 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Correctness Tests + +This module tests the correctness of regular ring attention (non-variable length). +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnTest(RingAttnTestBase): + """Test class for regular ring attention""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_runner.py" + + @property + def test_function_name(self) -> str: + return "ring_attn_test" + + @property + def test_name_prefix(self) -> str: + return "ring_attn" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_correctness = test_functions['test_ring_attn_correctness'] +test_ring_attn_multi_gpu = test_functions['test_ring_attn_multi_gpu'] +test_ring_attn_all_configs = test_functions['test_ring_attn_all_configs'] +test_ring_attn_gqa_correctness = test_functions['test_ring_attn_gqa_correctness'] +test_ring_attn_sliding_window = test_functions['test_ring_attn_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") + + # Example of running sliding window test + # test_instance.run_sliding_window("bf16", "small_window") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_ring_attn_varlen.py b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py new file mode 100644 index 00000000..f86fc276 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_ring_attn_varlen.py @@ -0,0 +1,57 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ring Attention Variable Length Correctness Tests + +This module tests the correctness of ring attention with variable length sequences. +It uses the shared test base framework to avoid code duplication. +""" + +import pytest +import torch + +# Skip all tests if flash_attn_varlen_func is not available +try: + from flash_attn import flash_attn_varlen_func +except ImportError: + pytest.skip("flash_attn_varlen_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase, create_parametrized_tests +from .configs import DEFAULT_CORRECTNESS_CONFIGS, DEFAULT_MULTI_GPU_CONFIGS, DEFAULT_GQA_CONFIGS + + +class RingAttnVarlenTest(RingAttnTestBase): + """Test class for ring attention variable length""" + + @property + def runner_script_name(self) -> str: + return "ring_attn_varlen_runner.py" + + @property + def test_function_name(self) -> str: + return "ring_attn_varlen_test" + + @property + def test_name_prefix(self) -> str: + return "ring_attn_varlen" + + +# Create parametrized test functions using the factory +test_functions = create_parametrized_tests(RingAttnVarlenTest) + +# Assign test functions to module globals for pytest discovery +test_ring_attn_varlen_correctness = test_functions['test_ring_attn_varlen_correctness'] +test_ring_attn_varlen_multi_gpu = test_functions['test_ring_attn_varlen_multi_gpu'] +test_ring_attn_varlen_all_configs = test_functions['test_ring_attn_varlen_all_configs'] +test_ring_attn_varlen_gqa_correctness = test_functions['test_ring_attn_varlen_gqa_correctness'] +test_ring_attn_varlen_sliding_window = test_functions['test_ring_attn_varlen_sliding_window'] + + +if __name__ == "__main__": + # Run specific test if called directly + test_instance = RingAttnVarlenTest() + test_instance.run_correctness_basic("bf16", "small") + + # Example of running GQA test + # test_instance.run_gqa_correctness("bf16", "qwen3_4b") diff --git a/tests/customized_ops/ring_attn/test_shuffle_varlen.py b/tests/customized_ops/ring_attn/test_shuffle_varlen.py new file mode 100644 index 00000000..a8d6b3c9 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_shuffle_varlen.py @@ -0,0 +1,220 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Simple test for shuffle_varlen and unshuffle_varlen functions. +""" + +import pytest +import torch +import torch.distributed as dist +from dataclasses import dataclass +from typing import List +from functools import partial + +from tests.launch_torchrun import torchrun + + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + + +@dataclass +class ShuffleVarlenConfig: + """Simple test configuration""" + name: str + batch_size: int + seq_lens: List[int] + hidden_dim: int + + +# Test configurations +CONFIGS = { + "tiny": ShuffleVarlenConfig("tiny", 2, [512, 768], 64), + "small": ShuffleVarlenConfig("small", 2, [1024, 1536], 128), + "medium": ShuffleVarlenConfig("medium", 2, [1024, 1536], 256), + "uneven": ShuffleVarlenConfig("uneven", 3, [256, 768, 1024], 128), +} + + +def shuffle_varlen_test(config_name="tiny", dtype="float32", world_size=2): + """Test shuffle_varlen and unshuffle_varlen functions""" + + if not dist.is_initialized(): + dist.init_process_group(backend='nccl') + + rank = dist.get_rank() + world_size_actual = dist.get_world_size() + device = torch.device(f'cuda:{rank}') + torch.cuda.set_device(device) + + if rank == 0: + print(f"Testing shuffle_varlen and unshuffle_varlen functions") + print(f"Configuration: {config_name}") + print(f"World size: {world_size_actual}") + print(f"Data type: {dtype}") + print("=" * 60) + + # Get configuration + config = CONFIGS[config_name] + + # Set up process group for context parallel + cp_ranks = list(range(world_size_actual)) + cp_group = dist.new_group(cp_ranks) + + # Create cumulative sequence lengths (padded to be divisible by 2*world_size) + cu_seqlens = torch.zeros(config.batch_size + 1, dtype=torch.int32, device=device) + total_slices_per_seq = 2 * world_size_actual + + for i, seq_len in enumerate(config.seq_lens): + # Pad sequence length to be divisible by total_slices_per_seq + padded_seq_len = ((seq_len + total_slices_per_seq - 1) // total_slices_per_seq) * total_slices_per_seq + cu_seqlens[i + 1] = cu_seqlens[i] + padded_seq_len + + total_seq_len = cu_seqlens[len(config.seq_lens)].item() # Use len(config.seq_lens) instead of -1 + + # Convert dtype string to torch dtype + torch_dtype = getattr(torch, dtype) + + # Import functions from varlen_utils + from nnscaler.customized_ops.ring_attention.varlen_utils import shuffle_varlen, unshuffle_varlen + + if rank == 0: + print("Running shuffle/unshuffle correctness tests...") + + tolerance = 1e-5 if torch_dtype == torch.float32 else 1e-2 + + # Test 1: 1D tensor (like position_ids) + if rank == 0: + print(" Test: 1D tensor (total_seq_len,)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_1d = torch.arange(total_seq_len, dtype=torch_dtype, device=device) + else: + full_tensor_1d = torch.empty(total_seq_len, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_1d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_1d = full_tensor_1d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_1d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_1d, unshuffled, atol=tolerance): + if rank == 0: + print(" โœ“ 1D tensor test passed") + else: + if rank == 0: + print(" โœ— 1D tensor test FAILED") + raise AssertionError("1D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" โœ— 1D tensor test FAILED with error: {e}") + raise e + + # Test 2: 2D tensor (total_seq_len, hidden_dim) + if rank == 0: + print(" Test: 2D tensor (total_seq_len, hidden_dim)...") + + try: + # Create full tensor first (on rank 0) + if rank == 0: + full_tensor_2d = torch.randn(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + else: + full_tensor_2d = torch.empty(total_seq_len, config.hidden_dim, dtype=torch_dtype, device=device) + + # Broadcast full tensor to all ranks for reference + dist.broadcast(full_tensor_2d, src=0, group=cp_group) + + # Split tensor for local input (each rank gets a chunk) + chunk_size = total_seq_len // world_size_actual + start_idx = rank * chunk_size + end_idx = start_idx + chunk_size if rank < world_size_actual - 1 else total_seq_len + local_tensor_2d = full_tensor_2d[start_idx:end_idx].clone() + + # Test shuffle -> unshuffle + shuffled = shuffle_varlen(local_tensor_2d, cu_seqlens, cp_ranks, cp_group) + unshuffled = unshuffle_varlen(shuffled, cu_seqlens, cp_ranks, cp_group) + + # Compare with original local chunk + if torch.allclose(local_tensor_2d, unshuffled, atol=tolerance): + if rank == 0: + print(" โœ“ 2D tensor test passed") + else: + if rank == 0: + print(" โœ— 2D tensor test FAILED") + raise AssertionError("2D tensor test failed") + + except Exception as e: + if rank == 0: + print(f" โœ— 2D tensor test FAILED with error: {e}") + raise e + + dist.barrier() + + if rank == 0: + print("โœ“ All shuffle/unshuffle tests PASSED!") + + dist.destroy_process_group() + + +class TestShuffleVarlen: + """Simple test class for shuffle/unshuffle varlen""" + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_tiny(self, dtype): + """Test shuffle/unshuffle varlen with tiny configuration""" + partial(torchrun, 2, shuffle_varlen_test, "tiny", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_small(self, dtype): + """Test shuffle/unshuffle varlen with small configuration""" + partial(torchrun, 2, shuffle_varlen_test, "small", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_medium(self, dtype): + """Test shuffle/unshuffle varlen with medium configuration""" + partial(torchrun, 2, shuffle_varlen_test, "medium", dtype)() + + @pytest.mark.parametrize("dtype", ["float32", "float16"]) + def test_shuffle_varlen_uneven(self, dtype): + """Test shuffle/unshuffle varlen with uneven sequence lengths""" + partial(torchrun, 2, shuffle_varlen_test, "uneven", dtype)() + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_shuffle_varlen_multi_gpu(self, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, "tiny", "float32")() + + +# Standalone test functions for pytest discovery +@pytest.mark.parametrize("config,dtype", [ + ("tiny", "float32"), ("tiny", "float16"), + ("small", "float32"), ("small", "float16"), + ("uneven", "float32"), ("uneven", "float16"), +]) +def test_shuffle_varlen_correctness(config, dtype): + """Test shuffle/unshuffle varlen correctness""" + partial(torchrun, 2, shuffle_varlen_test, config, dtype)() + + +@pytest.mark.parametrize("config,num_gpus", [ + ("tiny", 2), ("tiny", 4), + ("small", 2), ("small", 4), +]) +def test_shuffle_varlen_multi_gpu(config, num_gpus): + """Test shuffle/unshuffle varlen on multiple GPUs""" + partial(torchrun, num_gpus, shuffle_varlen_test, config, "float32")() \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/test_zigzag_attn.py b/tests/customized_ops/ring_attn/test_zigzag_attn.py new file mode 100644 index 00000000..3bca5792 --- /dev/null +++ b/tests/customized_ops/ring_attn/test_zigzag_attn.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention correctness tests. + +This module contains correctness tests for the zigzag attention implementation. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). + +Usage: + python -m pytest test_zigzag_attn.py -v + python -m pytest test_zigzag_attn.py::TestZigzagAttn::test_zigzag_attn_tiny_bf16 -v +""" + +import pytest + +# Skip all tests if flash_attn_func is not available +try: + from flash_attn import flash_attn_func +except ImportError: + pytest.skip("flash_attn_func not available", allow_module_level=True) + +from .test_base import RingAttnTestBase + + +class TestZigzagAttn(RingAttnTestBase): + """Test class for zigzag attention correctness testing""" + + @property + def runner_script_name(self) -> str: + return "zigzag_attn_runner.py" + + @property + def test_function_name(self) -> str: + return "zigzag_attn_test" + + @property + def test_name_prefix(self) -> str: + return "zigzag_attn" + + # Basic correctness tests + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_tiny(self, dtype): + """Test zigzag attention with tiny configuration""" + self.run_correctness_basic(dtype, "zigzag_tiny") + + @pytest.mark.parametrize("dtype", ["bf16", "fp16"]) + def test_zigzag_attn_small(self, dtype): + """Test zigzag attention with small configuration""" + self.run_correctness_basic(dtype, "zigzag_small") + + @pytest.mark.parametrize("dtype", ["bf16"]) + def test_zigzag_attn_medium(self, dtype): + """Test zigzag attention with medium configuration""" + self.run_correctness_basic(dtype, "zigzag_medium") + + # Multi-GPU tests + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_small(self, num_gpus): + """Test zigzag attention with small config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_small") + + @pytest.mark.parametrize("num_gpus", [2, 4]) + def test_zigzag_attn_multi_gpu_medium(self, num_gpus): + """Test zigzag attention with medium config on multiple GPUs""" + self.run_multi_gpu_scaling(num_gpus, "zigzag_medium") + + # GQA test + def test_zigzag_attn_gqa(self): + """Test zigzag attention with GQA configuration""" + self.run_gqa_correctness("bf16", "zigzag_gqa") + + +if __name__ == "__main__": + # For direct execution, run a simple test + test_instance = TestZigzagAttn() + test_instance.run_correctness_basic("bf16", "zigzag_tiny") \ No newline at end of file diff --git a/tests/customized_ops/ring_attn/zigzag_attn_runner.py b/tests/customized_ops/ring_attn/zigzag_attn_runner.py new file mode 100644 index 00000000..6e557e2a --- /dev/null +++ b/tests/customized_ops/ring_attn/zigzag_attn_runner.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Zigzag attention test runner implementation. +This module provides the specific runner for testing zigzag attention. +Note: Zigzag attention only supports causal=True and window_size=(-1, -1). +""" + +import os +import sys +from typing import Dict, Any, Tuple + +import torch +import torch.nn as nn + +from nnscaler.customized_ops.ring_attention.zigzag_attn import wrap_zigzag_attn_func +from runner_base import RingAttnRunnerBase + + +class ZigzagAttnRunner(RingAttnRunnerBase): + """Zigzag attention test runner""" + + @property + def function_signature(self) -> str: + return "nnscaler.customized_ops.ring_attention.zigzag_attn.wrap_zigzag_attn_func" + + @property + def partition_position(self) -> Tuple[int, int]: + return 0, 1 + + @property + def function_name(self) -> str: + return "wrap_zigzag_attn_func" + + def create_test_module(self, config) -> torch.nn.Module: + """Create test module for zigzag attention""" + class TestModule(nn.Module): + def __init__(self, causal=True, window_size=(-1, -1)): + super().__init__() + # Zigzag attention only supports causal=True and window_size=(-1, -1) + assert causal is True, "Zigzag attention only supports causal=True" + assert window_size == (-1, -1), "Zigzag attention only supports window_size=(-1, -1)" + self.causal = causal + self.window_size = window_size + + def forward(self, q, k, v): + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + return wrap_zigzag_attn_func(q, k, v, causal=self.causal, window_size=self.window_size) + + return TestModule(causal=config.causal, window_size=config.window_size) + + def prepare_inputs(self, config, device, torch_dtype): + """Prepare inputs for zigzag attention""" + batch_size = config.batch_size + max_seqlen = config.max_seqlen + num_heads = config.num_heads + num_kv_heads = config.num_kv_heads + head_dim = config.head_dim + + # Create input tensors + q = torch.clamp(torch.randn(batch_size, max_seqlen, num_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + k = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + v = torch.clamp(torch.randn(batch_size, max_seqlen, num_kv_heads, head_dim, device=device, dtype=torch_dtype), min=-1, max=1) + + return { + 'q': q, + 'k': k, + 'v': v + } + + def run_single_gpu_reference(self, inputs, config): + """Run single GPU reference implementation""" + # Note: zigzag_attn always uses causal=True and window_size=(-1, -1) + output = wrap_zigzag_attn_func( + inputs['q'], inputs['k'], inputs['v'], + causal=config.causal, window_size=config.window_size) + output.retain_grad() + + return output, [inputs['q'], inputs['k'], inputs['v']] + + def get_dummy_forward_args(self, inputs) -> Dict[str, Any]: + """Get dummy forward arguments for model parallelization""" + return { + 'q': inputs['q'], + 'k': inputs['k'], + 'v': inputs['v'] + } + + +def zigzag_attn_test(dtype="bf16", config_name="zigzag_tiny", **kwargs): + """Pure test function that can be used with torchrun""" + runner = ZigzagAttnRunner() + return runner.run_correctness_test(dtype=dtype, config_name=config_name, **kwargs) + + +def main(): + """Main entry point for command line execution""" + kwargs = dict(arg.split("=") for arg in sys.argv[1:]) + + runner = ZigzagAttnRunner() + runner.main(**kwargs) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tests/graph/function/test_functions.py b/tests/graph/function/test_functions.py index 319e3b6f..8a32e20d 100644 --- a/tests/graph/function/test_functions.py +++ b/tests/graph/function/test_functions.py @@ -7,6 +7,7 @@ from operator import add from nnscaler.graph.function.dimops import IRDimops, OpAnno import nnscaler.graph.function.function as F +from nnscaler.graph.parser.value_tracker import ValueTracker from nnscaler.ir.cten import IR, IRObject, IRTensor import pytest @@ -47,6 +48,35 @@ def test_Full(): assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 1' +def test_Randn(): + op = F.Randn(IRObject(value=[2, 3, 4])) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + for dim_track in op.output(0).dim_tracks: + assert dim_track.deps == [op.kwargs['size'].value_track.value_id] + + op = F.Randn(2, IRObject(value=3), IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 2 3 4' + + assert op.output(0).dim_tracks[0].deps == [] + assert op.output(0).dim_tracks[1].deps == [op.kwargs['size'][1].value_track.value_id] + assert op.output(0).dim_tracks[2].deps == [op.kwargs['size'][2].value_track.value_id] + + +def test_Eye(): + op = F.Eye(IRObject(value=3)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 3^' + + op = F.Eye(IRObject(value=3), IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + op = F.Eye(3, IRObject(value=4)) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + op = F.Eye(IRObject(value=3), 4) + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == ' -> 3^ 4^' + + def test_Expand(): inp = IRTensor([10, 1]) out = IRTensor([10, 2]) @@ -1147,3 +1177,43 @@ def test_dict_keys_values_items(): # key will never be wrapped with IRObject # IRFullTensor will be reconstructed, so their ids are different assert all(x[0] == y[0] and x[1].shape == y[1].shape and x[1] != y[1] for x, y in zip(r.output(0), d.items())) + +def test_Stack(): + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=0) + expected_annotation = 'a b, a b, a b -> 3 a b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=1) + expected_annotation = 'a b, a b, a b -> a 3 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + op = F.Stack([IRTensor([2, 3]), IRTensor([2, 3]), IRTensor([2, 3])], dim=2) + expected_annotation = 'a b, a b, a b -> a b 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + + op = F.Stack([IRTensor([]), IRTensor([]), IRTensor([])], dim=0) + expected_annotation = '1, 1, 1 -> 3' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Stack." + + +def test_Dot(): + op = F.Dot(IRTensor([4]), IRTensor([4])) + expected_annotation = 'k+, k+ -> 1' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation, "Annotation mismatch for Dot." + + +def test_chunk(): + op = F.Chunk(IRTensor([8, 10]), chunks=4, dim=0) + expected_annotation = '8 b -> 2 b, 2 b, 2 b, 2 b' + assert len(op._annos_candidates) == 1 and op._annos_candidates[0] == expected_annotation + value_tracker = ValueTracker() + value_tracker.track_nodes([op]) + value_tracker.complete_tracking([op]) + input_dim_tracks = op.input(0).dim_tracks + output_dim_tracks = [out.dim_tracks for out in op.outputs()] + # all dim 1 tracks should be the same + assert output_dim_tracks[0][1] is input_dim_tracks[1] + # output dim 0 tracks should depend on input dim 0 track + assert output_dim_tracks[0][0].deps == [input_dim_tracks[0].value_id] + for output_dim_track in output_dim_tracks[1:]: + assert output_dim_track[0] is output_dim_tracks[0][0] + assert output_dim_track[1] is output_dim_tracks[0][1] + assert True diff --git a/tests/graph/parser/test_ast_transformer.py b/tests/graph/parser/test_ast_transformer.py index 51330fd9..548e9024 100644 --- a/tests/graph/parser/test_ast_transformer.py +++ b/tests/graph/parser/test_ast_transformer.py @@ -164,12 +164,12 @@ def f(self) -> None: assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) def g(): - return patched_run(x + y, a, b) + return patched_run(x + y, '(x + y)(a, b)', a, b) class A: def f(self) -> None: - patched_run(patched_run(super).f) + patched_run(patched_run(super, 'super()').f, 'super().f()') ''').strip() @@ -188,10 +188,10 @@ def __init__(self) -> None: modified, new_ast = transform(tree, transfomers) assert modified assert '\n'.join(line for line in ast.unparse(new_ast).split('\n') if line.strip()) == dedent(''' - x = patched_run(not_, True) + x = patched_run(not_, 'not_(True)', True) def f(func_name, type: int, /, *args, **kwargs): - return patched_run(func_name, type, *args, **kwargs) + return patched_run(func_name, 'func_name(type, *args, **kwargs)', type, *args, **kwargs) class A: def __init__(self) -> None: - patched_run(super(self.__class__, self).__init__) + patched_run(super(self.__class__, self).__init__, 'super(self.__class__, self).__init__()') ''').strip() diff --git a/tests/graph/parser/test_converter.py b/tests/graph/parser/test_converter.py index 20f2bcff..aed04fe6 100644 --- a/tests/graph/parser/test_converter.py +++ b/tests/graph/parser/test_converter.py @@ -44,7 +44,6 @@ def forward(self, x, **kwargs): assert any(node.op == 'call_function' and node.target == torch.nn.functional.linear for node in nodes) with tempfile.TemporaryDirectory() as tempdir: - to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tempdir, constant_folding=False) assert ir_graph is not None assert (Path(tempdir) / FxModuleParser.ATTR_MAP_FILE).exists() @@ -52,13 +51,45 @@ def forward(self, x, **kwargs): assert ir_graph.name == 'MyModule' inputs = ir_graph.inputs() assert len(inputs) == 2 - assert inputs[0].name == nodes[0].name + assert inputs[0].name == nodes[0].target assert isinstance(inputs[0], IRTensor) - assert inputs[1].name == nodes[1].name + assert inputs[0].value_track.deps == None + # inputs has no dependency + assert all(dt.deps == [] for dt in inputs[0].dim_tracks) + assert inputs[1].name == nodes[1].target assert isinstance(inputs[1], IRObject) + assert inputs[1].value_track.deps == [] + + assert len(ir_graph.nodes()) == 1 + linear_node = ir_graph.nodes()[0] + assert len(linear_node.inputs()) == 3 # x, weight, bias + + assert all(isinstance(i, IRTensor) for i in linear_node.inputs()) + # from its annotation, a k^, n k^, n -> a n + # we can check the value_track and dim_track dependencies + + # the same with graph inputs + assert all(linear_node.input(0).dim_tracks[i] is inputs[0].dim_tracks[i] for i in range(len(inputs[0].dim_tracks))) + # weights has no dependency + assert linear_node.input(1).dim_tracks[0].deps == [] + # the `k` dimension + assert linear_node.input(1).dim_tracks[1] is inputs[0].dim_tracks[1] + # the `n` dimension + assert linear_node.input(2).dim_tracks[0] is linear_node.input(1).dim_tracks[0] + + assert len(linear_node.outputs()) == 1 + assert isinstance(linear_node.outputs()[0], IRTensor) + # `a` + assert linear_node.output(0).dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert linear_node.output(0).dim_tracks[1] is linear_node.input(1).dim_tracks[0] outputs = ir_graph.outputs() assert len(outputs) == 1 + # `a` + assert outputs[0].dim_tracks[0] is inputs[0].dim_tracks[0] + # `n` + assert outputs[0].dim_tracks[1] is linear_node.input(1).dim_tracks[0] nodes = list(ir_graph.nodes()) assert any(node.signature == 'torch.nn.functional.linear' for node in nodes) diff --git a/tests/graph/parser/test_parser.py b/tests/graph/parser/test_parser.py index a0bc33b8..c468417a 100644 --- a/tests/graph/parser/test_parser.py +++ b/tests/graph/parser/test_parser.py @@ -166,7 +166,7 @@ def forward(self, x): ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) print(ir_graph.extra_repr()) - assert len(ir_graph.nodes()) == 5 + assert len(ir_graph.nodes()) == 4 assert len(ir_graph.nodes()[0].outputs()) == 3 assert len(ir_graph.outputs()) == 1 assert isinstance(ir_graph.output(0), list) @@ -324,3 +324,23 @@ def forward(self, x): # so the output number is 1 for now. # Will be fixed later. assert len(ir_graph.outputs()) == 1 + + +@replace_all_device_with('cpu') +def test_T(tmp_path): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.matmul(x, x.T) + + dummy_input = {'x': torch.randn(4, 8)} + module = MyModule() + fx_graph = to_fx_graph(module, dummy_input) + + print(fx_graph.graph) + ir_graph = to_ir_graph(fx_graph, dummy_input, attr_savedir=tmp_path, constant_folding=False) + print(ir_graph.extra_repr()) + + assert ir_graph.nodes()[0].signature == 'torch.transpose' diff --git a/tests/graph/parser/test_value_tracker.py b/tests/graph/parser/test_value_tracker.py new file mode 100644 index 00000000..ff45eacf --- /dev/null +++ b/tests/graph/parser/test_value_tracker.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile + +import pytest +import torch + +from nnscaler.graph.parser.converter import convert_model +from nnscaler import register_op, mark_dynamic + +from ...utils import replace_all_device_with + + +@replace_all_device_with('cpu') +def test_hidden_dim(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return x.repeat(4, 1) + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + node = ir_graph.node(0) + assert str(node.anno) == 'a^ b -> (4^ a^) b' + dim0_vi = node.input(0).dim_tracks[0].value_id + dim1_vi = node.input(0).dim_tracks[1].value_id + + assert node.output(0).dim_tracks[0].value_id != dim0_vi + assert node.output(0).dim_tracks[0].deps == [dim0_vi] + assert node.output(0).dim_tracks[1].value_id == dim1_vi + assert node.output(0).dim_tracks[1].deps == [] + + +@replace_all_device_with('cpu') +def test_equiv_class(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor): + x = x + 1 + y = y * 2 + return x@y + + dummy_input = {'x': torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 'y': torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 3 + x_node = ir_graph.node(0) + y_node = ir_graph.node(1) + assert x_node.input(0).dim_tracks[0] is x_node.output(0).dim_tracks[0] + assert x_node.input(0).dim_tracks[1] is x_node.output(0).dim_tracks[1] + + assert y_node.input(0).dim_tracks[0] is y_node.output(0).dim_tracks[0] + assert y_node.input(0).dim_tracks[1] is y_node.output(0).dim_tracks[1] + + node = ir_graph.node(-1) + assert str(node.anno) == 'm k+, k+ n -> m n' + # the `k` dimension of input 1 should be the same as input 0 + # they are in the same equivalence class + assert node.input(0).dim_tracks[0] is x_node.input(0).dim_tracks[0] + assert node.input(0).dim_tracks[1] is x_node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[0] is node.input(0).dim_tracks[1] + assert node.input(1).dim_tracks[1] is y_node.input(0).dim_tracks[1] + + assert node.output(0).dim_tracks[0] is node.input(0).dim_tracks[0] + assert node.output(0).dim_tracks[1] is node.input(1).dim_tracks[1] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_size(dynamic_dim): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = x + 1 + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': mark_dynamic(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), [0] if dynamic_dim else [])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert size_node.output(0)[0].value_track is ir_graph.inputs()[0].dim_tracks[0] + assert size_node.output(0)[0].value_track.is_constant != (dynamic_dim is True) + assert size_node.output(0)[1].value_track is ir_graph.inputs()[0].dim_tracks[1] + assert size_node.output(0)[1].value_track.is_constant is True + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) + assert randn_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] + assert randn_node.output(0).dim_tracks[1].is_constant is True + + +# Note: the custom op here is just for testing purpose +@register_op('l (2 m) n -> n (2 l) m') +def my_op(x: torch.Tensor) -> torch.Tensor: + return torch.randn(x.size(2), x.size(0) * 2, x.size(1) // 2) + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dim', [True, False]) +def test_custom_op(dynamic_dim): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + x = my_op(x) + s = x.size() + y = torch.randn(s) + return x + y + + dummy_input = {'x': mark_dynamic(torch.randn(2, 2, 2), [0, 2] if dynamic_dim else [])} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 4 + my_op_node = ir_graph.node(0) + size_node = ir_graph.node(1) + randn_node = ir_graph.node(2) + + assert [t.is_constant for t in ir_graph.inputs()[0].dim_tracks] == [not dynamic_dim, True, not dynamic_dim] + + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[2] + assert ir_graph.inputs()[0].dim_tracks[0].value_id in my_op_node.output(0).dim_tracks[1].deps + assert ir_graph.inputs()[0].dim_tracks[1].value_id in my_op_node.output(0).dim_tracks[2].deps + + assert [t.is_constant for t in my_op_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + assert size_node.output(0)[0].value_track is my_op_node.output(0).dim_tracks[0] + assert size_node.output(0)[1].value_track is my_op_node.output(0).dim_tracks[1] + assert size_node.output(0)[2].value_track is my_op_node.output(0).dim_tracks[2] + + assert [t.value_track.is_constant for t in size_node.output(0)] == [not dynamic_dim, not dynamic_dim, True] + + # dim tracks of randn node is from equivalence class originally from torch.add + assert randn_node.output(0).dim_tracks[0] is my_op_node.output(0).dim_tracks[0] + # assert randn_node.output(0).dim_tracks[0].is_constant != (dynamic_dim is True) + assert randn_node.output(0).dim_tracks[1] is my_op_node.output(0).dim_tracks[1] + assert randn_node.output(0).dim_tracks[2] is my_op_node.output(0).dim_tracks[2] + + assert [t.is_constant for t in randn_node.outputs()[0].dim_tracks] == [not dynamic_dim, not dynamic_dim, True] + + +# Note: the custom op here is just for testing purpose +@register_op('l l -> l l') +def my_identity(x: torch.Tensor) -> torch.Tensor: + return x + + +@replace_all_device_with('cpu') +def test_custom_op2(): + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor): + return my_identity(x) + + dummy_input = {'x': torch.randn(2, 2)} + module = MyModule() + + with tempfile.TemporaryDirectory() as tempdir: + ir_graph = convert_model(module, dummy_input, attr_savedir=tempdir, constant_folding=False) + assert ir_graph is not None + assert len(ir_graph.nodes()) == 1 + my_op_node = ir_graph.node(0) + + assert ir_graph.inputs()[0].dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[0] is ir_graph.inputs()[0].dim_tracks[0] + assert my_op_node.output(0).dim_tracks[1] is ir_graph.inputs()[0].dim_tracks[1] diff --git a/tests/graph/schedule/test_interleaved_1f1b.py b/tests/graph/schedule/test_interleaved_1f1b.py index 593ff23f..ed779ba3 100644 --- a/tests/graph/schedule/test_interleaved_1f1b.py +++ b/tests/graph/schedule/test_interleaved_1f1b.py @@ -21,7 +21,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.common import assert_equal from tests.parallel_module.test_gencode import _gencode_contains @@ -131,7 +131,7 @@ def worker_pipeline_2(n_micro_batches): trace_data = torch.randn([2, 32], dtype=torch.float32, device=torch.cuda.current_device()) cfg = ComputeConfig(2, 2, use_end2end=True, pas_config=dict(n_micro_batches=n_micro_batches)) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_1f1b_interleaved') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_1f1b_interleaved_{PYTEST_RUN_ID}') as tempdir: pm_1f1b = parallelize( m, {'x': trace_data}, diff --git a/tests/graph/test_segment.py b/tests/graph/test_segment.py index 108b9e3e..32e36bf6 100644 --- a/tests/graph/test_segment.py +++ b/tests/graph/test_segment.py @@ -18,7 +18,7 @@ from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer from nnscaler.ir.operator import IRFwOperation, IRDataOperation from tests.parallel_module.test_gencode import _gencode_contains, print_gencode -from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random +from ..utils import replace_all_device_with, clear_dir_on_rank0, init_random, PYTEST_RUN_ID from ..launch_torchrun import torchrun @@ -119,7 +119,7 @@ def worker_a(): m.train() trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_pyfunc') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_infer_grad_pyfunc_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'q': trace_data,}, @@ -167,7 +167,11 @@ def policy_nograd(graph: IRGraph, cfg: ComputeConfig) -> IRGraph: else: fc1_node = graph.nodes()[0] func_node = graph.nodes()[1] - assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + if cfg.use_end2end: + assert not fc1_node.inputs()[0].requires_grad and not fc1_node.inputs()[0].grad + else: + assert fc1_node.inputs()[0].requires_grad and fc1_node.inputs()[0].grad + assert fc1_node.inputs()[1].requires_grad and fc1_node.inputs()[1].grad assert fc1_node.outputs()[0].requires_grad and fc1_node.outputs()[0].grad assert func_node.inputs()[0].requires_grad and not func_node.inputs()[0].grad @@ -207,7 +211,7 @@ def worker_b(use_end2end): init_random() trace_data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) data = torch.randn([2, 2, 2, 8], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_infer_grad_no_grad') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_infer_grad_no_grad_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'q': trace_data,}, diff --git a/tests/graph/tracer/test_pack_kwargs.py b/tests/graph/tracer/test_pack_kwargs.py new file mode 100644 index 00000000..4db75d5d --- /dev/null +++ b/tests/graph/tracer/test_pack_kwargs.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +from torch.nn import Module + +from nnscaler.graph.tracer import concrete_trace +from ...utils import replace_all_device_with + + +class Model(Module): + def __init__(self): + super(Model, self).__init__() + self.linear = torch.nn.Linear(10, 5) + + def forward(self, **kwargs): + return self.linear(kwargs['input']) + + +@replace_all_device_with('cpu') +def test_pack_kwargs(): + model = Model() + example_inputs = {'input': torch.randn(1, 10)} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' + + +@replace_all_device_with('cpu') +def test_direct_kwargs(): + model = Model() + example_inputs = {'**kwargs': {'input': torch.randn(1, 10)}} + traced_model = concrete_trace(model, example_inputs) + assert list(traced_model.graph.nodes)[0].target == '**kwargs' diff --git a/tests/ir/test_cten.py b/tests/ir/test_cten.py index c406f7bc..42f7d115 100644 --- a/tests/ir/test_cten.py +++ b/tests/ir/test_cten.py @@ -120,9 +120,9 @@ def test_from_complex(tosub, requires_grad): assert type(obj[2]) == tensor_type and obj[2].parent.tid != obj_tensor_item.tid t1 = TensorMetadata(shape=(), dtype=torch.float, requires_grad=False, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) t2 = TensorMetadata(shape=(2,), dtype=torch.float, requires_grad=True, - stride=None, memory_format=None, is_quantized=None, qparams=None) + stride=None, memory_format=None, is_quantized=None, qparams=None, dynamic_dims=set()) obj = IR.new('n', {'a': t1, 'b': t2}.values(), tensor_types=(TensorMetadata,), diff --git a/tests/launch_torchrun.py b/tests/launch_torchrun.py index fce62ed2..27933ecc 100644 --- a/tests/launch_torchrun.py +++ b/tests/launch_torchrun.py @@ -4,11 +4,12 @@ from typing import Callable import uuid import torch +import os from torch.distributed.run import elastic_launch, LaunchConfig from torch.distributed.elastic.multiprocessing.errors import ChildFailedError -from .utils import retry +from .utils import retry, MASTER_PORT @retry(ChildFailedError, delay=10, match='The server socket has failed to listen on any local network address.') @@ -18,7 +19,7 @@ def launch_torchrun(nproc_per_node, worker_fn, *args, **kwargs): max_nodes=1, nproc_per_node=nproc_per_node, rdzv_backend = "c10d", - rdzv_endpoint = "localhost:29401", + rdzv_endpoint = f"localhost:{MASTER_PORT}", run_id = str(uuid.uuid4()), monitor_interval=0.1, max_restarts=0, diff --git a/tests/parallel_module/common.py b/tests/parallel_module/common.py index 0cfac768..cc3b22e3 100644 --- a/tests/parallel_module/common.py +++ b/tests/parallel_module/common.py @@ -107,6 +107,19 @@ def forward(self, x): return x +class FFN(nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + def init_distributed(): torch.distributed.init_process_group(backend='nccl') rank = torch.distributed.get_rank() @@ -115,9 +128,10 @@ def init_distributed(): def assert_equal(a: Any, b: Any): - assert type(a) == type(b) + # treat dict and OrderedDict as same for comparison + assert type(a) == type(b) or (isinstance(a, dict) and isinstance(b, dict)), f'{type(a)} != {type(b)}' if isinstance(a, torch.Tensor): - assert torch.equal(a.cpu(), b.cpu()) + assert torch.equal(a.cpu(), b.cpu()), torch.max(torch.abs(a.cpu() - b.cpu())) elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): @@ -127,7 +141,7 @@ def assert_equal(a: Any, b: Any): for i in range(len(a)): assert_equal(a[i], b[i]) else: - assert a == b + assert a == b, f"Values are not equal: {a} != {b}" def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): @@ -137,10 +151,10 @@ def assert_close(a: Any, b: Any, atol=1e-6, rtol=1e-6): elif isinstance(a, dict): assert len(a) == len(b) for k in a.keys(): - assert_close(a[k], b[k]) + assert_close(a[k], b[k], atol=atol, rtol=rtol) elif isinstance(a, (list, tuple)): assert len(a) == len(b) for i in range(len(a)): - assert_close(a[i], b[i]) + assert_close(a[i], b[i], atol=atol, rtol=rtol) else: - raise ValueError(f'unsupported type {type(a)}') \ No newline at end of file + assert a == b, f"Values are not equal: {a} != {b}" diff --git a/tests/parallel_module/test_async.py b/tests/parallel_module/test_async.py index 97b770b5..af9d75f0 100644 --- a/tests/parallel_module/test_async.py +++ b/tests/parallel_module/test_async.py @@ -13,7 +13,7 @@ from tests.launch_torchrun import launch_torchrun from tests.launch_torchrun import clone_to_cpu_recursively from tests.parallel_module.common import assert_equal, init_distributed -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from .test_wholemodule import FcRelu_4_4 @@ -88,7 +88,7 @@ def _train(model: ParallelModule, update_freq): def _gpu_worker(pas, ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_async') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_async_{PYTEST_RUN_ID}') as tempdir: whole_module_async, sub_module_async = _create_modules( pas, ComputeConfig( 1, ngpus, use_async_reducer=True, @@ -203,7 +203,7 @@ def _train_pp(model: ParallelModule, num_replicas, rank): def _gpu_worker_pp(pas, pp_ngpus, runtime_ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_pp_async') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_pp_async_{PYTEST_RUN_ID}') as tempdir: init_random() whole_module_async = parallelize( OrigModuleEnd2End(), { diff --git a/tests/parallel_module/test_attr_dedup.py b/tests/parallel_module/test_attr_dedup.py index 64bf490a..99d18443 100644 --- a/tests/parallel_module/test_attr_dedup.py +++ b/tests/parallel_module/test_attr_dedup.py @@ -20,7 +20,8 @@ from .common import init_distributed, assert_equal from ..launch_torchrun import launch_torchrun -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID + class Net(torch.nn.Module): def __init__(self): @@ -38,6 +39,7 @@ def forward(self, x): x = self.buffer + x return x + def pas(graph: IRGraph, config: ComputeConfig): fw_nodes = graph.select(ntype=IRFwOperation) assert len(fw_nodes) == 4 @@ -50,9 +52,10 @@ def pas(graph: IRGraph, config: ComputeConfig): _replica(graph, fw_nodes[3], devs=devs) return graph + def _gpu_worker_spmd(cc: ComputeConfig): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'nnscaler_test_dedup_attr') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_dedup_attr_{PYTEST_RUN_ID}') as tempdir: module = parallelize( Net(), {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, @@ -65,13 +68,17 @@ def _gpu_worker_spmd(cc: ComputeConfig): world_size = torch.distributed.get_world_size() attr_area_maps = [None for _ in range(world_size)] curr_rank = torch.distributed.get_rank() - torch.distributed.all_gather_object(attr_area_maps, module.fullmap) + # Construct the three-level nested structure: rank -> module_name -> fullmap + # In this test case, we have only one module instance 'attr_dedup' + module_fullmap = {'attr_dedup': module.fullmap} + torch.distributed.all_gather_object(attr_area_maps, module_fullmap) rank2attr_area_map = {} for i, attr_area_map in enumerate(attr_area_maps): rank2attr_area_map[i] = attr_area_map torch.distributed.barrier() dedup_meta_info = dedup_attrs(rank2attr_area_map) - dedup_area_map = list(dedup_meta_info[curr_rank].items()) + # Access the deduped fullmap for the specific module + dedup_area_map = list(dedup_meta_info[curr_rank]['attr_dedup'].items()) if curr_rank == 0: assert len(dedup_area_map) == 4 assert dedup_area_map[0][1].orig_name == 'fc1.weight' diff --git a/tests/parallel_module/test_broadcast.py b/tests/parallel_module/test_broadcast.py index c7bc814a..e7552ab8 100644 --- a/tests/parallel_module/test_broadcast.py +++ b/tests/parallel_module/test_broadcast.py @@ -8,7 +8,7 @@ import pytest import torch -from nnscaler.parallel import ComputeConfig, parallelize, broadcast_weights +from nnscaler.parallel import ComputeConfig, _prepare_namespace, parallelize, broadcast_weights from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -41,13 +41,24 @@ def _to_cube_model(module, compute_config, cube_savedir, ) -def _gpu_worker(): +def _gpu_worker(tmp_path): init_distributed() + world_size = torch.distributed.get_world_size() + local_world_size = world_size // 2 # fake two machines, as we use different cube_savedir for each worker - os.environ['LOCAL_WORLD_SIZE'] = '1' + os.environ['LOCAL_WORLD_SIZE'] = str(local_world_size) + tempdir = tmp_path / f'worker_{torch.distributed.get_rank() // local_world_size}' + node_rank = torch.distributed.get_rank() // local_world_size + + # from nnscaler.runtime.device import DeviceGroup + # # create groups + # for i in range(local_world_size): + # group_ranks = list(range(i, world_size, local_world_size)) + # DeviceGroup().get_group(group_ranks) + p = lambda t, b, i, load_module=True, **kwargs: _to_cube_model( Module(), - ComputeConfig(1, 2), + ComputeConfig(1, world_size), t, load_module=load_module, broadcast_strategy=b, @@ -56,74 +67,78 @@ def _gpu_worker(): ) # case 1: no broadcast, so only rank 0 can load the module # rank 1 will raise ModuleNotFoundError - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: - p(tempdir, 'none', '_1') - else: - with pytest.raises(ModuleNotFoundError): - p(tempdir, 'none', '_1') + # this will hang forever due to the distributed group creation in generated code. + # if node_rank == 0: + # p(tempdir, 'none', '_1') + # else: + # with pytest.raises(ModuleNotFoundError): + # p(tempdir, 'none', '_1') # case 2: broadcast only code, so only rank 0 can load the module - # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: + # rank 1 will raise FileNotFoundError because it will fail to load attr_map files and more + if node_rank == 0: + p(tempdir, 'code', '_2') + else: + with pytest.raises(FileNotFoundError): p(tempdir, 'code', '_2') - else: - with pytest.raises(RuntimeError, match='Cannot find file.*'): - p(tempdir, 'code', '_2') # case 3: broadcast except weights, so only rank 0 can load the module # rank 1 will raise RuntimeError because it will fail to load fullmodel.pt - with tempfile.TemporaryDirectory() as tempdir: - if torch.distributed.get_rank() == 0: + if node_rank == 0: + p(tempdir, 'no_weights', '_3') + else: + with pytest.raises(RuntimeError, match='Cannot find file.*'): p(tempdir, 'no_weights', '_3') - else: - with pytest.raises(RuntimeError, match='Cannot find file.*'): - p(tempdir, 'no_weights', '_3') # case 4: broadcast except weights, every rank can succeed if don't lood init params - with tempfile.TemporaryDirectory() as tempdir: - m = p(tempdir, 'no_weights', '_4', - init_module_params=torch.distributed.get_rank() == 0 - ) - if torch.distributed.get_rank() == 0: - for n, pa in m.named_parameters(): - if n.startswith('linear_weight'): - pa.data.fill_(1.0) - else: - for n, pa in m.named_parameters(): - if n.startswith('linear_weight'): - assert not torch.equal(pa.data, torch.ones_like(pa.data)) - broadcast_weights(m) - # check if broadcast_weights works + m = p(tempdir, 'no_weights', '_4', + init_module_params=torch.distributed.get_rank() == 0 + ) + if node_rank == 0: for n, pa in m.named_parameters(): if n.startswith('linear_weight'): - assert torch.equal(pa.data, torch.ones_like(pa.data)) + pa.data.fill_(1.0) + else: + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + assert not torch.equal(pa.data, torch.ones_like(pa.data)) + broadcast_weights(m) + # check if broadcast_weights works + for n, pa in m.named_parameters(): + if n.startswith('linear_weight'): + assert torch.equal(pa.data, torch.ones_like(pa.data)) # case 5: broadcast all, all ranks will succeed - with tempfile.TemporaryDirectory() as tempdir: - p(tempdir, 'all', '_5') + p(tempdir, 'all', '_5') # case 6: test incremental broadcast - with tempfile.TemporaryDirectory() as tempdir: - # generate without broadcasting - m = p(tempdir, 'none', '_6', load_module=False) - if torch.distributed.get_rank() != 0: - assert list(Path(tempdir).glob('*')) == [] - - # case 6.1: broadcast code even we set broadcast_strategy to `all` - # because only code is new generated. - m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') - if torch.distributed.get_rank() != 0: - # only python files are broadcasted - assert set(f.name for f in Path(tempdir).glob('**/*') if f.is_file()) == set(['gencode0.py', 'gencode1.py', 'compute_config.pt']) + # generate without broadcasting + _, outdir6 = _prepare_namespace(tempdir, Module, '_6') + m = p(tempdir, 'none', '_6', load_module=False) + if node_rank != 0: + assert list(Path(outdir6).glob('*')) == [] + + # case 6.1: broadcast code even we set broadcast_strategy to `all` + # because only code is new generated. + m = p(tempdir, 'all', '_6', load_module=False, reuse='graph') + if node_rank != 0: + # only python files are broadcasted + assert set(f.name for f in Path(outdir6).glob('**/*') if f.is_file()) == set( + [f'gencode{i}.py' for i in range(world_size)] + ['compute_config.pt'] + ) - # case 6.2: everything should be broadcasted, including weights - # so the load_module will succeed. - m = p(tempdir, 'all', '_6', load_module=True, reuse='override') + torch.distributed.barrier() + # case 6.2: everything should be broadcasted, including weights + # so the load_module will succeed. + m = p(tempdir, 'all', '_6', load_module=True, reuse='override') @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') -def test_broadcast(): - launch_torchrun(2, _gpu_worker) +def test_broadcast(tmp_path): + launch_torchrun(2, _gpu_worker, tmp_path) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +def test_broadcast4(tmp_path): + launch_torchrun(4, _gpu_worker, tmp_path) diff --git a/tests/parallel_module/test_checkpoint.py b/tests/parallel_module/test_checkpoint.py index 6faee640..565ad8b1 100644 --- a/tests/parallel_module/test_checkpoint.py +++ b/tests/parallel_module/test_checkpoint.py @@ -17,13 +17,20 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts, load_merged_state_dict +from nnscaler.parallel import ( + ComputeConfig, parallelize, + build_optimizer, + merge_state_dicts, + load_merged_state_dict, + load_merged_state_dict_from_rank, + trimmed_broadcast_merged_state_dict, +) from nnscaler.runtime.module import ParallelModule, ExtraState from nnscaler.runtime.gnorm import calcuate_gnorm -from .common import CubeLinear, init_random, init_distributed, PASMegatron +from .common import CubeLinear, init_random, init_distributed, PASMegatron, assert_equal from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import replace_all_device_with, clear_dir_on_rank0 +from ..utils import replace_all_device_with, clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -345,6 +352,23 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf optimizer_from_merged, merged_opt_state_dict, ) + model_from_merged_rank = type(model)() + optimizer_from_merged_rank = build_optimizer(model_from_merged_rank, torch.optim.Adam, lr=0.01) + load_merged_state_dict_from_rank( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(model_from_merged.state_dict(), model_from_merged_rank.state_dict()) + assert_equal(optimizer_from_merged.state_dict(), optimizer_from_merged_rank.state_dict()) + + trimmed_model_state_dict, trimmed_opt_state_dict = trimmed_broadcast_merged_state_dict( + model_from_merged_rank, merged_model_state_dict if torch.distributed.get_rank() == 0 else None, + optimizer_from_merged_rank, merged_opt_state_dict if torch.distributed.get_rank() == 0 else None, + ) + assert_equal(dict(model_from_merged.state_dict()), trimmed_model_state_dict) + assert_equal(optimizer_from_merged.state_dict()['state'], trimmed_opt_state_dict['state']) + assert_equal(optimizer_from_merged.state_dict()['param_groups'], trimmed_opt_state_dict['param_groups']) + # check merged model result_orig_model_state_dict = model.state_dict() result_merged_model_state_dict = model_from_merged.state_dict() @@ -358,7 +382,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf assert set(result_orig_opt_state_dict['state']) == set(result_merged_opt_state_dict['state']) for index in result_orig_opt_state_dict['state']: for key in ('step', 'exp_avg', 'exp_avg_sq'): - assert torch.equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) + assert_equal(result_orig_opt_state_dict['state'][index][key], result_merged_opt_state_dict['state'][index][key]) torch.distributed.barrier() data = gendata(model, DATA_SIZE, start, end, rank, num_replicas) results = [] @@ -425,6 +449,25 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf 'model': merged_model_state_dicts, 'optimizer': merged_optimizer_state_dict }, ckpt_merged_file) + from nnscaler.runtime.serialization import convert, load + from contextlib import ExitStack + ckpt_st_file_template = 'ckpt_{rank}_{start}.safetensors' + ckpt_st_files = [ckpt_dir / ckpt_st_file_template.format(rank=i, start=end) for i in range(torch.distributed.get_world_size())] + for pt, st in zip(ckpt_files, ckpt_st_files): + convert(pt, st, src_format='pt', dst_format='safetensors') + ckpt_st_state_dict_loaders = [load(f, lazy=True) for f in ckpt_st_files] + with ExitStack() as stack: + ckpt_st_state_dicts = [] + for f in ckpt_st_state_dict_loaders: + ckpt_st_state_dicts.append(stack.enter_context(f).get_lazy_data()) + model_st_state_dicts = [ckpt['model'] for ckpt in ckpt_st_state_dicts] + optimizer_st_state_dicts = [ckpt['optimizer'] for ckpt in ckpt_st_state_dicts] + merged_model_st_state_dicts, merged_optimizer_st_state_dict = merge_state_dicts( + model_st_state_dicts, optimizer_st_state_dicts + ) + assert_equal(merged_model_state_dicts, merged_model_st_state_dicts) + assert_equal(merged_optimizer_state_dict, merged_optimizer_st_state_dict) + torch.distributed.barrier() return results @@ -432,7 +475,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, start, end, ckpt_dir, inf def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus, per_resume_update_count, resume_count, check_module=None): init_distributed() compiled_results = [] - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: for i in range(resume_count): start = i * per_resume_update_count end = (i + 1) * per_resume_update_count @@ -548,9 +591,9 @@ def test_checkpoint_intra_reducer(module_type, use_zero): def _gpu_merge_worker(): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_merge') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_merge_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module('data', - ComputeConfig(2, 2, use_zero=True), + ComputeConfig(2, 4, use_zero=True), tempdir, 'whole', ) @@ -565,6 +608,40 @@ def _gpu_merge_worker(): ) -@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_checkpoint_merge(): - launch_torchrun(2, _gpu_merge_worker) + launch_torchrun(4, _gpu_merge_worker) + + +def _gather_full_model_state_dict_worker(tmp_path, use_zero): + from .test_end2end import MLP, dummy_data + from nnscaler.parallel import gather_full_model_state_dict, merge_state_dicts + init_distributed() + + model = MLP() + model = parallelize( + model, + {'data': dummy_data()}, + pas_policy='tp', + compute_config= ComputeConfig( + 2, 4, + use_end2end=True, + use_zero=use_zero, + ), + gen_savedir=tmp_path + ) + model.cuda() + rank = torch.distributed.get_rank() + torch.save(model.state_dict(), tmp_path / f'{rank}.pt') + torch.distributed.barrier() + merged_state_dict = merge_state_dicts( + [torch.load(tmp_path / f'{i}.pt', weights_only=False) for i in range(torch.distributed.get_world_size())] + ) + full_state_dict = gather_full_model_state_dict(model) + assert_equal(merged_state_dict, full_state_dict) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [0, 1, 3]) +def test_gather_full_model_state_dict(tmp_path, use_zero): + launch_torchrun(4, _gather_full_model_state_dict_worker, tmp_path, use_zero) diff --git a/tests/parallel_module/test_checkpoint_buffer.py b/tests/parallel_module/test_checkpoint_buffer.py index cde6f864..c81c11a2 100644 --- a/tests/parallel_module/test_checkpoint_buffer.py +++ b/tests/parallel_module/test_checkpoint_buffer.py @@ -12,7 +12,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_log, clear_dir_on_rank0 +from ..utils import catch_log, clear_dir_on_rank0, PYTEST_RUN_ID class Net1(torch.nn.Module): @@ -63,7 +63,7 @@ def _to_cube_model(module, compute_config, cube_savedir, instance_name, input_sh def _gpu_worker(): init_distributed() compute_config = ComputeConfig(1, 1, use_zero=False) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64)) cube_state_dict = net1.state_dict() assert not any(key.startswith('buffer') for key in cube_state_dict) @@ -129,7 +129,7 @@ def _gpu_worker_broadcast(): init_distributed() compute_config = ComputeConfig(1, 2, use_zero=False) rank = torch.distributed.get_rank() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_broadcast_fail') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_broadcast_fail_{PYTEST_RUN_ID}') as tempdir: net1 = _to_cube_model(Net1(), compute_config, tempdir, 'net1', (128, 64), init_module_params=False) with pytest.raises(RuntimeError, match="Non-persistent buffers haven't been initialized."): broadcast_weights(net1) diff --git a/tests/parallel_module/test_checkpoint_dedup.py b/tests/parallel_module/test_checkpoint_dedup.py index 339ea888..6ffdafdd 100644 --- a/tests/parallel_module/test_checkpoint_dedup.py +++ b/tests/parallel_module/test_checkpoint_dedup.py @@ -17,7 +17,7 @@ from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal from ..launch_torchrun import launch_torchrun from .test_checkpoint import gendata, train_step, End2EndMLP, End2EndMLPWithUnusedAndShared -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -139,8 +139,6 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): if not isinstance(model, ParallelModule): # in this case, non parallel module is removed, so it should have less keys assert len(parallel_modules) < len(dedupped_model_state_dict) < len(model_state_dict) - else: - assert len(dedupped_model_state_dict) == len(model_state_dict) for k, v in dedupped_model_state_dict.items(): assert_equal(v, model_state_dict[k]) @@ -181,7 +179,7 @@ def _check_deduped(model: torch.nn.Module, ckpt_dir): def _gpu_worker(pas, cc1, cc2): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_compact_{PYTEST_RUN_ID}') as tempdir: _train(_create_cube_module(pas, cc1, cc2, tempdir), tempdir) torch.distributed.barrier() _check_deduped( @@ -204,7 +202,7 @@ def test_checkpoint_compact(use_zero): def _gpu_worker_pipeline(cc): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt_compact_pipeline') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_compact_pipeline_{PYTEST_RUN_ID}') as tempdir: for model_cls in [End2EndMLP, End2EndMLPWithUnusedAndShared]: pipeline_moule_cls = model_cls.to_pipeline_module(cc, tempdir) _train(pipeline_moule_cls().cuda(), tempdir) diff --git a/tests/parallel_module/test_checkpoint_shared.py b/tests/parallel_module/test_checkpoint_shared.py index 8b5e91b1..efceb594 100644 --- a/tests/parallel_module/test_checkpoint_shared.py +++ b/tests/parallel_module/test_checkpoint_shared.py @@ -14,7 +14,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun from .test_checkpoint import End2EndMLP, train_step, gendata -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcReluWithShared(nn.Module): @@ -194,7 +194,7 @@ def _gpu_worker(module_type, use_zero, pas, plan_ngpus, runtime_ngpus): # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. init_distributed() compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: if torch.distributed.get_rank() == 0: tempdir.mkdir(parents=True, exist_ok=True) _train_raw(_create_cube_module(pas, compute_config, tempdir, f'{module_type}/raw'), tempdir) diff --git a/tests/parallel_module/test_checkpoint_unused.py b/tests/parallel_module/test_checkpoint_unused.py index ca8ea820..c0aec698 100644 --- a/tests/parallel_module/test_checkpoint_unused.py +++ b/tests/parallel_module/test_checkpoint_unused.py @@ -24,7 +24,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively from .test_checkpoint_shared import _train_raw, _load_merged -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcReluWithUnused(nn.Module): @@ -113,7 +113,7 @@ def _gpu_worker(use_zero, pas, plan_ngpus, runtime_ngpus): # d. compare the full state dict in step a and the merged state dict in step c. They should be the same. init_distributed() compute_config = ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ckpt') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ckpt_{PYTEST_RUN_ID}') as tempdir: if torch.distributed.get_rank() == 0: tempdir.mkdir(parents=True, exist_ok=True) _train_raw(_create_cube_module(pas, compute_config, tempdir, 'raw'), tempdir) diff --git a/tests/parallel_module/test_ddp.py b/tests/parallel_module/test_ddp.py index b43e0d20..9a4a179a 100644 --- a/tests/parallel_module/test_ddp.py +++ b/tests/parallel_module/test_ddp.py @@ -17,13 +17,13 @@ import numpy as np -from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer, merge_state_dicts from nnscaler.runtime.module import ParallelModule from nnscaler.runtime.gnorm import calcuate_gnorm from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -206,7 +206,7 @@ def _gpu_worker_ddp(update_freq): def _gpu_worker_cube(pas, plan_ngpus, runtime_ngpus, update_freq, use_zero): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=use_zero), tempdir @@ -261,6 +261,36 @@ def _compare_weights(orig0, compiled0, compiled1, fc1_fullmap, fc2_fullmap, fc1_ # print(f'key: {k}, max diff: {torch.max(torch.abs(orig0[k] - v))}') assert torch.allclose(v, orig0[k], rtol=1e-4, atol=1e-4) +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('update_freq', [1, 2, 4]) +def test_zero3(update_freq): + zero3_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 3) + zero1_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 1) + no_zero_results = launch_torchrun(2, _gpu_worker_cube, 'dp', 1, 2, update_freq, 0) + + zero3_results0: List[StepResult] + zero3_results1: List[StepResult] + zero1_results0: List[StepResult] + zero1_results1: List[StepResult] + no_zero_results0: List[StepResult] + no_zero_results1: List[StepResult] + + zero3_results0, zero3_results1 = zero3_results[0][0], zero3_results[1][0] + zero1_results0, zero1_results1 = zero1_results[0][0], zero1_results[1][0] + no_zero_results0, no_zero_results1 = no_zero_results[0][0], no_zero_results[1][0] + + for r0, r1 in [ + (zero3_results0, zero1_results0), (zero1_results0, no_zero_results0), + (zero3_results1, zero1_results1), (zero1_results1, no_zero_results1), + ]: + # have the same input + assert len(r0) == len(r1) # iteration count + for i in range(len(r0)): + a, b = r0[i], r1[i] + assert torch.equal(a.pred, b.pred) # pred + assert torch.equal(a.loss, b.loss) # loss + assert torch.equal(a.gnorm, b.gnorm) # gnorm + @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') @pytest.mark.parametrize('update_freq', [1, 2, 4]) diff --git a/tests/parallel_module/test_e2e_detach_loss.py b/tests/parallel_module/test_e2e_detach_loss.py index f2284af9..4b4cc914 100644 --- a/tests/parallel_module/test_e2e_detach_loss.py +++ b/tests/parallel_module/test_e2e_detach_loss.py @@ -19,7 +19,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random +from tests.utils import clear_dir_on_rank0, init_random, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.test_gencode import _gencode_contains @@ -94,7 +94,7 @@ def worker_pipeline_2x2(model_cls): torch.cuda.manual_seed(0) trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2x2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2x2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, @@ -159,7 +159,7 @@ def worker_pipeline_2(model_cls): torch.cuda.manual_seed(0) trace_data = torch.randn([2048, 4096], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, diff --git a/tests/parallel_module/test_end2end.py b/tests/parallel_module/test_end2end.py index c38d4403..a9b1f8d9 100644 --- a/tests/parallel_module/test_end2end.py +++ b/tests/parallel_module/test_end2end.py @@ -23,7 +23,7 @@ from nnscaler.parallel import ComputeConfig, build_optimizer, parallelize, merge_state_dicts from .common import assert_equal, init_distributed, PASMegatron, init_random from ..launch_torchrun import clone_to_cpu_recursively, launch_torchrun -from ..utils import replace_all_device_with, clear_dir_on_rank0 +from ..utils import replace_all_device_with, clear_dir_on_rank0, PYTEST_RUN_ID from .test_checkpoint import End2EndMLP @@ -111,7 +111,7 @@ def gpu_worker_cube_general(runtime_ngpus, plan_ngpus, policy, nstages=None, nmi init_random() nstages = nstages or plan_ngpus nmicros = nmicros or plan_ngpus - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_{PYTEST_RUN_ID}') as tempdir: init_random() model = model_cls() model = parallelize( @@ -204,14 +204,14 @@ def allclose(a, b, atol=1e-6, rtol=1e-6): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_end2end(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 - # will be used for comparision when zero_use_reduce_scatter is True - ga4_result_without_grads = [] - for i in range(len(ga4_result)): - ga4_result_without_grads.append([ga4_result[i][1], ga4_result[i][2]]) + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + # will be used for comparision when zero_use_reduce_scatter is True + ga4_result_without_grads = [] + for i in range(len(ga4_result)): + ga4_result_without_grads.append([ga4_result[i][1], ga4_result[i][2]]) cube2_results = launch_torchrun(4, gpu_worker_cube, 4, 2, 'hybrid') # micro_batch_size = 4 for _, v in cube2_results.items(): @@ -311,14 +311,14 @@ def __init__(self): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_pipeline_shared(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLPShared() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 - for step in range(len(ga4_result)): - # fake shared weights for later compare - ga4_result[step][0]['layers.5.weight'] = ga4_result[step][0]['layers.0.weight'] - ga4_result[step][1]['layers.5.weight'] = ga4_result[step][1]['layers.0.weight'] + with torch.device('cuda:0'): + model = MLPShared() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 + for step in range(len(ga4_result)): + # fake shared weights for later compare + ga4_result[step][0]['layers.5.weight'] = ga4_result[step][0]['layers.0.weight'] + ga4_result[step][1]['layers.5.weight'] = ga4_result[step][1]['layers.0.weight'] with pytest.raises(ValueError, match='is not supported in training mode'): ComputeConfig( @@ -356,10 +356,10 @@ def test_pipeline_shared(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 8, reason='lack of gpu devices') def test_pipeline(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 16 + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 16 # pp_size = 2 # tp_size = 2 @@ -416,7 +416,7 @@ def _train_cube_one_sample(model: ParallelModule, mbs): def gpu_worker_cube_one_sample(): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_{PYTEST_RUN_ID}') as tempdir: init_random() model = MLP() model = parallelize( @@ -441,12 +441,12 @@ def gpu_worker_cube_one_sample(): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_loss_scaling(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') - model = MLP() - ga4_result = _train_ga(model, 1, 1) - assert len(ga4_result) == 1 - ga4_grads = ga4_result[0][0] - scaled_ga4_grads = {n: g * 2.0 for n, g in ga4_grads.items()} + with torch.device('cuda:0'): + model = MLP() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] + scaled_ga4_grads = {n: g * 2.0 for n, g in ga4_grads.items()} cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) diff --git a/tests/parallel_module/test_end2end_mix_precision.py b/tests/parallel_module/test_end2end_mix_precision.py index c1922b5b..f1785acd 100644 --- a/tests/parallel_module/test_end2end_mix_precision.py +++ b/tests/parallel_module/test_end2end_mix_precision.py @@ -26,7 +26,7 @@ from .test_checkpoint import End2EndMLP from .test_end2end import allclose, merge_cube_result -from ..utils import init_parameter, clear_dir_on_rank0 +from ..utils import init_parameter, clear_dir_on_rank0, PYTEST_RUN_ID DATA_SIZE = 16 @@ -136,7 +136,7 @@ def gpu_worker_cube(use_zero=False, async_reducer=False, use_bucket=False): plan_ngpus = 2 runtime_ngpus = 4 nmicros = plan_ngpus - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_end2end_mp') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_end2end_mp_{PYTEST_RUN_ID}') as tempdir: init_random() model = MPModule() model = parallelize( @@ -171,12 +171,12 @@ def gpu_worker_cube(use_zero=False, async_reducer=False, use_bucket=False): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 4, reason='lack of gpu devices') def test_mixed_precision(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') init_random() - model = MPModule() - torch.save(model.state_dict(), 'model.pth') - ga4_result = _train_ga(model, 4) # micro_batch_size = 4 - assert len(ga4_result) == 4 + with torch.device('cuda:0'): + model = MPModule() + torch.save(model.state_dict(), 'model.pth') + ga4_result = _train_ga(model, 4) # micro_batch_size = 4 + assert len(ga4_result) == 4 cube2_results_non_pipeline = {} for use_async_reducer in [False, True]: diff --git a/tests/parallel_module/test_gencode.py b/tests/parallel_module/test_gencode.py index b7117563..dbca3de0 100644 --- a/tests/parallel_module/test_gencode.py +++ b/tests/parallel_module/test_gencode.py @@ -6,6 +6,7 @@ import re from contextlib import nullcontext from typing import Union +from functools import partial import torch import torch.nn.functional as F @@ -17,7 +18,8 @@ from nnscaler.graph.function.pyfunc import IRPyFunc from nnscaler.graph.parser.mapping import SignFx2Op from nnscaler.ir.cten import IR, IRObject -from nnscaler.parallel import parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.parallel import _load_parallel_module_class, parallelize, ComputeConfig, CubeModule, _gen_graph +from nnscaler.utils import mark_dynamic from .common import init_distributed from ..launch_torchrun import launch_torchrun @@ -67,6 +69,7 @@ def __init__(self): def forward(self, x): return x[:2] + @replace_all_device_with('cpu') def test_codegen_slice(): with tempfile.TemporaryDirectory() as tempdir: @@ -208,9 +211,8 @@ def _gencode_unused_args_worker(tempdir): m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), m=1) ) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) + # if y is not None, we will not raise error now. + m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -256,9 +258,6 @@ def _gencode_unused_args_worker2(tempdir): with pytest.raises(TypeError, match='.*must be Tensor, not NoneType.*'): # raise by torch.add, as m is None m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])) - with pytest.raises(ValueError): - # y must be None - m_new(torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), 1) @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') @@ -394,11 +393,11 @@ def print_gencode(cubesave_dir, module_class, index=0): print(filecontent) -def _gencode_contains(cubesave_dir, module_class, index, search_re): +def _gencode_contains(cubesave_dir, module_class, index, search_re, *, instance_name=None): from nnscaler.parallel import _PARALLEL_MODULE_NAMESPACE, _get_full_qualified_name, _DEFAULT_INSTANCE_NAME from pathlib import Path import re - namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{_DEFAULT_INSTANCE_NAME}' + namespace = f'{_PARALLEL_MODULE_NAMESPACE}.{_get_full_qualified_name(module_class)}.{instance_name or _DEFAULT_INSTANCE_NAME}' outdir: Path = cubesave_dir / Path(namespace.replace('.', '/').strip('/')) filecontent = (outdir /f'gencode{index}.py').read_text() matches = re.findall(search_re, filecontent) @@ -453,8 +452,13 @@ def test_codegen_getitem(): gen_savedir=tempdir, load_module=False, ) - assert _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') - assert _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + + assert _gencode_contains(tempdir, GetItemModule, 0, r"_operator.getitem\(batched_data.*, 'x'\)") + assert _gencode_contains(tempdir, GetItemModule, 1, r"_operator.getitem\(batched_data.*, 'x'\)") + # data_x.size() will be expanded to a list of ir objects, + # so no slice operation will be generated. + assert not _gencode_contains(tempdir, GetItemModule, 0, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') + assert not _gencode_contains(tempdir, GetItemModule, 1, r'_operator.getitem\(.*, slice\(None, 2, None\)\)') assert m_new is None @@ -654,6 +658,37 @@ def test_codegen_dictget(): assert m_new is None +class NonConstModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + shape = x.shape + z = torch.randn(shape) + shape = z.shape + return z + shape[0] + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0]]) +def test_codegen_nonconst(dynamic_dims): + with tempfile.TemporaryDirectory() as tempdir: + m_new = parallelize( + NonConstModule(), + {'x': mark_dynamic(torch.tensor([[[1.0], [2.0], [3.0], [6.0]]]), dynamic_dims)}, # shape 1/4/1 + 'dp', + ComputeConfig(1, 1, constant_folding=True), + gen_savedir=tempdir, + load_module=False + ) + if not dynamic_dims: + # shape[0] is constant 1, so can be folded to constant 1 + assert _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + else: + # shape[0] is dynamic, so cannot be folded to constant 1 + assert not _gencode_contains(tempdir, NonConstModule, 0, r'torch.add\(.*, 1, alpha=1\)') + + class CloneModule(torch.nn.Module): def __init__(self): super().__init__() @@ -733,6 +768,7 @@ def __init__(self): def forward(self, a, b): return torch.min(a, b) + def _gencode_min_function_worker(tempdir): init_distributed() m_new = parallelize( @@ -1657,7 +1693,7 @@ def check_op(*names): for name in names: code = add_codes.pop(0) if name in not_folded_names: - assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, getitem_.*, alpha=1\)', code) + assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, get.*, alpha=1\)', code) else: assert re.match(r'\s*add_.* = torch\.add\((linear|add)_.*, 2, alpha=1\)', code) @@ -1726,7 +1762,7 @@ def test_fold_constant(tmp_path, fold_input): else: # add_27 = torch.add(linear_30, getitem_20, alpha=1) assert _gencode_contains(tmp_path, CCFModule2, 0, - r'add_.* = torch\.add\(linear_.*, getitem_.*, alpha=1\)') + r'add_.* = torch\.add\(linear_.*, get.*, alpha=1\)') # b = b * ashape3 # mul_2_51 = torch.mul(mul_1_57, add_38) assert _gencode_contains(tmp_path, CCFModule2, 0, @@ -1963,3 +1999,303 @@ def test_codegen_forward_error_compile(tmp_path): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') def test_codegen_forward_error(tmp_path): launch_torchrun(2, _gencode_forward_error_worker, tmp_path) + + +class WeightModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, input): + input = input + self.weights + out = input @ self.weights + return out + + +class WeightModel2(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = WeightModel() + + def forward(self, input): + return self.weights(input) + + +def pas_weight(graph, cfg, with_auto_multiref=True): + from nnscaler.ir import IRFwOperation, IRDataOperation + from nnscaler.policies import _tp, _replica, auto_multiref + ngpus = cfg.plan_ngpus + if with_auto_multiref: + auto_multiref(graph) + for node in graph.select(ntype=(IRFwOperation, IRDataOperation)): + if node.name == 'add': + _tp(graph, node, list(range(ngpus)), 1, 0) + elif node.name == 'matmul': + _tp(graph, node, list(range(ngpus)), 1, 0) + else: + _replica(graph, node, list(range(ngpus))) + return graph + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('with_auto_multiref', [True, False]) +def test_weight_partition(tmp_path, with_auto_multiref): + """ + If auto_multiref is not applied, the weight will correctly partitioned + If auto_multiref is applied, the weight will be replicated as a whole + """ + input = torch.randn((4, 4)) + instance_name = f'with_auto_multiref_{with_auto_multiref}' + + dummy_input = {'input': input} + + m = WeightModel2() + m.train() + + parallelize( + m, + dummy_input, + partial(pas_weight, with_auto_multiref=with_auto_multiref), + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + + module_class = _load_parallel_module_class(WeightModel2, gen_savedir=tmp_path, instance_name=instance_name, rank=0) + + if with_auto_multiref: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (4, 4) + else: + for rank in range(2): + fullmap = module_class.attr_meta_maps[rank] + assert fullmap[list(fullmap.keys())[0]].sub_shape == (2, 4) + +class DynamicInputModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weights = torch.nn.Parameter(torch.randn(1, 1)) + + def forward(self, input): + return input + self.weights + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('dynamic_dims', [[], [0, 1]]) +def test_dynamic_dim_partition(tmp_path, dynamic_dims): + input = mark_dynamic(torch.randn((4, 4)), dynamic_dims) + dummy_input = {'input': input} + instance_name=f'{"no" if not dynamic_dims else ""}_dynamic_dims' + + m = DynamicInputModel() + m.train() + + parallelize( + m, + dummy_input, + 'tp', + ComputeConfig(2, 2), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + instance_name=instance_name, + ) + if dynamic_dims: + # no partition for dynamic input + assert not _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + else: + assert _gencode_contains(tmp_path, DynamicInputModel, 0, r'nnscaler.runtime.adapter.nn.split_allgather', instance_name=instance_name) + + +@replace_all_device_with('cpu') +def test_zero3_normal(tmp_path): + from tests.parallel_module.test_end2end import MLP + m = MLP(2, 2) + dummy_input = { + 'data': torch.randn( + 2, 2), + 'target': torch.rand( + 2, 2) + } + m.train() + parallelize( + m, + {'data': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + # code looks like: + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.prefetch_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.postevict_param\(self\.layers_0_weight_\d+\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_postevict_param\(.*, self\.layers_0_weight_\d+, 1\)') + assert _gencode_contains(tmp_path, MLP, 0, + r'self\.backward_prefetch_param\(.*, self\.layers_0_weight_\d+, 1\)') + + # def segment35_impl(self, data_23): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 46, in forward, x = data['data'] + # getitem_25 = _operator.getitem(data_23, 'data') + # self.prefetch_param(self.layers_0_weight_26) + # getitem_25 = self.backward_postevict_param(getitem_25, self.layers_0_weight_26, 1) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_27 = torch.nn.functional.linear(getitem_25, self.layers_0_weight_26, bias=None) + # self.postevict_param(self.layers_0_weight_26) + # linear_27 = self.backward_prefetch_param(linear_27, self.layers_0_weight_26, 1) + # del getitem_25 + # self.prefetch_param(self.layers_1_weight_28) + # linear_27 = self.backward_postevict_param(linear_27, self.layers_1_weight_28, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 48, in forward, x = layer(x) + # linear_1_29 = torch.nn.functional.linear(linear_27, self.layers_1_weight_28, bias=None) + # self.postevict_param(self.layers_1_weight_28) + # linear_1_29 = self.backward_prefetch_param(linear_1_29, self.layers_1_weight_28, 2) + # del linear_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 49, in forward, x = torch.sigmoid(x) + # sigmoid_30 = torch.sigmoid(linear_1_29) + # del linear_1_29 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # getitem_1_31 = _operator.getitem(data_23, 'target') + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_end2end.py", line 50, in forward, loss = self.loss_fn(x, data['target']) + # binary_cross_entropy_24 = torch.nn.functional.binary_cross_entropy(sigmoid_30, getitem_1_31, weight=None, reduction='mean') + # del sigmoid_30, getitem_1_31 + # return binary_cross_entropy_24 + + # def segment35(self, data_23): + # with self.save_params_hooks(): + # return self.segment35_impl(data_23) + + +class SoloOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + self.p = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x): + x = self.scale.sum() + x + self.p + return torch.sum(x) + + +def launch_zero3_run_solo_param(tmp_path): + init_distributed() + m = SoloOpModule() + dummy_input = torch.randn(4, 4) + m.train() + m_new = parallelize( + m, + {'x': dummy_input}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=True, + reuse='override', + ) + loss = m_new(dummy_input) + loss.backward() + # scale can't be evicited with backward hook + assert len(m_new._backward_prefetched_params) == 1 + # but it should have been evicted in reducer. + assert list(m_new._backward_prefetched_params.keys())[0].shape == (8,) + assert not _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.scale_\d+, \d+\)') + assert _gencode_contains(tmp_path, SoloOpModule, 0, + r'self\.backward_postevict_param\(.*, self\.p_\d+, \d+\)') + # code looks like: + # def segment32_impl(self, x_17): + # self.prefetch_param(self.scale_19) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # sum_1_20 = torch.sum(self.scale_19) + # self.postevict_param(self.scale_19) + # sum_1_20 = self.backward_prefetch_param(sum_1_20, self.scale_19, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_21 = torch.add(sum_1_20, x_17, alpha=1) + # del x_17, sum_1_20 + # self.prefetch_param(self.p_22) + # add_21 = self.backward_postevict_param(add_21, self.p_22, 2) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2186, in forward, x = self.scale.sum() + x + self.p + # add_1_23 = torch.add(add_21, self.p_22, alpha=1) + # self.postevict_param(self.p_22) + # add_1_23 = self.backward_prefetch_param(add_1_23, self.p_22, 2) + # del add_21 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2187, in forward, return torch.sum(x) + # sum_2_18 = torch.sum(add_1_23) + # del add_1_23 + # return sum_2_18 + + # def segment32(self, x_17): + # with self.save_params_hooks(): + # return self.segment32_impl(x_17) + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of GPU devices') +def test_zero3_run_solo_param(tmp_path): + launch_torchrun(2, launch_zero3_run_solo_param, tmp_path) + + +@nnscaler.register_op('*, *, *, * -> *, *, *, *') +def _zero3_multi_inout(x, y, z, w): + return x + 1, y + 1, z + 1, w + 1 + + +class Zero3MultiInoutModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.p = torch.nn.Parameter(torch.randn(4, 4)) + self.q = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, y): + return _zero3_multi_inout(x, y, self.p, self.q) + + +@replace_all_device_with('cpu') +def test_zero3_multi_inout(tmp_path): + m = Zero3MultiInoutModule() + m.train() + m_new = parallelize( + m, + {'x': torch.randn(4, 4), 'y': torch.randn(4, 4)}, + 'dp', + ComputeConfig(1, 2, use_zero=3), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_prefetch_param')) == 8 + assert len(_gencode_contains(tmp_path, Zero3MultiInoutModule, 0, + 'self.backward_postevict_param')) == 4 + # code looks like: + # def segment34_impl(self, x_25, y_26): + # self.prefetch_param(self.p_31) + # x_25 = self.backward_postevict_param(x_25, self.p_31, 0) + # y_26 = self.backward_postevict_param(y_26, self.p_31, 0) + # self.prefetch_param(self.q_32) + # x_25 = self.backward_postevict_param(x_25, self.q_32, 0) + # y_26 = self.backward_postevict_param(y_26, self.q_32, 0) + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 2259, in forward, return _zero3_multi_inout(x, y, self.p, self.q) + # _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 = tests.parallel_module.test_gencode._zero3_multi_inout(x_25, y_26, self.p_31, self.q_32) + # self.postevict_param(self.p_31) + # self.postevict_param(self.q_32) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.p_31, 0) + # _zero3_multi_inout_27 = self.backward_prefetch_param(_zero3_multi_inout_27, self.q_32, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.p_31, 0) + # _zero3_multi_inout_28 = self.backward_prefetch_param(_zero3_multi_inout_28, self.q_32, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.p_31, 0) + # _zero3_multi_inout_29 = self.backward_prefetch_param(_zero3_multi_inout_29, self.q_32, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.p_31, 0) + # _zero3_multi_inout_30 = self.backward_prefetch_param(_zero3_multi_inout_30, self.q_32, 0) + # del x_25, y_26 + # return _zero3_multi_inout_27, _zero3_multi_inout_28, _zero3_multi_inout_29, _zero3_multi_inout_30 + + # def segment34(self, x_25, y_26): + # with self.save_params_hooks(): + # return self.segment34_impl(x_25, y_26) + assert True diff --git a/tests/parallel_module/test_gencode_ctx_manager.py b/tests/parallel_module/test_gencode_ctx_manager.py index eba20786..fc4f07c1 100644 --- a/tests/parallel_module/test_gencode_ctx_manager.py +++ b/tests/parallel_module/test_gencode_ctx_manager.py @@ -12,7 +12,7 @@ from .common import init_distributed, init_random from .test_end2end import merge_cube_result from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class CtxManagerModel(torch.nn.Module): @@ -77,7 +77,7 @@ def check_ctx_manager_codegen(tempdir): # use_scheduler = False # nmicros_per_scheduler_step = 1 # rank = 0 - + # def __init__(self, init_params=True, *, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False): # super().__init__() # # communication groups @@ -85,17 +85,17 @@ def check_ctx_manager_codegen(tempdir): # self.init_group(ranks=[1, 3]) # self.init_group(ranks=[0, 1]) # self.init_group(ranks=[2, 3]) - + # self.register_parameter('param_1_62', torch.nn.Parameter(torch.empty((16, 16), dtype=torch.float32))) # self.add_full_map('param_1_62', 5, True, 'param_1', (16, 16), (slice(0, 16, None), slice(0, 16, None)), 1) - - + + # self.wreducer312 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) # self.wreducer312.add_param(self.param_1_62) # self.add_reducer(self.wreducer312) - + # self._post_init(init_params) - + # def segment308(self, x_75, y_78): # # auto_multiref # param_1_106, param_1_107, param_1_108, param_1_109, param_1_110 = nnscaler.runtime.function.multiref(self.param_1_62, times=5) @@ -117,12 +117,12 @@ def check_ctx_manager_codegen(tempdir): # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_1_202, matmul_1_228 = nnscaler.runtime.function.multiref(matmul_1_182, times=2) # del matmul_1_182 - + # with torch.no_grad(): # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 24, in forward, r_3 = torch.matmul(r_1, self.param_1) # matmul_2_196 = torch.matmul(matmul_194, param_1_106) # del param_1_106, matmul_194 - + # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_2_216, matmul_2_242 = nnscaler.runtime.function.multiref(matmul_2_196, times=2) # del matmul_2_196 @@ -133,12 +133,12 @@ def check_ctx_manager_codegen(tempdir): # # create at IRAdapterGener:autoref, comment before transformation: auto_multiref # matmul_3_252, matmul_3_218 = nnscaler.runtime.function.multiref(matmul_3_204, times=2) # del matmul_3_204 - + # with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True, cache_enabled=True): # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 28, in forward, r_5 = r_3 * r_4 # mul_220 = torch.mul(matmul_2_216, matmul_3_218) # del matmul_2_216, matmul_3_218 - + # # File "/scratch/nishang/MagicCube/tests/parallel_module/test_gencode_ctx_manager.py", line 29, in forward, r = r_1 * r_2 * r_3 * r_4 * r_5 # mul_1_230 = torch.mul(matmul_226, matmul_1_228) # del matmul_226, matmul_1_228 @@ -161,11 +161,11 @@ def check_ctx_manager_codegen(tempdir): # norm_61 = torch.norm(matmul_4_72, p='fro', dim=None, keepdim=False, out=None, dtype=None) # del matmul_4_72 # return norm_61 - + # def reducer312(self): # self.wreducer312.sync_grads() - # return - + # return + # def _forward_impl(self, x, y): # norm_61 = self.segment308(x, y) # return norm_61 @@ -257,7 +257,7 @@ def _train_cube_one_sample(model: ParallelModule, mbs): def gpu_worker_cube_one_sample(): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_ctx_manager') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_ctx_manager_{PYTEST_RUN_ID}') as tempdir: init_random() model = CtxManagerModel() model = parallelize( @@ -301,12 +301,12 @@ def _train_ga(model, update_freq, data_size): @pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') def test_loss_scaling(): torch.cuda.set_device(0) - torch.set_default_device(f'cuda:0') init_random() - model = CtxManagerModel() - ga4_result = _train_ga(model, 1, 1) - assert len(ga4_result) == 1 - ga4_grads = ga4_result[0][0] + with torch.device('cuda:0'): + model = CtxManagerModel() + ga4_result = _train_ga(model, 1, 1) + assert len(ga4_result) == 1 + ga4_grads = ga4_result[0][0] cube2_results = launch_torchrun(2, gpu_worker_cube_one_sample) cube2_result = merge_cube_result({k: v for k, v in cube2_results.items()}) diff --git a/tests/parallel_module/test_gencode_einops.py b/tests/parallel_module/test_gencode_einops.py new file mode 100644 index 00000000..2a9d2aa3 --- /dev/null +++ b/tests/parallel_module/test_gencode_einops.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +import functools +from einops import rearrange +import torch +import pytest + +from nnscaler import parallelize, ComputeConfig +from nnscaler.graph import parser +from nnscaler.graph.tracer import ConcreteTracer + +from tests.utils import replace_all_device_with +from .test_gencode import _gencode_contains, print_gencode + + +class RearrangeModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(3, 3) + + def forward(self, x, y): + return self.linear(x) + rearrange(y, '(h w) -> h w', h=3, w=3) + f(3) + + +def log_f(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + print(f"Function '{func.__name__}' called") + return func(*args, **kwargs) + return wrapper + + +@log_f +@functools.cache +def f(x: int) -> int: + return x * 2 + + +@replace_all_device_with('cpu') +def test_trace_rearrange(): + import gc + def _convert(): + model = RearrangeModule() + parser.to_fx_graph(model, {'x': torch.randn(3, 3), 'y': torch.randn(9)}) + gc.collect() + + _convert() + for obj in gc.get_objects(): + # einops is using functools.cache + # will leak memory if not properly handle it. + assert not isinstance(obj, ConcreteTracer) + + +@replace_all_device_with('cpu') +def test_codegen_rearrange(): + with tempfile.TemporaryDirectory() as tempdir: + parallelize( + RearrangeModule(), + {'x': torch.randn(3, 3), 'y': torch.randn(9)}, + 'dp', + ComputeConfig(1, 1), + gen_savedir=tempdir, + load_module=False + ) + # parallelize will succeed. + assert True + + +class RearrangeModule2(torch.nn.Module): + def __init__(self, shape: tuple[int, ...]): + super().__init__() + self.shape = shape + self.weight = torch.nn.Parameter(torch.ones(self.shape)) + + def forward(self, x): + bsz = x.size(0) + x = rearrange(x, 'n l d -> (n l) d', n=bsz) + return x + self.weight + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize("constant_folding", [True, False]) +def test_rearrange2(tmp_path, constant_folding): + parallelize( + RearrangeModule2(4), + {'x': torch.randn(4, 4, 4)}, + 'dp', + ComputeConfig(1, 1, constant_folding=constant_folding), + gen_savedir=tmp_path, + load_module=False + ) + # parallelize will succeed. + assert True + + # code will look like this when constant_folding=True + # def segment22(self, x_25): + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 93, in reshape, return x.reshape(shape) + # reshape_27 = torch.Tensor.reshape(x_25, shape=(16, 4)) + # del x_25 + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 80, in forward, return x + self.weight + # add_26 = torch.add(reshape_27, self.weight_28, alpha=1) + # del reshape_27 + # return add_26 + + + # code will look like this when constant_folding=False + # def segment25(self, x_28): + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 78, in forward, bsz = x.size(0) + # size_21 = torch.Tensor.size(x_28, dim=0) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 90, in shape, return x.shape + # im_output_36 = builtins.getattr(x_28, 'shape') + # getattr_3_22 = im_output_36[0] + # getattr_3_23 = im_output_36[1] + # getattr_3_24 = im_output_36[2] + # del im_output_36 + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # mul_1_25 = _operator.mul(1, size_21) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # imul_26 = _operator.imul(mul_1_25, getattr_3_23) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/einops.py", line 33, in _product, result *= element + # mul_2_27 = _operator.mul(1, getattr_3_24) + # # File "/data/weijiangxu/uvenv/kosmos3.10/lib/python3.10/site-packages/einops/_backends.py", line 93, in reshape, return x.reshape(shape) + # reshape_30 = torch.Tensor.reshape(x_28, shape=(imul_26, mul_2_27)) + # del x_28 + # # File "/data/weijiangxu/nnscaler/tests/parallel_module/test_gencode_einops.py", line 80, in forward, return x + self.weight + # add_29 = torch.add(reshape_30, self.weight_31, alpha=1) + # del reshape_30 + # return add_29 + diff --git a/tests/parallel_module/test_gencode_kwargs.py b/tests/parallel_module/test_gencode_kwargs.py new file mode 100644 index 00000000..85a1b3c5 --- /dev/null +++ b/tests/parallel_module/test_gencode_kwargs.py @@ -0,0 +1,189 @@ +import nnscaler +from nnscaler import parallelize, ComputeConfig + +import torch + +from .test_gencode import _gencode_contains, replace_all_device_with, print_gencode + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator') +def kw_operator(x: torch.Tensor, y: torch.Tensor, **kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='kw_operator2') +def kw_operator2(x: torch.Tensor, y: torch.Tensor, kwargs) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class KwargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, **kwargs): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return kw_operator(x, self.scale, **kwargs) \ + + kw_operator2(x, self.scale, kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_kwargs(tmp_path): + m = KwargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'a': 3, 'c': 4, 'd': 5}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'a\'\)') + assert not _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'b\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'c\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, KwargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.kw_operator2\(x_\d+, self.scale_\d+, kwargs=\{'a': getitem_[\d_]+, 'c': getitem_[\d_]+, 'd': getitem_[\d_]+\}\)") + # code looks like: + # def segment49(self, x_31, **kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_28 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_1_29 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2338, in test_kwargs, parallelize( + # getitem_2_30 = _operator.getitem(kwargs_6, 'd') + # # created at IRAdapterGener:local_consumer_multiref + # x_52, x_56 = nnscaler.runtime.function.multiref(x_31, times=2) + # del x_31 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator_34 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_52, self.scale_33, a=getitem_28, c=getitem_1_29, d=getitem_2_30) + # del x_52 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # kw_operator2_35 = tests.parallel_module.test_gencode_kwargs.kw_operator2(x_56, self.scale_33, kwargs={'a': getitem_28, 'c': getitem_1_29, 'd': getitem_2_30}) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_36 = torch.add(kw_operator_34, kw_operator2_35, alpha=1) + # del kw_operator_34, kw_operator2_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_1_37 = torch.add(add_36, getitem_28, alpha=1) + # del add_36 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_2_38 = torch.add(add_1_37, 2, alpha=1) + # del add_1_37 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2330, in forward, return kw_operator(x, self.scale, **kwargs) \ + # add_3_32 = torch.add(add_2_38, getitem_1_29, alpha=1) + # del add_2_38 + # return add_3_32 + + # def _forward_impl(self, x, **kwargs): + # add_3_32 = self.segment49(x, **kwargs) + # return add_3_32 + assert True + + +# note: annotation is wrong. test only +@nnscaler.register_op('*, * -> 1', name='dict_operator') +def dict_operator(x: torch.Tensor, y: torch.Tensor, kwargs: dict) -> torch.Tensor: + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + x = (x * a + b) @ y + return torch.sum(x) + + +class DictargsModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.scale = torch.nn.Parameter(torch.randn(4, 4)) + + def forward(self, x, kwargs: dict): + a = kwargs.get('a', 1) + b = kwargs.get('b', 2) + c = kwargs['c'] + return dict_operator(x, self.scale, kwargs) \ + + kw_operator(x, self.scale, **kwargs) + a + b + c + + +@replace_all_device_with('cpu') +def test_dictargs(tmp_path): + m = DictargsModule() + m.train() + parallelize( + m, + {'x': torch.randn(4, 4), 'kwargs': {'a': 3, 'c': 4, 'd': 5}}, + 'dp', + ComputeConfig(1, 1, constant_folding=False), + gen_savedir=tmp_path, + load_module=False, + reuse='override', + ) + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'a', 1\)") + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"builtins.dict.get\(kwargs_.*, 'b', 2\)") + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'a'\)")) == 1 + assert len(_gencode_contains(tmp_path, DictargsModule, 0, + r"_operator\.getitem\(kwargs_.*, 'c'\)")) == 2 + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'_operator\.getitem\(kwargs_.*, \'d\'\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r'tests.parallel_module.test_gencode_kwargs.kw_operator\(x_\d+, self.scale_\d+, a=getitem_[\d_]+, c=getitem_[\d_]+, d=getitem_[\d_]+\)') + assert _gencode_contains(tmp_path, DictargsModule, 0, + r"tests.parallel_module.test_gencode_kwargs.dict_operator\(x_\d+, self.scale_\d+, kwargs=kwargs_\d+\)") + # code looks like: + # def segment52(self, x_35, kwargs_6): + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2411, in forward, a = kwargs.get('a', 1) + # get_7 = builtins.dict.get(kwargs_6, 'a', 1) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2412, in forward, b = kwargs.get('b', 2) + # get_1_8 = builtins.dict.get(kwargs_6, 'b', 2) + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2413, in forward, c = kwargs['c'] + # getitem_31 = _operator.getitem(kwargs_6, 'c') + # # created at IRAdapterGener:local_consumer_multiref + # x_56, x_60 = nnscaler.runtime.function.multiref(x_35, times=2) + # del x_35 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # dict_operator_38 = tests.parallel_module.test_gencode_kwargs.dict_operator(x_56, self.scale_37, kwargs=kwargs_6) + # del x_56 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_1_32 = _operator.getitem(kwargs_6, 'a') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_2_33 = _operator.getitem(kwargs_6, 'c') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # getitem_3_34 = _operator.getitem(kwargs_6, 'd') + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # kw_operator_39 = tests.parallel_module.test_gencode_kwargs.kw_operator(x_60, self.scale_37, a=getitem_1_32, c=getitem_2_33, d=getitem_3_34) + # del x_60 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_40 = torch.add(dict_operator_38, kw_operator_39, alpha=1) + # del dict_operator_38, kw_operator_39 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_1_41 = torch.add(add_40, get_7, alpha=1) + # del add_40 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_2_42 = torch.add(add_1_41, get_1_8, alpha=1) + # del add_1_41 + # # File "/data/weijiangxu/MagicCube/tests/parallel_module/test_gencode_kwargs.py", line 2414, in forward, return dict_operator(x, self.scale, kwargs) \ + # add_3_36 = torch.add(add_2_42, getitem_31, alpha=1) + # del add_2_42 + # return add_3_36 + + # def _forward_impl(self, x, kwargs): + # add_3_36 = self.segment52(x, kwargs) + # return add_3_36 diff --git a/tests/parallel_module/test_gencode_torch_compile.py b/tests/parallel_module/test_gencode_torch_compile.py index 62ba534c..4865783c 100644 --- a/tests/parallel_module/test_gencode_torch_compile.py +++ b/tests/parallel_module/test_gencode_torch_compile.py @@ -9,7 +9,7 @@ from nnscaler import parallelize, ComputeConfig, register_op -from tests.utils import replace_all_device_with +from tests.utils import raises_with_cause, replace_all_device_with from .test_gencode import _gencode_contains, print_gencode @@ -182,7 +182,7 @@ def forward(self, x): @pytest.mark.skipif(not torch.cuda.is_available(), reason='lack of gpu devices') def test_codegen_compile_failed_g(): - with pytest.raises(RuntimeError), tempfile.TemporaryDirectory() as tempdir: + with raises_with_cause(RuntimeError, match=".*You must register it to avoid tracing failure..*"), tempfile.TemporaryDirectory() as tempdir: parallelize( Module2(), {'x': torch.randn(3, 3)}, diff --git a/tests/parallel_module/test_inference.py b/tests/parallel_module/test_inference.py index 6b15d8a1..ce9adbdd 100644 --- a/tests/parallel_module/test_inference.py +++ b/tests/parallel_module/test_inference.py @@ -13,7 +13,7 @@ from .common import CubeLinear, init_distributed, init_random from ..launch_torchrun import torchrun -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -62,7 +62,7 @@ def _inference_worker(ngpus, inference_only): init_distributed() init_random() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_inference_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_inference_test_{PYTEST_RUN_ID}') as tempdir: model = Module() model.eval() diff --git a/tests/parallel_module/test_init.py b/tests/parallel_module/test_init.py index 3d046393..3278cc0f 100644 --- a/tests/parallel_module/test_init.py +++ b/tests/parallel_module/test_init.py @@ -72,13 +72,13 @@ def test_empty_weights(model_class, tp): model_class, {'x': torch.tensor([[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0]])}, 'tp', - ComputeConfig(2, 4, use_zero=True, zero_ngroups=2), + ComputeConfig(2, 8, use_zero=True, zero_ngroups=2), gen_savedir=tempdir, reuse='match', load_module=False, instance_name=instance_name, ) - for i in range(4): + for i in range(8): module_class = _load_parallel_module_class(model_class, gen_savedir=tempdir, instance_name=instance_name, rank=i) m = new_empty(module_class) assert m.rank == i @@ -86,9 +86,9 @@ def test_empty_weights(model_class, tp): assert p.device == torch.device('meta') for r in m.reducers: if tp: - assert r.ranks == ((0, 2) if i in (0, 2) else (1, 3)) + assert r.ranks == ((0, 2, 4, 6) if i in (0, 2, 4, 6) else (1, 3, 5, 7)) else: - assert r.ranks == (0, 1, 2, 3) + assert r.ranks == (0, 1, 2, 3, 4, 5, 6, 7) assert len(r.buckets) == 1 assert r.zero assert r.zero_ngroups == 2 diff --git a/tests/parallel_module/test_line_timer.py b/tests/parallel_module/test_line_timer.py index 90483848..28948a38 100644 --- a/tests/parallel_module/test_line_timer.py +++ b/tests/parallel_module/test_line_timer.py @@ -13,7 +13,7 @@ from .common import init_distributed from ..launch_torchrun import launch_torchrun -from ..utils import catch_stdout, clear_dir_on_rank0 +from ..utils import catch_stdout, clear_dir_on_rank0, PYTEST_RUN_ID class Net(torch.nn.Module): @@ -43,7 +43,7 @@ def _gpu_worker(): compute_config = ComputeConfig(1, 1, use_zero=False) try: CompileFlag.line_timer = True - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_line_timer') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_line_timer_{PYTEST_RUN_ID}') as tempdir: net = _to_cube_model(Net(), compute_config, tempdir, 'net', (128, 64)) x = torch.randn(128, 64).cuda() diff --git a/tests/parallel_module/test_offload_params.py b/tests/parallel_module/test_offload_params.py new file mode 100644 index 00000000..2e96a123 --- /dev/null +++ b/tests/parallel_module/test_offload_params.py @@ -0,0 +1,213 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import tempfile +from pathlib import Path +import pytest +from typing import Dict, Tuple, List, Any + +import torch +from torch import nn + +from nnscaler.parallel import ComputeConfig, parallelize, build_optimizer +from nnscaler.graph import IRGraph + +from .common import PASMegatron, CubeLinear, init_random, init_distributed, assert_equal +from ..launch_torchrun import launch_torchrun +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID + + +class SimpleMLP(nn.Module): + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super(SimpleMLP, self).__init__() + init_random() + self.register_buffer('buffer', torch.zeros(hidden_dim,)) + self.fc1 = nn.Linear(input_dim, hidden_dim, bias=False) + self.fc2 = nn.Linear(hidden_dim, output_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc1(x) + x = x + self.buffer + x = torch.relu(x) + x = self.fc2(x) + return x + + +def get_tensor_bytesize(t: torch.Tensor) -> int: + return t.numel() * t.element_size() + + +def pas_test_offload(graph: IRGraph, cfg: ComputeConfig): + ngpus = cfg.plan_ngpus + auto_multiref(graph) + + batch_dim = 0 + for dl in graph.select(ntype=IRDataOperation): + _replica(graph, dl, list(range(ngpus))) + + found_linear = False + for node in graph.nodes(): + if isinstance(node, IRFwOperation): + if 'linear' in node.signature and not found_linear: + found_linear = True + algo = node.algorithm('dim') + sub_nodes = graph.partition( + node, algo, idx=1, dim=1, num=ngpus) + else: + sub_nodes = graph.replicate(node, ngpus) + + for idx, node in enumerate(sub_nodes): + graph.assign(node, idx) + return graph + + +def _mem_worker(): + init_distributed() + bsz, dim = 32, 1024 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_offload_mem_{PYTEST_RUN_ID}') as tempdir: + module = SimpleMLP(dim, dim, dim) + p_module = parallelize( + module, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + ) + + before_mem = torch.cuda.memory_allocated() + size_to_free = 0 + for reducer in p_module.reducers: + assert get_tensor_bytesize(reducer._contiguous_params) == get_tensor_bytesize(reducer._contiguous_grads) + size_to_free += get_tensor_bytesize(reducer._contiguous_params) + + for buffer in p_module.buffers(): + size_to_free += get_tensor_bytesize(buffer) + + for param in p_module.parameters(): + size_to_free += get_tensor_bytesize(param) + + p_module.sleep() + torch.distributed.barrier() + after_mem = torch.cuda.memory_allocated() + print(f"Memory before offload: {before_mem}, after offload: {after_mem}, freed: {before_mem - after_mem}") + print(f"Total size to free: {size_to_free}") + + assert size_to_free == before_mem - after_mem, f"Expected {size_to_free}, but got {before_mem - after_mem}" + + +def _correctness_worker(): + init_distributed() + bsz, dim, num_steps = 32, 1024, 5 + compute_config = ComputeConfig( + plan_ngpus=1, + runtime_ngpus=2, + ) + + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'nnscaler_test_offload_correctness_{PYTEST_RUN_ID}') as tempdir: + # Create test data + torch.manual_seed(42 + torch.distributed.get_rank()) + test_data = [torch.randn(bsz, dim).cuda() for _ in range(num_steps)] + + # Test 1: Normal execution without offload/load + init_random() + module1 = SimpleMLP(dim, dim, dim) + p_module1 = parallelize( + module1, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='normal' + ) + optimizer1 = build_optimizer(p_module1, torch.optim.Adam, lr=0.01) + + results_normal = [] + for step, x in enumerate(test_data): + p_module1.train() + output = p_module1(x) + loss = output.sum() + loss.backward() + optimizer1.step() + optimizer1.zero_grad() + + # Save intermediate results for comparison + results_normal.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module1.named_parameters()} + }) + + torch.distributed.barrier() + + # Test 2: Execution with offload/load + init_random() + module2 = SimpleMLP(dim, dim, dim) + p_module2 = parallelize( + module2, + {'x': torch.randn(bsz, dim)}, + 'dp', + compute_config, + gen_savedir=tempdir, + instance_name='offload' + ) + optimizer2 = build_optimizer(p_module2, torch.optim.Adam, lr=0.01) + + # First offload to initialize the buffer_shape + p_module2.sleep() + + results_offload = [] + for step, x in enumerate(test_data): + # Load params at the beginning of each step + p_module2.wake_up() + + p_module2.train() + output = p_module2(x) + loss = output.sum() + loss.backward() + optimizer2.step() + optimizer2.zero_grad() + + # Save intermediate results for comparison + results_offload.append({ + 'loss': loss.detach().cpu(), + 'output': output.detach().cpu(), + 'params': {name: param.detach().cpu().clone() for name, param in p_module2.named_parameters()} + }) + + # Offload params at the end of each step + p_module2.sleep() + + torch.distributed.barrier() + + # Compare results + for step in range(num_steps): + normal_result = results_normal[step] + offload_result = results_offload[step] + + # Compare loss + assert torch.equal(normal_result['loss'], offload_result['loss']), \ + f"Loss mismatch at step {step}: {normal_result['loss']} vs {offload_result['loss']}" + + # Compare output + assert torch.equal(normal_result['output'], offload_result['output']), \ + f"Output mismatch at step {step}" + + # Compare parameters + for param_name in normal_result['params']: + assert torch.equal(normal_result['params'][param_name], + offload_result['params'][param_name]), \ + f"Parameter {param_name} mismatch at step {step}" + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_mem(): + launch_torchrun(2, _mem_worker) + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_offload_params_correctness(): + launch_torchrun(2, _correctness_worker) diff --git a/tests/parallel_module/test_reducer_hook.py b/tests/parallel_module/test_reducer_hook.py index 851dfad0..eb6807ae 100644 --- a/tests/parallel_module/test_reducer_hook.py +++ b/tests/parallel_module/test_reducer_hook.py @@ -14,7 +14,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -118,7 +118,7 @@ def post_hook(reducer, grad): def _gpu_worker(pas, plan_ngpus, runtime_ngpus=None): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_hook') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_hook_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus or plan_ngpus), tempdir) _train(compiled_module) diff --git a/tests/parallel_module/test_scale_grads.py b/tests/parallel_module/test_scale_grads.py index 255869cf..ce7b9b45 100644 --- a/tests/parallel_module/test_scale_grads.py +++ b/tests/parallel_module/test_scale_grads.py @@ -23,7 +23,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -130,7 +130,7 @@ def _train(model: torch.nn.Module, num_replicas, rank, scale_grads: bool): def _gpu_worker(pas, plan_ngpus, runtime_ngpus, scale_grads: bool): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test_scale_grads') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_scale_grads_{PYTEST_RUN_ID}') as tempdir: compiled_module = _create_cube_module(pas, ComputeConfig(plan_ngpus, runtime_ngpus, use_zero=True), tempdir diff --git a/tests/parallel_module/test_shared_param_pipeline.py b/tests/parallel_module/test_shared_param_pipeline.py index 86dc4380..3d6159e7 100644 --- a/tests/parallel_module/test_shared_param_pipeline.py +++ b/tests/parallel_module/test_shared_param_pipeline.py @@ -19,7 +19,7 @@ from nnscaler.ir.operator import IRFwOperation, IRDataOperation from nnscaler.graph.segment import IRSegment from nnscaler.graph.schedule.predefined import PredefinedSched -from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause +from tests.utils import clear_dir_on_rank0, init_random, raises_with_cause, PYTEST_RUN_ID from tests.launch_torchrun import torchrun from tests.parallel_module.test_gencode import _gencode_contains, print_gencode @@ -264,7 +264,7 @@ def worker_pipeline(model_cls, pas, plan_ngpus, checker): torch.cuda.manual_seed(0) trace_data = torch.randn([2, 16], dtype=torch.float32, device=torch.cuda.current_device()) - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'test_detach_loss_pp_2') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'test_detach_loss_pp_2_{PYTEST_RUN_ID}') as tempdir: pm = parallelize( m, {'x': trace_data}, diff --git a/tests/parallel_module/test_submodule.py b/tests/parallel_module/test_submodule.py index 047f9f5b..628071ea 100644 --- a/tests/parallel_module/test_submodule.py +++ b/tests/parallel_module/test_submodule.py @@ -17,7 +17,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -118,7 +118,7 @@ def _train(model, update_freq, is_cube): def _gpu_worker(pas, ngpus, update_freq): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) orig_results = _train(orig_module, update_freq, False) compiled_results = _train(compiled_module, update_freq, True) diff --git a/tests/parallel_module/test_wholemodule.py b/tests/parallel_module/test_wholemodule.py index 64ec5ff0..3d385d19 100644 --- a/tests/parallel_module/test_wholemodule.py +++ b/tests/parallel_module/test_wholemodule.py @@ -17,7 +17,7 @@ from .common import CubeLinear, init_random, init_distributed from ..launch_torchrun import launch_torchrun, clone_to_cpu_recursively -from ..utils import clear_dir_on_rank0 +from ..utils import clear_dir_on_rank0, PYTEST_RUN_ID class FcRelu(nn.Module): @@ -106,7 +106,7 @@ def _train(model, is_cube): def _gpu_worker(pas, ngpus): init_distributed() - with clear_dir_on_rank0(Path(tempfile.gettempdir()) / 'cube_test') as tempdir: + with clear_dir_on_rank0(Path(tempfile.gettempdir()) / f'cube_test_{PYTEST_RUN_ID}') as tempdir: orig_module, compiled_module = _create_modules(pas, ComputeConfig(ngpus, ngpus), tempdir) orig_results = _train(orig_module, False) compiled_results = _train(compiled_module, True) diff --git a/tests/runtime/test_gnorm.py b/tests/runtime/test_gnorm.py index 80025906..1de3ee35 100644 --- a/tests/runtime/test_gnorm.py +++ b/tests/runtime/test_gnorm.py @@ -57,7 +57,7 @@ def cal_wnorm_cube(model: CubeModule): for p in model.parameters_for_optimizer(): p.grad = p.data # p.grad.copy_(p.data) - nreplicas2localparams = prepare_for_grad_clip(model, is_zero=CompileFlag.use_zero) + nreplicas2localparams = prepare_for_grad_clip(model, use_zero=CompileFlag.use_zero) wnorm, _ = clip_gnorm(nreplicas2localparams, None) # maps = {tid: [t.size() for t in ts] for tid, ts in nreplicas2localparams.items()} # print(f'cube nrepicas len: {maps}') diff --git a/tests/runtime/test_hybrid_optimizer.py b/tests/runtime/test_hybrid_optimizer.py new file mode 100644 index 00000000..72494837 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer.py @@ -0,0 +1,283 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from pathlib import Path +import shutil + +import torch +import pytest +import torch.distributed + +from nnscaler.cli.trainer import Trainer +from nnscaler.runtime.hybrid_optimizer import ScaleDelayedOptimizerMixin +from tests.parallel_module.common import assert_close, assert_equal +from ..launch_torchrun import launch_torchrun + + +def param_clss_fn(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + elif 'layers.2.' in param_name or 'layers.12.' in param_name: + return 0, 1 + else: + return 1, 0 + +_lr_history = [] +def on_train_step_start(trainer: 'Trainer', batches) -> None: + _lr_history.append(( + trainer.optimizer.optimizers[0].param_groups[0]['lr'], + trainer.optimizer.optimizers[0].param_groups[1]['lr'], + trainer.optimizer.optimizers[1].param_groups[0]['lr'], + )) + + +def trainer_worker(save_dir, use_zero): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + _lr_history.clear() + + # train with a resume + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 10 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + _lr_history.clear() + # train in one time + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(use_zero), + ]) + trainer.run() + torch.distributed.barrier() + assert len(_lr_history) == 20 + assert all(x == (0.02, 0.03, 0.008) for x in _lr_history[:5]) + assert all(x == (0.04, 0.06, 0.04) for x in _lr_history[5:]) + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + # train with different config + trainer_config = [ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', 'last', + '--checkpoint.resume_from.with_merged', str(True), + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + '--compute_config.use_zero', str(1 - use_zero), + ] + trainer = Trainer(trainer_config) + trainer.run() + torch.distributed.barrier() + if torch.distributed.get_rank() == 0: + r = trainer._merge_checkpoint([ckpt0_savedir / 'last' / f'{i}.ckpt' for i in range(2)]) + # should success + assert r + + torch.distributed.barrier() + + from subprocess import check_call as _call + from functools import partial + call = partial(_call, shell=True) + + if torch.distributed.get_rank() == 0: + call(f"python -m nnscaler.cli.checkpoint distribute {ckpt1_savedir}/last {ckpt1_savedir}/sharded {' '.join(trainer_config)} --compute_config.runtime_ngpus {torch.distributed.get_world_size()}") + + torch.distributed.barrier() + + trainer = Trainer([ + '-f', config_path, + '--compute_config.plan_ngpus', '2', + '--pas_policy', 'tp', + '--max_train_steps', '30', + '--checkpoint.resume_from.checkpoint', f'{ckpt1_savedir}/sharded', + '--checkpoint.resume_from.with_merged', str(False), + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--compute_config.use_zero', str(1 - use_zero), + ]) + trainer.run() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer'], y['optimizer']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +@pytest.mark.parametrize('use_zero', [0, 1]) +def test_hybrid_optimizer(tmp_path, use_zero): + launch_torchrun(2, trainer_worker, tmp_path, use_zero) + + +def param_clss_fn_mp(param_name: str) -> tuple[int, int]: + """ + Classify a parameter name into an optimizer index and a parameter group index. + """ + if 'layers.1.' in param_name or 'layers.10.' in param_name: + return 0, 0 + else: + return 1, 0 + + +def trainer_worker_mp(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args_mixed_precision.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + # train with a hybrid optimizer + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # resume training with hybrid optimizer + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # train with normal optimizer + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--optimizer.args.config!', + '--optimizer.type', 'nnscaler.MixedPrecisionAdamW', + '--optimizer.args.lr', '0.02', + ]) + trainer.run() + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer']['state'], y['optimizer']['state']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_hybrid_optimizer_mp(tmp_path): + launch_torchrun(2, trainer_worker_mp, tmp_path) + + + +class ScaleDelayedAdamW(ScaleDelayedOptimizerMixin, torch.optim.AdamW): + pass + + +def trainer_worker_mp2(save_dir): + save_dir = Path(save_dir) + config_path = str(Path(__file__).with_name('test_hybrid_optimizer_trainer_args_mixed_precision2.yaml').resolve()) + gen_savedir = save_dir / 'gen' + + # train with a hybrid optimizer + ckpt0_savedir = save_dir / 'ckpt0' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '10', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # resume training with hybrid optimizer + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--checkpoint.resume_from', 'last', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt0_savedir), + ]) + trainer.run() + torch.distributed.barrier() + + # train with normal optimizer + ckpt1_savedir = save_dir / 'ckpt1' + trainer = Trainer([ + '-f', config_path, + '--max_train_steps', '20', + '--gen_savedir', str(gen_savedir), + '--checkpoint.save_dir', str(ckpt1_savedir), + '--optimizer.args.config.optimizers.1.type', 'tests.runtime.test_hybrid_optimizer.ScaleDelayedAdamW', + ]) + trainer.run() + torch.distributed.barrier() + + if torch.distributed.get_rank() == 0: + for i in range(2): + x = torch.load(ckpt0_savedir / 'last' / f'{i}.ckpt', weights_only=False) + y = torch.load(ckpt1_savedir / 'last' / f'{i}.ckpt', weights_only=False) + assert_equal(x['model'], y['model']) + assert_equal(x['optimizer']['state'], y['optimizer']['state']) + + torch.distributed.barrier() + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_hybrid_optimizer_mp2(tmp_path): + """ + Demonstrate that ScaleDelayedOptimizerMixin that is applied to existing optimizers + are equivalent to defining new optimizers that inherit from the mixin. + """ + launch_torchrun(2, trainer_worker_mp2, tmp_path) diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml new file mode 100644 index 00000000..b84c4870 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args.yaml @@ -0,0 +1,76 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn + args: + config: + optimizers: + - type: torch.optim.Adam + options: + lr: 0.02 + param_groups: + - options: + lr: 0.04 + - options: + lr: 0.06 + - type: torch.optim.AdamW + options: + lr: 0.04 + +lr_scheduler: + type: nnscaler.HybridLRScheduler + args: + config: + schedulers: + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.5 + total_iters: 5 + - type: torch.optim.lr_scheduler.ConstantLR + options: + factor: 0.2 + total_iters: 5 + interval: step + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped + +hook: + on_train_step_start: tests.runtime.test_hybrid_optimizer.on_train_step_start diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml new file mode 100644 index 00000000..971ada07 --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision.yaml @@ -0,0 +1,54 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn_mp + args: + config: + optimizers: + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + clip_gnorm: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml new file mode 100644 index 00000000..8601c18e --- /dev/null +++ b/tests/runtime/test_hybrid_optimizer_trainer_args_mixed_precision2.yaml @@ -0,0 +1,54 @@ +compute_config: + plan_ngpus: 1 + constant_folding: true + use_zero: true + use_end2end: true + +run_mode: run +pas_policy: autodist +micro_batch_size: 2 +global_batch_size: 8 +max_train_steps: 10 +enable_progress_bar: false +precision: bf16 +instance_name: p$(compute_config.plan_ngpus) + +model: + type: tests.cli.common.MLP + args: + dim: 16 + nlayers: 16 + +optimizer: + type: nnscaler.HybridOptimizer + param_clss_fn: tests.runtime.test_hybrid_optimizer.param_clss_fn_mp + args: + config: + optimizers: + - type: nnscaler.MixedPrecisionAdamW + options: + lr: 0.02 + - type: torch.optim.AdamW + options: + lr: 0.02 + clip_gnorm: 0.01 + +dataset: + type: tests.cli.common.SimpleDataset + train_args: + dim: 16 + size: 100 + val_args: + dim: 16 + size: 10 + +dataloader: + train_args: + drop_last: true + val_args: + drop_last: true + +checkpoint: + keep_last_n_checkpoints: 30 + every_n_train_steps: 1 + save_type: deduped diff --git a/tests/runtime/test_serialization.py b/tests/runtime/test_serialization.py new file mode 100644 index 00000000..aaee1eae --- /dev/null +++ b/tests/runtime/test_serialization.py @@ -0,0 +1,75 @@ +import torch +import pytest + +from nnscaler.runtime.serialization import load, save, convert +from nnscaler.cli.serialization import convert_format + +from tests.parallel_module.common import assert_equal + + +def test_normal(tmp_path): + a = torch.randn((2, 2), device='cpu') + b = torch.randn((2, 3), device='cpu') + c = torch.randn((4, 4), device='cpu') + tensors = { + "embedding": a, + "attention": b, + "fc": a, # shared tensor + "bias": {'inner': b, 'outer': {'deep': c}} + } + save(tensors, tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(tensors, loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + convert_format( + src=str(tmp_path / "model.safetensors"), + dst=str(tmp_path / "model2.ckpt"), + ) + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(tensors, loaded_pt) + loaded_pt2 = torch.load(tmp_path / "model2.ckpt") + assert_equal(tensors, loaded_pt2) + + +def test_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + self.fc2.weight = self.fc1.weight + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + save(model.state_dict(), tmp_path / "model.safetensors") + loaded = load(tmp_path / "model.safetensors", lazy=False) + assert_equal(model.state_dict(), loaded) + convert(tmp_path / "model.safetensors", tmp_path / "model.pt") + loaded_pt = torch.load(tmp_path / "model.pt") + assert_equal(model.state_dict(), loaded_pt) + + +def test_bad_shared_params(tmp_path): + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 4) + self.fc2 = torch.nn.Linear(4, 4) + # share the same weight + # This case is not common, + # so we don't support it currently. + self.fc2.weight.data = self.fc1.weight.reshape(-1) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + model = Model() + with pytest.raises(RuntimeError): + save(model.state_dict(), tmp_path / "model.safetensors") diff --git a/tests/runtime/test_utils.py b/tests/runtime/test_utils.py new file mode 100644 index 00000000..4e736649 --- /dev/null +++ b/tests/runtime/test_utils.py @@ -0,0 +1,62 @@ +from nnscaler.runtime.utils import split_array_min_max + + +def test_split_array_min_max(): + nums = [1, 2, 3, 4, 5, 6, 7, 8, 9] + g = 3 + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + assert groups == [[1, 2, 3, 4, 5], [6, 7], [8, 9]] + assert group_idx == [[0, 1, 2, 3, 4], [5, 6], [7, 8]] + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert groups == [[9, 4, 3], [8, 5, 2], [7, 6, 1]] + assert group_idx == [[8, 3, 2], [7, 4, 1], [6, 5, 0]] + + nums = [10, 10, 10, 10, 10, 10] + g = 3 + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + assert groups == [[10, 10], [10, 10], [10, 10]] + assert group_idx == [[0, 1], [2, 3], [4, 5]] + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert groups == [[10, 10], [10, 10], [10, 10]] + assert group_idx == [[5, 2], [4, 1], [3, 0]] + + nums = [ + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 1310720, 1310720, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 6553600, 6553600, 2621440, 2621440, 2621440, + 1310720, 1310720 + ] + g = 8 + best_sum = sum(nums) // g + + groups, group_idx = split_array_min_max(nums, g, keep_order=True) + max_sum = max(sum(group) for group in groups) + assert len(groups) == 8 + assert list(j for k in group_idx for j in k) == list(range(len(nums))) + + groups, group_idx = split_array_min_max(nums, g, keep_order=False) + assert len(groups) == 8 + max_sum2 = max(sum(group) for group in groups) + assert list(j for k in group_idx for j in k) != list(range(len(nums))) + + assert best_sum< max_sum2 < max_sum + print(f'best_sum: {best_sum}, keep_order: {max_sum}, not keep_order: {max_sum2}') diff --git a/tests/test_policies.py b/tests/test_policies.py index 4f1fa7b0..003adc0c 100644 --- a/tests/test_policies.py +++ b/tests/test_policies.py @@ -8,9 +8,13 @@ import torch import torch.nn as nn -from nnscaler.parallel import ComputeConfig, parallelize +from nnscaler.parallel import ComputeConfig, _load_parallel_module_class, parallelize +from nnscaler.policies import get_called_self_module_name, get_pas_ops +from tests.launch_torchrun import launch_torchrun +from tests.parallel_module.common import FFN, init_distributed +from tests.parallel_module.test_gencode import _gencode_contains, print_gencode -from .utils import init_random +from .utils import init_random, replace_all_device_with MBS = 2 DIM = 16 @@ -58,3 +62,1048 @@ def test_autodist(): load_module=False ) assert m_new is None + + +def test_call_name(): + assert get_called_self_module_name('self.up_proj(x)') == 'up_proj' + assert get_called_self_module_name('self.act_fn(self.gate_proj(x))') == 'act_fn' + assert get_called_self_module_name('self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))') == 'down_proj' + assert get_called_self_module_name('torch.tanh(x)') == '' + assert get_called_self_module_name('x * y') == '' + assert get_called_self_module_name('self.up_proj(x).transpose()') == '' + + +class FnPolicyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = FFN(4, 8) + + def forward(self, x): + x = x * 2 + x = self.ffn(x) + x = x + 3 + return x + + +def megatron_ffn_policy(graph, cfg): + from nnscaler.ir import IRSubTensor + from nnscaler.policies import OpPlan, OpPartition + + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn in [torch.tanh, torch.mul]: + yield OpPlan(node, partition=OpPartition(input=0, dim=1)) + continue + + assert node.fn == torch.nn.functional.linear + + input1: IRSubTensor = node.input(1) + if not input1.is_param(): # linear weight param + continue + + # we will partition gate_proj/up_proj with column parallelism (tp=ngpus) + # and partition down_proj with row parallelism (tp=ngpus) + + if input1.name.endswith('gate_proj.weight') or input1.name.endswith('up_proj.weight'): + # gate_proj/up_proj + # column parallelism + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + elif input1.name.endswith('down_proj.weight'): + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +def megatron_ffn_policy_auto(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + + linear_rank = 0 + for node in get_pas_ops(graph): + if FFN not in node.module_class_chain: # work on FFN module + continue + + if node.fn == torch.nn.functional.linear: + if linear_rank in [0, 1]: + # gate_proj/up_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=0)) + else: + assert linear_rank == 2 + # down_proj + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + linear_rank += 1 + else: + # other ops + yield OpPlan(node, partition='auto') + + +@replace_all_device_with('cpu') +@pytest.mark.parametrize('policy', [megatron_ffn_policy, megatron_ffn_policy_auto]) +def test_codegen_fn(tmp_path, policy): + parallelize( + FnPolicyModule(), + {'x': torch.randn(2, 4)}, + policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + assert fullmap['ffn.gate_proj.weight'].shape == (8, 4) and fullmap['ffn.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.up_proj.weight'].shape == (8, 4) and fullmap['ffn.up_proj.weight'].sub_shape == (4, 4) + assert fullmap['ffn.down_proj.weight'].shape == (4, 8) and fullmap['ffn.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + assert len(_gencode_contains(tmp_path, FnPolicyModule, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + + # Generated code of rank 0 should looks like: + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + + # self.register_parameter('ffn_gate_proj_weight_49', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_gate_proj_weight_49', 5, True, 'ffn.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_up_proj_weight_63', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_up_proj_weight_63', 11, True, 'ffn.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_down_proj_weight_77', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_down_proj_weight_77', 17, True, 'ffn.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment118(self, x_25): + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1653, in forward, x = x * 2 + # mul_27 = torch.mul(x_25, 2) + # del x_25 + # mul_27 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_27, ranks=[0, 1]) + # # created at IRAdapterGener:local_consumer_multiref + # mul_85, mul_89 = nnscaler.runtime.function.multiref(mul_27, times=2) + # del mul_27 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_51 = torch.nn.functional.linear(mul_85, self.ffn_gate_proj_weight_49, bias=None) + # del mul_85 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_59 = torch.tanh(linear_51) + # del linear_51 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_65 = torch.nn.functional.linear(mul_89, self.ffn_up_proj_weight_63, bias=None) + # del mul_89 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_73 = torch.mul(tanh_59, linear_1_65) + # del tanh_59, linear_1_65 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/common.py", line 119, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_79 = torch.nn.functional.linear(mul_1_73, self.ffn_down_proj_weight_77, bias=None) + # del mul_1_73 + # linear_2_35 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_79, ranks=[0, 1]) + # del linear_2_79 + # # File "/home/weijiangxu/MagicCube/tests/parallel_module/test_gencode.py", line 1655, in forward, x = x + 3 + # add_26 = torch.add(linear_2_35, 3, alpha=1) + # del linear_2_35 + # return add_26 + + +class FFNDropout(torch.nn.Module): + def __init__(self, hidden_size, intermediate_size): + super().__init__() + self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False) + self.act_fn = torch.nn.Tanh() + self.dropout = torch.nn.Dropout(p=0.1) + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return self.dropout(down_proj) + + +class FnPolicyModuleList(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +def megatron_ffn_policy_list(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_layer_index, get_called_self_module_name + + for node in get_pas_ops(graph): + if FFNDropout not in node.module_class_chain: # work on FFN module + continue + + ffn_idx = get_layer_index(node.fqn) + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear: + if module_called in ['gate_proj', 'up_proj']: + # gate_proj/up_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=0)) + else: + # down_proj + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition=OpPartition(input=1, dim=1)) + else: + # other ops + yield OpPlan(node, recompute_id=ffn_idx, stage_id=ffn_idx, partition='auto') + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicyModuleList, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + tp_idx = rank // 2 + assert fullmap[f'ffn.{tp_idx}.gate_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.gate_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.up_proj.weight'].shape == (8, 4) and fullmap[f'ffn.{tp_idx}.up_proj.weight'].sub_shape == (4, 4) + assert fullmap[f'ffn.{tp_idx}.down_proj.weight'].shape == (4, 8) and fullmap[f'ffn.{tp_idx}.down_proj.weight'].sub_shape == (4, 4) + + # will generate two communication ops + # one for ffn input + if tp_idx == 0: + assert not _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + else: + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.identity_allreduce') + # one for ffn output + assert _gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.allreduce_identity') + + if tp_idx == 0: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 1 + else: + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, f'nnscaler.runtime.adapter.nn.')) == 2 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'ckpt.checkpoint\(recompute')) == 1 + assert len(_gencode_contains(tmp_path, FnPolicyModuleList, rank, r'def recompute\(')) == 1 + + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 0 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_0_gate_proj_weight_168', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_168', 5, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_182', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_182', 11, True, 'ffn.0.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_196', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_196', 17, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def segment79(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + # mul_51 = nnscaler.runtime.adapter.nn.identity_allreduce(mul_51, ranks=[0, 1]) + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_246, mul_250 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_170 = torch.nn.functional.linear(mul_246, self.ffn_0_gate_proj_weight_168, bias=None) + # del mul_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_178 = torch.tanh(linear_170) + # del linear_170 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_184 = torch.nn.functional.linear(mul_250, self.ffn_0_up_proj_weight_182, bias=None) + # del mul_250 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_192 = torch.mul(tanh_178, linear_1_184) + # del tanh_178, linear_1_184 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_198 = torch.nn.functional.linear(mul_1_192, self.ffn_0_down_proj_weight_196, bias=None) + # del mul_1_192 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # linear_2_59 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_2_198, ranks=[0, 1]) + # del linear_2_198 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_60 = torch. + # nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # return dropout_60 + + # def adapter196(self, dropout_60): + # dropout_236 = nnscaler.runtime.adapter.chunk(dropout_60, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_236, shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # return + + # def adapter207(self): + # gdropout_242 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # gdropout_85 = nnscaler.runtime.adapter.all_gather(gdropout_242, dim=1, ranks=[0, 1]) + # return gdropout_85 + + # def adapter160(self): + # sum_1_50 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_50 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_71): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # gdropout_85 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_73 = nnscaler.runtime.executor.backward('segment79', (x_49, ), (dropout_60, ), (gdropout_85, )) + # del x_49, dropout_60, gdropout_85, gx_73 + # gdropout_287 = nnscaler.runtime.executor.aexecute(model.adapter207, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_279 = nnscaler.runtime.executor.backward('segment79', (x_278, ), (dropout_286, ), (gdropout_287, )) + # del x_278, dropout_286, gdropout_287, gx_279 + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=True) + + # def _infer_step(model, dataloader_71): + # _ = None + # x_49 = next(*(dataloader_71, )) + # dropout_60 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_49, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_60, ), requires_grad=False) + # x_278 = next(*(dataloader_71, )) + # dropout_286 = nnscaler.runtime.executor.fexecute('segment79', model.segment79, *(x_278, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter196, *(dropout_286, ), requires_grad=False) + # sum_1_50 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # sum_1_306 = nnscaler.runtime.executor.aexecute(model.adapter160, *(), requires_grad=False) + # return sum_1_50, sum_1_306 + assert True + + +class FnPolicyModuleSharedWeight(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + for ffn in self.ffn: + x = ffn(x) + x = self.output_projection(x) + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline_shared_weight(tmp_path): + parallelize( + FnPolicyModuleSharedWeight(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + 'pipeline_size': 2, + } + ), + gen_savedir=tmp_path, + load_module=False + ) + for rank in range(2): + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + + for rank in range(2, 4): + # receive shared weight projection via identity + assert _gencode_contains(tmp_path, FnPolicyModuleSharedWeight, rank, r'nnscaler.runtime.function.identity\(input_projection') + + # Generated code of rank 0 looks like: + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 1 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_55', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_55', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_gate_proj_weight_189', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_gate_proj_weight_189', 7, True, 'ffn.0.gate_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_up_proj_weight_203', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_up_proj_weight_203', 13, True, 'ffn.0.up_proj.weight', (8, 4), (slice(4, 8, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_0_down_proj_weight_217', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_0_down_proj_weight_217', 19, True, 'ffn.0.down_proj.weight', (4, 8), (slice(0, 4, None), slice(4, 8, None)), 1) + # self._post_init(init_params, build_buckets) + + # def segment83(self, x_53): + # # shared param + # input_projection_weight_173, input_projection_weight_174 = nnscaler.runtime.function.multiref(self.input_projection_weight_55, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 441, in forward, x = self.input_projection(x) + # linear_56 = torch.nn.functional.linear(x_53, input_projection_weight_173, bias=None) + # del x_53, input_projection_weight_173 + # linear_56 = nnscaler.runtime.adapter.nn.identity_allreduce(linear_56, ranks=[0, 1]) + + # def recompute(linear_56): + # # created at IRAdapterGener:local_consumer_multiref + # linear_278, linear_282 = nnscaler.runtime.function.multiref(linear_56, times=2) + # del linear_56 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_191 = torch.nn.functional.linear(linear_278, self.ffn_0_gate_proj_weight_189, bias=None) + # del linear_278 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_199 = torch.tanh(linear_1_191) + # del linear_1_191 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_205 = torch.nn.functional.linear(linear_282, self.ffn_0_up_proj_weight_203, bias=None) + # del linear_282 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_213 = torch.mul(tanh_199, linear_2_205) + # del tanh_199, linear_2_205 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_219 = torch.nn.functional.linear(mul_213, self.ffn_0_down_proj_weight_217, bias=None) + # del mul_213 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_23 = self.training + # linear_3_64 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_3_219, ranks=[0, 1]) + # del linear_3_219 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_65 = torch.nn.functional.dropout(linear_3_64, p=0.1, training=ffn_0_dropout_training_23, inplace=False) + # del linear_3_64 + # return dropout_65 + + # dropout_65 = ckpt.checkpoint(recompute, linear_56, use_reentrant=False) + # return dropout_65, input_projection_weight_174 + + # def adapter190(self, input_projection_weight_174): + # input_projection_weight_257 = nnscaler.runtime.adapter.chunk(input_projection_weight_174, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(input_projection_weight_257, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter234(self, dropout_65): + # dropout_265 = nnscaler.runtime.adapter.chunk(dropout_65, dim=1, ranks=[0, 1]) + # _ = nnscaler.runtime.adapter.move(dropout_265, shape=(4, 2), dtype=torch.float32, src=1, dst=3) + # return + + # def adapter245(self): + # gdropout_267 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # gdropout_92 = nnscaler.runtime.adapter.all_gather(gdropout_267, dim=1, ranks=[0, 1]) + # return gdropout_92 + + # def adapter201(self): + # ginput_projection_weight_263 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=3, dst=1) + # ginput_projection_weight_177 = nnscaler.runtime.adapter.all_gather(ginput_projection_weight_263, dim=1, ranks=[0, 1]) + # return ginput_projection_weight_177 + + # def adapter214(self): + # sum_1_54 = nnscaler.runtime.adapter.move((), shape=(), dtype=torch.float32, src=3, dst=1) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # gdropout_92 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_177 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gx_78 = nnscaler.runtime.executor.backward('segment83', (x_53, ), (dropout_65, input_projection_weight_174, ), (gdropout_92, ginput_projection_weight_177, )) + # del x_53, dropout_65, input_projection_weight_174, gdropout_92, ginput_projection_weight_177, gx_78 + # gdropout_311 = nnscaler.runtime.executor.aexecute(model.adapter245, *(), requires_grad=False) + # ginput_projection_weight_315 = nnscaler.runtime.executor.aexecute(model.adapter201, *(), requires_grad=False) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gx_303 = nnscaler.runtime.executor.backward('segment83', (x_302, ), (dropout_310, input_projection_weight_314, ), (gdropout_311, ginput_projection_weight_315, )) + # del x_302, dropout_310, input_projection_weight_314, gdropout_311, ginput_projection_weight_315, gx_303 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # x_53 = next(*(dataloader_76, )) + # dropout_65, input_projection_weight_174 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_53, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_174, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_65, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # dropout_310, input_projection_weight_314 = nnscaler.runtime.executor.fexecute('segment83', model.segment83, *(x_302, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter190, *(input_projection_weight_314, ), requires_grad=False) + # _ = nnscaler.runtime.executor.aexecute(model.adapter234, *(dropout_310, ), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(), requires_grad=False) + # return sum_1_54, sum_1_349 + + # Generated code of rank 2 looks like: + + # class GenModel(nnscaler.runtime.module.ParallelModule): + # use_scheduler = True + # nmicros_per_scheduler_step = 2 + # rank = 2 + # world_size = 4 + + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('ffn_1_gate_proj_weight_222', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_gate_proj_weight_222', 26, True, 'ffn.1.gate_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_up_proj_weight_236', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_up_proj_weight_236', 32, True, 'ffn.1.up_proj.weight', (8, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.register_parameter('ffn_1_down_proj_weight_250', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('ffn_1_down_proj_weight_250', 38, True, 'ffn.1.down_proj.weight', (4, 8), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self._post_init(init_params, build_buckets) + + # def adapter190(self): + # input_projection_weight_256 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # input_projection_weight_174 = nnscaler.runtime.adapter.all_gather(input_projection_weight_256, dim=1, ranks=[2, 3]) + # return input_projection_weight_174 + + # def adapter234(self): + # dropout_264 = nnscaler.runtime.adapter.move((), shape=(4, 2), dtype=torch.float32, src=0, dst=2) + # dropout_65 = nnscaler.runtime.adapter.all_gather(dropout_264, dim=1, ranks=[2, 3]) + # return dropout_65 + + # def segment93(self, dropout_65, input_projection_weight_174): + # input_projection_weight_184 = nnscaler.runtime.function.identity(input_projection_weight_174) + # del input_projection_weight_174 + # dropout_180 = nnscaler.runtime.function.identity(dropout_65) + # del dropout_65 + # dropout_180 = nnscaler.runtime.adapter.nn.identity_allreduce(dropout_180, ranks=[2, 3]) + + # def recompute(dropout_180): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_286, dropout_290 = nnscaler.runtime.function.multiref(dropout_180, times=2) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_224 = torch.nn.functional.linear(dropout_286, self.ffn_1_gate_proj_weight_222, bias=None) + # del dropout_286 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_232 = torch.tanh(linear_4_224) + # del linear_4_224 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_238 = torch.nn.functional.linear(dropout_290, self.ffn_1_up_proj_weight_236, bias=None) + # del dropout_290 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_246 = torch.mul(tanh_1_232, linear_5_238) + # del tanh_1_232, linear_5_238 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 230, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_6_252 = torch.nn.functional.linear(mul_1_246, self.ffn_1_down_proj_weight_250, bias=None) + # del mul_1_246 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_42 = self.training + # linear_6_73 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_6_252, ranks=[2, 3]) + # del linear_6_252 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 231, in forward, return self.dropout(down_proj) + # dropout_1_74 = torch.nn.functional.dropout(linear_6_73, p=0.1, training=ffn_1_dropout_training_42, inplace=False) + # del linear_6_73 + # return dropout_1_74 + + # dropout_1_74 = ckpt.checkpoint(recompute, dropout_180, use_reentrant=False) + # del dropout_180 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 444, in forward, x = self.output_projection(x) + # linear_7_75 = torch.nn.functional.linear(dropout_1_74, input_projection_weight_184, bias=None) + # del input_projection_weight_184, dropout_1_74 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 445, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_54 = torch.sum(linear_7_75) + # del linear_7_75 + # return sum_1_54 + + # def adapter245(self, gdropout_92): + # gdropout_266 = nnscaler.runtime.adapter.chunk(gdropout_92, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(gdropout_266, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter201(self, ginput_projection_weight_177): + # ginput_projection_weight_262 = nnscaler.runtime.adapter.chunk(ginput_projection_weight_177, dim=1, ranks=[2, 3]) + # _ = nnscaler.runtime.adapter.move(ginput_projection_weight_262, shape=(4, 2), dtype=torch.float32, src=2, dst=0) + # return + + # def adapter214(self, sum_1_54): + # _ = nnscaler.runtime.adapter.move(sum_1_54, shape=(), dtype=torch.float32, src=2, dst=0) + # return sum_1_54 + + # def _forward_impl(self, *args, **kwargs): + # raise NotImplementedError("Code of forward is not generated. You should use module.train_step/module.infer_step instead.") + + # ########## Generated Schedule Code ########### + # import torch + # import nnscaler + + # def _train_step(model, dataloader_76): + # _ = None + # nnscaler.flags.RuntimeFlag.skip_zero_grad = False + # model.zero_grad() + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = True + # gdropout_92, ginput_projection_weight_177 = nnscaler.runtime.executor.backward('segment93', (dropout_65, input_projection_weight_174, ), (sum_1_54, ), (None, )) + # sum_1_54 = sum_1_54.detach() + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=True) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=True) + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_92, ), requires_grad=False) + # del dropout_65, gdropout_92 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_177, ), requires_grad=False) + # del input_projection_weight_174, ginput_projection_weight_177 + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=True) + # nnscaler.flags.RuntimeFlag.skip_reducer = False + # gdropout_311, ginput_projection_weight_315 = nnscaler.runtime.executor.backward('segment93', (dropout_310, input_projection_weight_314, ), (sum_1_349, ), (None, )) + # sum_1_349 = sum_1_349.detach() + # _ = nnscaler.runtime.executor.aexecute(model.adapter245, *(gdropout_311, ), requires_grad=False) + # del dropout_310, gdropout_311 + # _ = nnscaler.runtime.executor.aexecute(model.adapter201, *(ginput_projection_weight_315, ), requires_grad=False) + # del input_projection_weight_314, ginput_projection_weight_315 + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=True) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=True) + # return sum_1_54, sum_1_349 + + # def _infer_step(model, dataloader_76): + # _ = None + # input_projection_weight_174 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_65 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_54 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_65, input_projection_weight_174, ), requires_grad=False) + # input_projection_weight_314 = nnscaler.runtime.executor.aexecute(model.adapter190, *(), requires_grad=False) + # dropout_310 = nnscaler.runtime.executor.aexecute(model.adapter234, *(), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.fexecute('segment93', model.segment93, *(dropout_310, input_projection_weight_314, ), requires_grad=False) + # x_302 = next(*(dataloader_76, )) + # del x_302 + # x_53 = next(*(dataloader_76, )) + # del x_53 + # sum_1_54 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_54, ), requires_grad=False) + # sum_1_349 = nnscaler.runtime.executor.aexecute(model.adapter214, *(sum_1_349, ), requires_grad=False) + # return sum_1_54, sum_1_349 + + +class FnPolicySharedWeightModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection = torch.nn.Linear(4, 4, bias=False) + self.output_projection.weight = self.input_projection.weight # share weight + + def forward(self, x): + x = self.input_projection(x) + x = self.output_projection(x) + return x + + +def shared_weight_different_partition_policy(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition, get_called_self_module_name + + for node in get_pas_ops(graph): + module_called = get_called_self_module_name(node.call_expr) + + if node.fn == torch.nn.functional.linear and module_called == 'output_projection': + # input_projection.weight is used two times with different partition + # x = self.input_projection(x) --> no partition + # x = self.output_projection(x) --> partition dim=1 + yield OpPlan(node, partition=OpPartition(input=1, dim=1)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_shared_weight(tmp_path): + parallelize( + FnPolicySharedWeightModule(), + {'x': torch.randn(4, 4)}, + # 'pp', + shared_weight_different_partition_policy, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + + for rank in range(4): + module_class = _load_parallel_module_class(FnPolicySharedWeightModule, gen_savedir=tmp_path, rank=rank) + + fullmap = {m.orig_name: m for m in module_class.attr_meta_maps[rank].values()} + # the input projection is multiref'ed + assert _gencode_contains(tmp_path, FnPolicySharedWeightModule, rank, r'nnscaler.runtime.function.multiref\(self.input_projection') + # input_projection.weight will not be splitted + # because it is multiref'ed + assert fullmap['input_projection.weight'].shape == (4, 4) and fullmap['input_projection.weight'].sub_shape == (4, 4) + + # Generated code of rank 0 looks like: + # def __init__(self, init_params=True, build_buckets=True, *args, async_op=False, max_bucket_size_bytes=None, zero_use_reduce_scatter=False, **kwargs): + # super().__init__() + # # communication groups + # self.init_group(ranks=[0, 2]) + # self.init_group(ranks=[1, 3]) + # self.init_group(ranks=[0, 1]) + # self.init_group(ranks=[2, 3]) + + # self.register_parameter('input_projection_weight_15', torch.nn.Parameter(torch.empty((4, 4), dtype=torch.float32))) + # self.add_full_map('input_projection_weight_15', 3, True, 'input_projection.weight', (4, 4), (slice(0, 4, None), slice(0, 4, None)), 1) + + # self.wreducer80 = nnscaler.runtime.adapter.Reducer(ranks=[0, 2], reduce_op='sum', async_op=async_op, zero=False, max_bucket_size_bytes=max_bucket_size_bytes, zero_use_reduce_scatter=zero_use_reduce_scatter, zero_ngroups=1) + # self.wreducer80.add_param(self.input_projection_weight_15) + # self.add_reducer(self.wreducer80) + + # self._post_init(init_params, build_buckets) + + # def segment76(self, x_13): + # # shared param + # input_projection_weight_32, input_projection_weight_33 = nnscaler.runtime.function.multiref(self.input_projection_weight_15, times=2) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 763, in forward, x = self.input_projection(x) + # linear_16 = torch.nn.functional.linear(x_13, input_projection_weight_32, bias=None) + # del x_13, input_projection_weight_32 + # linear_22 = nnscaler.runtime.adapter.nn.split_allgather(linear_16, dim=1, ranks=[0, 1]) + # del linear_16 + # input_projection_weight_37 = nnscaler.runtime.adapter.nn.split_allgather(input_projection_weight_33, dim=1, ranks=[0, 1]) + # del input_projection_weight_33 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 764, in forward, x = self.output_projection(x) + # linear_1_26 = torch.nn.functional.linear(linear_22, input_projection_weight_37, bias=None) + # del linear_22, input_projection_weight_37 + # linear_1_14 = nnscaler.runtime.adapter.nn.allreduce_identity(linear_1_26, ranks=[0, 1]) + # del linear_1_26 + # return linear_1_14 + + +class FnPolicyModuleList2(torch.nn.Module): + def __init__(self): + super().__init__() + self.ffn = torch.nn.ModuleList([ + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + FFNDropout(4, 8), + ]) + + def forward(self, x): + x = x * 2 + for ffn in self.ffn: + x = ffn(x) + x = x + 3 + return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + + +@replace_all_device_with('cpu') +def test_codegen_fn_pipeline2(tmp_path): + parallelize( + FnPolicyModuleList2(), + {'x': torch.randn(4, 4)}, + # 'pp', + megatron_ffn_policy_list, + ComputeConfig(4, 4, use_end2end=True, + pas_config={ + 'pipeline_nmicros': 2, + # 4 stages, with pp=2 + 'pipeline_size': 2, + 'pipeline_scheduler': '1f1b_interleaved', + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + assert True + + +class HookModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + a, b = x.size()[:2] + r = torch.randn(a * 2, b) + r = r.chunk(2, dim=0)[0] + return self.linear(x) + r + + + +def hello(module, meta, *args, **kwargs): + print(f'hello: {meta}') + + +def policy_hook(graph, cfg): + from nnscaler.policies import OpPlan, OpPartition + # add hook to all ops + def _hook(op_plan: OpPlan): + op_plan.pre_hook = hello + op_plan.post_hook = hello + op_plan.hook_meta = op_plan.op.name + return op_plan + + for node in get_pas_ops(graph): + if node.fn == torch.nn.functional.linear: + yield _hook(OpPlan(node, partition=OpPartition(input=1, dim=0))) + else: + yield _hook(OpPlan(node)) + + +@replace_all_device_with('cpu') +def test_codegen_fn_with_hook(tmp_path): + parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 4), + gen_savedir=tmp_path, + load_module=False + ) + # should successfully generate code without error + # and hooks are inserted + for rank in range(4): + assert _gencode_contains(tmp_path, HookModule, rank, r'tests.test_policies.hello\(self,') + + # Generated code of rank 0 looks like: + # def segment64(self, x_32): + # x_32 = nnscaler.runtime.adapter.nn.identity_allreduce(x_32, ranks=[0, 1]) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 883, in forward, a, b = x.size()[:2] + # tests.test_policies.hello(self, 'size', (x_32, ), dict()) + # im_output_63 = torch.Tensor.size(x_32) + # tests.test_policies.hello(self, 'size', (x_32, ), dict(), im_output_63) + # size_26 = im_output_63[0] + # size_27 = im_output_63[1] + # del im_output_63 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict()) + # mul_28 = _operator.mul(size_26, 2) + # tests.test_policies.hello(self, 'mul', (size_26, 2), dict(), mul_28) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 884, in forward, r = torch.randn(a * 2, b) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False)) + # randn_34 = nnscaler.runtime.function.randn(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False) + # tests.test_policies.hello(self, 'randn', (), dict(size=(mul_28, size_27), requires_grad=False, dtype=torch.float32, pin_memory=False), randn_34) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 885, in forward, r = r.chunk(2, dim=0)[0] + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0)) + # chunk_35, chunk_36 = torch.chunk(randn_34, chunks=2, dim=0) + # tests.test_policies.hello(self, 'chunk', (randn_34, ), dict(chunks=2, dim=0), (chunk_35, chunk_36)) + # del randn_34, chunk_36 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict()) + # linear_49 = torch.nn.functional.linear(x_32, self.linear_weight_45, self.linear_bias_47) + # tests.test_policies.hello(self, 'linear', (x_32, self.linear_weight_45, self.linear_bias_47), dict(), linear_49) + # del x_32 + # linear_39 = nnscaler.runtime.adapter.nn.allgather_split(linear_49, dim=1, ranks=[0, 1]) + # del linear_49 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 886, in forward, return self.linear(x) + r + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1)) + # add_33 = torch.add(linear_39, chunk_35, alpha=1) + # tests.test_policies.hello(self, 'add', (linear_39, chunk_35), dict(alpha=1), add_33) + # del chunk_35, linear_39 + # return add_33 + + +def _gencode_unused_args_worker(tempdir): + init_distributed() + m_new = parallelize( + HookModule(), + {'x': torch.randn(4, 4)}, + policy_hook, + ComputeConfig(2, 2), + gen_savedir=tempdir, + load_module=True + ) + assert m_new is not None + m_new(torch.randn(4, 4)) + # should successfully run without error + assert True + + +@pytest.mark.skipif(not torch.cuda.is_available() or torch.cuda.device_count() < 2, reason='lack of gpu devices') +def test_run_codegen_fn_with_hook(): + """ + Verify the generated code can run correctly. + """ + with tempfile.TemporaryDirectory() as tempdir: + launch_torchrun(2, _gencode_unused_args_worker, tempdir) + + +@replace_all_device_with('cpu') +def test_codegen_fsdp(tmp_path): + parallelize( + FnPolicyModuleList(), + {'x': torch.randn(4, 4)}, + 'fsdp', + ComputeConfig( + 1, 2, + use_end2end=True, + use_zero=3, + pas_config={ + 'recomputes': [FFNDropout], + } + ), + gen_savedir=tmp_path, + load_module=False + ) + # code should look like: + # def segment105_impl(self, x_49): + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 239, in forward, x = x * 2 + # mul_51 = torch.mul(x_49, 2) + # del x_49 + + # def recompute(mul_51): + # # created at IRAdapterGener:local_consumer_multiref + # mul_100, mul_104 = nnscaler.runtime.function.multiref(mul_51, times=2) + # del mul_51 + # self.prefetch_param(self.ffn_0_gate_proj_weight_52) + # mul_100 = self.backward_postevict_param(mul_100, self.ffn_0_gate_proj_weight_52, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_53 = torch.nn.functional.linear(mul_100, self.ffn_0_gate_proj_weight_52, bias=None) + # self.postevict_param(self.ffn_0_gate_proj_weight_52) + # linear_53 = self.backward_prefetch_param(linear_53, self.ffn_0_gate_proj_weight_52, 1) + # del mul_100 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_54 = torch.tanh(linear_53) + # del linear_53 + # self.prefetch_param(self.ffn_0_up_proj_weight_55) + # mul_104 = self.backward_postevict_param(mul_104, self.ffn_0_up_proj_weight_55, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_1_56 = torch.nn.functional.linear(mul_104, self.ffn_0_up_proj_weight_55, bias=None) + # self.postevict_param(self.ffn_0_up_proj_weight_55) + # linear_1_56 = self.backward_prefetch_param(linear_1_56, self.ffn_0_up_proj_weight_55, 3) + # del mul_104 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_1_57 = torch.mul(tanh_54, linear_1_56) + # del tanh_54, linear_1_56 + # self.prefetch_param(self.ffn_0_down_proj_weight_58) + # mul_1_57 = self.backward_postevict_param(mul_1_57, self.ffn_0_down_proj_weight_58, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_2_59 = torch.nn.functional.linear(mul_1_57, self.ffn_0_down_proj_weight_58, bias=None) + # self.postevict_param(self.ffn_0_down_proj_weight_58) + # linear_2_59 = self.backward_prefetch_param(linear_2_59, self.ffn_0_down_proj_weight_58, 5) + # del mul_1_57 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_0_dropout_training_21 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_60 = torch.nn.functional.dropout(linear_2_59, p=0.1, training=ffn_0_dropout_training_21, inplace=False) + # del linear_2_59 + # return dropout_60 + + # dropout_60 = ckpt.checkpoint(recompute, mul_51, use_reentrant=False) + # del mul_51 + + # def recompute(dropout_60): + # # created at IRAdapterGener:local_consumer_multiref + # dropout_108, dropout_112 = nnscaler.runtime.function.multiref(dropout_60, times=2) + # del dropout_60 + # self.prefetch_param(self.ffn_1_gate_proj_weight_61) + # dropout_108 = self.backward_postevict_param(dropout_108, self.ffn_1_gate_proj_weight_61, 1) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_3_62 = torch.nn.functional.linear(dropout_108, self.ffn_1_gate_proj_weight_61, bias=None) + # self.postevict_param(self.ffn_1_gate_proj_weight_61) + # linear_3_62 = self.backward_prefetch_param(linear_3_62, self.ffn_1_gate_proj_weight_61, 1) + # del dropout_108 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # tanh_1_63 = torch.tanh(linear_3_62) + # del linear_3_62 + # self.prefetch_param(self.ffn_1_up_proj_weight_64) + # dropout_112 = self.backward_postevict_param(dropout_112, self.ffn_1_up_proj_weight_64, 3) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_4_65 = torch.nn.functional.linear(dropout_112, self.ffn_1_up_proj_weight_64, bias=None) + # self.postevict_param(self.ffn_1_up_proj_weight_64) + # linear_4_65 = self.backward_prefetch_param(linear_4_65, self.ffn_1_up_proj_weight_64, 3) + # del dropout_112 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # mul_2_66 = torch.mul(tanh_1_63, linear_4_65) + # del tanh_1_63, linear_4_65 + # self.prefetch_param(self.ffn_1_down_proj_weight_67) + # mul_2_66 = self.backward_postevict_param(mul_2_66, self.ffn_1_down_proj_weight_67, 5) + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 226, in forward, down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + # linear_5_68 = torch.nn.functional.linear(mul_2_66, self.ffn_1_down_proj_weight_67, bias=None) + # self.postevict_param(self.ffn_1_down_proj_weight_67) + # linear_5_68 = self.backward_prefetch_param(linear_5_68, self.ffn_1_down_proj_weight_67, 5) + # del mul_2_66 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # ffn_1_dropout_training_40 = self.training + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 227, in forward, return self.dropout(down_proj) + # dropout_1_69 = torch.nn.functional.dropout(linear_5_68, p=0.1, training=ffn_1_dropout_training_40, inplace=False) + # del linear_5_68 + # return dropout_1_69 + + # dropout_1_69 = ckpt.checkpoint(recompute, dropout_60, use_reentrant=False) + # del dropout_60 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 242, in forward, x = x + 3 + # add_70 = torch.add(dropout_1_69, 3, alpha=1) + # del dropout_1_69 + # # File "/home/weijiangxu/MagicCube/tests/test_policies.py", line 243, in forward, return torch.sum(x) # make sure output is scalar loss (required by pipeline parallelism) + # sum_1_50 = torch.sum(add_70) + # del add_70 + # return sum_1_50 + + # def segment105(self, x_49): + # with self.save_params_hooks(): + # return self.segment105_impl(x_49) + assert True diff --git a/tests/test_utils.py b/tests/test_utils.py index 7fa7d80a..864cab67 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,10 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from collections import OrderedDict from dataclasses import dataclass import pytest +import torch -from nnscaler.utils import select_many, classproperty, fields +from nnscaler.utils import ( + select_many, classproperty, fields, set_member_by_name, unchecked_fields, + transform_recursively, +) def test_select_many(): @@ -53,3 +58,123 @@ class A: assert fields(A).y == 'y' with pytest.raises(AttributeError): fields(A).z + + assert unchecked_fields(A).x == 'x' + assert unchecked_fields(A).y == 'y' + assert unchecked_fields(A).z == 'z' + + a = A(x=0, y=0) + assert unchecked_fields(a).x == 'x' + assert unchecked_fields(a).y == 'y' + assert unchecked_fields(a).z == 'z' + + class B: + def __init__(self): + self.a = A(x=1, y=2) + + assert unchecked_fields(B).x == 'x' + b = B() + assert unchecked_fields(b).x == 'x' + assert unchecked_fields(b.a).x == 'x' + + +def test_set_member_by_name(): + model = torch.nn.Module() + set_member_by_name(model, "x", 42) + assert model.x == 42 + with pytest.raises(AttributeError): + set_member_by_name(model, 'x.y.z', 43) + + set_member_by_name(model, 'a.b.c', 44) + assert model.a.b.c == 44 + + model = torch.nn.Module() + child_module = torch.nn.Module() + set_member_by_name(model, "x.y", child_module) + assert model.x.y == child_module + + set_member_by_name(model, 'x.y.z', 45) + assert model.x.y == child_module + assert model.x.y.z == 45 + + +def test_transform_recursively(): + data = { + 'a': torch.tensor([1]), + 'b': [torch.tensor(4), {'c': torch.tensor([5])}], + 'd': (7, torch.tensor(8)), + 'e': {1: 9, 2: torch.tensor(10)}.keys(), + 'f': {1: 9, 2: torch.tensor(11)}.items(), + 'g': {1: 9, 2: torch.tensor(12)}.values(), + 'h': {1: 9, 2: torch.tensor(13)}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': {1: 9, 2: 10}.values(), + 'm': [1, 2, 3], + 'n': slice(0, 10, torch.tensor(2)), + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {torch.tensor(1): 9, torch.tensor(2): 10}.keys() + } + + def fn(x): + if isinstance(x, torch.Tensor): + return x.item() + return x + + result1 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=True, + ) + + result2 = transform_recursively( + data, fn, + target_types=torch.Tensor, + collection_types=None, + skip_dict_keys=False, + ) + target = { + 'a': 1, + 'b': [4, {'c': 5}], + 'd': (7, 8), + 'e': {1: 1, 2: 2}.keys(), + 'f': dict([(1, 9), (2, 11)]).items(), + 'g': {1: 9, 2: 12}.values(), + 'h': {1: 9, 2: 13}, + 'i': slice(0, 10, None), + 'j': torch.Size([11, 12]), + 'k': OrderedDict({1: 9, 2: 10}), + 'l': data['l'], + 'm': [1, 2, 3], + 'n': slice(0, 10, 2), + } + # dict values are not comparable. + assert list(target['g']) == list(result1.pop('g')) + assert list(target['g']) == list(result2.pop('g')) + target.pop('g') + + + skip_key_target = { + **target, + 'o': {torch.tensor(1): 9, torch.tensor(2): 10}, + 'p': {torch.tensor(1): 9, torch.tensor(2): 10}.items(), + 'q': {1: 9, 2: 10}.keys() + } + noskip_key_target = { + **target, + 'o': {1: 9, 2: 10}, + 'p': dict([(1, 9), (2, 10)]).items(), + 'q': {1: 9, 2: 10}.keys() + } + + from tests.parallel_module.common import assert_equal + + assert_equal(list(skip_key_target.pop('o')), list(result1.pop('o'))) + assert_equal(list(skip_key_target.pop('p')), list(result1.pop('p'))) + assert_equal(list(skip_key_target.pop('q')), list(result1.pop('q'))) + + assert_equal(result1, skip_key_target) + assert_equal(result2, noskip_key_target) diff --git a/tests/utils.py b/tests/utils.py index 22036f42..0aae6c4e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -26,6 +26,10 @@ from nnscaler.runtime.device import DeviceGroup, CompileFlag +MASTER_PORT = os.environ.get("MASTER_PORT", "29401") +PYTEST_RUN_ID = MASTER_PORT + + def init_parameter(model: torch.nn.Module, seed: int = 0): """ Initialize a model's parameters with truncated normal distribution. @@ -58,7 +62,7 @@ def init_random(seed: int = 1): torch.cuda.manual_seed(seed) -def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-4) -> bool: +def assert_parity(baseline_fn: Callable, compile_fn: Callable, atol: float=1e-3) -> bool: """Compare the output of baseline_fn and compile_fn Error will raise if the output of two functions are not the same. @@ -92,10 +96,10 @@ def assert_same_complex(gt, out): assert_same_complex(gt[key], out[key]) elif isinstance(gt, torch.Tensor): assert isinstance(out, torch.Tensor) - assert torch.allclose(gt, out, atol=atol), f'mismatched: {gt} != {out}' + assert torch.allclose(gt, out, atol=atol), f'mismatched (with atol {atol}): {gt} != {out}' elif isinstance(gt, float): assert isinstance(out, float) - assert math.isclose(gt, out, abs_tol=atol), f'mismatched: {gt} != {out}' + assert math.isclose(gt, out, abs_tol=atol), f'mismatched (with atol {atol}): {gt} != {out}' else: assert gt == out, f'mismatched: {gt} != {out}' assert_same_complex(baseline_outputs, compile_outputs) @@ -114,6 +118,7 @@ def replace_all_device_with(device='cpu', force=False): orig_to = torch.Tensor.to orig_cuda = torch.Tensor.cuda orig_cpu = torch.Tensor.cpu + orig_is_cuda = torch.Tensor.is_cuda def patch_tensor_constructor(fn): orig_func = getattr(fn, '__cube_orig_func__', fn) # to support nested patching @@ -158,6 +163,8 @@ def wrapper(*args, **kwargs): } def patched_to(self, *args, **kwargs): + if device == 'meta': + return self if len(args) > 0 and isinstance(args[0], (torch.device, str)): return orig_to(self, device, *args[1:], **kwargs) if 'device' in kwargs: @@ -166,15 +173,20 @@ def patched_to(self, *args, **kwargs): return orig_to(self, *args, **kwargs) def patched_cuda(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) def patched_cpu(self, *args, **kwargs): + if device == 'meta': + return self return orig_to(self, device) try: torch.Tensor.to = patched_to torch.Tensor.cuda = patched_cuda torch.Tensor.cpu = patched_cpu + torch.Tensor.is_cuda = property(lambda self: True) # patch tensor constructors for tf_name, fn in old_tensor_constructors.items(): setattr(torch, tf_name, patched_tensor_constructors[tf_name]) @@ -205,6 +217,7 @@ def patched_cpu(self, *args, **kwargs): torch.Tensor.to = orig_to torch.Tensor.cuda = orig_cuda torch.Tensor.cpu = orig_cpu + torch.Tensor.is_cuda = orig_is_cuda # mock process group is from pytorch testing code @@ -373,6 +386,7 @@ def catch_stdout(): def clear_dir_on_rank0(tempdir): if torch.distributed.get_rank() == 0 and tempdir.exists(): shutil.rmtree(tempdir) + torch.distributed.barrier() yield tempdir torch.distributed.barrier() if torch.distributed.get_rank() == 0 and tempdir.exists(): diff --git a/tox.ini b/tox.ini index 04e0c3ae..0cfe015c 100644 --- a/tox.ini +++ b/tox.ini @@ -8,9 +8,11 @@ envlist = py310 skipsdist = True [testenv] -allowlist_externals = rm +allowlist_externals = + rm + uv passenv = * -install_command = pip install {opts} {packages} +install_command = uv pip install {opts} {packages} deps = -rrequirements.txt -rrequirements-dev.txt diff --git a/utility/aggregate.sh b/utility/aggregate.sh new file mode 100644 index 00000000..19185564 --- /dev/null +++ b/utility/aggregate.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# gather the folder to all workers to node-0 under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKER}:${WORKSPACE}/${FOLDER} ${WORKSPACE}/${FOLDER}-${WORKER} +done diff --git a/utility/broadcast.sh b/utility/broadcast.sh new file mode 100644 index 00000000..dbb77c7a --- /dev/null +++ b/utility/broadcast.sh @@ -0,0 +1,15 @@ +#!/bin/bash +# broadcast the folder to all workers under the same workspace + +set -ex + +WORKSPACE=/workspace +FOLDER=MagicCube + +WORKER_PREFIX=node- +WORKER_NUM=2 + +for ((i=1; i<=${WORKER_NUM}; i++)); do + WORKER=${WORKER_PREFIX}${i} + scp -r ${WORKSPACE}/${SYNC_FOLDER} ${WORKER}:${WORKSPACE} +done diff --git a/utility/comm_profile.py b/utility/comm_profile.py new file mode 100644 index 00000000..5f767c24 --- /dev/null +++ b/utility/comm_profile.py @@ -0,0 +1,108 @@ +import argparse +import json +import torch +from pathlib import Path +import os +from typing import Tuple, List, Dict + +import nnscaler +from nnscaler.runtime.adapter.collectives import all_gather, all_reduce, all_to_all, reduce_scatter +from nnscaler.profiler import CudaTimer +from nnscaler.runtime.device import DeviceGroup +from nnscaler.autodist.util import get_node_arch, get_default_profile_path + + +class CommProfiler: + + def __init__(self, + nranks: int, + warmup_times: int = 10, + profile_times: int = 10) -> None: + self.nranks = nranks + self.warmup_times = warmup_times + self.profile_times = profile_times + self.ranks = tuple(range(self.nranks)) + + def collect_profile_info(self, + primitive: str) -> Tuple[List[float], List[float]]: + + b_size = 16 + sequence_len = 16 + quarter_mb_size_list = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048 + ] + model_dim_list = [ + mem * 256 * 256 // b_size // sequence_len + for mem in quarter_mb_size_list + ] + sizes_in_mb = [0.25 * val for val in quarter_mb_size_list] + times_in_s = [] + for cur_sz, d_size in zip(sizes_in_mb, model_dim_list): + assert d_size % self.nranks == 0 + if primitive in ['all gather', 'all to all']: + d_size = d_size // self.nranks + tensor = torch.rand([b_size, sequence_len, d_size], + dtype=torch.float32, + device=torch.cuda.current_device()) + if primitive == 'all gather': + func = all_gather + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all reduce': + func = all_reduce + kwargs = {'tensor': tensor, 'ranks': self.ranks} + elif primitive == 'reduce scatter': + func = reduce_scatter + kwargs = {'tensor': tensor, 'dim': 2, 'ranks': self.ranks} + elif primitive == 'all to all': + func = all_to_all + kwargs = { + 'tensor': tensor, + 'idim': 0, + 'odim': 2, + 'ranks': self.ranks + } + else: + raise ValueError('Unknown primitive: {}'.format(primitive)) + for _ in range(self.warmup_times): + func(**kwargs) + CudaTimer().clear() + for _ in range(self.profile_times): + otensor = func(**kwargs) + cur_t = CudaTimer().instance.field_data['comm'] / self.profile_times + times_in_s.append(cur_t) + return sizes_in_mb, times_in_s + + def profile(self) -> Dict[str, Tuple[List[float], List[float]]]: + profile_info = {} + for primitive in [ + 'all gather', 'all reduce', 'reduce scatter', 'all to all' + ]: + profile_info[primitive] = self.collect_profile_info( + primitive=primitive) + return profile_info + +if __name__ == '__main__': + + parser = argparse.ArgumentParser( + description='Profile runtime communication cost') + parser.add_argument('--comm_profile_dir', + type=str, + default=get_default_profile_path() / get_node_arch() / 'comm', + help='autodist comm profile folder') + args = parser.parse_args() + + nnscaler.init() + + CudaTimer(enable=True, predefined=True) + world_size = DeviceGroup().world_size + comm_profiler = CommProfiler(nranks=world_size) + + profile_info = comm_profiler.profile() + + if torch.distributed.get_rank() == 0: + dir_path = Path(args.comm_profile_dir) + if not dir_path.exists(): + dir_path.mkdir(parents=True, exist_ok=True) + file_name = dir_path / f'intra_{world_size}.json' + with open(file_name, 'w') as f: + json.dump(profile_info, f, indent=2) diff --git a/utility/dgx1_reorder_gpu.py b/utility/dgx1_reorder_gpu.py new file mode 100644 index 00000000..aa312587 --- /dev/null +++ b/utility/dgx1_reorder_gpu.py @@ -0,0 +1,119 @@ +""" +Reorder GPU index by finding DGX-1 topology Find dgx topology + +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +1 = 0 = 4 = 5 +โ€– x | | x โ€– +2 = 3 = 7 = 6 +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + +""" +from typing import List +import subprocess +import numpy as np + +_kConnType = { + "NV1": 1, + "NV2": 2, + "NODE": 3, + "X": -1, +} + +_kConnTypeStr = {val: key for key, val in _kConnType.items()} + + + +def get_topology(): + cmds = [ + 'nvidia-smi', + 'topo', + '-m', + ] + + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + outputs = [out for out in outputs if out.startswith('GPU')] + ngpus = len(outputs) + print(f'Detected GPU number: {ngpus}') + + topology = np.empty((ngpus, ngpus), dtype=int) + for src, output in enumerate(outputs): + connections = output.split('\t')[1:1+ngpus] + for dst, link in enumerate(connections): + link = link.replace(" ", "") + assert link in _kConnType, f"Find link not in DGX-1 topology: {link}" + topology[src, dst] = _kConnType[link] + return topology + + +def topology_repr(topology: np.ndarray, reorder: List[int]): + reorder = list(reorder) + ngpus = topology.shape[0] + reorder_topo = np.empty((ngpus, ngpus), dtype=object) + for src in range(ngpus): + for dst in range(ngpus): + link = _kConnTypeStr[topology[src, dst]] + reorder_topo[reorder.index(src), reorder.index(dst)] = link + maxlen = max(len(key) for key in _kConnType) + dscp = '' + for gidx, line in enumerate(reorder_topo): + dscp += f'GPU{gidx}: '+ ' '.join(link.ljust(maxlen) for link in line) + '\n' + return dscp + + +def reorder(topology: np.ndarray) -> np.ndarray: + """ + Reorder GPU according to DGX-1 topology + + โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” + 1 = 0 = 4 = 5 + โ€– x | | x โ€– + 2 = 3 = 7 = 6 + โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ + """ + ngpus = topology.shape[0] + # find NV2 ring + ring = [0] + while len(ring) < ngpus: + nv2s = np.where(topology[ring[-1]] == _kConnType['NV2'])[0] + find_next = False + for gid in nv2s: + if gid not in ring: + ring.append(gid) + find_next = True + break + assert find_next + ring = np.array(ring, dtype=int) + print(f'Get ring: {ring}') + # find fc + for idx, src in enumerate(ring): + is_fc = True + pairs = [ + (src, ring[(idx + 3) % len(ring)]), + (src, ring[(idx + 2) % len(ring)]), + (ring[(idx+1) % len(ring)], ring[(idx+3) % len(ring)]) + ] + for src, dst in pairs: + if topology[src, dst] != _kConnType['NV1']: + is_fc = False + break + if is_fc: + break + assert is_fc, f"Cannot find FC group." + ring = np.roll(ring, 0-idx) + return ring + + +if __name__ == '__main__': + topology = get_topology() + print('original topology:') + print(topology_repr(topology, list(range(topology.shape[0])))) + reorder = reorder(topology) + print('reorder topology:') + print(topology_repr(topology, reorder)) + print( + f"Command need to be added into environment:\n" + f"export CUDA_VISIBLE_DEVICES={','.join(str(gid) for gid in reorder)}" + ) diff --git a/utility/keep.py b/utility/keep.py new file mode 100644 index 00000000..d45a02fc --- /dev/null +++ b/utility/keep.py @@ -0,0 +1,72 @@ +import torch +import time +import argparse + +import subprocess +import re + +def get_gpu_util(rank): + from shutil import which + smi = None + if which('nvidia-smi') is not None: + smi = 'nvidia-smi' + elif which('rocm-smi') is not None: + smi = 'rocm-smi' + else: + raise Exception('Cannot find either nvidia-smi or rocm-smi!') + + cmds = [ + smi, + '-i', + str(rank), + ] + proc = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = proc.communicate() + outputs = stdout.decode('utf-8').split('\n') + + util = 0 + for output in outputs[::-1]: + # switch to performance line + if 'Default' in output: + # match all the numbers and return the last one + util = re.findall(r'\d+', output)[-1] + util = int(util) + break + else: + print("rank {}: couldn't match any, check GPU status!".format(rank)) + return util + + +def keep(rank, args): + + torch.cuda.set_device(rank) + a = torch.rand((8192, 8192)).cuda() + b = torch.rand((8192, 8192)).cuda() + + print(f'benchmarking {args.gpus} gpus...') + while True: + tic = time.time() + for _ in range(5000): + c = a * b + torch.cuda.synchronize() + toc = time.time() + # if rank == 0: + # print('benchmark 8K matmul: time span: {}ms'.format((toc - tic) * 1000 / 5000)) + time.sleep(args.interval) + while True: + util = get_gpu_util(rank) + if util <= 10: + break + # print('rank {}: find gpu busy, keep sleeping...'.format(rank)) + time.sleep(args.interval) + # print('rank {} gets up'.format(rank)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--interval', type=int, default=2) + parser.add_argument('--gpus', type=int, default=1) + args = parser.parse_args() + + torch.multiprocessing.spawn(keep, args=(args,), nprocs=args.gpus, join=True) diff --git a/utility/nightly_test/nightly_test.py b/utility/nightly_test/nightly_test.py new file mode 100644 index 00000000..99eb2b2d --- /dev/null +++ b/utility/nightly_test/nightly_test.py @@ -0,0 +1,228 @@ +from test_utils import TestUtils +from azure.communication.email import EmailClient +from subprocess import CalledProcessError +import subprocess +from datetime import datetime, timedelta +from pathlib import Path +import argparse +import base64 +import zipfile +import json +import os +import sys + +sender_address = "DoNotReply@ca1e34f6-1a6d-4181-8b16-692dbe193525.azurecomm.net" + +def zip_folder(folder_path, output_path): + """ Zip the folder to the output path + Args: + folder_path: the folder path + output_path: the output path + """ + with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + for root, _, files in os.walk(folder_path): + for file in files: + relative_path = os.path.relpath(os.path.join(root, file), os.path.dirname(folder_path)) + zipf.write(os.path.join(root, file), arcname=relative_path) + +def get_branch_commit(repo_path, branch_name = None, days_ago = 0): + """ Get the branch name or commit ID of the branch_name that is days_ago + Args: + repo_path: the path of the git repo + branch_name: the branch name, if not provided return the branch name + days_ago: the days ago, 0 means get current commit ID of the branch + Returns: + The branch name of the commit ID + """ + if branch_name is None: + git_command = 'git rev-parse --abbrev-ref HEAD' + return TestUtils.execute_command(git_command, repo_path) + elif days_ago == 0: + git_command = 'git rev-parse HEAD' + return TestUtils.execute_command(git_command, repo_path) + else: + before_date = (datetime.now() - timedelta(days=int(days_ago)) + timedelta(hours=15)).strftime('%Y-%m-%d %H:%M:%S') # add 15 hours to align with Beijing time + git_command = 'git fetch && git rev-list -n 1 --before="{}" {}'.format(before_date, branch_name) + return TestUtils.execute_command(git_command, repo_path) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Running Nightly Test') + parser.add_argument('-w', '--workspace', required=True, help='workspace for nightly test') + parser.add_argument('-d', '--data-path', required=True, help='dataset path') + + parser.add_argument('-n', '--nnscaler-commit-id', help='nnscaler commit id, decide the version of nnscaler for unit test and example parity-check') + + parser.add_argument('-u', '--unit-test', default=False, action=argparse.BooleanOptionalAction, help='unit test for nnscaler') + + parser.add_argument('-ep', '--example-parity-check', default=False, action=argparse.BooleanOptionalAction, help='example parity check for nnscaler. It will compare nnscaler or main with or main') + + # Keeping old argument name for compatibility if needed, but help text updated + parser.add_argument('-p2', '--parity-check2', dest='example_parity_check', action='store_true', help='Alias for --example-parity-check') + + parser.add_argument('-pb', '--parity-check-conda-base', help='base conda environment for parity check, needed if example-parity-check is True') + parser.add_argument('-ngt', '--cube-branch-gt', default='main', help='cube branch for ground truth, default is main') + + parser.add_argument('-e', '--email-connect-string', help='email connect string for sending email address') + parser.add_argument('-et', '--email-to', action='append', default=[], help='multiple -et will be combined') + parser.add_argument('-ec', '--email-cc', action='append', default=[], help='multiple -ec will be combined') + args = parser.parse_args() + return args + +if __name__ == "__main__": + args = parse_arguments() + workspace = Path(args.workspace).expanduser() + data_path = Path(args.data_path).expanduser() + if not workspace.exists(): + raise ValueError(f"Invalid workspace path: {workspace}") + if not data_path.exists(): + raise ValueError(f"Invalid data_path path: {data_path}") + log_folder = TestUtils.gen_log_folder(workspace) + + # Assuming nnscaler is cloned as "nnscaler" in the workspace + nnscaler_repo_path = workspace / "nnscaler" + pytest_dir = nnscaler_repo_path / "tests" + + script_dir = Path(__file__).parent.absolute() + parity_alert_script = script_dir / "parity_alert_examples/parity_alert.sh" + parity_check_cases_dir = script_dir / "parity_alert_examples/test_cases" + + if not pytest_dir.exists(): + raise ValueError(f"Invalid pytest_dir path: {pytest_dir}") + if not parity_alert_script.exists(): + raise ValueError(f"Invalid parity_alert_script path: {parity_alert_script}") + + if args.nnscaler_commit_id: + cmd = f"parallel-ssh -x -q -h ~/.pssh_hosts_files git -C {nnscaler_repo_path} checkout {args.nnscaler_commit_id}" + TestUtils.call([cmd]) + + with open(TestUtils.gen_log_folder(workspace) / "nightly_test.log", 'a') as nightly_test_file: + nnscaler_branch = get_branch_commit(nnscaler_repo_path) + nnscaler_commit_id = get_branch_commit(nnscaler_repo_path, nnscaler_branch) + nightly_test_file.write(f"nnscaler on branch {nnscaler_branch}, commit ID {nnscaler_commit_id}\n\n") + nightly_test_file.flush() + + # Run Unit Test + pytest_output = "" + if args.unit_test: + pytest_cmd = f"{sys.executable} -m pytest {pytest_dir}" + try: + pytest_log_file = log_folder / "pytest.log" + with open(pytest_log_file, 'w') as f: + # Run pytest from inside nnscaler repo + result = subprocess.run([sys.executable, '-m', 'pytest', '-v', str(pytest_dir)], stdout=f, stderr=f, cwd=nnscaler_repo_path) + if result.returncode != 0: + pytest_output = f"NNScaler Unit test didn't pass, see {pytest_log_file.name} for more details." + else: + pytest_output = "NNScaler Unit test passed" + except CalledProcessError as e: + pytest_output = f"Command {pytest_cmd} failed with error code {e.returncode}" + finally: + nightly_test_file.write(pytest_output + "\n") + + # Run Example Parity Check + parity_alert_output = "" + if args.example_parity_check: + tmp_parity_check = workspace / 'tmp_example_parity_check' + if os.path.isdir(tmp_parity_check): + import shutil + shutil.rmtree(tmp_parity_check) + + if not args.nnscaler_commit_id: + # If not specified, get the current one for consistency in logging/checking + args.nnscaler_commit_id = get_branch_commit(nnscaler_repo_path, "origin/main", 0) + + nightly_test_file.write(f"Example Parity check:\nnnscaler commit ID: {args.nnscaler_commit_id}" + "\n") + + parity_check_cmd = f"bash {parity_alert_script} {tmp_parity_check} {data_path} {parity_check_cases_dir} --cube-branch {args.nnscaler_commit_id} --cube-branch-gt {args.cube_branch_gt} --conda-base {args.parity_check_conda_base}" + + env = os.environ.copy() + # Assuming we might need to set PYTHONPATH if needed for some scripts, but usually the parity script handles env setup + # But let's keep consistency if we copied parity_alert_examples which relies on some imports + try: + parity_log_file = log_folder / "example_parity_check.log" + with open(parity_log_file, 'w') as f: + # CWD to the directory of parity_alert_examples for any relative path assumptions inside train.py potentially + cwd_path = script_dir / "parity_alert_examples" + result = subprocess.run(parity_check_cmd, stdout=f, stderr=f, shell=True, env=env, cwd=cwd_path) + if result.returncode != 0: + parity_alert_output = f"Example Parity Check didn't pass, see {parity_log_file.name} for more details." + else: + parity_alert_output = "Example Parity Check passed" + except CalledProcessError as e: + parity_alert_output = f"Command {parity_check_cmd} failed with error code {e.returncode}" + finally: + nightly_test_file.write(parity_alert_output + "\n") + + nightly_test_file.flush() + + # Send email + if args.email_connect_string: + if not args.email_to: + raise ValueError(f"Invalid email_to: {args.email_to}") + zip_output = log_folder.parent / 'nightly_test_logs.zip' + zip_folder(log_folder, zip_output) + with open(zip_output, "rb") as file: + zip_b64encoded = base64.b64encode(file.read()) + + html_output = """ + + + Test Results + + + + """ + + if args.unit_test: + pytest_html_message = f"""

NNScaler Unit Test

{pytest_output}

""" + html_output += pytest_html_message + + if args.example_parity_check: + parity_html_message = f"""

Example Parity Check

{parity_alert_output}

""" + html_output += parity_html_message + + html_output +="""""" + + message = { + "senderAddress": sender_address, + "recipients": { + "to": [{ "address": t } for t in args.email_to], + "cc": [{ "address": t } for t in args.email_cc] + }, + "content": { + "subject": "Nightly Test Notification", + "html": html_output + }, + "attachments": [ + { + "name": "attachment.zip", + "contentType": "application/zip", + "contentInBase64": zip_b64encoded.decode() + } + ] + } + + try: + POLLER_WAIT_TIME = 10 + client = EmailClient.from_connection_string(args.email_connect_string) + poller = client.begin_send(message) + time_elapsed = 0 + while not poller.done(): + poller.wait(POLLER_WAIT_TIME) + time_elapsed += POLLER_WAIT_TIME + if time_elapsed > 18 * POLLER_WAIT_TIME: + raise RuntimeError("Polling timed out.") + if poller.result()["status"] == "Succeeded": + nightly_test_file.write(f"Successfully sent the email (operation id: {poller.result()['id']})") + else: + raise RuntimeError(str(poller.result()["error"])) + except Exception as ex: + nightly_test_file.write(str(ex)) + else: + nightly_test_file.write("No email connection string provided, skip sending email") diff --git a/utility/nightly_test/parity_alert_examples/parity_alert.sh b/utility/nightly_test/parity_alert_examples/parity_alert.sh new file mode 100644 index 00000000..7df42bf2 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/parity_alert.sh @@ -0,0 +1,164 @@ +#!/bin/bash + +# For parity check. +# Example: +# bash parity_alert.sh [] +# : the workspace where all codes are stored. +# : the folder when the train data for torchscale is stored. +# : the definition of parity check. +# Default value is ${the dir of the current script}/test_cases/ +# Options: +# --cube-branch-gt : default is main +# --cube-branch : default is main +# --conda-base : default is base +# --test-cases : default is all +# The test cases are listed under (`test_cases/`) folder, e.g., pasdata, dp2, tp2, hybrid2. +# +# Currently the workspace is not cleared after execution, so it can help fix the parity problem if any. +# To clean the workspace +# 1. run `rm -rf ` to clean the cloned source code. +# 2. run `conda env remove -n parity` to remove conda env. + +set -e + +export NCCL_DEBUG=WARN + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +POSITIONAL_ARGS=() + +CUBE_BRANCH_GT=main + +CUBE_BRANCH_NEW=main + +CONDA_ENV_BASE=base +TEST_CASES= + +while [[ $# -gt 0 ]]; do + case $1 in + --cube-branch-gt) + CUBE_BRANCH_GT="$2" + shift # past argument + shift # past value + ;; + --cube-branch) + CUBE_BRANCH_NEW="$2" + shift # past argument + shift # past value + ;; + --conda-base) + CONDA_ENV_BASE="$2" + shift # past argument + shift # past value + ;; + --test-cases) + TEST_CASES="$2" + shift # past argument + shift # past value + ;; + -*|--*) + echo "Unknown option $1" + exit 1 + ;; + *) + POSITIONAL_ARGS+=("$1") # save positional arg + shift # past argument + ;; + esac +done + +set -- "${POSITIONAL_ARGS[@]}" # restore positional parameters + +OPERATION=$1 + +if [[ $# -ne 2 ]] && [[ $# -ne 3 ]]; then + echo "usage: $0 WORKSPACE TRAIN_DATA_DIR [PARITY_CHECK_DATA_DIR]" + echo " [--cube-branch-gt ]" + echo " [--cube-branch ]" + echo " [--conda-base ]" + echo " [--test-cases ]" + exit 1 +fi + + +WORKSPACE=$1 +TRAIN_DATA_DIR=$2 +PARITY_CHECK_DATA_DIR=${3:-${SCRIPT_DIR}/test_cases} + +if [[ -d $WORKSPACE ]]; then + echo "Error: $WORKSPACE has existed, please remove the folder before running the test(s)." + exit 2 +fi + + +ENV_NAME=parity_$(echo $RANDOM | md5sum | head -c 10) +TMP_SETUP_ENV_SH=tmp_setup_env.sh +TMP_SWITCH_BRANCH_SH=tmp_switch_branch.sh +TMP_MODEL_DIR=result_models # will not be removed after execution +# get an unused port +UNUSED_PORT=`python -c 'import socket; s=socket.socket(); s.bind(("", 0)); print(s.getsockname()[1]); s.close()'` + +conda create -y -n ${ENV_NAME} --clone ${CONDA_ENV_BASE} + +LIBSTDC_PATH=$(conda env list | grep ${ENV_NAME} | awk '{print $NF}')/lib/libstdc++.so.6 +rm -f ${LIBSTDC_PATH} + +trap "rm -rf tmp_* && conda env remove -n ${ENV_NAME} -y" EXIT + +cat > ${TMP_SETUP_ENV_SH} << EOF +#!/bin/bash + +set -e + +# init python env +pip install build + +mkdir -p ${WORKSPACE} +cd ${WORKSPACE} + +git clone --recursive "https://github.com/msrasys/nnscaler.git" -b $CUBE_BRANCH_GT +cd nnscaler +# Rename directory to match expected 'MagicCube' or just adapt strict usage. +# The original script used 'MagicCube' directory name. Let's stick effectively to cloning nnscaler. +# However, train.py and others might expect import structure. + +pip install -e . + +python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' +cd .. + +# verify installation +python -c 'import torch; import nnscaler; print(torch.__path__, nnscaler.__path__)' + +EOF + +cat > ${TMP_SWITCH_BRANCH_SH} << EOF +#!/bin/bash + +set -e + +cd ${WORKSPACE} + +cd nnscaler +git checkout $CUBE_BRANCH_NEW + +pip install -e . + +python -c 'import os,sys,nnscaler,cppimport.import_hook ; sys.path.append(os.path.dirname(nnscaler.__path__[0])) ; import nnscaler.autodist.dp_solver' + +cd .. +EOF + +export TEST_CASES="$TEST_CASES" +export TRAIN_DATA_DIR="$TRAIN_DATA_DIR" +export UNUSED_PORT="$UNUSED_PORT" +export DETERMINISTIC=1 + +conda run --no-capture-output -n ${ENV_NAME} bash ${TMP_SETUP_ENV_SH} + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/train.py ${WORKSPACE} ${PARITY_CHECK_DATA_DIR} ${TMP_MODEL_DIR}/gt + +conda run --no-capture-output -n ${ENV_NAME} bash ${TMP_SWITCH_BRANCH_SH} + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/train.py ${WORKSPACE} ${PARITY_CHECK_DATA_DIR} ${TMP_MODEL_DIR}/new + +conda run --no-capture-output -n ${ENV_NAME} python ${SCRIPT_DIR}/parity_check.py ${TMP_MODEL_DIR}/gt ${TMP_MODEL_DIR}/new diff --git a/utility/nightly_test/parity_alert_examples/parity_check.py b/utility/nightly_test/parity_alert_examples/parity_check.py new file mode 100644 index 00000000..94d852f8 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/parity_check.py @@ -0,0 +1,48 @@ +import os +from pathlib import Path +import sys + +import torch + + +def parity_check(task_name, ground_truth_model_file, new_model_file): + gt_ckpt = torch.load(ground_truth_model_file, map_location='cpu', weights_only=False) + new_ckpt = torch.load(new_model_file, map_location='cpu', weights_only=False) + if 'model' in gt_ckpt: + gt_model = gt_ckpt['model'] + new_model = new_ckpt['model'] + elif 'state_dict' in gt_ckpt: + gt_model = gt_ckpt['state_dict'] + new_model = new_ckpt['state_dict'] + for name in gt_model: + if not torch.allclose(gt_model[name], new_model[name], rtol=1e-06, atol=1e-06): + raise Exception(f'{task_name} failed: {name} mismatch (rtol=1e-06, atol=1e-06)') + print('All weights match (rtol=1e-06, atol=1e-06)') + + +def main(gt_dir: str, new_dir: str): + new_dir = Path(new_dir).absolute() + + test_cases = os.getenv('TEST_CASES') + if test_cases: + test_cases = test_cases.split(',') + print(f'Check test cases: {test_cases}') + else: + test_cases = None + print('Check all test cases') + passed = [] + for d in Path(gt_dir).glob('*'): + if not d.is_dir(): + continue + if not test_cases or d.name in test_cases: + print(f'Checking for {d.name}...') + parity_check(d.name, d / 'model.pt', new_dir / d.name / 'model.pt') + passed.append(d.name) + print(f'All passed: {passed}') + + +if __name__ == '__main__': + if len(sys.argv) !=3: + print('Usage: python check.py ') + exit(1) + main(sys.argv[1], sys.argv[2]) diff --git a/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml new file mode 100644 index 00000000..ffa2133e --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/llama/config.yaml @@ -0,0 +1,16 @@ +# NOTE: +# Must set HF_TOKEN +# Must install apex and flash-attn manually + +name: Llama 3 8B 128K +train: + path: nnscaler/examples/llama + output: ./merged.ckpt + commands: + - rm -rf .nnscaler ./checkpoints ./merged.ckpt + - pip install -r requirements.txt + - python bookcorpus.py --data_path_or_name bookcorpus/bookcorpus --tokenizer_path_or_name meta-llama/Meta-Llama-3-8B-Instruct --save_path ./bookcorpus_llama3_4K --sequence_length 4096 + - python create_mini_model.py --model_id meta-llama/Meta-Llama-3-8B-Instruct --output_id ./llama3_mini + - python train.py --run_mode compile --plan_ngpus 4 --runtime_ngpus 4 --name llama3_debug --model_id ./llama3_mini --attn_implementation sdpa --dataset_path ./bookcorpus_llama3_4K --max_train_steps 50 --pipeline_pivots LlamaDecoderLayer --pipeline_nstages 2 + - torchrun --nproc_per_node=4 train.py --plan_ngpus 4 --runtime_ngpus 4 --name llama3_debug --model_id ./llama3_mini --attn_implementation sdpa --dataset_path ./bookcorpus_llama3_4K --max_train_steps 50 --pipeline_pivots LlamaDecoderLayer --pipeline_nstages 2 + - python ckpt_merger.py --ckpt_dir ./checkpoints/last --output_fname ./merged.ckpt diff --git a/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml new file mode 100644 index 00000000..8dbcb120 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/llama3_demo/config.yaml @@ -0,0 +1,12 @@ +# NOTE: Must set HF_TOKEN + +name: Llama 3 demo +train: + path: nnscaler/examples/llama3_demo + output: checkpoints/merged.ckpt + commands: + - rm -rf .nnscaler ./checkpoints + - pip install -r requirements.txt + - python train.py --prepare_data --mini + - torchrun --nproc_per_node=4 train.py --mini --max_train_steps=50 + - python train.py --merge_checkpoint=./checkpoints/last diff --git a/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml b/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml new file mode 100644 index 00000000..6f1df5f1 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/test_cases/nanogpt/config.yaml @@ -0,0 +1,13 @@ +name: nanoGPT lightning +train: + path: nnscaler/examples/nanogpt + output: _merge/merged.pt + commands: + - rm -rf .nnscaler + - rm -rf lightning_logs _merge + - pip install -r requirements.txt + - python nanoGPT/data/shakespeare_char/prepare.py + - torchrun --nproc_per_node=2 train_nnscaler.py nanoGPT/config/train_shakespeare_char.py --max_iters=100 + - mkdir _merge + - cp lightning_logs/version_0/checkpoints/*/*.pt _merge + - python -c "from nnscaler.integration.lightning.pytorch import NnScalerStrategy ; NnScalerStrategy.merge_checkpoint(['_merge/0.pt', '_merge/1.pt'], '_merge/merged.pt')" diff --git a/utility/nightly_test/parity_alert_examples/train.py b/utility/nightly_test/parity_alert_examples/train.py new file mode 100644 index 00000000..c96e4b10 --- /dev/null +++ b/utility/nightly_test/parity_alert_examples/train.py @@ -0,0 +1,60 @@ +from pathlib import Path +from functools import partial +from subprocess import check_call as _call, check_output +import os +import sys + +import shutil +import yaml + +call = partial(_call, shell=True) + + +def train_model(config_dir: Path, save_dir: Path): + save_dir = save_dir / config_dir.name + save_dir.mkdir(parents=True, exist_ok=True) + config_file = config_dir / 'config.yaml' + + with open(config_file) as f: + config = yaml.safe_load(f) + + path = Path(config['train']['path']).absolute() + new_model = path / config['train']['output'] + env = {} + env.update(os.environ) + env.update({ + 'TRAIN_DATA_DIR': str(Path(os.getenv('TRAIN_DATA_DIR'))), + 'CONFIG_DIR': str(config_dir), + 'SAVE_DIR': str(save_dir), + 'RDZV_ENDPOINT': 'localhost:' + os.getenv('UNUSED_PORT'), + }) + env.update(config['train'].get('envs', {})) + for command in config['train']['commands']: + call(command, env=env, cwd=path) + shutil.copy2(new_model, save_dir / 'model.pt') + + +def main(workspace: str, parity_check_dir: str, parity_save_dir: str): + parity_check_root = Path(parity_check_dir).absolute() + parity_save_root = Path(parity_save_dir).absolute() + os.chdir(workspace) + test_cases = os.getenv('TEST_CASES') + if test_cases: + test_cases = test_cases.split(',') + print(f'Run test cases: {test_cases}') + else: + test_cases = None + print('Run all test cases') + for d in parity_check_root.glob('*'): + if not d.is_dir(): + continue + if not test_cases or d.name in test_cases: + print(f'Training for {d.name}...') + train_model(d, parity_save_root) + + +if __name__ == '__main__': + if len(sys.argv) !=4: + print('Usage: python train.py ') + exit(1) + main(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/utility/nightly_test/test_utils.py b/utility/nightly_test/test_utils.py new file mode 100644 index 00000000..bec73fcf --- /dev/null +++ b/utility/nightly_test/test_utils.py @@ -0,0 +1,183 @@ +from subprocess import CalledProcessError +import subprocess +import asyncio +import logging +import logging.handlers +from pathlib import Path +import os +import copy +import yaml + +logging.basicConfig( + filemode='a', + format="%(asctime)s | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), +) + +logger = logging.getLogger("CubeSystemTest") +result_logger = logging.getLogger("cube_system_test_results") +warning_logger = logging.getLogger("CompareInterface") + +# smaller buffer to quick output +buffer_handler = logging.handlers.BufferingHandler(10) +logger.addHandler(buffer_handler) +result_logger.addHandler(buffer_handler) + +global_time = None + +class TestUtils: + @staticmethod + def execute_command(cmd: str, cwd: str): + """Execute a command and log the output""" + try: + result = subprocess.check_output(cmd, shell=True, cwd=cwd).decode('utf-8').strip() + return result + except subprocess.CalledProcessError as e: + print("An error occurred while trying to execute:", cmd) + return None + + @staticmethod + def call(cmds): + """Call commands async and log the output""" + + if isinstance(cmds, str): + cmds = [cmds] + try: + results = asyncio.run(TestUtils.run_commands_async(cmds)) + for result in results: + stdout, stderr = result + if stdout: + logger.info(f'{stdout.decode()}') + if stderr: + err_msg = stderr.decode().strip() + if ("Traceback (most recent call last):" in err_msg): + result_logger.error(f'{err_msg}') + else: + logger.error(f'{err_msg}') + except CalledProcessError as e: + result_logger.error(f"Commands {cmds} failed with error code {e.returncode}") + raise + + @staticmethod + async def run_command_async(cmd: str): + """run a command async and return the output of stdout and stderr""" + logger.info(f"Running command: {cmd}") + proc = await asyncio.create_subprocess_shell( + cmd, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + stdout, stderr = await proc.communicate() + return stdout, stderr + + @staticmethod + async def run_commands_async(cmds: list): + """run commands async and return the output of stdout and stderr""" + tasks = [asyncio.ensure_future(TestUtils.run_command_async(cmd)) for cmd in cmds] + results = await asyncio.gather(*tasks) + return results + + @staticmethod + def get_ipv4_address(): + import re + interface_name = subprocess.check_output("route -n | grep '^0.0.0.0' | awk '{print $8}'", shell=True).decode().strip() + ifconfig_output = subprocess.check_output(f"ifconfig {interface_name}", shell=True).decode() + ip_address_match = re.search(r'inet (\S+)', ifconfig_output) + ip_pattern = re.compile( + r'^(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.' + r'(25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$' + ) + if ip_address_match: + ip_addr = ip_address_match.group(1) + if ip_addr and ip_pattern.match(ip_addr): + return ip_addr + + if os.getenv("CUBE_MASTER_ADDR"): + ip_addr = os.getenv("CUBE_MASTER_ADDR").strip() + if ip_addr and ip_pattern.match(ip_addr): + return ip_addr + + raise RuntimeError(f"cannot get ip address for interface {interface_name}, you can set master_addr manually by setting the environment variable CUBE_MASTER_ADDR") + + @staticmethod + def gen_log_folder(workspace): + global global_time + if global_time is None: + from datetime import datetime + now = datetime.now() + global_time = now.strftime("%Y%m%d_%H%M%S") + log_folder = Path(workspace) / 'cube_test_logs' / global_time + if not log_folder.exists(): + log_folder.mkdir(parents=True, exist_ok=True) + return log_folder + + @staticmethod + def parse_hosts_file(file_path): + file_path = Path(file_path).expanduser() + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + with open(file_path, 'r') as file: + lines = file.readlines() + ssh_host_list = [line.strip() for line in lines] + return ssh_host_list + + @staticmethod + def logger_redirect(logger1, log_folder, filename) -> tuple[str, logging.FileHandler]: + import logging.handlers + file_path = f"{log_folder}/{filename}.log" + result_handler = logging.FileHandler(file_path, 'a') + formatter = logging.Formatter("%(asctime)s | %(levelname)s | %(name)s | %(message)s") + result_handler.setFormatter(formatter) + logger1.addHandler(result_handler) + return file_path, result_handler + + @staticmethod + def merge_dict(dict_a, dict_b): + a = copy.deepcopy(dict_a) + b = copy.deepcopy(dict_b) + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + a[key] = TestUtils.merge_dict(a[key], b[key]) + elif b[key] is None or (b[key] == {}): + continue + else: + a[key] = b[key] + else: + a[key] = b[key] + return a + + @staticmethod + def merge_dicts(*dicts): + result = {} + for current_dict in dicts: + result = TestUtils.merge_dict(result, current_dict) + return result + + @staticmethod + def load_yaml_file(file_path) -> dict: + with open(file_path, 'r') as f: + element = yaml.safe_load(f) + if isinstance(element, dict): + TestUtils.recursive_replace_keys(element, '_', '-', 'fairseq') + TestUtils.recursive_replace_keys(element, '-', '_', 'torchrun') + TestUtils.recursive_replace_keys(element, '-', '_', 'envs') + return element + else: + raise ValueError(f"Invalid config_file {file_path}") + + @staticmethod + def recursive_replace_keys(d, old_char, new_char, target_key): + if target_key in d: + target_dict = d[target_key] + d[target_key] = {k.replace(old_char, new_char): v for k, v in target_dict.items()} + else: + for _, value in d.items(): + if isinstance(value, dict): + TestUtils.recursive_replace_keys(value, old_char, new_char, target_key) + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + TestUtils.recursive_replace_keys(item, old_char, new_char, target_key) \ No newline at end of file diff --git a/utility/prim_profiler.py b/utility/prim_profiler.py new file mode 100644 index 00000000..6daef5ff --- /dev/null +++ b/utility/prim_profiler.py @@ -0,0 +1,52 @@ +import torch +import os +import sys +import shutil +from datetime import datetime +import subprocess +import torch +import logging +from pathlib import Path +from nnscaler.autodist.util import get_node_arch, get_default_profile_path + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("nnscaler.comm_profiler") + + +def main(): + default_path = get_default_profile_path() + + if not default_path.is_dir(): + default_path.mkdir(parents=True) + logger.info(f'create folder: {default_path}') + else: + logger.info(f'folder already exists: {default_path}') + + comm_path = default_path / 'comm' + + if comm_path.is_dir(): + logger.info(f'back up legacy comm info: {comm_path}') + shutil.move( + comm_path, + default_path / f'comm_back_{str(datetime.now().timestamp())}') + comm_path.mkdir(parents=True, exist_ok=True) + + logger.info(f'CUDA device num: {torch.cuda.device_count()}') + profiler_fname = Path(__file__).parent / 'comm_profile.py' + device_num = 2 + while device_num <= torch.cuda.device_count(): + command = f'torchrun --master_port 21212 --nproc_per_node={device_num} {profiler_fname} --comm_profile_dir={comm_path}' + output = subprocess.check_output(command, shell=True, text=True) + device_num = device_num * 2 + + logger.info(f'comm profile done') + + +if __name__ == '__main__': + main() diff --git a/utility/test_rvd_prim.py b/utility/test_rvd_prim.py new file mode 100644 index 00000000..7e8251c2 --- /dev/null +++ b/utility/test_rvd_prim.py @@ -0,0 +1,137 @@ +""" +OMP_NUM_THREADS=4 torchrun \ + --nproc_per_node=8 \ + utility/test_rvd_prim.py --prims allreduce + +OMP_NUM_THREADS=4 torchrun \ + --nnode=2 --node_rank=$NODE_RANK --master_addr=node-0 \ + --nproc_per_node=8 \ + utility/test_rvd_prim.py --prims all +""" + +from typing import Callable +import nnscaler +import torch +import time +import argparse +from nnscaler.profiler.timer import CudaTimer, print_each_rank + +from nnscaler.runtime.adapter.collectives import all_reduce, all_gather, reduce_scatter, all_to_all +from nnscaler.runtime.device import DeviceGroup + + +def prim_allreduce(itensor, ranks, dim0=None, dim1=None): + return all_reduce(itensor, ranks) + + +def bw_allreduce(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * 2 * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_allgather(itensor, ranks, dim0=0, dim1=None): + return all_gather(itensor, dim0, ranks) + + +def bw_allgather(itensor: torch.Tensor, ranks, sec_per_call: float): + ndevs = len(ranks) + msg_size = itensor.nelement() * 4 / 1e9 * ndevs + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_reducescatter(itensor, ranks, dim0=0, dim1=None): + return reduce_scatter(itensor, dim0, ranks) + + +def bw_reducescatter(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_alltoall(itensor, ranks, dim0=0, dim1=1): + return all_to_all(itensor, dim0, dim1, ranks) + + +def bw_alltoall(itensor: torch.Tensor, ranks, sec_per_call: float): + msg_size = itensor.nelement() * 4 / 1e9 + ndevs = len(ranks) + algo_bw = msg_size / sec_per_call + bus_bw = algo_bw * (ndevs - 1) / ndevs + return algo_bw, bus_bw + + +def prim_bw(prim: Callable, bandwidth: Callable, ranks, size, warmup=100, profile=100): + if 'allgather' in prim.__name__: + size = size // len(ranks) + tensor: torch.Tensor = torch.zeros(size, device=torch.cuda.current_device()) + tensor = tensor.view(256, -1).contiguous() + torch.distributed.barrier() + # warm up + for _ in range(warmup): + _ = prim(tensor, ranks) + # profile + torch.cuda.synchronize() + torch.distributed.barrier() + tic = time.perf_counter() + for _ in range(profile): + _ = prim(tensor, ranks) + torch.cuda.synchronize() + toc = time.perf_counter() + + span = (toc - tic) / profile # seconds + msg_size = tensor.nelement() * 4 // 1024 // 1024 # MB + if 'allgather' in prim.__name__: + msg_size = len(ranks) * tensor.nelement() * 4 // 1024 // 1024 # MB + algo_bw, bus_bw = bandwidth(tensor, ranks, span) + print_each_rank( + '{} msg {} MB | wall-time(ms) algo-bw(GB/s) bus-bw(GB/s) {:.2f} {:.2f} {:.2f}'.format( + prim.__name__, msg_size, span*1000, algo_bw, bus_bw + ), rank_only=0 + ) + + +if __name__ == '__main__': + + nnscaler.init() + + parser = argparse.ArgumentParser(description='comm primitive') + parser.add_argument('--prims', type=str, nargs='+', action='append', + help='prims: all, allreduce, reducescatter, allgather, alltoall') + parser.add_argument('--begin', type=int, default=1, + help='start message size in MB') + parser.add_argument('--end', type=int, default=256, + help='end message size in MB') + args = parser.parse_args() + args.prims = args.prims[0] + + prims, bws = [], [] + if 'allreduce' in args.prims or 'all' in args.prims: + prims.append(prim_allreduce) + bws.append(bw_allreduce) + if 'allgather' in args.prims or 'all' in args.prims: + prims.append(prim_allgather) + bws.append(bw_allgather) + if 'reducescatter' in args.prims or 'all' in args.prims: + prims.append(prim_reducescatter) + bws.append(bw_reducescatter) + if 'alltoall' in args.prims or 'all' in args.prims: + prims.append(prim_alltoall) + bws.append(bw_alltoall) + + ranks = tuple(range(DeviceGroup().world_size)) + CudaTimer(enable=False) + for prim, bw in zip(prims, bws): + print_each_rank(f'====> test start {prim.__name__}', rank_only=0) + size = args.begin + while size <= args.end: + prim_bw(prim, bw, ranks, size * 1024 * 1024 // 4) + size *= 2 + print_each_rank(f'====> test finish {prim.__name__}', rank_only=0) diff --git a/utility/verify_ops/verify_dimops.py b/utility/verify_ops/verify_dimops.py new file mode 100644 index 00000000..bb88bc72 --- /dev/null +++ b/utility/verify_ops/verify_dimops.py @@ -0,0 +1,470 @@ +""" +This test verifies the correctness of an operator's annotation by running its distributed versions. +The processing pipeline is: +1. generate the input and calculate the output for the operator on a single device +2. construct the partition search space based on its annotation +3. for each partition choice, nnscaler will generate runnable code with communication adapters automatically +4. compare each distributed result with single device version, the difference should be less than a threshold +NOTE: only consider partitioning along one dimension currently +""" + +import os +from typing import Dict, List, Tuple, Any, Union +from dataclasses import dataclass, field +import logging +import subprocess +import torch + +from nnscaler.graph.function.dimops import IRDimops, OpAnno, DimAnno +from nnscaler.ir.cten import IRTensor, IRObject + + +logger = logging.getLogger(__name__) + + +_SINGLE_GPU_TEST_FILE = "single_gpu_test.py" +_TWO_GPUS_TEST_FILE = "two_gpus_test.py" + +module_template_common = """ +import os +import numpy +import sys +import torch +import nnscaler + +from nnscaler.graph import IRGraph +from nnscaler.ir.operator import IRFwOperation +from nnscaler.parallel import parallelize, ComputeConfig, ReuseType +{import_cumsomized_func} + +import nnscaler.graph +import nnscaler.graph.function +import nnscaler.graph.function.wrapnn + +import torch +import numpy as np +import random + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + + def forward(self, {args}): + # Add clone to resolve the issue: + # a leaf Variable that requires grad is being used in an in-place operation. + {clone_args} + + {func_sig_call} + + out = 0 + for one_out in [{outputs}]: + if not isinstance(one_out, torch.Tensor): + continue + out += torch.sum(one_out) + return out + +model = TestModule() #.to(torch.float16) +""" + +module_template_single_main = """ +# Load inputs from file, ensuring inputs.pt is always a tuple, even when there's only one input +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device('cuda:0')) + +model = model.cuda() + +single_loss = model({args}) +single_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, single_loss], '{func_sig}_loss_single.pt') +print('single gpu loss: ', single_loss) +""" + +module_template_single = module_template_common + module_template_single_main + +module_template_parallel_main = """ +nnscaler.init() +rank_id = torch.distributed.get_rank() + +{args}, = torch.load('{func_sig}_inputs.pt', map_location=torch.device(f'cuda:{{rank_id}}')) + +def policy(graph: IRGraph, resource) -> IRGraph: + ngpus = 2 + partitioned = False + + for idx, node in enumerate(graph.select(ntype=IRFwOperation)): + if not partitioned and node.signature == '{func_sig}': + print('Partitioned node: ', node) + sub_nodes = graph.partition( + node, node.algorithm('dim'), idx={idx}, dim={dim}, num=ngpus) + partitioned = True + else: + sub_nodes = graph.replicate(node, times=ngpus) + for idx, sub_node in enumerate(sub_nodes): + graph.assign(sub_node, idx) + + assert partitioned, f'No node is partitioned for {func_sig}.' + return graph + +parallel_model = parallelize( + model, + dummy_forward_args={dummy_input_str}, + pas_policy=policy, + compute_config=ComputeConfig(2, 2), + reuse=ReuseType.OVERRIDE +) + +parallel_model.train() + +parallel_loss = parallel_model({args}) +parallel_loss.backward() + +grad_tensors = {grad_tensors} +torch.save([grad_tensors, parallel_loss], '{func_sig}_loss_para_'+str(rank_id)+'.pt') +print('two gpus loss: ', parallel_loss) +""" + +module_template_parallel = module_template_common + module_template_parallel_main + + +@dataclass +class TensorInfo: + value_form: str # 'shape' or 'value' + value: Union[Tuple[int], Any] + dtype: torch.dtype = torch.float32 + requires_grad: bool = True + + # make TensorInfo hashable + def __hash__(self): + value = self.value + if isinstance(value, slice): + value = (value.start, value.stop, value.step) + return hash((self.value_form, value)) + + +@dataclass +class VerifyConfig: + fsig: str + args: List[TensorInfo] + kwargs: Dict[str, Any] + noutputs: int + parti_options: List[Dict[str, int]] + import_customized_func: str = "" + non_grad_indices: List[int] = field(default_factory=list) + + +def _complex(val: Any): + """ + Convert IRObject to concrete value + NOTE: only used for handling kwargs + """ + if isinstance(val, tuple): + return tuple(_complex(t) for t in val) + if isinstance(val, list): + return list(_complex(t) for t in val) + if isinstance(val, dict): + return {_complex(key): _complex(val) for key, val in val.items()} + if isinstance(val, slice): + return slice(_complex(val.start), _complex(val.stop), _complex(val.step)) + if isinstance(val, IRObject): + assert not isinstance(val, IRTensor), "IRTensor should not be in kwargs" + return _complex(val.value) + return val + + +def get_candidate_options( + anno: OpAnno, ins_outs_shape: List[TensorInfo], npartitions: int = 2 +) -> List[Dict[str, int]]: + """ + Get all the feasible partitions specified by the annotation of an operator. + Checks whether the dimension can be divided, and also checks whether the size of the dimension can be evenly divided by the number of partitions + Args: + anno (OpAnno): operator annotation + ins_outs_shape (List[TensorInfo]): input and output shapes + npartitions (int, optional): number of partitions. Defaults to 2. + Returns: + List[Dict[str, int]]: a list of feasible partitions + + """ + all_configs = anno.transform_space() + + candidate_partitions = [] + for idx, dim in all_configs: + if ( + ins_outs_shape[idx].value_form == "shape" + and ins_outs_shape[idx].value[dim] % npartitions == 0 + ): + candidate_partitions.append({"idx": idx, "dim": dim}) + + return candidate_partitions + + +def handle_buffer_parameters(inputs, non_grad_indices): + """ + Detach specified buffer parameters from the computational graph and disable their gradient computation. + This is necessary for parameters that should not participate in the backward pass, + such as statistical parameters in certain layers (e.g., running_mean in normalization layers). + + Args: + inputs (List[torch.Tensor]): The list of input tensors. + non_grad_indices (List[int]): The indices of buffer parameters in the input list. + """ + for idx in non_grad_indices: + if inputs[idx] is not None: + inputs[idx] = inputs[idx].detach() + inputs[idx].requires_grad = False + + +def _create_op_inputs(verify_config: VerifyConfig) -> List[Any]: + """ + Create input tensors/non-tensors for the operator. + The input tensors/non-tensors are only for args, not for kwargs. + Args: + verify_config (VerifyConfig): configuration for verifying the partitions + Returns: + List[Any]: input tensors + """ + torch.manual_seed(0) + inputs = [] + + def process_slice(slice_obj): + start = ( + slice_obj.start.value + if isinstance(slice_obj.start, IRObject) + else slice_obj.start + ) + stop = ( + slice_obj.stop.value + if isinstance(slice_obj.stop, IRObject) + else slice_obj.stop + ) + step = slice_obj.step + return slice(start, stop, step) + + for i, tensor_info in enumerate(verify_config.args): + if tensor_info.value_form == "shape": + # Special handling: For torch. rsqrt, generate random integers between 1 and 10 to avoid invalid values + if verify_config.fsig == "torch.rsqrt": + inputs.append( + torch.randint( + 1, + 10, + tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + # Special handling: for the first parameter of torch.where which is a boolean mask + elif verify_config.fsig == "torch.where" and i == 0: + inputs.append( + torch.rand( + *tensor_info.value, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ) + > 0.5 + ) + elif verify_config.fsig == "torch.add" and tensor_info.value == (1,): + # Special handling:add in the model generates values that cannot be partitioned + inputs.append(torch.randn(4, dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad)) + else: + if tensor_info.value == (): + inputs.append( + torch.randn( + (), dtype=tensor_info.dtype, requires_grad=tensor_info.requires_grad + ).squeeze() + ) + else: + inputs.append( + torch.randn( + *tensor_info.value, + dtype=tensor_info.dtype, + requires_grad=tensor_info.requires_grad, + ) + ) + elif tensor_info.value_form == "value" and isinstance(tensor_info.value, slice): + inputs.append(process_slice(tensor_info.value)) + else: + inputs.append(tensor_info.value) + if verify_config.non_grad_indices: + handle_buffer_parameters(inputs, verify_config.non_grad_indices) + return inputs + + +def verify_partition_options(verify_config: VerifyConfig) -> bool: + errors = [] + try: + logger.info(f"Verifying partitions of {verify_config.fsig}...") + inputs = _create_op_inputs(verify_config) + torch.save(inputs, f"{verify_config.fsig}_inputs.pt") + logger.info(f"Input tensors saved to {verify_config.fsig}_inputs.pt") + + outputs_str = ", ".join([f"_out{i}" for i in range(verify_config.noutputs)]) + + kwargs_str = ", ".join( + [ + f'{k}="{v}"' if isinstance(v, str) else f"{k}={_complex(v)}" + for k, v in verify_config.kwargs.items() + ] + ) + + func_sig_call = verify_config.fsig + args_str = ", ".join([f"_in{i}" for i in range(len(verify_config.args))]) + tensor_member_methods_prefix = 'torch.Tensor.' + if func_sig_call.startswith(tensor_member_methods_prefix): + # workaround because tracer does not support tensor member methods + func_sig_call = f'_in0.' + func_sig_call[len(tensor_member_methods_prefix):] + func_args_str = ", ".join([f"_in{i}" for i in range(1, len(verify_config.args))]) + else: + func_args_str = args_str + + if func_args_str: + func_call = f"{outputs_str} = {func_sig_call}({func_args_str}, {kwargs_str})" + else: + func_call = f"{outputs_str} = {func_sig_call}({kwargs_str})" + + clone_args_right = ", ".join( + [ + f"_in{i}.clone()" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + if clone_args_right: + clone_args_left = ", ".join( + [ + f"_in{i}" + for i, tinfo in enumerate(verify_config.args) + if tinfo.value_form == "shape" + ] + ) + clone_args = f"{clone_args_left} = {clone_args_right}" + else: + clone_args = "" + + dummy_input_str = ( + "{" + + ", ".join([f'"_in{i}": _in{i}' for i in range(len(verify_config.args))]) + + "}" + ) + + grad_tensors = ( + "[" + + ", ".join( + [ + f"_in{i}.grad" + for i in range(len(verify_config.args)) + if i not in verify_config.non_grad_indices + and verify_config.args[i].value_form == "shape" + ] + ) + + "]" + ) + module_single_str = module_template_single.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + grad_tensors=grad_tensors, + ) + with open(_SINGLE_GPU_TEST_FILE, "w") as f: + f.write(module_single_str) + logger.info("Generated test code for single gpu and running...") + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_single.pt"]) + subprocess.run(["python", _SINGLE_GPU_TEST_FILE]) + logger.info( + f"Single GPU test completed. Output saved to {verify_config.fsig}_loss_single.pt" + ) + logger.info(f"verify_config: {verify_config}") + logger.info(f"verify_config.parti_options: {verify_config.parti_options}") + + for poption in verify_config.parti_options: + try: + logger.info(f"Verifying the partition {poption}...") + module_para_str = module_template_parallel.format( + import_cumsomized_func=verify_config.import_customized_func, + clone_args=clone_args, + args=args_str, + kwargs=kwargs_str, + func_sig=verify_config.fsig, + func_sig_call=func_call, + outputs=outputs_str, + dummy_input_str=dummy_input_str, + grad_tensors=grad_tensors, + idx=poption["idx"], + dim=poption["dim"], + ) + with open(_TWO_GPUS_TEST_FILE, "w") as f: + f.write(module_para_str) + logger.info("Generated test code for two gpus.") + + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_0.pt"]) + subprocess.run(["rm", "-f", f"{verify_config.fsig}_loss_para_1.pt"]) + subprocess.run( + [ + "torchrun", + "--nproc_per_node=2", + "--nnodes=1", + "--rdzv-endpoint=localhost:23457", + _TWO_GPUS_TEST_FILE, + ] + ) + logger.info( + f"Two GPU test completed. Outputs saved to {verify_config.fsig}_loss_para_0.pt and {verify_config.fsig}_loss_para_1.pt" + ) + single = torch.load(f"{verify_config.fsig}_loss_single.pt") + logger.info( + f"Loading single loss from: {verify_config.fsig}_loss_single.pt" + ) + para0 = torch.load(f"{verify_config.fsig}_loss_para_0.pt") + para1 = torch.load(f"{verify_config.fsig}_loss_para_1.pt") + + logger.info(f"Single loss: {single[1]}") + logger.info(f"Multi-GPU loss (para0): {para0[1]}") + logger.info(f"Multi-GPU loss (para1): {para1[1]}") + + assert torch.allclose( + single[1], para0[1], rtol=1e-3, atol=1e-5 + ), f"Loss mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[1], para1[1].to(para0[1]) + ), f"Loss mismatch between multi-GPU (para0 and para1)" + + for i in range(len(single[0])): + if single[0][i] is None or para0[0][i] is None: + logger.debug( + f"Skipping comparison for index {i} because it is None" + ) + continue + logger.debug(f"Absolute error: {single[0][i] - para0[0][i]}") + logger.debug( + f"Relative error: {(single[0][i] - para0[0][i]) / single[0][i]}" + ) + assert torch.allclose( + single[0][i], para0[0][i], rtol=1e-3, atol=1e-5 + ), f"Gradient mismatch between single and multi-GPU (para0)" + assert torch.equal( + para0[0][i], para1[0][i].to(para0[0][i]) + ), f"Gradient mismatch between multi-GPU (para0 and para1)" + + logger.info( + f"{verify_config.fsig} of partition {poption} passed the allclose comparison." + ) + except Exception as e: + error_message = f"Partition {poption} failed with error: {str(e)}" + logger.error(error_message) + errors.append(error_message) + if errors: + logger.error("Some partitions failed:") + for error in errors: + logger.error(error) + return False + else: + logger.info( + f"Verified all the partitions of {verify_config.fsig} successfully." + ) + return True + except Exception as e: + logger.exception("Exception occurred during verification process") + raise e diff --git a/utility/verify_ops/verify_graph_operations.py b/utility/verify_ops/verify_graph_operations.py new file mode 100644 index 00000000..680d3fc1 --- /dev/null +++ b/utility/verify_ops/verify_graph_operations.py @@ -0,0 +1,161 @@ +import argparse +import os +import sys +import torch +from nnscaler.graph.function.dimops import DimAnno, IRDimops, OpAnno +from nnscaler.graph.graph import IRGraph +from nnscaler.ir.cten import IRObject, IRTensor +from pathlib import Path +import logging + +from verify_dimops import TensorInfo, get_candidate_options + +_VERIFIED_OPS_FILE_NAME = "verified_ops.pt" +_DEFAULT_CACHE_DIR = Path(os.path.expanduser("~/.cache/nnscaler")) + + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_verified_ops(outdir: Path): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + if verified_ops_file.exists(): + logger.info(f"{verified_ops_file} exists, load it.") + return torch.load(verified_ops_file) + else: + logger.info(f"{verified_ops_file} does not exist, start from scratch.") + return set() + + +def save_verified_ops(outdir: Path, verified_ops: set): + verified_ops_file = outdir / _VERIFIED_OPS_FILE_NAME + torch.save(verified_ops, verified_ops_file) + logger.info(f"Verification results saved to {verified_ops_file}") + + +def verify_op_partitions(graph: IRGraph, outdir: Path): + """ + Test if the partitioned ops in the graph are computationally correct. + + Args: + graph (IRGraph): the graph to be verified + outdir (Path): the directory to save the verified ops + + Returns: + None + """ + from verify_dimops import ( + VerifyConfig, + TensorInfo, + verify_partition_options, + ) + + verified_ops = load_verified_ops(outdir) + skipped_nodes = [] + + gnodes = graph.nodes(flatten=True) + for idx, node in enumerate(gnodes): + logger.info(f"node: {node}") + logger.info(f"Verification progress: {idx} / {len(gnodes)}") + if node.isfw() and isinstance(node, IRDimops): + ins_info = [ + ( + TensorInfo("shape", _input.shape) + if isinstance(_input, IRTensor) + else TensorInfo( + "value", + _input.value if isinstance(_input, IRObject) else _input, + ) + ) + for _input in node.inputs() + ] + if not ins_info: + skipped_nodes.append(f"{node.signature} (type: {type(node)})") + logger.info(f"ins_info is empty for node: {node.signature}, skipping.") + continue + + outs_info = [ + ( + TensorInfo("shape", output.shape) + if isinstance(output, IRTensor) + else TensorInfo( + "value", + output.value if isinstance(output, IRObject) else output, + ) + ) + for output in node.outputs() + ] + if (node.signature, tuple(ins_info + outs_info)) in verified_ops: + logger.info(f"{node.signature} has been verified before, skip.") + continue + + logger.info(f"Node annos: {node.signature}, {node.anno}") + + parti_options = get_candidate_options(node.anno, ins_info + outs_info) + + logger.info(f"Candidate partition options: {parti_options}") + + verify_config = VerifyConfig( + fsig=node.signature, + args=ins_info, + kwargs=node.kwargs, + noutputs=len(node.outputs()), + parti_options=parti_options, + ) + try: + iscorrect = verify_partition_options(verify_config) + except Exception as e: + logger.warning( + f"Verification failed for {node.signature}, {e}, please manually verify." + ) + iscorrect = True # fake true to skip this node + if not iscorrect: + logger.warning(f"Verification failed for {node.signature}, continuing execution.") + continue + + verified_ops.add((node.signature, tuple(ins_info + outs_info))) + save_verified_ops(outdir, verified_ops) + + if skipped_nodes: + logger.info("Skipped the following nodes due to empty ins_info:") + for node_info in skipped_nodes: + logger.info(f" - {node_info}") + +def main(): + parser = argparse.ArgumentParser( + description="Verify partitions of operations in an IRGraph." + ) + parser.add_argument( + "--graph", type=str, required=True, help="Path to the graph file." + ) + parser.add_argument( + "--outdir", + type=str, + help="Optional directory to save the verified operations. If not provided, results will be saved to the default cache directory.", + ) + + args = parser.parse_args() + + graph_path = Path(args.graph) + if not graph_path.exists(): + raise FileNotFoundError(f"Graph file {graph_path} does not exist.") + + graph = IRGraph.load(graph_path) + + if args.outdir: + outdir = Path(args.outdir) + else: + outdir = _DEFAULT_CACHE_DIR + + outdir.mkdir(parents=True, exist_ok=True) + verify_op_partitions(graph, outdir) + + +if __name__ == "__main__": + main() diff --git a/utility/visualize_value_tracks.py b/utility/visualize_value_tracks.py new file mode 100644 index 00000000..164b838e --- /dev/null +++ b/utility/visualize_value_tracks.py @@ -0,0 +1,158 @@ +import argparse +import matplotlib.pyplot as plt +from nnscaler.graph import IRGraph +from matplotlib.patches import FancyArrowPatch +from nnscaler.ir.cten import IR, IRTensor, IRObject + + +class Visualizer: + NUM_ROWS_PER_OP = 3 + TEXT_HEIGHT_IN_INCH = 0.4 + PER_OP_GAP_IN_INCH = 0.2 + PER_ROW_HEIGHT_IN_INCH = TEXT_HEIGHT_IN_INCH * 1.1 + PER_OP_HEIGHT_IN_INCH = PER_ROW_HEIGHT_IN_INCH * NUM_ROWS_PER_OP + PER_INOUT_GAP = 0.01 + + INIT_Y = 0.001 + INIT_X = 0.001 + + def __init__(self, graph): + self.graph = graph + self.value_loc = {} + self.ops = [node for node in self.graph.nodes() if node.isfw()] + + self.fig_heigth_in_inch = ( + self.PER_OP_HEIGHT_IN_INCH + self.PER_OP_GAP_IN_INCH + ) * (len(self.ops) + 1) + self.coord_per_inch = 1.0 / self.fig_heigth_in_inch + self.per_op_height = self.PER_OP_HEIGHT_IN_INCH * self.coord_per_inch + self.per_row_height = self.per_op_height / self.NUM_ROWS_PER_OP + self.per_op_gap = self.PER_OP_GAP_IN_INCH * self.coord_per_inch + + self.fig, self.ax = plt.subplots(figsize=(30, self.fig_heigth_in_inch)) + self.ax.axis('off') + self.ax.invert_yaxis() + + def draw_value(self, value, value_track, cur_x, cur_y, previous_value_loc): + t = self.ax.text(cur_x, cur_y, str(value), + fontsize=14, ha="left", va="top") + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + if value_track is not None: + if value_track.value_id in previous_value_loc: + prev_x, prev_y = previous_value_loc[value_track.value_id] + arrow = FancyArrowPatch( + (prev_x, prev_y), + (cur_x + bbox.width/2, cur_y), + arrowstyle="Simple,tail_width=0.25,head_width=1,head_length=1", + mutation_scale=6, + color="#2c7bb6", + linewidth=0.02, + connectionstyle="arc3,rad=0", + alpha=0.5, + zorder=4 + ) + self.ax.add_patch(arrow) + self.value_loc[value_track.value_id] = (cur_x + bbox.width/2, cur_y) + + cur_x += bbox.width + self.PER_INOUT_GAP/2 + return cur_x + + def draw_obj(self, obj, cur_x, cur_y, previous_value_loc): + if isinstance(obj, IRTensor): + cur_x = self.draw_value('T(', None, cur_x, cur_y, previous_value_loc) + for i, d in enumerate(obj.shape): + if i > 0: + cur_x = self.draw_value(',', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(d), obj.dim_tracks[i], cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + else: + assert isinstance(obj, IRObject) + cur_x = self.draw_value('O(', None, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(str(obj.value), obj.value_track, cur_x, cur_y, previous_value_loc) + cur_x = self.draw_value(')', None, cur_x, cur_y, previous_value_loc) + cur_x += self.PER_INOUT_GAP + return cur_x + + def draw_objs(self, objs, cur_x, cur_y): + previous_value_loc = dict(self.value_loc) + for inp in objs: + cur_x = self.draw_obj(inp, cur_x, cur_y, previous_value_loc) + + def draw_graph_inputs(self, g, cur_x, cur_y): + label = "GRAPH IN: " + t = self.ax.text(cur_x, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = cur_x + bbox.width + self.PER_INOUT_GAP + + ir_objs = [] + for inp in g.inputs(): + if isinstance(inp, (IRObject, IRTensor)): + ir_objs.append(inp) + elif isinstance(inp, IRObject): + sub_objs = IR.get_objects(inp.value) + if sub_objs: + ir_objs.extend(sub_objs) + else: + ir_objs.append(inp) + + self.draw_objs(ir_objs, cur_x, cur_y) + + def draw_inout(self, node, cur_y, is_in): + if is_in: + ir_objs = node.iobjs() + label = "IN: " + cur_y += self.per_row_height + else: + ir_objs = node.oobjs() + label = "OU: " + cur_y += self.per_row_height * 2 + + t = self.ax.text(self.INIT_X, cur_y, label, + fontsize=14, fontweight="bold", ha="left", va="top") + + bbox = t.get_window_extent() + bbox = bbox.transformed(self.ax.transData.inverted()) + cur_x = self.INIT_X + bbox.width + self.PER_INOUT_GAP + + self.draw_objs(ir_objs, cur_x, cur_y) + + def visualize(self): + self.draw_graph_inputs(self.graph, self.INIT_X, self.INIT_Y) + cur_y = self.INIT_Y + (self.per_op_height + self.per_op_gap)/2 + + for node in self.ops: + op_name = node.name + self.ax.text(self.INIT_X, cur_y, op_name + ":", + fontsize=16, fontweight="bold", ha="left", va="top") + + self.draw_inout(node, cur_y, is_in=True) + self.draw_inout(node, cur_y, is_in=False) + + cur_y += self.per_op_height + self.per_op_gap + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + 'graphfile', + type=str, + help="Graph dump file" + ) + parser.add_argument( + 'imagefile', + type=str, + nargs='?', + default=None, + help="Save generated image to file" + ) + args = parser.parse_args() + g = IRGraph.load(args.graphfile) + visualizer = Visualizer(g) + visualizer.visualize() + if args.imagefile: + plt.savefig(args.imagefile, bbox_inches='tight', dpi=100) + plt.show()