Skip to content

add support for conv widths 5-8#87

Open
biobenkj wants to merge 3 commits into
Dao-AILab:mainfrom
biobenkj:support-width-5-8
Open

add support for conv widths 5-8#87
biobenkj wants to merge 3 commits into
Dao-AILab:mainfrom
biobenkj:support-width-5-8

Conversation

@biobenkj
Copy link
Copy Markdown

No description provided.

@biobenkj
Copy link
Copy Markdown
Author

My apologies! I moved too quickly and need to make sure the tests all pass first. I will leave this open for now.

@biobenkj
Copy link
Copy Markdown
Author

Okay, I think all tests should pass now for kernel widths 5-8. I've added a small test script as well and appear to pass. The primary need I had for expanding these kernel widths is for applications related to sequence annotation and cell state inference. I've confirmed builds on python 3.10-3.14 with pip.

simple tests for kernel widths:

#!/usr/bin/env python3
"""
Quick test to verify width 5-8 support works correctly.
Tests only the essential functionality without the full parametrized test suite.
"""

import torch
from causal_conv1d import causal_conv1d_fn

def test_width_5_8_with_fp16():
    """Test that widths 5-8 work with fp16."""
    print("Testing widths 5-8 with fp16...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for width in [5, 6, 7, 8]:
        x = torch.randn(2, 64, 128, device=device, dtype=torch.float16)
        weight = torch.randn(64, width, device=device, dtype=torch.float16)
        bias = torch.randn(64, device=device, dtype=torch.float16)

        out = causal_conv1d_fn(x, weight, bias, activation="silu")

        assert out.shape == x.shape
        assert out.dtype == torch.float16
        print(f"  width={width} with fp16: PASSED")

def test_width_5_8_with_bf16():
    """Test that widths 5-8 work with bf16."""
    print("\nTesting widths 5-8 with bf16...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for width in [5, 6, 7, 8]:
        x = torch.randn(2, 64, 128, device=device, dtype=torch.bfloat16)
        weight = torch.randn(64, width, device=device, dtype=torch.bfloat16)
        bias = torch.randn(64, device=device, dtype=torch.bfloat16)

        out = causal_conv1d_fn(x, weight, bias, activation="silu")

        assert out.shape == x.shape
        assert out.dtype == torch.bfloat16
        print(f"  width={width} with bf16: PASSED")

def test_width_5_8_rejects_fp32():
    """Test that widths 5-8 properly reject fp32 with clear error."""
    print("\nTesting widths 5-8 properly reject fp32...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for width in [5, 6, 7, 8]:
        x = torch.randn(2, 64, 128, device=device, dtype=torch.float32)
        weight = torch.randn(64, width, device=device, dtype=torch.float32)
        bias = torch.randn(64, device=device, dtype=torch.float32)

        try:
            out = causal_conv1d_fn(x, weight, bias, activation="silu")
            raise AssertionError(f"width={width} with fp32 should have failed but didn't!")
        except RuntimeError as e:
            error_msg = str(e)
            assert "width 5-8 is only supported for float16/bfloat16" in error_msg, \
                f"Unexpected error message: {error_msg}"
            print(f"  width={width} with fp32: Correctly rejected")

def test_width_2_4_still_works_with_fp32():
    """Test that widths 2-4 still work with fp32 (regression test)."""
    print("\nTesting widths 2-4 still work with fp32...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    for width in [2, 3, 4]:
        x = torch.randn(2, 64, 128, device=device, dtype=torch.float32)
        weight = torch.randn(64, width, device=device, dtype=torch.float32)
        bias = torch.randn(64, device=device, dtype=torch.float32)

        out = causal_conv1d_fn(x, weight, bias, activation="silu")

        assert out.shape == x.shape
        assert out.dtype == torch.float32
        print(f"  width={width} with fp32: PASSED")

if __name__ == "__main__":
    print("=" * 80)
    print("Quick test for causal_conv1d width 5-8 support")
    print("=" * 80)

    try:
        test_width_5_8_with_fp16()
        test_width_5_8_with_bf16()
        test_width_5_8_rejects_fp32()
        test_width_2_4_still_works_with_fp32()

        print("\n" + "=" * 80)
        print("All tests PASSED!")
        print("=" * 80)
        print("\nSummary:")
        print("  • Widths 5-8 work correctly with fp16/bf16")
        print("  • Widths 5-8 properly reject fp32 with clear error")
        print("  • Widths 2-4 continue to work with fp32")
        print("  • Implementation is working as expected!")

    except Exception as e:
        print("\n" + "=" * 80)
        print("Test FAILED!")
        print("=" * 80)
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant