Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 43 additions & 35 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,10 @@ absl::Status BytecodeEmitter::DestructureLet(
}

absl::Status BytecodeEmitter::HandleLambda(const Lambda* node) {
return absl::UnimplementedError("lambdas not yet supported");
XLS_ASSIGN_OR_RETURN(auto result,
HandleNameDefInternal(node->function()->name_def()));
AddResult(node->span(), result);
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleLet(const Let* node) {
Expand All @@ -1421,13 +1424,17 @@ absl::Status BytecodeEmitter::HandleLet(const Let* node) {

absl::Status BytecodeEmitter::HandleNameRef(const NameRef* node) {
XLS_ASSIGN_OR_RETURN(auto result, HandleNameRefInternal(node));
AddResult(node->span(), result);
return absl::OkStatus();
}

void BytecodeEmitter::AddResult(
const Span span, std::variant<InterpValue, Bytecode::SlotIndex> result) {
if (std::holds_alternative<InterpValue>(result)) {
Add(Bytecode::MakeLiteral(node->span(), std::get<InterpValue>(result)));
Add(Bytecode::MakeLiteral(span, std::get<InterpValue>(result)));
} else {
Add(Bytecode::MakeLoad(node->span(),
std::get<Bytecode::SlotIndex>(result)));
Add(Bytecode::MakeLoad(span, std::get<Bytecode::SlotIndex>(result)));
}
return absl::OkStatus();
}

absl::StatusOr<InterpValue> BytecodeEmitter::HandleExternRef(
Expand Down Expand Up @@ -1498,40 +1505,41 @@ BytecodeEmitter::HandleNameRefInternal(const NameRef* node) {
use_tree_entry != nullptr) {
return HandleExternRef(*node, *name_def, *use_tree_entry);
}
return HandleNameDefInternal(name_def);
}},
any_name_def);
}

// Emit function and constant refs directly so that they can be
// stack elements without having to load slots with them.
if (auto* f = dynamic_cast<Function*>(definer); f != nullptr) {
return InterpValue::MakeFunction(
InterpValue::UserFnData{f->owner(), f});
}
if (auto* cd = dynamic_cast<ConstantDef*>(name_def->definer());
cd != nullptr) {
return type_info_->GetConstExpr(cd->value());
}
absl::StatusOr<std::variant<InterpValue, Bytecode::SlotIndex>>
BytecodeEmitter::HandleNameDefInternal(const NameDef* node) {
AstNode* definer = node->definer();
// Emit function and constant refs directly so that they can be
// stack elements without having to load slots with them.
if (auto* f = dynamic_cast<Function*>(definer); f != nullptr) {
return InterpValue::MakeFunction(InterpValue::UserFnData{f->owner(), f});
}
if (auto* cd = dynamic_cast<ConstantDef*>(definer); cd != nullptr) {
return type_info_->GetConstExpr(cd->value());
}

// The value is either a local name or a parametric name.
if (namedef_to_slot_.contains(name_def)) {
int64_t slotno = namedef_to_slot_.at(name_def);
return Bytecode::SlotIndex(slotno);
}
// The value is either a local name or a parametric name.
if (namedef_to_slot_.contains(node)) {
int64_t slotno = namedef_to_slot_.at(node);
return Bytecode::SlotIndex(slotno);
}

if (caller_bindings_.has_value()) {
absl::flat_hash_map<std::string, InterpValue> bindings_map =
caller_bindings_.value().ToMap();
if (bindings_map.contains(name_def->identifier())) {
return caller_bindings_.value().ToMap().at(
name_def->identifier());
}
}
if (caller_bindings_.has_value()) {
absl::flat_hash_map<std::string, InterpValue> bindings_map =
caller_bindings_.value().ToMap();
if (bindings_map.contains(node->identifier())) {
return caller_bindings_.value().ToMap().at(node->identifier());
}
}

return absl::InternalError(absl::StrCat(
"BytecodeEmitter could not find slot or binding for name: ",
name_def->ToString(), " @ ",
name_def->span().ToString(file_table()),
" stack: ", GetSymbolizedStackTraceAsString()));
}},
any_name_def);
return absl::InternalError(
absl::StrCat("BytecodeEmitter could not find slot or binding for name: ",
node->ToString(), " @ ", node->span().ToString(file_table()),
" stack: ", GetSymbolizedStackTraceAsString()));
}

