diff --git a/src/num_sys_class.py b/src/num_sys_class.py index 7c10838..504e275 100644 --- a/src/num_sys_class.py +++ b/src/num_sys_class.py @@ -265,6 +265,120 @@ def real_to_format_tensor(self, tensor): return tensor.to(torch.bfloat16) +class num_fp8_e4m3(_ieee754): + """FP8 E4M3 Number System (4-bit exponent, 3-bit mantissa) + + Used primarily for weights in transformer inference. + Range: [-448, 448], no inf representation (all exponent bits = 1 reserved for NaN) + """ + def __init__(self): + super(num_fp8_e4m3, self).__init__( + exp_len=4, + mant_len=3, + bias=7, + denorm=True, + max_val=448.0, + min_val=None + ) + self.has_inf = False + + def real_to_format(self, num): + if abs(num) > self.max_val: + num = self.max_val if num > 0 else -self.max_val + return super().real_to_format(num) + + def format_to_real(self, bit_arr): + exp_str = "".join(bit_arr[1:self.exp_len + 1]) + mant_str = "".join(bit_arr[self.exp_len + 1:]) + + if exp_str == "1" * self.exp_len and mant_str == "1" * self.mant_len: + return float('nan') + + if exp_str == "1" * self.exp_len: + sign = pow(-1, int(bit_arr[0])) + mant = _number_sys.bin_to_frac(mant_str) + 1 + exp = int(exp_str, 2) - self.bias + return sign * mant * pow(2, exp) + + return super().format_to_real(bit_arr) + + def real_to_format_tensor(self, tensor): + return self._quantize_fp8_e4m3(tensor) + + def _quantize_fp8_e4m3(self, float_arr): + sign = torch.sign(float_arr) + float_arr = torch.abs(float_arr) + + float_arr = torch.clamp(float_arr, max=self.max_val) + + min_normal = 2 ** (-self.bias + 1) + min_subnormal = 2 ** (-self.bias - self.mant_len + 1) + float_arr[float_arr < min_subnormal] = 0 + + mant, exp = torch.frexp(float_arr) + mant = 2 * mant + exp = exp - 1 + + n_mant = self.mant_len + scale = 2 ** (-n_mant) + mant = ((mant / scale).round()) * scale + + power_exp = torch.exp2(exp.float()) + + return sign * power_exp * mant + + +class num_fp8_e5m2(_ieee754): + """FP8 E5M2 Number System (5-bit exponent, 2-bit mantissa) + + Used primarily for gradients in transformer training. + Range: [-57344, 57344], supports inf representation + """ + def __init__(self): + super(num_fp8_e5m2, self).__init__( + exp_len=5, + mant_len=2, + bias=15, + denorm=True, + max_val=57344.0, + min_val=None + ) + self.has_inf = True + + def real_to_format(self, num): + if abs(num) > self.max_val: + return super().real_to_format(float('inf') if num > 0 else float('-inf')) + return super().real_to_format(num) + + def real_to_format_tensor(self, tensor): + return self._quantize_fp8_e5m2(tensor) + + def _quantize_fp8_e5m2(self, float_arr): + sign = torch.sign(float_arr) + float_arr = torch.abs(float_arr) + + inf_mask = float_arr > self.max_val + float_arr = torch.clamp(float_arr, max=self.max_val) + + min_subnormal = 2 ** (-self.bias - self.mant_len + 1) + float_arr[float_arr < min_subnormal] = 0 + + mant, exp = torch.frexp(float_arr) + mant = 2 * mant + exp = exp - 1 + + n_mant = self.mant_len + scale = 2 ** (-n_mant) + mant = ((mant / scale).round()) * scale + + power_exp = torch.exp2(exp.float()) + + result = sign * power_exp * mant + result[inf_mask] = sign[inf_mask] * float('inf') + + return result + + class num_fixed_pt(_number_sys): """Fixed Point Number System""" # 1 bit for sign + len(integer part) + len(frac part) diff --git a/src/util.py b/src/util.py index 2f7cdc0..ae678cf 100644 --- a/src/util.py +++ b/src/util.py @@ -616,6 +616,10 @@ def getNumSysName(name, bits=16, radix_up=5, radix_down=10, bias=None): return num_fp16(), name elif name == "bfloat16": return num_bfloat16(), name + elif name == "fp8_e4m3": + return num_fp8_e4m3(), name + elif name == "fp8_e5m2": + return num_fp8_e5m2(), name # generic number systems in PyTorch elif name == "fp_n": diff --git a/val/test_num_sys.py b/val/test_num_sys.py index f9bf3c3..4afa6bd 100644 --- a/val/test_num_sys.py +++ b/val/test_num_sys.py @@ -334,6 +334,139 @@ def test_num_bfloat16(): '0', '0', '1', '1', '0', '1', '0']) == 255) +# FP8 E4M3 +def test_num_fp8_e4m3(): + # bitwidth of 8, 1 sign bit, 4 exponent bits, 3 mantissa bits + fp8 = num_fp8_e4m3() + + # basic conversion tests + assert(fp8.format_to_real(['0', '0', '1', '1', '1', '1', '0', '0']) == 1.0) + assert(fp8.real_to_format(1.0) + == ['0', '0', '1', '1', '1', '1', '0', '0']) + + assert(fp8.format_to_real(['1', '0', '1', '1', '1', '1', '0', '0']) == -1.0) + assert(fp8.real_to_format(-1.0) + == ['1', '0', '1', '1', '1', '1', '0', '0']) + + # larger values + assert(fp8.format_to_real(['0', '1', '0', '0', '0', '0', '0', '0']) == 2.0) + assert(fp8.real_to_format(2.0) + == ['0', '1', '0', '0', '0', '0', '0', '0']) + + # smaller values + assert(fp8.format_to_real(['0', '0', '1', '1', '0', '1', '0', '0']) == 0.625) + assert(fp8.real_to_format(0.625) + == ['0', '0', '1', '1', '0', '1', '0', '0']) + + # zero + assert(fp8.format_to_real(['0', '0', '0', '0', '0', '0', '0', '0']) == 0.0) + assert(fp8.real_to_format(0.0) + == ['0', '0', '0', '0', '0', '0', '0', '0']) + assert(fp8.format_to_real(['1', '0', '0', '0', '0', '0', '0', '0']) == 0.0) + + # NaN (all exponent bits = 1, all mantissa bits = 1) + assert(math.isnan(fp8.format_to_real( + ['0', '1', '1', '1', '1', '1', '1', '1']))) + assert(math.isnan(fp8.format_to_real( + ['1', '1', '1', '1', '1', '1', '1', '1']))) + + # max value (448) + assert(fp8.format_to_real(['0', '1', '1', '1', '1', '1', '1', '0']) == 448.0) + assert(fp8.real_to_format(448.0) + == ['0', '1', '1', '1', '1', '1', '1', '0']) + + # clamping beyond max + result = fp8.real_to_format(500.0) + assert(fp8.format_to_real(result) == 448.0) + + # tensor quantization + test1 = torch.tensor([[-1.0, 2.0, -0.5, 0.25], + [-1.5, 4.0, 1.0, 0.125], + [0.0, -0.0, -8.0, 16.0], + [-0.25, 0.5, 0.75, -2.0]]) + + output1 = fp8.real_to_format_tensor(test1) + assert(output1.shape == test1.shape) + assert(output1[0][0].item() == -1.0) + assert(output1[0][1].item() == 2.0) + assert(output1[2][2].item() == -8.0) + assert(output1[2][3].item() == 16.0) + + # bit flip tests + assert(fp8.single_bit_flip_in_format(1.0, 0) == 1.125) + assert(fp8.single_bit_flip_in_format(1.0, 7) == -1.0) + + # convert_numsys_flip + assert(fp8.convert_numsys_flip(1.0, 0) == 1.0) + assert(fp8.convert_numsys_flip(1.0, 0, True) == 1.125) + + +# FP8 E5M2 +def test_num_fp8_e5m2(): + # bitwidth of 8, 1 sign bit, 5 exponent bits, 2 mantissa bits + fp8 = num_fp8_e5m2() + + # basic conversion tests + assert(fp8.format_to_real(['0', '0', '1', '1', '1', '1', '0', '0']) == 1.0) + assert(fp8.real_to_format(1.0) + == ['0', '0', '1', '1', '1', '1', '0', '0']) + + assert(fp8.format_to_real(['1', '0', '1', '1', '1', '1', '0', '0']) == -1.0) + assert(fp8.real_to_format(-1.0) + == ['1', '0', '1', '1', '1', '1', '0', '0']) + + # larger values + assert(fp8.format_to_real(['0', '1', '0', '0', '0', '0', '0', '0']) == 2.0) + assert(fp8.real_to_format(2.0) + == ['0', '1', '0', '0', '0', '0', '0', '0']) + + # smaller values (0.5) + assert(fp8.format_to_real(['0', '0', '1', '1', '1', '0', '0', '0']) == 0.5) + assert(fp8.real_to_format(0.5) + == ['0', '0', '1', '1', '1', '0', '0', '0']) + + # zero + assert(fp8.format_to_real(['0', '0', '0', '0', '0', '0', '0', '0']) == 0.0) + assert(fp8.real_to_format(0.0) + == ['0', '0', '0', '0', '0', '0', '0', '0']) + assert(fp8.format_to_real(['1', '0', '0', '0', '0', '0', '0', '0']) == 0.0) + + # infinity + assert(fp8.format_to_real(['0', '1', '1', '1', '1', '1', '0', '0']) + == float('inf')) + assert(fp8.format_to_real(['1', '1', '1', '1', '1', '1', '0', '0']) + == float('-inf')) + + # NaN + assert(math.isnan(fp8.format_to_real( + ['0', '1', '1', '1', '1', '1', '0', '1']))) + assert(math.isnan(fp8.format_to_real( + ['1', '1', '1', '1', '1', '1', '1', '0']))) + assert(math.isnan(fp8.format_to_real( + ['0', '1', '1', '1', '1', '1', '1', '1']))) + + # tensor quantization + test1 = torch.tensor([[-1.0, 2.0, -0.5, 0.25], + [-1.5, 4.0, 1.0, 0.125], + [0.0, -0.0, -8.0, 16.0], + [-0.25, 0.5, 0.75, -2.0]]) + + output1 = fp8.real_to_format_tensor(test1) + assert(output1.shape == test1.shape) + assert(output1[0][0].item() == -1.0) + assert(output1[0][1].item() == 2.0) + assert(output1[2][2].item() == -8.0) + assert(output1[2][3].item() == 16.0) + + # bit flip tests + assert(fp8.single_bit_flip_in_format(1.0, 0) == 1.25) + assert(fp8.single_bit_flip_in_format(1.0, 7) == -1.0) + + # convert_numsys_flip + assert(fp8.convert_numsys_flip(1.0, 0) == 1.0) + assert(fp8.convert_numsys_flip(1.0, 0, True) == 1.25) + + # Fixed Point def test_fixed(): # bitwidth of 6, 1 sign bit, 2 integer bits, 3 fraction bits