From 7edb0a9d38b1f2f2e770cfaf24c5015ed762b167 Mon Sep 17 00:00:00 2001 From: ZaneHam Date: Wed, 25 Mar 2026 20:16:26 +1300 Subject: [PATCH] NVIDIA PTX: f64 support for local memory, conversions, and math MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add double-precision (f64) codegen support to the NVIDIA PTX backend. Previously all f64 local loads/stores fell through to u32, fabs used f32 abs, and int-to-float conversions hardcoded f32 targets. Discovered by Kokako (open-source quantum chemistry) — first f64 GPU workload. NVIDIA backend (isel.c, emit.c, nvidia.h): - NV_LD_LOC_F64, NV_ST_LOC_F64: f64 local (scratch) memory - NV_CVT_F64_S32, NV_CVT_F64_U32: int32 -> fp64 conversion - NV_CVT_S32_F64, NV_CVT_U32_F64: fp64 -> int32 conversion - BIR_FABS: dispatch to NV_ABS_F64 when operand is f64 - BIR_SITOFP/UITOFP: dispatch to f64 cvt when result type is f64 - BIR_FPTOSI/FPTOUI: dispatch to f64 cvt when source type is f64 - NV_MOV_F64 immediate: emit 0d hex format (not bare integer) Middle-end (bir_lower.c): - coerce_to(): implicit operand promotion for binary expressions - Handles int * double -> SITOFP + FMUL (C usual arithmetic) - One fix, all backends benefit Frontend (sema.c): - Register sqrt as f64 math builtin (-> PTX sqrt.rn.f64, exact) Validated: Kokako ERI kernel (Obara-Saika VRR+HRR) produces bit-identical results to CPU for benzene C6H6 (36 basis functions, 222K integrals) on RTX 4060 Ti. 20x speedup at FP64. 90/90 existing tests pass. Moa (f32) unaffected. --- src/fe/sema.c | 2 +- src/ir/bir_lower.c | 65 ++++++++++++++++++++++++++++++++++++++++++--- src/nvidia/emit.c | 34 +++++++++++++++++++++--- src/nvidia/isel.c | 28 +++++++++++++------ src/nvidia/nvidia.h | 10 ++++--- 5 files changed, 119 insertions(+), 20 deletions(-) diff --git a/src/fe/sema.c b/src/fe/sema.c index 43dee35..99286c4 100644 --- a/src/fe/sema.c +++ b/src/fe/sema.c @@ -605,7 +605,7 @@ static const cuda_builtin_t cuda_builtins[] = { {"__shfl_up_sync", -1, 0, 1}, {"__shfl_down_sync",-1, 0, 1}, {"__shfl_xor_sync", -1, 0, 1}, - {"sqrtf",1,0,0},{"__fsqrt_rn",1,0,0},{"rsqrtf",1,0,0},{"__frsqrt_rn",1,0,0}, + {"sqrtf",1,0,0},{"sqrt",1,0,0},{"__fsqrt_rn",1,0,0},{"rsqrtf",1,0,0},{"__frsqrt_rn",1,0,0}, {"__frcp_rn",1,0,0},{"expf",1,0,0},{"__expf",1,0,0},{"exp2f",1,0,0}, {"logf",1,0,0},{"__logf",1,0,0},{"log2f",1,0,0},{"__log2f",1,0,0}, {"log10f",1,0,0},{"sinf",1,0,0},{"__sinf",1,0,0},{"cosf",1,0,0}, diff --git a/src/ir/bir_lower.c b/src/ir/bir_lower.c index 738cb37..80c9834 100644 --- a/src/ir/bir_lower.c +++ b/src/ir/bir_lower.c @@ -522,6 +522,37 @@ static int is_float_type(const lower_t *L, uint32_t t) return k == BIR_TYPE_FLOAT || k == BIR_TYPE_BFLOAT; } +/* Implicit type coercion: insert conversion if val's type != dst. + * Handles int->float, float->float (width change), int widening. + * Used by binary ops to promote operands (C usual arithmetic). */ +static uint32_t coerce_to(lower_t *L, uint32_t val, uint32_t dst_t, + int src_unsigned) +{ + uint32_t src_t = ref_type(L, val); + if (src_t == dst_t) return val; + if (src_t >= L->M->num_types || dst_t >= L->M->num_types) return val; + + int sf = is_float_type(L, src_t), df = is_float_type(L, dst_t); + uint16_t cop; + if (sf && df) { + cop = (L->M->types[dst_t].width > L->M->types[src_t].width) + ? BIR_FPEXT : BIR_FPTRUNC; + } else if (!sf && df) { + cop = src_unsigned ? BIR_UITOFP : BIR_SITOFP; + } else if (sf && !df) { + cop = src_unsigned ? BIR_FPTOUI : BIR_FPTOSI; + } else { + int sw = L->M->types[src_t].width; + int dw = L->M->types[dst_t].width; + if (dw > sw) cop = src_unsigned ? BIR_ZEXT : BIR_SEXT; + else if (dw < sw) cop = BIR_TRUNC; + else return val; /* same width int — no conversion */ + } + uint32_t inst = emit(L, cop, dst_t, 1, 0); + set_op(L, inst, 0, val); + return BIR_MAKE_VAL(inst); +} + static int is_ptr_type(const lower_t *L, uint32_t t) { return t < L->M->num_types && L->M->types[t].kind == BIR_TYPE_PTR; @@ -1084,13 +1115,39 @@ static uint32_t lower_expr(lower_t *L, uint32_t node) } } - int fp = is_float_type(L, lt); + uint32_t rt = ref_type(L, rhs); + /* Usual arithmetic conversion: promote both operands + * to the wider/float type. C says int*double → double, + * not int*double → garbage. Without this, backends get + * mixed-type ops (mul.u32 with an f64 register) and + * the PTX JIT has strong opinions about that. */ + uint32_t res_t = lt; + int lf = is_float_type(L, lt), rf = is_float_type(L, rt); + if (lf && !rf) { + rhs = coerce_to(L, rhs, lt, node_is_unsigned(L, rhs_n)); + res_t = lt; + } else if (!lf && rf) { + lhs = coerce_to(L, lhs, rt, node_is_unsigned(L, lhs_n)); + res_t = rt; + } else if (lf && rf) { + /* Both float: promote narrower to wider */ + if (lt < L->M->num_types && rt < L->M->num_types + && L->M->types[rt].width > L->M->types[lt].width) { + lhs = coerce_to(L, lhs, rt, 0); + res_t = rt; + } else if (lt < L->M->num_types && rt < L->M->num_types + && L->M->types[lt].width > L->M->types[rt].width) { + rhs = coerce_to(L, rhs, lt, 0); + res_t = lt; + } + } + int fp = is_float_type(L, res_t); int opc = bin_op_code(op, fp, node_is_unsigned(L, node)); if (opc < 0) { lower_error(L, node, BC_E102); return lhs; } - uint32_t inst = emit(L, (uint16_t)opc, lt, 2, 0); + uint32_t inst = emit(L, (uint16_t)opc, res_t, 2, 0); set_op(L, inst, 0, lhs); set_op(L, inst, 1, rhs); return BIR_MAKE_VAL(inst); @@ -1522,7 +1579,7 @@ static uint32_t lower_expr(lower_t *L, uint32_t node) /* ---- Math builtins: unary ---- */ { static const struct { const char *n; uint16_t op; } mt1[] = { - {"sqrtf",BIR_SQRT},{"__fsqrt_rn",BIR_SQRT}, + {"sqrtf",BIR_SQRT},{"sqrt",BIR_SQRT},{"__fsqrt_rn",BIR_SQRT}, {"rsqrtf",BIR_RSQ},{"__frsqrt_rn",BIR_RSQ}, {"__frcp_rn",BIR_RCP}, {"exp2f",BIR_EXP2},{"log2f",BIR_LOG2},{"__log2f",BIR_LOG2}, @@ -1530,7 +1587,7 @@ static uint32_t lower_expr(lower_t *L, uint32_t node) {"floorf",BIR_FLOOR},{"ceilf",BIR_CEIL}, {"truncf",BIR_FTRUNC},{"roundf",BIR_RNDNE},{"rintf",BIR_RNDNE}, }; - for (int mi = 0; mi < 15; mi++) { + for (int mi = 0; mi < 16; mi++) { if (strcmp(cname, mt1[mi].n) != 0) continue; uint32_t an = ND(L, callee_n)->next_sibling; uint32_t v = lower_expr(L, an); diff --git a/src/nvidia/emit.c b/src/nvidia/emit.c index dbb61d9..eea7f83 100644 --- a/src/nvidia/emit.c +++ b/src/nvidia/emit.c @@ -470,7 +470,11 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I) case NV_MOV_F64: nv_apnd(nv, "mov.f64 "); em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); - em_opnd(nv, &I->ops[1]); + /* PTX f64 immediates need 0dXXXX hex format, not bare int */ + if (I->ops[1].kind == NV_MOP_IMM && I->ops[1].imm == 0) + nv_apnd(nv, "0d0000000000000000"); + else + em_opnd(nv, &I->ops[1]); break; case NV_MOV_PRED: nv_apnd(nv, "mov.pred "); @@ -489,6 +493,16 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I) em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); em_opnd(nv, &I->ops[1]); break; + case NV_CVT_S32_F64: + nv_apnd(nv, "cvt.rzi.s32.f64 "); + em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); + em_opnd(nv, &I->ops[1]); + break; + case NV_CVT_U32_F64: + nv_apnd(nv, "cvt.rzi.u32.f64 "); + em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); + em_opnd(nv, &I->ops[1]); + break; case NV_CVT_F32_U32: nv_apnd(nv, "cvt.rn.f32.u32 "); em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); @@ -544,6 +558,16 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I) em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); em_opnd(nv, &I->ops[1]); break; + case NV_CVT_F64_S32: + nv_apnd(nv, "cvt.rn.f64.s32 "); + em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); + em_opnd(nv, &I->ops[1]); + break; + case NV_CVT_F64_U32: + nv_apnd(nv, "cvt.rn.f64.u32 "); + em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); + em_opnd(nv, &I->ops[1]); + break; case NV_CVT_F32_F16: nv_apnd(nv, "cvt.f32.f16 "); em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", "); @@ -610,11 +634,13 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I) } /* ---- Loads/Stores: Local ---- */ - case NV_LD_LOC_U32: case NV_LD_LOC_U64: case NV_LD_LOC_F32: { + case NV_LD_LOC_U32: case NV_LD_LOC_U64: + case NV_LD_LOC_F32: case NV_LD_LOC_F64: { const char *tsuf; switch (I->op) { case NV_LD_LOC_U64: tsuf = ".u64"; break; case NV_LD_LOC_F32: tsuf = ".f32"; break; + case NV_LD_LOC_F64: tsuf = ".f64"; break; default: tsuf = ".u32"; break; } nv_apnd(nv, "ld.local%s ", tsuf); @@ -622,11 +648,13 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I) em_opnd(nv, &I->ops[1]); nv_apnd(nv, "]"); break; } - case NV_ST_LOC_U32: case NV_ST_LOC_U64: case NV_ST_LOC_F32: { + case NV_ST_LOC_U32: case NV_ST_LOC_U64: + case NV_ST_LOC_F32: case NV_ST_LOC_F64: { const char *tsuf; switch (I->op) { case NV_ST_LOC_U64: tsuf = ".u64"; break; case NV_ST_LOC_F32: tsuf = ".f32"; break; + case NV_ST_LOC_F64: tsuf = ".f64"; break; default: tsuf = ".u32"; break; } nv_apnd(nv, "st.local%s [", tsuf); diff --git a/src/nvidia/isel.c b/src/nvidia/isel.c index c68289b..d763555 100644 --- a/src/nvidia/isel.c +++ b/src/nvidia/isel.c @@ -560,10 +560,20 @@ static void is_cvt(uint32_t idx, const bir_inst_t *I) uint16_t op; switch (I->op) { - case BIR_FPTOSI: op = NV_CVT_S32_F32; break; - case BIR_FPTOUI: op = NV_CVT_U32_F32; break; - case BIR_SITOFP: op = NV_CVT_F32_S32; break; - case BIR_UITOFP: op = NV_CVT_F32_U32; break; + case BIR_FPTOSI: { + uint32_t si0 = BIR_VAL_INDEX(I->operands[0]); + uint8_t srf = (si0 < S.bir->num_insts) ? bir_rfile(S.bir->insts[si0].type) : NV_RF_F32; + op = (srf == NV_RF_F64) ? NV_CVT_S32_F64 : NV_CVT_S32_F32; + break; + } + case BIR_FPTOUI: { + uint32_t si0 = BIR_VAL_INDEX(I->operands[0]); + uint8_t srf = (si0 < S.bir->num_insts) ? bir_rfile(S.bir->insts[si0].type) : NV_RF_F32; + op = (srf == NV_RF_F64) ? NV_CVT_U32_F64 : NV_CVT_U32_F32; + break; + } + case BIR_SITOFP: op = (bir_rfile(I->type) == NV_RF_F64) ? NV_CVT_F64_S32 : NV_CVT_F32_S32; break; + case BIR_UITOFP: op = (bir_rfile(I->type) == NV_RF_F64) ? NV_CVT_F64_U32 : NV_CVT_F32_U32; break; case BIR_FPTRUNC: op = NV_CVT_F32_F64; break; case BIR_FPEXT: op = NV_CVT_F64_F32; break; case BIR_ZEXT: op = NV_CVT_U64_U32; break; @@ -610,7 +620,8 @@ static void is_load(uint32_t idx, const bir_inst_t *I) op = (drf == NV_RF_F32) ? NV_LD_SHR_F32 : NV_LD_SHR_U32; break; case BIR_AS_PRIVATE: - op = (drf == NV_RF_F32) ? NV_LD_LOC_F32 : + op = (drf == NV_RF_F64) ? NV_LD_LOC_F64 : + (drf == NV_RF_F32) ? NV_LD_LOC_F32 : (drf == NV_RF_U64) ? NV_LD_LOC_U64 : NV_LD_LOC_U32; break; default: /* global */ @@ -659,7 +670,8 @@ static void is_store(const bir_inst_t *I) op = (vrf == NV_RF_F32) ? NV_ST_SHR_F32 : NV_ST_SHR_U32; break; case BIR_AS_PRIVATE: - op = (vrf == NV_RF_F32) ? NV_ST_LOC_F32 : + op = (vrf == NV_RF_F64) ? NV_ST_LOC_F64 : + (vrf == NV_RF_F32) ? NV_ST_LOC_F32 : (vrf == NV_RF_U64) ? NV_ST_LOC_U64 : NV_ST_LOC_U32; break; default: @@ -1034,9 +1046,9 @@ static void is_math(uint32_t idx, const bir_inst_t *I) op = (I->op == BIR_SIN) ? NV_SIN_F32 : NV_COS_F32; break; } - case BIR_EXP2: op = NV_EX2_F32; break; + case BIR_EXP2: op = NV_EX2_F32; break; /* PTX ex2 is f32 only */ case BIR_LOG2: op = NV_LG2_F32; break; - case BIR_FABS: op = NV_ABS_F32; break; + case BIR_FABS: op = (rf == NV_RF_F64) ? NV_ABS_F64 : NV_ABS_F32; break; case BIR_FLOOR: op = NV_FLOOR_F32; break; case BIR_CEIL: op = NV_CEIL_F32; break; case BIR_FTRUNC: op = NV_TRUNC_F32; break; diff --git a/src/nvidia/nvidia.h b/src/nvidia/nvidia.h index c0c5ca8..e4f4646 100644 --- a/src/nvidia/nvidia.h +++ b/src/nvidia/nvidia.h @@ -66,13 +66,15 @@ typedef enum { NV_MOV_PRED, /* Conversions */ - NV_CVT_U32_F32, NV_CVT_S32_F32, /* fptosi/fptoui */ + NV_CVT_U32_F32, NV_CVT_S32_F32, /* fptosi/fptoui (f32 src) */ + NV_CVT_U32_F64, NV_CVT_S32_F64, /* fptosi/fptoui (f64 src) */ NV_CVT_F32_U32, NV_CVT_F32_S32, /* uitofp/sitofp */ NV_CVT_F32_F64, NV_CVT_F64_F32, /* fptrunc/fpext */ NV_CVT_U64_U32, NV_CVT_S64_S32, /* zext/sext to 64 */ NV_CVT_U32_U64, /* trunc 64->32 */ NV_CVT_U64_F64, NV_CVT_S64_F64, /* fp64->int64 */ NV_CVT_F64_U64, NV_CVT_F64_S64, /* int64->fp64 */ + NV_CVT_F64_U32, NV_CVT_F64_S32, /* int32->fp64 */ NV_CVT_F32_F16, NV_CVT_F16_F32, /* half conversions */ /* Loads / stores — global */ @@ -89,9 +91,9 @@ typedef enum { /* Loads / stores — local (scratch / alloca) */ NV_LD_LOC_U32, NV_LD_LOC_U64, - NV_LD_LOC_F32, + NV_LD_LOC_F32, NV_LD_LOC_F64, NV_ST_LOC_U32, NV_ST_LOC_U64, - NV_ST_LOC_F32, + NV_ST_LOC_F32, NV_ST_LOC_F64, /* Parameter loads */ NV_LD_PARAM_U32, NV_LD_PARAM_U64, @@ -126,7 +128,7 @@ typedef enum { NV_RCP_F32, /* rcp.approx.f32 */ NV_SIN_F32, /* sin.approx.f32 */ NV_COS_F32, /* cos.approx.f32 */ - NV_EX2_F32, /* ex2.approx.f32 */ + NV_EX2_F32, /* ex2.approx.f32 (no f64 in PTX) */ NV_LG2_F32, /* lg2.approx.f32 */ NV_FLOOR_F32, /* cvt.rmi.f32.f32 (floor) */ NV_CEIL_F32, /* cvt.rpi.f32.f32 (ceil) */