File tree Expand file tree Collapse file tree
excuter/op-mem-cuda/src/deepx/tensorfunc Expand file tree Collapse file tree Original file line number Diff line number Diff line change 44#include < cuda_runtime.h>
55#include < cuda_fp16.h>
66#include < cuda_bf16.h>
7+ #include < cuda_fp8.h>
78#include < cublas_v2.h>
89
910namespace deepx ::tensorfunc
@@ -37,6 +38,27 @@ namespace deepx::tensorfunc
3738 *out = hsqrt (*a);
3839 }
3940
41+ template <>
42+ __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out)
43+ {
44+ __half input_fp16 = __nv_cvt_fp8_to_halfraw (static_cast <__nv_fp8_storage_t >(*a), __NV_E4M3);
45+ __half result_fp16 = hsqrt (input_fp16); // CUDA 内置半精度平方根
46+ *out = static_cast <__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8 (result_fp16, __NV_SATFINITE, __NV_E4M3));
47+ }
48+
49+ template <>
50+ __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out)
51+ {
52+ __half input_fp16 = __nv_cvt_fp8_to_halfraw (static_cast <__nv_fp8_storage_t >(*a), __NV_E5M2);
53+
54+ // 2. 执行平方根
55+ __half result_fp16 = hsqrt (input_fp16);
56+
57+ // 3. 转回 FP8 → E5M2 格式
58+ *out =static_cast <__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8 (result_fp16, __NV_SATFINITE, __NV_E5M2));
59+ }
60+
61+
4062 // pow
4163 template <typename T>
4264 __device__ __forceinline__ void deepx_pow (const T *a, const T *b, T *out);
You can’t perform that action at this time.
0 commit comments