diff --git a/xls/dslx/ir_convert/get_conversion_records.cc b/xls/dslx/ir_convert/get_conversion_records.cc index 1c52f7e8a3..376428ac0e 100644 --- a/xls/dslx/ir_convert/get_conversion_records.cc +++ b/xls/dslx/ir_convert/get_conversion_records.cc @@ -114,6 +114,21 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault { : *type_info_->GetImportedTypeInfo(node->owner()); } + absl::Status HandleConditional(const Conditional* expr) override { + if (expr->IsConst()) { + XLS_ASSIGN_OR_RETURN(InterpValue test_value, + type_info_->GetConstExpr(expr->test())); + if (test_value.IsTrue()) { + XLS_RETURN_IF_ERROR(DefaultHandler(expr->consequent())); + } else { + XLS_RETURN_IF_ERROR(DefaultHandler(ToExprNode(expr->alternate()))); + } + return absl::OkStatus(); + } + + return DefaultHandler(expr); + } + // Generates a conversion record for the given function if it is a real // function (not parametric or compiler-derived) that has no incoming calls // known to `type_info_`. Also traverses such functions to ensure that diff --git a/xls/dslx/ir_convert/ir_converter_test.cc b/xls/dslx/ir_convert/ir_converter_test.cc index da1e1d864f..a0ebe3c34d 100644 --- a/xls/dslx/ir_convert/ir_converter_test.cc +++ b/xls/dslx/ir_convert/ir_converter_test.cc @@ -927,6 +927,66 @@ fn main() -> u32 { ExpectIr(converted); } +TEST_F(IrConverterTest, ConstConditionalProcScoped) { + constexpr std::string_view program = R"( + proc Multiply { + input: chan in; + output: chan out; + + init {} + + config(input: chan in, output: chan out) { + (input, output) + } + + next(state: ()) { + let (tok, req) = recv(join(), input); + let data = req * u32:2; + let tok = send(tok, output, data); + } + } + + proc Passthrough { + input: chan in; + output: chan out; + + init {} + + config(input: chan in, output: chan out) { + (input, output) + } + + next(state: ()) { + let (tok, req) = recv(join(), input); + let tok = send(tok, output, req); + } + } + + const CONFIG = u32:31; + + proc Top { + init {} + + config(req_r: chan in, resp_s: chan out) { + const if CONFIG <= u32:27 { + spawn Passthrough(req_r, resp_s); + } else { + spawn Multiply(req_r, resp_s); + }; + () + } + + next(state: ()) { state } + } + )"; + + ConvertOptions options; + options.lower_to_proc_scoped_channels = true; + XLS_ASSERT_OK_AND_ASSIGN(std::string converted, + ConvertOneFunctionForTest(program, "Top", options)); + ExpectIr(converted); +} + TEST_F(IrConverterTest, ConstantsWithConditionalsPlusStuff) { constexpr std::string_view program = R"( diff --git a/xls/dslx/ir_convert/testdata/ir_converter_test_ConstConditionalProcScoped.ir b/xls/dslx/ir_convert/testdata/ir_converter_test_ConstConditionalProcScoped.ir new file mode 100644 index 0000000000..fd6316d823 --- /dev/null +++ b/xls/dslx/ir_convert/testdata/ir_converter_test_ConstConditionalProcScoped.ir @@ -0,0 +1,34 @@ +package test_module + +file_number 0 "test_module.x" + +proc __test_module__Multiply_0_next<_input: bits[32] in, _output: bits[32] out>(__state: (), init={()}) { + chan_interface _input(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface _output(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + after_all.5: token = after_all(id=5) + literal.3: bits[1] = literal(value=1, id=3) + receive.6: (token, bits[32]) = receive(after_all.5, predicate=literal.3, channel=_input, id=6) + req: bits[32] = tuple_index(receive.6, index=1, id=9, pos=[(0,12,16)]) + literal.10: bits[32] = literal(value=2, id=10, pos=[(0,13,23)]) + tok: token = tuple_index(receive.6, index=0, id=8, pos=[(0,12,11)]) + data: bits[32] = umul(req, literal.10, id=11, pos=[(0,13,17)]) + __state: () = state_read(state_element=__state, id=2) + tuple.13: () = tuple(id=13, pos=[(0,11,19)]) + __token: token = literal(value=token, id=1) + tuple.4: () = tuple(id=4, pos=[(0,8,6)]) + tuple_index.7: token = tuple_index(receive.6, index=0, id=7) + tok__1: token = send(tok, data, predicate=literal.3, channel=_output, id=12) + next_value.14: () = next_value(param=__state, value=tuple.13, id=14) +} + +top proc __test_module__Top_0_next<_req_r: bits[32] in, _resp_s: bits[32] out>(__state: (), init={()}) { + chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none) + proc_instantiation __test_module__Multiply_0_next_inst(_req_r, _resp_s, proc=__test_module__Multiply_0_next) + __state: () = state_read(state_element=__state, id=16) + __token: token = literal(value=token, id=15) + literal.17: bits[1] = literal(value=1, id=17) + tuple.18: () = tuple(id=18, pos=[(0,42,13)]) + tuple.19: () = tuple(id=19, pos=[(0,45,6)]) + next_value.20: () = next_value(param=__state, value=__state, id=20) +}