Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/fe/sema.c
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ static const cuda_builtin_t cuda_builtins[] = {
{"__shfl_up_sync", -1, 0, 1},
{"__shfl_down_sync",-1, 0, 1},
{"__shfl_xor_sync", -1, 0, 1},
{"sqrtf",1,0,0},{"__fsqrt_rn",1,0,0},{"rsqrtf",1,0,0},{"__frsqrt_rn",1,0,0},
{"sqrtf",1,0,0},{"sqrt",1,0,0},{"__fsqrt_rn",1,0,0},{"rsqrtf",1,0,0},{"__frsqrt_rn",1,0,0},
{"__frcp_rn",1,0,0},{"expf",1,0,0},{"__expf",1,0,0},{"exp2f",1,0,0},
{"logf",1,0,0},{"__logf",1,0,0},{"log2f",1,0,0},{"__log2f",1,0,0},
{"log10f",1,0,0},{"sinf",1,0,0},{"__sinf",1,0,0},{"cosf",1,0,0},
Expand Down
65 changes: 61 additions & 4 deletions src/ir/bir_lower.c
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,37 @@ static int is_float_type(const lower_t *L, uint32_t t)
return k == BIR_TYPE_FLOAT || k == BIR_TYPE_BFLOAT;
}

/* Implicit type coercion: insert conversion if val's type != dst.
* Handles int->float, float->float (width change), int widening.
* Used by binary ops to promote operands (C usual arithmetic). */
static uint32_t coerce_to(lower_t *L, uint32_t val, uint32_t dst_t,
int src_unsigned)
{
uint32_t src_t = ref_type(L, val);
if (src_t == dst_t) return val;
if (src_t >= L->M->num_types || dst_t >= L->M->num_types) return val;

int sf = is_float_type(L, src_t), df = is_float_type(L, dst_t);
uint16_t cop;
if (sf && df) {
cop = (L->M->types[dst_t].width > L->M->types[src_t].width)
? BIR_FPEXT : BIR_FPTRUNC;
} else if (!sf && df) {
cop = src_unsigned ? BIR_UITOFP : BIR_SITOFP;
} else if (sf && !df) {
cop = src_unsigned ? BIR_FPTOUI : BIR_FPTOSI;
} else {
int sw = L->M->types[src_t].width;
int dw = L->M->types[dst_t].width;
if (dw > sw) cop = src_unsigned ? BIR_ZEXT : BIR_SEXT;
else if (dw < sw) cop = BIR_TRUNC;
else return val; /* same width int — no conversion */
}
uint32_t inst = emit(L, cop, dst_t, 1, 0);
set_op(L, inst, 0, val);
return BIR_MAKE_VAL(inst);
}

static int is_ptr_type(const lower_t *L, uint32_t t)
{
return t < L->M->num_types && L->M->types[t].kind == BIR_TYPE_PTR;
Expand Down Expand Up @@ -1084,13 +1115,39 @@ static uint32_t lower_expr(lower_t *L, uint32_t node)
}
}

