Skip to content

Commit c109e0d

Browse files
committed
sqrt support fp8
1 parent 933b55c commit c109e0d

1 file changed

Lines changed: 22 additions & 0 deletions

File tree

excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cuda_runtime.h>
55
#include <cuda_fp16.h>
66
#include <cuda_bf16.h>
7+
#include <cuda_fp8.h>
78
#include <cublas_v2.h>
89

910
namespace deepx::tensorfunc
@@ -37,6 +38,27 @@ namespace deepx::tensorfunc
3738
*out = hsqrt(*a);
3839
}
3940

41+
template <>
42+
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out)
43+
{
44+
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E4M3);
45+
__half result_fp16 = hsqrt(input_fp16); // CUDA 内置半精度平方根
46+
*out = static_cast<__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E4M3));
47+
}
48+
49+
template <>
50+
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out)
51+
{
52+
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E5M2);
53+
54+
// 2. 执行平方根
55+
__half result_fp16 = hsqrt(input_fp16);
56+
57+
// 3. 转回 FP8 → E5M2 格式
58+
*out =static_cast<__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E5M2));
59+
}
60+
61+
4062
// pow
4163
template <typename T>
4264
__device__ __forceinline__ void deepx_pow(const T *a, const T *b, T *out);

0 commit comments

Comments
 (0)