A multi-GPU "dense-equivalent" training method using semi-structured 2:4 sparse kernels.
- Installation
- Quick Start
- Elevator Pitch
- Repo Architecture
- Usage Examples
- Project Structure
- Contributing
- License
- Python 3.8 or higher
- PyTorch 2.0+ with CUDA support
- NVIDIA GPU with Tensor Core support (Ampere or newer recommended for optimal 2:4 sparse performance)
- 2 GPUs for multi-GPU experiments
- CUDA 11.8 or higher
- Clone the repository:
git clone https://github.com/MathewYoussef/Block-Shock.git
cd Block-Shock- Install dependencies:
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt- Verify installation:
python verify_setup.pyThis will check that all dependencies are installed and your environment is properly configured.
The easiest way to get started is with the example scripts:
# Run a simple correctness check (requires 2 GPUs)
torchrun --standalone --nproc_per_node=2 examples/simple_correctness_check.py
# Benchmark single GPU dense baseline
python examples/benchmark_comparison.py --method dense_single --N 4096
# Benchmark Block-Shock (requires 2 GPUs)
torchrun --standalone --nproc_per_node=2 examples/benchmark_comparison.py --method block_shock --N 8192See examples/README.md for more details.
Run a simple correctness check with dense single-GPU baseline:
python -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase0_correctness.yaml \
--method configs/methods/dense_single.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlRun a 2-GPU Block-Shock forward benchmark:
torchrun --standalone --nproc_per_node=2 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase1_forward.yaml \
--method configs/methods/block_shock_2gpu.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlYou want dense model capacity (all weights exist and get updated), but you want to ride NVIDIA's 2:4 sparse Tensor Core / cuSPARSELt path.
Write the dense weight as a sum of multiple 2:4-sparse matrices:
W = W(0) + W(1)
Each W(g) is 2:4 sparse (50% zeros in each 4-wide group) and is placed on a different GPU. Then:
y = x W^T
= x (W(0) + W(1))^T
= x W(0)^T + x W(1)^T
Compute each term on its GPU with semi-structured sparse matmul, and sum the outputs with an all-reduce. This is the core Block-Shock "dense by superposition" trick.
This is structurally aligned with how semi-structured kernels work: sparse weights stored in a compressed form plus a metadata mask.
- Semi-structured 2:4 is fixed at 50% sparsity, with a theoretical 2x improvement in the ideal case.
- PyTorch's semi-structured format explicitly targets compressed weights and sparse GEMM dispatch (cuSPARSELt) rather than "masked dense."
- For inference, the workflow is explicitly: prune/mask -> compress -> run sparse kernel.
- You do two sparse matmuls (one per GPU) plus a reduction, instead of one dense matmul. Communication and overhead can eat your win.
- Training adds another pain point: weights change each step, so "compress once" stops being true.
- Semi-structured ops support is limited to a specific set (mm/addmm/linear and transposes).
- Speedups vary a lot by architecture and workload; even PyTorch's own tutorial warns speedups may differ.
Phases 0-3 share a single experiment engine. Each phase is just a different pipeline of toggles:
- Phase 0: forward correctness (small N) + equivalence checks
- Phase 1: forward throughput (timing only)
- Phase 2: forward + backward wrt input (weights frozen)
- Phase 3: full training step (forward + backward + optimizer + optional recompress)
Key directories:
configs/: YAML stack (base + sweeps + phases + methods + masks + workloads + hardware)src/: core runner and shared modules used by all phasesresults/: raw runs, tables, and plotsanalysis/: aggregation and plotting scripts
A run is one invocation of the experiment runner with a fully resolved config. Each run produces a unique run_id, a config snapshot, environment metadata, and one or more metrics records.
Each run writes to results/raw/<phase>/<method>/<run_id>/ with:
config.yaml(resolved config)env.json(hardware + software metadata)seed.txt(seed used for reproducibility)metrics.jsonl(one JSON record per measurement block)
Aggregated tables and plots are written to results/tables/ and results/plots/.
Official runs
To keep an official record under version control, run with configs/official.yaml. Each run writes to:
results/official/runs/<run_group>/<phase>/<method>/<run_id>/
run_group is auto-generated (UTC timestamp) so new runs never collide.
Sweeps write to:
results/official/sweeps/<tag>/<phase>/<method>/<run_id>/
Configs are merged in order (later files override earlier keys):
configs/base.yaml- one phase config from
configs/phases/ - one method config from
configs/methods/ - one workload config from
configs/workloads/ - one hardware config from
configs/hardware/ - optional sweep config from
configs/sweeps/
- Dense single GPU (
configs/methods/dense_single.yaml) - Dense TP 2-GPU (
configs/methods/dense_tp.yaml) - Masked split dense (ablation) (
configs/methods/masked_split_dense.yaml) - Block-Shock 2-GPU (
configs/methods/block_shock_2gpu.yaml)
Block-Shock writes a dense weight as the sum of multiple 2:4-sparse matrices placed on different GPUs, computes each sparse matmul with semi-structured kernels, then all-reduces the partial outputs. This preserves dense capacity while attempting to exploit the 2:4 sparse Tensor Core path.
Note: The mask section appears in merged configs because it is set in configs/base.yaml. It is ignored by non-sparse baselines (e.g., dense_single, dense_tp) and only used by sparse/masked methods.
Note: For TP baselines, inputs are currently broadcast from rank 0 to ensure correctness. A future improvement is to generate X once on rank 0 and scatter feature shards to each GPU.
Phase 0 compares each method against a single dense reference:
- Reference: dense single-GPU
F.linearusing weightW(and optional biasb). - Test: each method computes its output on the same input
Xand sameW.
Determinism rules:
Xis generated with a fixed seed.W(and optionalb) are generated with the same seed for both reference and test.- Exact zeros in
Ware nudged toepsto avoid accidental 1-of-4 blocks in 2:4 validation; this is applied consistently across methods. - This ensures exact reproducibility across runs.
Memory metrics:
- Runs log per-method weight storage estimates (e.g.,
weight_bytes_total,weight_bytes_sparse_est) to help attribute memory overhead alongside timing. - For large-N Phase 1 sweeps, you can enable
method.drop_full_weight: trueto release the denseW_fullafter shards/compressed weights are created (kept when correctness checks are enabled).
Comparison metrics:
max_abs_errormean_abs_errormax_rel_error(with configurablephase.rel_eps)passed(boolean, based onphase.tol_max_absandphase.tol_max_rel)
Phase 0 also runs configurable warmup and timed iterations:
phase.warmup_iters(default 10)phase.timed_iters(default 100)
- Phase 0: correctness against dense reference, logs error metrics, uses
synctiming. - Phase 1: forward-only throughput, no reference comparisons, uses
cuda_eventstiming. - If a method records communication timing (e.g., dense TP), it populates
timings_ms.allreducewith p50/p95 stats. - Methods that call collectives run
collective_prepfirst; its timing is logged undertimings_ms.layout_fix, along withlayout_fix_trigger_rateandlayout_fix_bytes_per_copy.
Bias handling:
- If a method uses bias, the reference includes the same bias.
- If a method does not use bias, the reference bias is disabled.
The project defaults to bf16 because NVIDIA 2:4 semi-structured kernels require bf16/fp16 and dimensions multiple of 64. Keep model.dtype: bf16 unless you are running a specific fp32 debug check.
python -m src.main --help
python -m src.main --config configs/base.yaml --phase configs/phases/phase0_correctness.yaml --method configs/methods/dense_single.yaml --workload configs/workloads/gaussian.yaml --hardware configs/hardware/local_2gpu.yamlsync: CPU wall time withtorch.cuda.synchronize()before/after each region. Use for correctness and debugging.cuda_events: GPU event timing on the current stream, sync once at summary. Use for Phase 1 benchmarking.none: No sync. Not recommended for benchmarking.
Phase 0 - Correctness Check:
# Single GPU dense baseline
torchrun --standalone --nproc_per_node=1 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase0_correctness.yaml \
--method configs/methods/dense_single.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yaml
# Block-Shock 2-GPU
torchrun --standalone --nproc_per_node=2 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase0_correctness.yaml \
--method configs/methods/block_shock_2gpu.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlPhase 1 - Forward Throughput Benchmark:
# Dense tensor parallel baseline
torchrun --standalone --nproc_per_node=2 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase1_forward.yaml \
--method configs/methods/dense_tp.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yaml
# Masked split dense (ablation)
torchrun --standalone --nproc_per_node=2 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase1_forward.yaml \
--method configs/methods/masked_split_dense.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlPhase 2 - Backward Input Gradients:
torchrun --standalone --nproc_per_node=2 -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase2_backward_input.yaml \
--method configs/methods/dense_tp.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlRun an N-dimension sweep and generate plots:
# Execute sweep
python scripts/run_sweep.py --sweep configs/sweeps/N_sweep.yaml
# Aggregate results
python analysis/aggregate.py \
--input results/official/sweeps/<tag> \
--output results/tables/runs.csv
# Generate plots
python analysis/plot_speedups.py \
--input results/tables/runs.csv \
--out-dir results/plotsCreate your own workload configuration in configs/workloads/:
# configs/workloads/my_workload.yaml
workload:
name: my_custom_workload
type: random_normal
mean: 0.0
std: 0.5Override model size in command line:
python -m src.main \
--config configs/base.yaml \
--phase configs/phases/phase1_forward.yaml \
--method configs/methods/block_shock_2gpu.yaml \
--workload configs/workloads/gaussian.yaml \
--hardware configs/hardware/local_2gpu.yamlBlock-Shock/
├── .github/ # GitHub configuration
│ └── workflows/ # CI/CD workflows
│ └── tests.yml # Automated testing workflow
├── analysis/ # Data aggregation and visualization scripts
│ ├── aggregate.py # Aggregate JSONL metrics to CSV
│ ├── plot_speedups.py # Generate performance plots
│ └── report.md # Analysis reports
├── configs/ # YAML configuration files
│ ├── base.yaml # Base configuration
│ ├── phases/ # Phase configurations (0, 1, 2)
│ ├── methods/ # Method implementations config
│ ├── workloads/ # Input data generation patterns
│ ├── hardware/ # Hardware setup (GPU counts, backend)
│ ├── masks/ # 2:4 sparsity mask patterns
│ └── sweeps/ # Parameter sweep configurations
├── examples/ # Example scripts and tutorials
│ ├── README.md # Examples documentation
│ ├── simple_correctness_check.py # Basic correctness example
│ └── benchmark_comparison.py # Performance comparison example
├── results/ # Experimental outputs
│ ├── raw/ # Raw run data
│ ├── official/ # Versioned official runs
│ ├── tables/ # Aggregated CSV data
│ └── plots/ # Generated visualizations
├── scripts/ # Utility scripts
│ └── run_sweep.py # Sweep execution script
├── src/ # Source code
│ ├── main.py # Entry point
│ ├── orchestrator.py # Phase pipeline orchestration
│ ├── config.py # Config loading and merging
│ ├── distributed.py # Distributed training utilities
│ ├── logging_utils.py # Logging and metrics I/O
│ ├── metrics.py # Timing and metric tracking
│ ├── workloads.py # Input data generation
│ ├── validation.py # Correctness checks
│ ├── methods/ # Implementation of all methods
│ │ ├── dense_single.py # Single GPU dense baseline
│ │ ├── dense_tp.py # Dense tensor parallel
│ │ ├── masked_split_dense.py # Dense compute with masks
│ │ └── block_shock.py # Block-Shock sparse method
│ └── sparsity/ # Sparsity utilities
│ ├── masks.py # Mask generation and validation
│ └── semistructured.py # Semi-structured sparse ops
├── tests/ # Unit tests
│ ├── __init__.py
│ ├── test_config.py # Config system tests
│ ├── test_masks.py # Mask generation tests
│ └── test_workloads.py # Workload generation tests
├── .gitignore
├── CONTRIBUTING.md # Contribution guidelines
├── LICENSE
├── README.md
├── PROGRESS.md # Development progress tracking
└── requirements.txt # Python dependencies
We welcome contributions! Please see CONTRIBUTING.md for guidelines on how to contribute to this project.
This project is licensed under the MIT License - see the LICENSE file for details.
X.1 Drift test pipeline
Create:
src/orchestrator.py(new pipeline)configs/phases/phaseX_drift.yaml
What it must do:
- Repeatedly apply a forward-only block for T steps
- Compare trajectories vs dense reference (error vs step)
- Log drift metrics (max/mean/rel error per step)
Definition of Done
- A run produces per-step error curves for dense vs Block-Shock forward
Phase 2 and Phase 3 are deferred for now. Semi-structured sparse backward is not supported in the current PyTorch build, so the focus remains on forward-only experiments.
14.1 Sweeps
Create:
configs/sweeps/N_sweep.yamlconfigs/sweeps/batch_sweep.yaml
Definition of Done
- You can launch a sweep and get one run folder per configuration
14.2 Aggregation and plots
Create:
analysis/aggregate.pyanalysis/plot_speedups.py
Definition of Done
- One chart: throughput vs N (for each method)
- One table: step time breakdown (forward/backward/comm/compress)
Sweep usage (forward-only):
python scripts/run_sweep.py --sweep configs/sweeps/N_sweep.yaml
python analysis/aggregate.py --input results/official/sweeps/<tag> --output results/tables/runs.csv
python analysis/plot_speedups.py --input results/tables/runs.csv --out-dir results/plotsPlots produced:
- Phase 1 forward/allreduce/layout-fix timing vs N (avg + p50/p95; one line per method)
- Phase 1 forward minus layout-fix timing vs N (avg + p50/p95; optimistic upper bound)
- Phase 1 memory plots vs N (peak allocated memory + weight storage estimates + best-effort actual bytes, GiB + bytes)
- Phase 1 forward avg normalized by weight bytes (ns/byte)
- Phase 0 error metrics vs N (
max_abs_error,mean_abs_error,max_rel_error) - Quality-adjusted speed:
(1/forward_avg_ms) / (1 + mean_abs_error)
Note: If you run a Phase 1-only sweep, the Phase 0 error plots and quality-adjusted plot will be empty.
At every milestone, do a vertical slice:
- Implement just enough to run one method through one phase
- Lock correctness
- Then add the next method
A single giant linear layer:
W in R^(N x N) (e.g., 4096 or 40960)
X in R^(B x N) (B = batch/tokens)
Y in R^(B x N)
Use bf16 (bf16 + Blackwell). Semi-structured bf16 requires 2D and both dims multiples of 64.
Choose complementary masks per 4-wide group (example shown; any complementary pair works):
GPU0 mask per block: 1100
GPU1 mask per block: 0011
Construct:
W(0) = W .* M(0)
W(1) = W .* M(1)
Each is valid 2:4.
On each GPU g:
- Replicate
Xon both GPUs - Compute partial output:
Y(g) = X W(g)^Tusing semi-structured sparse ops - All-reduce sum:
Y = Y(0) + Y(1)
PyTorch's semi-structured sparse ops include torch.mm(dense, sparse) and aten.linear.default(dense, sparse, bias) among others.
Because Y = sum_g Y(g):
grad_X = sum_g grad_Y W(g)
Each GPU computes its contribution, then all-reduce sum.
Gradient wrt weights:
grad_W(g) = (grad_Y)^T X
Apply the mask so each GPU updates only its half of weights.
This is where the experiment becomes interesting.
- Each GPU stores
W(g)as a dense masked tensor (bf16). - After optimizer update, re-apply the mask (keep 2:4 pattern).
- Re-compress to semi-structured (
to_sparse_semi_structured) for the next forward.
This explicitly measures whether "compress + sparse GEMM" can beat dense, step after step.
- No optimizer updates.
- Compress once.
- Pure forward benchmark (and optionally grad wrt input only).
This gives an upper bound on the sparse kernel win without the "compress cost" dominating.
Standard nn.Linear (or X @ W.T) on one GPU. Purpose: absolute reference for correctness and speed.
Pick a standard TP split (two good choices):
B1) Row-parallel (sum outputs)
- Shard input features and weight columns
- Each GPU computes partial
Y(g) - All-reduce sum outputs
This baseline matches Block-Shock's "sum outputs" comms pattern.
B2) Column-parallel (gather outputs)
- Shard output features and weight rows
- All-gather output shards
Include B2 if you want a broader view; B1 is the fairest apples-to-apples.
Same split W = W(0) + W(1) using the same masks, but do dense GEMM on each GPU (so zeros do not get skipped). This isolates whether the win is really from cuSPARSELt semi-structured kernels.
- Forward: max abs error and relative error vs dense baseline
- Backward (if training): compare gradients vs dense (spot checks)
For each method:
- Forward time (ms)
- Backward time (ms)
- Optimizer step time (ms) (for training variants)
- End-to-end step time (ms)
- Achieved throughput (samples/s or tokens/s)
- Communication overhead (time spent in all-reduce)
- Peak allocated memory per GPU
- Weight storage size (dense vs semi-structured)
torch.profilertrace (kernel names + NCCL ops)- Optional: per-kernel time breakdown
40960 x 40960 is doable and a great stress test. It is divisible by 64 (good for bf16 semi-structured).
40960^2 = 1,677,721,600 elements
bf16 = 2 bytes -> 3.125 GiB weight tensor
PyTorch documents 2:4 bf16 as compression factor 9/16.
one 2:4 compressed copy ~= 3.125 * 9/16 ~= 1.758 GiB
two complementary copies ~= 3.516 GiB total
That is about 12.5% overhead vs dense weight storage.
Training memory is dominated by gradients and optimizer states. Start with:
- forward-only
- SGD
- Adam (if you want to feel pain in 4D)
- H1 (kernel ceiling): In forward-only mode, Block-Shock's two-GPU semi-structured method achieves higher throughput than two-GPU dense TP for sufficiently large N and B.
- H2 (training reality): In full training, Block-Shock's advantage shrinks or disappears unless the compress/repack overhead is small relative to GEMM time.
- H3 (ablation): Masked-split dense (Baseline C) performs worse than Block-Shock, proving the gain (if any) is from semi-structured kernels, not just splitting.
Verify Block-Shock output matches dense exactly (it should if masks are complementary and you preserve all weights).
- N sweep: 4096 -> 8192 -> 16384 -> 40960
- Batch sweep: B = 16, 64, 256, 1024 (as memory allows)
- Measure forward time and comm time
- Freeze weights, enable grad on X, backprop
- Verify grad correctness vs dense
- Time forward + backward
- SGD first (minimize state)
- Then Adam (realistic)
- Include compress cost explicitly in timing