Skip to content

Speed Comparison: BitLinear and nn.Linear #118

@ZiqingChang

Description

@ZiqingChang

Hello,

I measured the time of your BitLinear and BitLinearBitBLAS against nn.Linear, and it seems that the time for smaller input_features and out_features is slower than nn.Linear. Is there a solution for this?

I used the quant_utils from your BitNet integration: https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py

My GPU is NVIDIA GeForce RTX 3090

import time
import torch
import torch.nn as nn
from torch.autograd import Variable

# from bitlinear import BitLinear, BitLinearBitBLAS
from utils_quant import BitLinear, BitLinearBitBLAS

# Function to measure computation time
def measure_time(layer, input_tensor, num_runs=100):
    with torch.no_grad():
        # Warm up
        for _ in range(100): ## 100
            _ = layer(input_tensor)
        
        start_time = time.time()
        for _ in range(num_runs):
            _ = layer(input_tensor)
        torch.cuda.synchronize()  #### new
        end_time = time.time()
        
        avg_time = (end_time - start_time) / num_runs
    return avg_time

# # # Test parameters
input_features = 512
output_features = 256
batch_size = 8

# input_features = 1024
# output_features = 512
# batch_size = 32

# input_features = 10240
# output_features = 5120
# batch_size = 32

# input_features = 20480
# output_features = 10240
# batch_size = 32


# Create random input tensor
input_tensor = torch.randn(batch_size, input_features).cuda()

# Initialize layers
nn_linear_layer = nn.Linear(input_features, output_features).cuda()
bit_linear_layer = BitLinear(input_features, output_features).cuda()

bitblas_linear_layer = BitLinearBitBLAS.from_bit_linear(bit_linear_layer)

# Measure computation time
num_runs = 100
nn_linear_time = measure_time(nn_linear_layer, input_tensor, num_runs)
bit_linear_time = measure_time(bit_linear_layer, input_tensor, num_runs)
bitblas_linear_time = measure_time(bitblas_linear_layer, input_tensor, num_runs)

print('input_features, output_features, batch_size: ', input_features, output_features, batch_size)
print(f"Average computation time for nn.Linear: {nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for fp32 simulated BitLinear: {bit_linear_time * 1000:.4f} ms")
print(f"Average computation time Bitblas BitLinear: {bitblas_linear_time * 1000:.4f} ms")


Here are the testing results:

input_features, output_features, batch_size:  512 256 8
Average computation time for nn.Linear: 0.0230 ms
Average computation time for fp32 simulated BitLinear: 0.3450 ms
Average computation time Bitblas BitLinear: 0.3091 ms

input_features, output_features, batch_size:  1024 512 32
Average computation time for nn.Linear: 0.0265 ms
Average computation time for fp32 simulated BitLinear: 0.3427 ms
Average computation time Bitblas BitLinear: 0.3137 ms

input_features, output_features, batch_size:  10240 5120 32
Average computation time for nn.Linear: 0.5421 ms
Average computation time for fp32 simulated BitLinear: 6.3314 ms
Average computation time Bitblas BitLinear: 0.3170 ms

input_features, output_features, batch_size:  20480 10240 32
Average computation time for nn.Linear: 2.1726 ms
Average computation time for fp32 simulated BitLinear: 25.2509 ms
Average computation time Bitblas BitLinear: 0.5633 ms

Thanks for your reply in advance.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions