Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions xls/dslx/ir_convert/get_conversion_records.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
: *type_info_->GetImportedTypeInfo(node->owner());
}

absl::Status HandleConditional(const Conditional* expr) override {
if (expr->IsConst()) {
XLS_ASSIGN_OR_RETURN(InterpValue test_value,
type_info_->GetConstExpr(expr->test()));
if (test_value.IsTrue()) {
XLS_RETURN_IF_ERROR(DefaultHandler(expr->consequent()));
} else {
XLS_RETURN_IF_ERROR(DefaultHandler(ToExprNode(expr->alternate())));
}
return absl::OkStatus();
}

return DefaultHandler(expr);
}

// Generates a conversion record for the given function if it is a real
// function (not parametric or compiler-derived) that has no incoming calls
// known to `type_info_`. Also traverses such functions to ensure that
Expand Down
60 changes: 60 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,66 @@ fn main() -> u32 {
ExpectIr(converted);
}

TEST_F(IrConverterTest, ConstConditionalProcScoped) {
constexpr std::string_view program = R"(
proc Multiply {
input: chan<u32> in;
output: chan<u32> out;

init {}

config(input: chan<u32> in, output: chan<u32> out) {
(input, output)
}

next(state: ()) {
let (tok, req) = recv(join(), input);
let data = req * u32:2;
let tok = send(tok, output, data);
}
}

proc Passthrough {
input: chan<u32> in;
output: chan<u32> out;

init {}

config(input: chan<u32> in, output: chan<u32> out) {
(input, output)
}

next(state: ()) {
let (tok, req) = recv(join(), input);
let tok = send(tok, output, req);
}
}

const CONFIG = u32:31;

proc Top {
init {}

config(req_r: chan<u32> in, resp_s: chan<u32> out) {
const if CONFIG <= u32:27 {
spawn Passthrough(req_r, resp_s);
} else {
spawn Multiply(req_r, resp_s);
};
()
}

next(state: ()) { state }
}
)";

ConvertOptions options;
options.lower_to_proc_scoped_channels = true;
XLS_ASSERT_OK_AND_ASSIGN(std::string converted,
ConvertOneFunctionForTest(program, "Top", options));
ExpectIr(converted);
}

TEST_F(IrConverterTest, ConstantsWithConditionalsPlusStuff) {
constexpr std::string_view program =
R"(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package test_module

file_number 0 "test_module.x"

proc __test_module__Multiply_0_next<_input: bits[32] in, _output: bits[32] out>(__state: (), init={()}) {
chan_interface _input(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _output(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
after_all.5: token = after_all(id=5)
literal.3: bits[1] = literal(value=1, id=3)
receive.6: (token, bits[32]) = receive(after_all.5, predicate=literal.3, channel=_input, id=6)
req: bits[32] = tuple_index(receive.6, index=1, id=9, pos=[(0,12,16)])
literal.10: bits[32] = literal(value=2, id=10, pos=[(0,13,23)])
tok: token = tuple_index(receive.6, index=0, id=8, pos=[(0,12,11)])
data: bits[32] = umul(req, literal.10, id=11, pos=[(0,13,17)])
__state: () = state_read(state_element=__state, id=2)
tuple.13: () = tuple(id=13, pos=[(0,11,19)])
__token: token = literal(value=token, id=1)
tuple.4: () = tuple(id=4, pos=[(0,8,6)])
tuple_index.7: token = tuple_index(receive.6, index=0, id=7)
tok__1: token = send(tok, data, predicate=literal.3, channel=_output, id=12)
next_value.14: () = next_value(param=__state, value=tuple.13, id=14)
}

top proc __test_module__Top_0_next<_req_r: bits[32] in, _resp_s: bits[32] out>(__state: (), init={()}) {
chan_interface _req_r(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
chan_interface _resp_s(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
proc_instantiation __test_module__Multiply_0_next_inst(_req_r, _resp_s, proc=__test_module__Multiply_0_next)
__state: () = state_read(state_element=__state, id=16)
__token: token = literal(value=token, id=15)
literal.17: bits[1] = literal(value=1, id=17)
tuple.18: () = tuple(id=18, pos=[(0,42,13)])
tuple.19: () = tuple(id=19, pos=[(0,45,6)])
next_value.20: () = next_value(param=__state, value=__state, id=20)
}