absl::Status BytecodeEmitter::HandleNumber(const Number* node) {
Expand Down
4 changes: 4 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ class BytecodeEmitter : public ExprVisitor {

// Adds the given bytecode to the program.
void Add(Bytecode bytecode) { bytecode_.push_back(std::move(bytecode)); }
void AddResult(const Span span,
std::variant<InterpValue, Bytecode::SlotIndex> result);
absl::Status HandleArray(const Array* node) override;
absl::Status HandleAttr(const Attr* node) override;
absl::Status HandleBinop(const Binop* node) override;
Expand All @@ -125,6 +127,8 @@ class BytecodeEmitter : public ExprVisitor {

absl::StatusOr<std::variant<InterpValue, Bytecode::SlotIndex>>
HandleNameRefInternal(const NameRef* node);
absl::StatusOr<std::variant<InterpValue, Bytecode::SlotIndex>>
HandleNameDefInternal(const NameDef* node);

absl::StatusOr<InterpValue> HandleExternRef(const NameRef& name_ref,
const NameDef& name_def,
Expand Down
4 changes: 1 addition & 3 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2596,9 +2596,7 @@ class Function : public AstNode {
// Example: `let squares = map(range(u32:0, u32:5), |x| { x * x });`
//
// Attributes:
// * params: The explicit parameters of the lambda.
// * return_type: The return type of the lambda.
// * body: The body of the lambda.
// * function: A Function that represents the lambda behavior.
class Lambda : public Expr {
public:
Lambda(Module* owner, Span span, Function* function);
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
module_->Make<Function>(sp, fn_name_def, parametrics, params, return_type,
body, FunctionTag::kLambda,
/*is_public=*/false, /*is_stub=*/false);
fn_name_def->set_definer(fn);
return module_->Make<Lambda>(sp, fn);
}

Expand Down
32 changes: 17 additions & 15 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9209,22 +9209,21 @@ const F = Foo { a: C };
HasSpan(8, 19, 8, 20)));
}

// TODO: erinzmoore - Add `const_assert!` statements to lambda tests once
// constexpr is supported for lambdas.
TEST(TypecheckV2Test, LambdaWithExplicitTypes) {
EXPECT_THAT(R"(
const M = 0..6;

const ARR = map(M, | i: u16 | -> u16 { 2 * i });
const M = u16:0..6;
const ARR = map(M, | i: u16 | -> u16 { u16:2 * i });
const_assert!(ARR[1] == u16:2);
)",
TypecheckSucceeds(AllOf(HasNodeWithType("M", "uN[3][6]"),
HasNodeWithType("ARR", "uN[16][6]"))));
TypecheckSucceeds(HasNodeWithType("ARR", "uN[16][6]")));
}

TEST(TypecheckV2Test, LambdaWithImplicitParam) {
EXPECT_THAT(
R"(
const ARR = map(0..6, | i | -> u16 { i });
const ARR = map(u16:0..6, | i | -> u16 { i });
const_assert!(ARR[1] == u16:1);
const_assert!(ARR[5] == u16:5);
)",
TypecheckSucceeds(HasNodeWithType("ARR", "uN[16][6]")));
}
Expand All @@ -9233,8 +9232,9 @@ TEST(TypecheckV2Test, LambdaWithMultipleParams) {
EXPECT_THAT(
R"(
fn main() -> u32 {
(|i, j| -> u32 {i * j})(2, 4)
(|i, j| -> u32 {i * j})(u32:2, u32:4)
}
const_assert!(main() == 8);
)",
TypecheckSucceeds(HasNodeWithType("main", "() -> uN[32]")));
}
Expand All @@ -9249,13 +9249,15 @@ fn main() -> u32 {
TypecheckFails(HasSizeMismatch("bool", "u32")));
}

// TODO: Add `const_assert!` statements once we fully support lambdas with
// captured variables.
TEST(TypecheckV2Test, LambdaWithContextCapture) {
EXPECT_THAT(
R"(
fn main() -> u32 {
const X = u32:0;
let ARR = map(0..5, |i| -> u32 { X * i });
ARR[0]
const X = u32:8;
let ARR = map(u32:0..5, |i| -> u32 { X * i });
ARR[4]
}
)",
TypecheckSucceeds(HasNodeWithType("ARR", "uN[32][5]")));
Expand All @@ -9272,11 +9274,11 @@ fn main() {
TypecheckFails(HasSizeMismatch("uN[1]", "uN[32]")));
}

// TODO: Support lambdas in constant_collector.
TEST(TypecheckV2Test, DISABLED_LambdaConstEval) {
// TODO: Fully support lambdas with captured variables.
TEST(TypecheckV2Test, DISABLED_LambdaGeneratedValueAsType) {
EXPECT_THAT(R"(
const X = u32:3;
const ARR = map(0..5, |i: u32| -> u32 { X * i });
const ARR = map(u32:0..5, |i: u32| -> u32 { X * i });
const TEST = uN[ARR[1]]:0;
)",
TypecheckSucceeds(AllOf(HasNodeWithType("ARR", "uN[32][5]"),
Expand Down
Loading