Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions src/num_sys_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,120 @@ def real_to_format_tensor(self, tensor):
return tensor.to(torch.bfloat16)


class num_fp8_e4m3(_ieee754):
"""FP8 E4M3 Number System (4-bit exponent, 3-bit mantissa)

Used primarily for weights in transformer inference.
Range: [-448, 448], no inf representation (all exponent bits = 1 reserved for NaN)
"""
def __init__(self):
super(num_fp8_e4m3, self).__init__(
exp_len=4,
mant_len=3,
bias=7,
denorm=True,
max_val=448.0,
min_val=None
)
self.has_inf = False

def real_to_format(self, num):
if abs(num) > self.max_val:
num = self.max_val if num > 0 else -self.max_val
return super().real_to_format(num)

def format_to_real(self, bit_arr):
exp_str = "".join(bit_arr[1:self.exp_len + 1])
mant_str = "".join(bit_arr[self.exp_len + 1:])

if exp_str == "1" * self.exp_len and mant_str == "1" * self.mant_len:
return float('nan')

if exp_str == "1" * self.exp_len:
sign = pow(-1, int(bit_arr[0]))
mant = _number_sys.bin_to_frac(mant_str) + 1
exp = int(exp_str, 2) - self.bias
return sign * mant * pow(2, exp)

return super().format_to_real(bit_arr)

def real_to_format_tensor(self, tensor):
return self._quantize_fp8_e4m3(tensor)

def _quantize_fp8_e4m3(self, float_arr):
sign = torch.sign(float_arr)
float_arr = torch.abs(float_arr)

float_arr = torch.clamp(float_arr, max=self.max_val)

min_normal = 2 ** (-self.bias + 1)
min_subnormal = 2 ** (-self.bias - self.mant_len + 1)
float_arr[float_arr < min_subnormal] = 0

mant, exp = torch.frexp(float_arr)
mant = 2 * mant
exp = exp - 1

n_mant = self.mant_len
scale = 2 ** (-n_mant)
mant = ((mant / scale).round()) * scale

power_exp = torch.exp2(exp.float())

return sign * power_exp * mant


class num_fp8_e5m2(_ieee754):
"""FP8 E5M2 Number System (5-bit exponent, 2-bit mantissa)

Used primarily for gradients in transformer training.
Range: [-57344, 57344], supports inf representation
"""
def __init__(self):
super(num_fp8_e5m2, self).__init__(
exp_len=5,
mant_len=2,
bias=15,
denorm=True,
max_val=57344.0,
min_val=None
)
self.has_inf = True

def real_to_format(self, num):
if abs(num) > self.max_val:
return super().real_to_format(float('inf') if num > 0 else float('-inf'))
return super().real_to_format(num)

def real_to_format_tensor(self, tensor):
return self._quantize_fp8_e5m2(tensor)

def _quantize_fp8_e5m2(self, float_arr):
sign = torch.sign(float_arr)
float_arr = torch.abs(float_arr)

inf_mask = float_arr > self.max_val
float_arr = torch.clamp(float_arr, max=self.max_val)

min_subnormal = 2 ** (-self.bias - self.mant_len + 1)
float_arr[float_arr < min_subnormal] = 0

mant, exp = torch.frexp(float_arr)
mant = 2 * mant
exp = exp - 1

n_mant = self.mant_len
scale = 2 ** (-n_mant)
mant = ((mant / scale).round()) * scale

power_exp = torch.exp2(exp.float())

result = sign * power_exp * mant
result[inf_mask] = sign[inf_mask] * float('inf')

return result


class num_fixed_pt(_number_sys):
"""Fixed Point Number System"""
# 1 bit for sign + len(integer part) + len(frac part)
Expand Down
4 changes: 4 additions & 0 deletions src/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ def getNumSysName(name, bits=16, radix_up=5, radix_down=10, bias=None):
return num_fp16(), name
elif name == "bfloat16":
return num_bfloat16(), name
elif name == "fp8_e4m3":
return num_fp8_e4m3(), name
elif name == "fp8_e5m2":
return num_fp8_e5m2(), name

# generic number systems in PyTorch
elif name == "fp_n":
Expand Down
133 changes: 133 additions & 0 deletions val/test_num_sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,139 @@ def test_num_bfloat16():
'0', '0', '1', '1', '0', '1', '0']) == 255)


# FP8 E4M3
def test_num_fp8_e4m3():
# bitwidth of 8, 1 sign bit, 4 exponent bits, 3 mantissa bits
fp8 = num_fp8_e4m3()

# basic conversion tests
assert(fp8.format_to_real(['0', '0', '1', '1', '1', '1', '0', '0']) == 1.0)
assert(fp8.real_to_format(1.0)
== ['0', '0', '1', '1', '1', '1', '0', '0'])

assert(fp8.format_to_real(['1', '0', '1', '1', '1', '1', '0', '0']) == -1.0)
assert(fp8.real_to_format(-1.0)
== ['1', '0', '1', '1', '1', '1', '0', '0'])

# larger values
assert(fp8.format_to_real(['0', '1', '0', '0', '0', '0', '0', '0']) == 2.0)
assert(fp8.real_to_format(2.0)
== ['0', '1', '0', '0', '0', '0', '0', '0'])