int fp = is_float_type(L, lt);
uint32_t rt = ref_type(L, rhs);
/* Usual arithmetic conversion: promote both operands
* to the wider/float type. C says int*double → double,
* not int*double → garbage. Without this, backends get
* mixed-type ops (mul.u32 with an f64 register) and
* the PTX JIT has strong opinions about that. */
uint32_t res_t = lt;
int lf = is_float_type(L, lt), rf = is_float_type(L, rt);
if (lf && !rf) {
rhs = coerce_to(L, rhs, lt, node_is_unsigned(L, rhs_n));
res_t = lt;
} else if (!lf && rf) {
lhs = coerce_to(L, lhs, rt, node_is_unsigned(L, lhs_n));
res_t = rt;
} else if (lf && rf) {
/* Both float: promote narrower to wider */
if (lt < L->M->num_types && rt < L->M->num_types
&& L->M->types[rt].width > L->M->types[lt].width) {
lhs = coerce_to(L, lhs, rt, 0);
res_t = rt;
} else if (lt < L->M->num_types && rt < L->M->num_types
&& L->M->types[lt].width > L->M->types[rt].width) {
rhs = coerce_to(L, rhs, lt, 0);
res_t = lt;
}
}
int fp = is_float_type(L, res_t);
int opc = bin_op_code(op, fp, node_is_unsigned(L, node));
if (opc < 0) {
lower_error(L, node, BC_E102);
return lhs;
}
uint32_t inst = emit(L, (uint16_t)opc, lt, 2, 0);
uint32_t inst = emit(L, (uint16_t)opc, res_t, 2, 0);
set_op(L, inst, 0, lhs);
set_op(L, inst, 1, rhs);
return BIR_MAKE_VAL(inst);
Expand Down Expand Up @@ -1522,15 +1579,15 @@ static uint32_t lower_expr(lower_t *L, uint32_t node)
/* ---- Math builtins: unary ---- */
{
static const struct { const char *n; uint16_t op; } mt1[] = {
{"sqrtf",BIR_SQRT},{"__fsqrt_rn",BIR_SQRT},
{"sqrtf",BIR_SQRT},{"sqrt",BIR_SQRT},{"__fsqrt_rn",BIR_SQRT},
{"rsqrtf",BIR_RSQ},{"__frsqrt_rn",BIR_RSQ},
{"__frcp_rn",BIR_RCP},
{"exp2f",BIR_EXP2},{"log2f",BIR_LOG2},{"__log2f",BIR_LOG2},
{"fabsf",BIR_FABS},{"fabs",BIR_FABS},
{"floorf",BIR_FLOOR},{"ceilf",BIR_CEIL},
{"truncf",BIR_FTRUNC},{"roundf",BIR_RNDNE},{"rintf",BIR_RNDNE},
};
for (int mi = 0; mi < 15; mi++) {
for (int mi = 0; mi < 16; mi++) {
if (strcmp(cname, mt1[mi].n) != 0) continue;
uint32_t an = ND(L, callee_n)->next_sibling;
uint32_t v = lower_expr(L, an);
Expand Down
34 changes: 31 additions & 3 deletions src/nvidia/emit.c
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,11 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I)
case NV_MOV_F64:
nv_apnd(nv, "mov.f64 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
/* PTX f64 immediates need 0dXXXX hex format, not bare int */
if (I->ops[1].kind == NV_MOP_IMM && I->ops[1].imm == 0)
nv_apnd(nv, "0d0000000000000000");
else
em_opnd(nv, &I->ops[1]);
break;
case NV_MOV_PRED:
nv_apnd(nv, "mov.pred ");
Expand All @@ -489,6 +493,16 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I)
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_S32_F64:
nv_apnd(nv, "cvt.rzi.s32.f64 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_U32_F64:
nv_apnd(nv, "cvt.rzi.u32.f64 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_F32_U32:
nv_apnd(nv, "cvt.rn.f32.u32 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
Expand Down Expand Up @@ -544,6 +558,16 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I)
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_F64_S32:
nv_apnd(nv, "cvt.rn.f64.s32 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_F64_U32:
nv_apnd(nv, "cvt.rn.f64.u32 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
em_opnd(nv, &I->ops[1]);
break;
case NV_CVT_F32_F16:
nv_apnd(nv, "cvt.f32.f16 ");
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", ");
Expand Down Expand Up @@ -610,23 +634,27 @@ static void em_inst(nv_module_t *nv, const nv_minst_t *I)
}

/* ---- Loads/Stores: Local ---- */
case NV_LD_LOC_U32: case NV_LD_LOC_U64: case NV_LD_LOC_F32: {
case NV_LD_LOC_U32: case NV_LD_LOC_U64:
case NV_LD_LOC_F32: case NV_LD_LOC_F64: {
const char *tsuf;
switch (I->op) {
case NV_LD_LOC_U64: tsuf = ".u64"; break;
case NV_LD_LOC_F32: tsuf = ".f32"; break;
case NV_LD_LOC_F64: tsuf = ".f64"; break;
default: tsuf = ".u32"; break;
}
nv_apnd(nv, "ld.local%s ", tsuf);
em_opnd(nv, &I->ops[0]); nv_apnd(nv, ", [");
em_opnd(nv, &I->ops[1]); nv_apnd(nv, "]");
break;
}
case NV_ST_LOC_U32: case NV_ST_LOC_U64: case NV_ST_LOC_F32: {
case NV_ST_LOC_U32: case NV_ST_LOC_U64:
case NV_ST_LOC_F32: case NV_ST_LOC_F64: {
const char *tsuf;
switch (I->op) {
case NV_ST_LOC_U64: tsuf = ".u64"; break;
case NV_ST_LOC_F32: tsuf = ".f32"; break;
case NV_ST_LOC_F64: tsuf = ".f64"; break;
default: tsuf = ".u32"; break;
}
nv_apnd(nv, "st.local%s [", tsuf);
Expand Down
28 changes: 20 additions & 8 deletions src/nvidia/isel.c
Original file line number Diff line number Diff line change
Expand Up @@ -560,10 +560,20 @@ static void is_cvt(uint32_t idx, const bir_inst_t *I)

uint16_t op;
switch (I->op) {
case BIR_FPTOSI: op = NV_CVT_S32_F32; break;
case BIR_FPTOUI: op = NV_CVT_U32_F32; break;
case BIR_SITOFP: op = NV_CVT_F32_S32; break;
case BIR_UITOFP: op = NV_CVT_F32_U32; break;
case BIR_FPTOSI: {
uint32_t si0 = BIR_VAL_INDEX(I->operands[0]);
uint8_t srf = (si0 < S.bir->num_insts) ? bir_rfile(S.bir->insts[si0].type) : NV_RF_F32;
op = (srf == NV_RF_F64) ? NV_CVT_S32_F64 : NV_CVT_S32_F32;
break;
}
case BIR_FPTOUI: {
uint32_t si0 = BIR_VAL_INDEX(I->operands[0]);
uint8_t srf = (si0 < S.bir->num_insts) ? bir_rfile(S.bir->insts[si0].type) : NV_RF_F32;
op = (srf == NV_RF_F64) ? NV_CVT_U32_F64 : NV_CVT_U32_F32;
break;
}
case BIR_SITOFP: op = (bir_rfile(I->type) == NV_RF_F64) ? NV_CVT_F64_S32 : NV_CVT_F32_S32; break;
case BIR_UITOFP: op = (bir_rfile(I->type) == NV_RF_F64) ? NV_CVT_F64_U32 : NV_CVT_F32_U32; break;
case BIR_FPTRUNC: op = NV_CVT_F32_F64; break;
case BIR_FPEXT: op = NV_CVT_F64_F32; break;
case BIR_ZEXT: op = NV_CVT_U64_U32; break;
Expand Down Expand Up @@ -610,7 +620,8 @@ static void is_load(uint32_t idx, const bir_inst_t *I)
op = (drf == NV_RF_F32) ? NV_LD_SHR_F32 : NV_LD_SHR_U32;
break;
case BIR_AS_PRIVATE:
op = (drf == NV_RF_F32) ? NV_LD_LOC_F32 :
op = (drf == NV_RF_F64) ? NV_LD_LOC_F64 :
(drf == NV_RF_F32) ? NV_LD_LOC_F32 :
(drf == NV_RF_U64) ? NV_LD_LOC_U64 : NV_LD_LOC_U32;
break;
default: /* global */
Expand Down Expand Up @@ -659,7 +670,8 @@ static void is_store(const bir_inst_t *I)
op = (vrf == NV_RF_F32) ? NV_ST_SHR_F32 : NV_ST_SHR_U32;
break;
case BIR_AS_PRIVATE:
op = (vrf == NV_RF_F32) ? NV_ST_LOC_F32 :
op = (vrf == NV_RF_F64) ? NV_ST_LOC_F64 :
(vrf == NV_RF_F32) ? NV_ST_LOC_F32 :
(vrf == NV_RF_U64) ? NV_ST_LOC_U64 : NV_ST_LOC_U32;
break;
default:
Expand Down Expand Up @@ -1034,9 +1046,9 @@ static void is_math(uint32_t idx, const bir_inst_t *I)
op = (I->op == BIR_SIN) ? NV_SIN_F32 : NV_COS_F32;
break;
}
case BIR_EXP2: op = NV_EX2_F32; break;
case BIR_EXP2: op = NV_EX2_F32; break; /* PTX ex2 is f32 only */
case BIR_LOG2: op = NV_LG2_F32; break;
case BIR_FABS: op = NV_ABS_F32; break;
case BIR_FABS: op = (rf == NV_RF_F64) ? NV_ABS_F64 : NV_ABS_F32; break;
case BIR_FLOOR: op = NV_FLOOR_F32; break;
case BIR_CEIL: op = NV_CEIL_F32; break;
case BIR_FTRUNC: op = NV_TRUNC_F32; break;
Expand Down
10 changes: 6 additions & 4 deletions src/nvidia/nvidia.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,15 @@ typedef enum {
NV_MOV_PRED,

/* Conversions */
NV_CVT_U32_F32, NV_CVT_S32_F32, /* fptosi/fptoui */
NV_CVT_U32_F32, NV_CVT_S32_F32, /* fptosi/fptoui (f32 src) */
NV_CVT_U32_F64, NV_CVT_S32_F64, /* fptosi/fptoui (f64 src) */
NV_CVT_F32_U32, NV_CVT_F32_S32, /* uitofp/sitofp */
NV_CVT_F32_F64, NV_CVT_F64_F32, /* fptrunc/fpext */
NV_CVT_U64_U32, NV_CVT_S64_S32, /* zext/sext to 64 */
NV_CVT_U32_U64, /* trunc 64->32 */
NV_CVT_U64_F64, NV_CVT_S64_F64, /* fp64->int64 */
NV_CVT_F64_U64, NV_CVT_F64_S64, /* int64->fp64 */
NV_CVT_F64_U32, NV_CVT_F64_S32, /* int32->fp64 */
NV_CVT_F32_F16, NV_CVT_F16_F32, /* half conversions */

/* Loads / stores — global */
Expand All @@ -89,9 +91,9 @@ typedef enum {

/* Loads / stores — local (scratch / alloca) */
NV_LD_LOC_U32, NV_LD_LOC_U64,
NV_LD_LOC_F32,
NV_LD_LOC_F32, NV_LD_LOC_F64,
NV_ST_LOC_U32, NV_ST_LOC_U64,
NV_ST_LOC_F32,
NV_ST_LOC_F32, NV_ST_LOC_F64,

/* Parameter loads */
NV_LD_PARAM_U32, NV_LD_PARAM_U64,
Expand Down Expand Up @@ -126,7 +128,7 @@ typedef enum {
NV_RCP_F32, /* rcp.approx.f32 */
NV_SIN_F32, /* sin.approx.f32 */
NV_COS_F32, /* cos.approx.f32 */
NV_EX2_F32, /* ex2.approx.f32 */
NV_EX2_F32, /* ex2.approx.f32 (no f64 in PTX) */
NV_LG2_F32, /* lg2.approx.f32 */
NV_FLOOR_F32, /* cvt.rmi.f32.f32 (floor) */
NV_CEIL_F32, /* cvt.rpi.f32.f32 (ceil) */
Expand Down
Loading