From 14975aa77862ee019bfc5035358cf49f0471df56 Mon Sep 17 00:00:00 2001 From: Erin Moore Date: Mon, 26 Jan 2026 13:01:43 -0800 Subject: [PATCH] Properly handle lambda parameters from context. * Add a way to mark whether a parameter is captured and track the node it captures. * Use the captured node type directly intead of including in parametrics. * Ignore the captured nodes when resolving the function type. PiperOrigin-RevId: 861320585 --- xls/dslx/frontend/ast.cc | 13 ++++++- xls/dslx/frontend/ast.h | 11 ++++++ xls/dslx/frontend/ast_cloner.cc | 7 +++- xls/dslx/frontend/ast_cloner_test.cc | 3 +- xls/dslx/frontend/parser.cc | 34 ++++++++++++++----- xls/dslx/frontend/parser.h | 4 +-- .../inference_table_converter_impl.cc | 3 +- .../type_system_v2/populate_table_visitor.cc | 14 +++++++- .../type_system_v2/type_annotation_utils.cc | 3 ++ .../typecheck_module_v2_test.cc | 4 +-- 10 files changed, 79 insertions(+), 17 deletions(-) diff --git a/xls/dslx/frontend/ast.cc b/xls/dslx/frontend/ast.cc index eb07e51330..4078be3ac4 100644 --- a/xls/dslx/frontend/ast.cc +++ b/xls/dslx/frontend/ast.cc @@ -1065,7 +1065,8 @@ Param::Param(Module* owner, NameDef* name_def, TypeAnnotation* type_annotation) : AstNode(owner), name_def_(name_def), type_annotation_(type_annotation), - span_(name_def_->span().start(), type_annotation_->span().limit()) {} + span_(name_def_->span().start(), type_annotation_->span().limit()), + context_node_(nullptr) {} Param::~Param() = default; @@ -2366,6 +2367,16 @@ std::vector Function::GetFreeParametricKeys() const { return results; } +int Function::GetNumCapturedParams() const { + int num_captured_params = 0; + for (const Param* param : params_) { + if (param->IsCaptured()) { + ++num_captured_params; + } + } + return num_captured_params; +} + // -- class TestFunction TestFunction::~TestFunction() = default; diff --git a/xls/dslx/frontend/ast.h b/xls/dslx/frontend/ast.h index 469c6c9488..23a474f1c1 100644 --- a/xls/dslx/frontend/ast.h +++ b/xls/dslx/frontend/ast.h @@ -2130,10 +2130,19 @@ class Param : public AstNode { const std::string& identifier() const { return name_def_->identifier(); } std::optional GetSpan() const override { return span_; } + bool IsCaptured() const { return context_node_ != nullptr; } + void set_context_node(AstNode* context_node) { context_node_ = context_node; } + AstNode* context_node() const { return context_node_; } + private: NameDef* name_def_; TypeAnnotation* type_annotation_; Span span_; + + // If this parameter is captured from context, this is the node that is + // captured. May only be used for lambda functions; will be null in other + // cases. + AstNode* context_node_; }; #define XLS_DSLX_UNOP_KIND_EACH(X) \ @@ -2568,6 +2577,8 @@ class Function : public AstNode { parametric_bindings_.back()->span().limit()); } + int GetNumCapturedParams() const; + private: Span span_; NameDef* name_def_; diff --git a/xls/dslx/frontend/ast_cloner.cc b/xls/dslx/frontend/ast_cloner.cc index d7a01f597a..ef6ccf3d47 100644 --- a/xls/dslx/frontend/ast_cloner.cc +++ b/xls/dslx/frontend/ast_cloner.cc @@ -671,9 +671,14 @@ class AstCloner : public AstNodeVisitor { absl::Status HandleParam(const Param* n) override { XLS_RETURN_IF_ERROR(VisitChildren(n)); - old_to_new_[n] = module(n)->Make( + Param* new_param = module(n)->Make( down_cast(old_to_new_.at(n->name_def())), down_cast(old_to_new_.at(n->type_annotation()))); + if (n->context_node() != nullptr) { + XLS_RETURN_IF_ERROR(ReplaceOrVisit(n->context_node())); + new_param->set_context_node(old_to_new_.at(n->context_node())); + } + old_to_new_[n] = new_param; return absl::OkStatus(); } diff --git a/xls/dslx/frontend/ast_cloner_test.cc b/xls/dslx/frontend/ast_cloner_test.cc index 590c73c1ac..9c89a48a8b 100644 --- a/xls/dslx/frontend/ast_cloner_test.cc +++ b/xls/dslx/frontend/ast_cloner_test.cc @@ -144,7 +144,8 @@ fn main() -> u32 { TEST(AstClonerTest, Lambda) { constexpr std::string_view kProgram = R"(fn main() -> u32[10] { - let ARR = map(range(0, 10), |i: u32| -> u32 { 2 * i }); + let a = u32:0; + let ARR = map(range(0, 10), |i: u32, a: u32| -> u32 { a * i }); ARR })"; diff --git a/xls/dslx/frontend/parser.cc b/xls/dslx/frontend/parser.cc index f71265ada2..4ce4a16fb1 100644 --- a/xls/dslx/frontend/parser.cc +++ b/xls/dslx/frontend/parser.cc @@ -270,8 +270,10 @@ absl::StatusOr Parser::ParseLambda(Bindings& bindings) { VLOG(5) << "ParseLambda @ " << start_pos; XLS_ASSIGN_OR_RETURN(const Token* peek, PeekToken()); std::vector parametrics; + Bindings lambda_bindings(&bindings); const auto& missing_annotation_generator = - [&](const Span& span) -> absl::StatusOr { + [&](const Span& span, + std::string_view param_name) -> absl::StatusOr { // For lambdas, we allow an implicit type annotation. Treat this as a new // generic type and create a name_def to use to reference it. This must be // added as a parametric binding to the function. For example: @@ -284,6 +286,12 @@ absl::StatusOr Parser::ParseLambda(Bindings& bindings) { // 2 * i // } TypeAnnotation* gta = module_->Make(span); + // If already bound, this parameter is a captured variable from context and + // we can use the generic type directly without adding to the parametric + // bindings. + if (bindings.HasName(param_name)) { + return gta; + } NameDef* generic_name_def = module_->Make( span, absl::Substitute("lambda_param_type_$0_at_$1", parametrics.size(), @@ -299,22 +307,31 @@ absl::StatusOr Parser::ParseLambda(Bindings& bindings) { std::vector params; if (peek->kind() == TokenKind::kBar) { XLS_ASSIGN_OR_RETURN( - params, ParseParamsInternal(bindings, TokenKind::kBar, TokenKind::kBar, - missing_annotation_generator)); + params, + ParseParamsInternal(lambda_bindings, TokenKind::kBar, TokenKind::kBar, + missing_annotation_generator)); } else { XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kDoubleBar)); } + for (auto p : params) { + std::optional captured_node = + bindings.ResolveNode(p->identifier()); + if (captured_node.has_value()) { + p->set_context_node(ToAstNode(*captured_node)); + } + } XLS_ASSIGN_OR_RETURN(bool dropped_arrow, TryDropToken(TokenKind::kArrow)); TypeAnnotation* return_type = nullptr; if (dropped_arrow) { - XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(bindings)); + XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(lambda_bindings)); } else { - XLS_ASSIGN_OR_RETURN( - return_type, missing_annotation_generator(Span(start_pos, GetPos()))); + XLS_ASSIGN_OR_RETURN(return_type, missing_annotation_generator( + Span(start_pos, GetPos()), "")); } - XLS_ASSIGN_OR_RETURN(StatementBlock * body, ParseBlockExpression(bindings)); + XLS_ASSIGN_OR_RETURN(StatementBlock * body, + ParseBlockExpression(lambda_bindings)); Span sp = Span(start_pos, GetPos()); NameDef* fn_name_def = module_->Make(sp, "lambda_fn", nullptr); Function* fn = @@ -3917,7 +3934,8 @@ absl::StatusOr Parser::ParseParam( } else { XLS_ASSIGN_OR_RETURN(bool peek_is_colon, PeekTokenIs(TokenKind::kColon)); if (!peek_is_colon && missing_annotation_generator) { - XLS_ASSIGN_OR_RETURN(type, missing_annotation_generator(name->span())); + XLS_ASSIGN_OR_RETURN( + type, missing_annotation_generator(name->span(), name->identifier())); } else { XLS_RETURN_IF_ERROR( DropTokenOrError(TokenKind::kColon, /*start=*/nullptr, diff --git a/xls/dslx/frontend/parser.h b/xls/dslx/frontend/parser.h index d69bf74a3b..d8ebe0abe5 100644 --- a/xls/dslx/frontend/parser.h +++ b/xls/dslx/frontend/parser.h @@ -504,8 +504,8 @@ class Parser : public TokenParser { absl::StatusOr ParseConditionalNode( Bindings& bindings, ExprRestrictions restrictions, bool is_const = true); - using AnnotationGeneratorFn = - std::function(const Span&)>; + using AnnotationGeneratorFn = std::function( + const Span&, std::string_view)>; // Parse a parameter. If `missing_annotation_generator` is provided, it will // be called if the parameter is missing a type annotation; the returned type diff --git a/xls/dslx/type_system_v2/inference_table_converter_impl.cc b/xls/dslx/type_system_v2/inference_table_converter_impl.cc index b2fdc71a81..33abe168bd 100644 --- a/xls/dslx/type_system_v2/inference_table_converter_impl.cc +++ b/xls/dslx/type_system_v2/inference_table_converter_impl.cc @@ -653,7 +653,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter, file_table_); } const int formal_param_count_without_self = - (function->params().size() - (function->IsMethod() ? 1 : 0)); + (function->params().size() - (function->IsMethod() ? 1 : 0) - + function->GetNumCapturedParams()); if (invocation->args().size() != formal_param_count_without_self) { // Note that the eventual unification of the signature would also catch // this, but this redundant check ensures that an arg count mismatch error diff --git a/xls/dslx/type_system_v2/populate_table_visitor.cc b/xls/dslx/type_system_v2/populate_table_visitor.cc index a71990a90c..4b2f94c98c 100644 --- a/xls/dslx/type_system_v2/populate_table_visitor.cc +++ b/xls/dslx/type_system_v2/populate_table_visitor.cc @@ -179,7 +179,19 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor, absl::Status HandleParam(const Param* node) override { VLOG(5) << "HandleParam: " << node->ToString(); - XLS_RETURN_IF_ERROR(DefineTypeVariableForVariableOrConstant(node).status()); + if (node->IsCaptured()) { + // If the parameter is captured from context, re-use the captured node + // type variable. + const NameRef* captured_type_variable = + *table_.GetTypeVariable(node->context_node()); + XLS_RET_CHECK(captured_type_variable != nullptr); + XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node, captured_type_variable)); + XLS_RETURN_IF_ERROR( + table_.SetTypeVariable(node->name_def(), captured_type_variable)); + } else { + XLS_RETURN_IF_ERROR( + DefineTypeVariableForVariableOrConstant(node).status()); + } return DefaultHandler(node); } diff --git a/xls/dslx/type_system_v2/type_annotation_utils.cc b/xls/dslx/type_system_v2/type_annotation_utils.cc index e8c97fcc11..173b275d38 100644 --- a/xls/dslx/type_system_v2/type_annotation_utils.cc +++ b/xls/dslx/type_system_v2/type_annotation_utils.cc @@ -267,6 +267,9 @@ FunctionTypeAnnotation* CreateFunctionTypeAnnotation(Module& module, std::vector param_types; param_types.reserve(function.params().size()); for (const Param* param : function.params()) { + if (param->IsCaptured()) { + continue; + } param_types.push_back(param->type_annotation()); } return module.Make( diff --git a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc index 532c1feb67..f40f329769 100644 --- a/xls/dslx/type_system_v2/typecheck_module_v2_test.cc +++ b/xls/dslx/type_system_v2/typecheck_module_v2_test.cc @@ -9254,7 +9254,7 @@ TEST(TypecheckV2Test, LambdaWithContextCapture) { R"( fn main() -> u32 { const X = u32:0; - let ARR = map(0..5, |i| -> u32 { X * i }); + let ARR = map(0..5, |i, X| -> u32 { X * i }); ARR[0] } )", @@ -9266,7 +9266,7 @@ TEST(TypecheckV2Test, LambdaWithContextParamsTypeMismatch) { R"( fn main() { const X = false; - let ARR = map(0..5, |i| -> u32 { X * i }); + let ARR = map(0..5, |i, X| -> u32 { X * i }); } )", TypecheckFails(HasSizeMismatch("uN[1]", "uN[32]")));