# smaller values
assert(fp8.format_to_real(['0', '0', '1', '1', '0', '1', '0', '0']) == 0.625)
assert(fp8.real_to_format(0.625)
== ['0', '0', '1', '1', '0', '1', '0', '0'])

# zero
assert(fp8.format_to_real(['0', '0', '0', '0', '0', '0', '0', '0']) == 0.0)
assert(fp8.real_to_format(0.0)
== ['0', '0', '0', '0', '0', '0', '0', '0'])
assert(fp8.format_to_real(['1', '0', '0', '0', '0', '0', '0', '0']) == 0.0)

# NaN (all exponent bits = 1, all mantissa bits = 1)
assert(math.isnan(fp8.format_to_real(
['0', '1', '1', '1', '1', '1', '1', '1'])))
assert(math.isnan(fp8.format_to_real(
['1', '1', '1', '1', '1', '1', '1', '1'])))

# max value (448)
assert(fp8.format_to_real(['0', '1', '1', '1', '1', '1', '1', '0']) == 448.0)
assert(fp8.real_to_format(448.0)
== ['0', '1', '1', '1', '1', '1', '1', '0'])

# clamping beyond max
result = fp8.real_to_format(500.0)
assert(fp8.format_to_real(result) == 448.0)

# tensor quantization
test1 = torch.tensor([[-1.0, 2.0, -0.5, 0.25],
[-1.5, 4.0, 1.0, 0.125],
[0.0, -0.0, -8.0, 16.0],
[-0.25, 0.5, 0.75, -2.0]])

output1 = fp8.real_to_format_tensor(test1)
assert(output1.shape == test1.shape)
assert(output1[0][0].item() == -1.0)
assert(output1[0][1].item() == 2.0)
assert(output1[2][2].item() == -8.0)
assert(output1[2][3].item() == 16.0)

# bit flip tests
assert(fp8.single_bit_flip_in_format(1.0, 0) == 1.125)
assert(fp8.single_bit_flip_in_format(1.0, 7) == -1.0)

# convert_numsys_flip
assert(fp8.convert_numsys_flip(1.0, 0) == 1.0)
assert(fp8.convert_numsys_flip(1.0, 0, True) == 1.125)


# FP8 E5M2
def test_num_fp8_e5m2():
# bitwidth of 8, 1 sign bit, 5 exponent bits, 2 mantissa bits
fp8 = num_fp8_e5m2()

# basic conversion tests
assert(fp8.format_to_real(['0', '0', '1', '1', '1', '1', '0', '0']) == 1.0)
assert(fp8.real_to_format(1.0)
== ['0', '0', '1', '1', '1', '1', '0', '0'])

assert(fp8.format_to_real(['1', '0', '1', '1', '1', '1', '0', '0']) == -1.0)
assert(fp8.real_to_format(-1.0)
== ['1', '0', '1', '1', '1', '1', '0', '0'])

# larger values
assert(fp8.format_to_real(['0', '1', '0', '0', '0', '0', '0', '0']) == 2.0)
assert(fp8.real_to_format(2.0)
== ['0', '1', '0', '0', '0', '0', '0', '0'])

# smaller values (0.5)
assert(fp8.format_to_real(['0', '0', '1', '1', '1', '0', '0', '0']) == 0.5)
assert(fp8.real_to_format(0.5)
== ['0', '0', '1', '1', '1', '0', '0', '0'])

# zero
assert(fp8.format_to_real(['0', '0', '0', '0', '0', '0', '0', '0']) == 0.0)
assert(fp8.real_to_format(0.0)
== ['0', '0', '0', '0', '0', '0', '0', '0'])
assert(fp8.format_to_real(['1', '0', '0', '0', '0', '0', '0', '0']) == 0.0)

# infinity
assert(fp8.format_to_real(['0', '1', '1', '1', '1', '1', '0', '0'])
== float('inf'))
assert(fp8.format_to_real(['1', '1', '1', '1', '1', '1', '0', '0'])
== float('-inf'))

# NaN
assert(math.isnan(fp8.format_to_real(
['0', '1', '1', '1', '1', '1', '0', '1'])))
assert(math.isnan(fp8.format_to_real(
['1', '1', '1', '1', '1', '1', '1', '0'])))
assert(math.isnan(fp8.format_to_real(
['0', '1', '1', '1', '1', '1', '1', '1'])))

# tensor quantization
test1 = torch.tensor([[-1.0, 2.0, -0.5, 0.25],
[-1.5, 4.0, 1.0, 0.125],
[0.0, -0.0, -8.0, 16.0],
[-0.25, 0.5, 0.75, -2.0]])

output1 = fp8.real_to_format_tensor(test1)
assert(output1.shape == test1.shape)
assert(output1[0][0].item() == -1.0)
assert(output1[0][1].item() == 2.0)
assert(output1[2][2].item() == -8.0)
assert(output1[2][3].item() == 16.0)

# bit flip tests
assert(fp8.single_bit_flip_in_format(1.0, 0) == 1.25)
assert(fp8.single_bit_flip_in_format(1.0, 7) == -1.0)

# convert_numsys_flip
assert(fp8.convert_numsys_flip(1.0, 0) == 1.0)
assert(fp8.convert_numsys_flip(1.0, 0, True) == 1.25)


# Fixed Point
def test_fixed():
# bitwidth of 6, 1 sign bit, 2 integer bits, 3 fraction bits
Expand Down