-
Notifications
You must be signed in to change notification settings - Fork 57
Open
Description
Hello,
I measured the time of your BitLinear and BitLinearBitBLAS against nn.Linear, and it seems that the time for smaller input_features and out_features is slower than nn.Linear. Is there a solution for this?
I used the quant_utils from your BitNet integration: https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet/utils_quant.py
My GPU is NVIDIA GeForce RTX 3090
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
# from bitlinear import BitLinear, BitLinearBitBLAS
from utils_quant import BitLinear, BitLinearBitBLAS
# Function to measure computation time
def measure_time(layer, input_tensor, num_runs=100):
with torch.no_grad():
# Warm up
for _ in range(100): ## 100
_ = layer(input_tensor)
start_time = time.time()
for _ in range(num_runs):
_ = layer(input_tensor)
torch.cuda.synchronize() #### new
end_time = time.time()
avg_time = (end_time - start_time) / num_runs
return avg_time
# # # Test parameters
input_features = 512
output_features = 256
batch_size = 8
# input_features = 1024
# output_features = 512
# batch_size = 32
# input_features = 10240
# output_features = 5120
# batch_size = 32
# input_features = 20480
# output_features = 10240
# batch_size = 32
# Create random input tensor
input_tensor = torch.randn(batch_size, input_features).cuda()
# Initialize layers
nn_linear_layer = nn.Linear(input_features, output_features).cuda()
bit_linear_layer = BitLinear(input_features, output_features).cuda()
bitblas_linear_layer = BitLinearBitBLAS.from_bit_linear(bit_linear_layer)
# Measure computation time
num_runs = 100
nn_linear_time = measure_time(nn_linear_layer, input_tensor, num_runs)
bit_linear_time = measure_time(bit_linear_layer, input_tensor, num_runs)
bitblas_linear_time = measure_time(bitblas_linear_layer, input_tensor, num_runs)
print('input_features, output_features, batch_size: ', input_features, output_features, batch_size)
print(f"Average computation time for nn.Linear: {nn_linear_time * 1000:.4f} ms")
print(f"Average computation time for fp32 simulated BitLinear: {bit_linear_time * 1000:.4f} ms")
print(f"Average computation time Bitblas BitLinear: {bitblas_linear_time * 1000:.4f} ms")
Here are the testing results:
input_features, output_features, batch_size: 512 256 8
Average computation time for nn.Linear: 0.0230 ms
Average computation time for fp32 simulated BitLinear: 0.3450 ms
Average computation time Bitblas BitLinear: 0.3091 ms
input_features, output_features, batch_size: 1024 512 32
Average computation time for nn.Linear: 0.0265 ms
Average computation time for fp32 simulated BitLinear: 0.3427 ms
Average computation time Bitblas BitLinear: 0.3137 ms
input_features, output_features, batch_size: 10240 5120 32
Average computation time for nn.Linear: 0.5421 ms
Average computation time for fp32 simulated BitLinear: 6.3314 ms
Average computation time Bitblas BitLinear: 0.3170 ms
input_features, output_features, batch_size: 20480 10240 32
Average computation time for nn.Linear: 2.1726 ms
Average computation time for fp32 simulated BitLinear: 25.2509 ms
Average computation time Bitblas BitLinear: 0.5633 ms
Thanks for your reply in advance.
Metadata
Metadata
Assignees
Labels
No labels