diff --git a/xls/dslx/bytecode/bytecode_emitter.cc b/xls/dslx/bytecode/bytecode_emitter.cc index 9bc353890b..cd72b7e896 100644 --- a/xls/dslx/bytecode/bytecode_emitter.cc +++ b/xls/dslx/bytecode/bytecode_emitter.cc @@ -1405,7 +1405,10 @@ absl::Status BytecodeEmitter::DestructureLet( } absl::Status BytecodeEmitter::HandleLambda(const Lambda* node) { - return absl::UnimplementedError("lambdas not yet supported"); + XLS_ASSIGN_OR_RETURN(auto result, + HandleNameDefInternal(node->function()->name_def())); + AddResult(node->span(), result); + return absl::OkStatus(); } absl::Status BytecodeEmitter::HandleLet(const Let* node) { @@ -1421,13 +1424,17 @@ absl::Status BytecodeEmitter::HandleLet(const Let* node) { absl::Status BytecodeEmitter::HandleNameRef(const NameRef* node) { XLS_ASSIGN_OR_RETURN(auto result, HandleNameRefInternal(node)); + AddResult(node->span(), result); + return absl::OkStatus(); +} + +void BytecodeEmitter::AddResult( + const Span span, std::variant result) { if (std::holds_alternative(result)) { - Add(Bytecode::MakeLiteral(node->span(), std::get(result))); + Add(Bytecode::MakeLiteral(span, std::get(result))); } else { - Add(Bytecode::MakeLoad(node->span(), - std::get(result))); + Add(Bytecode::MakeLoad(span, std::get(result))); } - return absl::OkStatus(); } absl::StatusOr BytecodeEmitter::HandleExternRef( @@ -1498,40 +1505,41 @@ BytecodeEmitter::HandleNameRefInternal(const NameRef* node) { use_tree_entry != nullptr) { return HandleExternRef(*node, *name_def, *use_tree_entry); } + return HandleNameDefInternal(name_def); + }}, + any_name_def); +} - // Emit function and constant refs directly so that they can be - // stack elements without having to load slots with them. - if (auto* f = dynamic_cast(definer); f != nullptr) { - return InterpValue::MakeFunction( - InterpValue::UserFnData{f->owner(), f}); - } - if (auto* cd = dynamic_cast(name_def->definer()); - cd != nullptr) { - return type_info_->GetConstExpr(cd->value()); - } +absl::StatusOr> +BytecodeEmitter::HandleNameDefInternal(const NameDef* node) { + AstNode* definer = node->definer(); + // Emit function and constant refs directly so that they can be + // stack elements without having to load slots with them. + if (auto* f = dynamic_cast(definer); f != nullptr) { + return InterpValue::MakeFunction(InterpValue::UserFnData{f->owner(), f}); + } + if (auto* cd = dynamic_cast(definer); cd != nullptr) { + return type_info_->GetConstExpr(cd->value()); + } - // The value is either a local name or a parametric name. - if (namedef_to_slot_.contains(name_def)) { - int64_t slotno = namedef_to_slot_.at(name_def); - return Bytecode::SlotIndex(slotno); - } + // The value is either a local name or a parametric name. + if (namedef_to_slot_.contains(node)) { + int64_t slotno = namedef_to_slot_.at(node); + return Bytecode::SlotIndex(slotno); + } - if (caller_bindings_.has_value()) { - absl::flat_hash_map bindings_map = - caller_bindings_.value().ToMap(); - if (bindings_map.contains(name_def->identifier())) { - return caller_bindings_.value().ToMap().at( - name_def->identifier()); - } - } + if (caller_bindings_.has_value()) { + absl::flat_hash_map bindings_map = + caller_bindings_.value().ToMap(); + if (bindings_map.contains(node->identifier())) { + return caller_bindings_.value().ToMap().at(node->identifier()); + } + } - return absl::InternalError(absl::StrCat( - "BytecodeEmitter could not find slot or binding for name: ", - name_def->ToString(), " @ ", - name_def->span().ToString(file_table()), - " stack: ", GetSymbolizedStackTraceAsString())); - }}, - any_name_def); + return absl::InternalError( + absl::StrCat("BytecodeEmitter could not find slot or binding for name: ", + node->ToString(), " @ ", node->span().ToString(file_table()), + " stack: ", GetSymbolizedStackTraceAsString())); } absl::Status BytecodeEmitter::HandleNumber(const Number* node) { diff --git a/xls/dslx/bytecode/bytecode_emitter.h b/xls/dslx/bytecode/bytecode_emitter.h index 66accc232c..4a315d5f9a 100644 --- a/xls/dslx/bytecode/bytecode_emitter.h +++ b/xls/dslx/bytecode/bytecode_emitter.h @@ -100,6 +100,8 @@ class BytecodeEmitter : public ExprVisitor { // Adds the given bytecode to the program. void Add(Bytecode bytecode) { bytecode_.push_back(std::move(bytecode)); } + void AddResult(const Span span, + std::variant result); absl::Status HandleArray(const Array* node) override; absl::Status HandleAttr(const Attr* node) override; absl::Status HandleBinop(const Binop* node) override; @@ -125,6 +127,8 @@ class BytecodeEmitter : public ExprVisitor { absl::StatusOr> HandleNameRefInternal(const NameRef* node); + absl::StatusOr> + HandleNameDefInternal(const NameDef* node); absl::StatusOr HandleExternRef(const NameRef& name_ref, const NameDef& name_def, diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 469c6c9488..d075acd0ca 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -2596,9 +2596,7 @@ class Function : public AstNode { // Example: `let squares = map(range(u32:0, u32:5), |x| { x * x });` // // Attributes: -// * params: The explicit parameters of the lambda. -// * return_type: The return type of the lambda. -// * body: The body of the lambda. +// * function: A Function that represents the lambda behavior. class Lambda : public Expr { public: Lambda(Module* owner, Span span, Function* function); diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index fe37fbb9a9..a31c24a6c7 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -321,6 +321,7 @@ absl::StatusOr Parser::ParseLambda(Bindings& bindings) { module_->Make(sp, fn_name_def, parametrics, params, return_type, body, FunctionTag::kLambda, /*is_public=*/false, /*is_stub=*/false); + fn_name_def->set_definer(fn); return module_->Make(sp, fn); } diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index 532c1feb67..1626969597 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -9209,22 +9209,21 @@ const F = Foo { a: C }; HasSpan(8, 19, 8, 20))); } -// TODO: erinzmoore - Add `const_assert!` statements to lambda tests once -// constexpr is supported for lambdas. TEST(TypecheckV2Test, LambdaWithExplicitTypes) { EXPECT_THAT(R"( -const M = 0..6; - -const ARR = map(M, | i: u16 | -> u16 { 2 * i }); +const M = u16:0..6; +const ARR = map(M, | i: u16 | -> u16 { u16:2 * i }); +const_assert!(ARR[1] == u16:2); )", - TypecheckSucceeds(AllOf(HasNodeWithType("M", "uN[3][6]"), - HasNodeWithType("ARR", "uN[16][6]")))); + TypecheckSucceeds(HasNodeWithType("ARR", "uN[16][6]"))); } TEST(TypecheckV2Test, LambdaWithImplicitParam) { EXPECT_THAT( R"( -const ARR = map(0..6, | i | -> u16 { i }); +const ARR = map(u16:0..6, | i | -> u16 { i }); +const_assert!(ARR[1] == u16:1); +const_assert!(ARR[5] == u16:5); )", TypecheckSucceeds(HasNodeWithType("ARR", "uN[16][6]"))); } @@ -9233,8 +9232,9 @@ TEST(TypecheckV2Test, LambdaWithMultipleParams) { EXPECT_THAT( R"( fn main() -> u32 { - (|i, j| -> u32 {i * j})(2, 4) + (|i, j| -> u32 {i * j})(u32:2, u32:4) } +const_assert!(main() == 8); )", TypecheckSucceeds(HasNodeWithType("main", "() -> uN[32]"))); } @@ -9249,13 +9249,15 @@ fn main() -> u32 { TypecheckFails(HasSizeMismatch("bool", "u32"))); } +// TODO: Add `const_assert!` statements once we fully support lambdas with +// captured variables. TEST(TypecheckV2Test, LambdaWithContextCapture) { EXPECT_THAT( R"( fn main() -> u32 { - const X = u32:0; - let ARR = map(0..5, |i| -> u32 { X * i }); - ARR[0] + const X = u32:8; + let ARR = map(u32:0..5, |i| -> u32 { X * i }); + ARR[4] } )", TypecheckSucceeds(HasNodeWithType("ARR", "uN[32][5]"))); @@ -9272,11 +9274,11 @@ fn main() { TypecheckFails(HasSizeMismatch("uN[1]", "uN[32]"))); } -// TODO: Support lambdas in constant_collector. -TEST(TypecheckV2Test, DISABLED_LambdaConstEval) { +// TODO: Fully support lambdas with captured variables. +TEST(TypecheckV2Test, DISABLED_LambdaGeneratedValueAsType) { EXPECT_THAT(R"( const X = u32:3; -const ARR = map(0..5, |i: u32| -> u32 { X * i }); +const ARR = map(u32:0..5, |i: u32| -> u32 { X * i }); const TEST = uN[ARR[1]]:0; )", TypecheckSucceeds(AllOf(HasNodeWithType("ARR", "uN[32][5]"),