diff --git a/src/num_sys_class.py b/src/num_sys_class.py index 7c10838..14efc2e 100644 --- a/src/num_sys_class.py +++ b/src/num_sys_class.py @@ -350,6 +350,10 @@ def quant_bfloat_py(self, float_arr, n_bits, n_exp): power_exp_diff = torch.exp2(exp_diff) mant_adj = mant / power_exp_diff + # handle mantissa underflow after shifting + # values that become subnormal should be clamped to maintain valid representation + mant_adj[mant_adj < 1.0] = 0.0 + exp_adj = torch.full(exp.shape, shared_exp, device=float_arr.device) # exp should not be larger than max_exp @@ -415,6 +419,10 @@ def quant_bfloat_meta_py(self, float_arr, n_bits=8, n_exp=3): power_exp_diff = torch.exp2(exp_diff) mant_adj = mant / power_exp_diff + # handle mantissa underflow after shifting + # values that become subnormal should be clamped to maintain valid representation + mant_adj[mant_adj < 1.0] = 0.0 + exp_adj = torch.full(exp.shape, shared_exp, device=float_arr.device) # exp should not be larger than max_exp