Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion xls/dslx/frontend/ast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,8 @@ Param::Param(Module* owner, NameDef* name_def, TypeAnnotation* type_annotation)
: AstNode(owner),
name_def_(name_def),
type_annotation_(type_annotation),
span_(name_def_->span().start(), type_annotation_->span().limit()) {}
span_(name_def_->span().start(), type_annotation_->span().limit()),
context_node_(nullptr) {}

Param::~Param() = default;

Expand Down Expand Up @@ -2366,6 +2367,16 @@ std::vector<std::string> Function::GetFreeParametricKeys() const {
return results;
}

int Function::GetNumCapturedParams() const {
int num_captured_params = 0;
for (const Param* param : params_) {
if (param->IsCaptured()) {
++num_captured_params;
}
}
return num_captured_params;
}

// -- class TestFunction

TestFunction::~TestFunction() = default;
Expand Down
11 changes: 11 additions & 0 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -2130,10 +2130,19 @@ class Param : public AstNode {
const std::string& identifier() const { return name_def_->identifier(); }
std::optional<Span> GetSpan() const override { return span_; }

bool IsCaptured() const { return context_node_ != nullptr; }
void set_context_node(AstNode* context_node) { context_node_ = context_node; }
AstNode* context_node() const { return context_node_; }

private:
NameDef* name_def_;
TypeAnnotation* type_annotation_;
Span span_;

// If this parameter is captured from context, this is the node that is
// captured. May only be used for lambda functions; will be null in other
// cases.
AstNode* context_node_;
};

#define XLS_DSLX_UNOP_KIND_EACH(X) \
Expand Down Expand Up @@ -2568,6 +2577,8 @@ class Function : public AstNode {
parametric_bindings_.back()->span().limit());
}

int GetNumCapturedParams() const;

private:
Span span_;
NameDef* name_def_;
Expand Down
7 changes: 6 additions & 1 deletion xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,9 +671,14 @@ class AstCloner : public AstNodeVisitor {

absl::Status HandleParam(const Param* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));
old_to_new_[n] = module(n)->Make<Param>(
Param* new_param = module(n)->Make<Param>(
down_cast<NameDef*>(old_to_new_.at(n->name_def())),
down_cast<TypeAnnotation*>(old_to_new_.at(n->type_annotation())));
if (n->context_node() != nullptr) {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(n->context_node()));
new_param->set_context_node(old_to_new_.at(n->context_node()));
}
old_to_new_[n] = new_param;
return absl::OkStatus();
}

Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/frontend/ast_cloner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ fn main() -> u32 {

TEST(AstClonerTest, Lambda) {
constexpr std::string_view kProgram = R"(fn main() -> u32[10] {
let ARR = map(range(0, 10), |i: u32| -> u32 { 2 * i });
let a = u32:0;
let ARR = map(range(0, 10), |i: u32, a: u32| -> u32 { a * i });
ARR
})";

Expand Down
34 changes: 26 additions & 8 deletions xls/dslx/frontend/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,8 +270,10 @@ absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
VLOG(5) << "ParseLambda @ " << start_pos;
XLS_ASSIGN_OR_RETURN(const Token* peek, PeekToken());
std::vector<ParametricBinding*> parametrics;
Bindings lambda_bindings(&bindings);
const auto& missing_annotation_generator =
[&](const Span& span) -> absl::StatusOr<TypeAnnotation*> {
[&](const Span& span,
std::string_view param_name) -> absl::StatusOr<TypeAnnotation*> {
// For lambdas, we allow an implicit type annotation. Treat this as a new
// generic type and create a name_def to use to reference it. This must be
// added as a parametric binding to the function. For example:
Expand All @@ -284,6 +286,12 @@ absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
// 2 * i
// }
TypeAnnotation* gta = module_->Make<GenericTypeAnnotation>(span);
// If already bound, this parameter is a captured variable from context and
// we can use the generic type directly without adding to the parametric
// bindings.
if (bindings.HasName(param_name)) {
return gta;
}
NameDef* generic_name_def = module_->Make<NameDef>(
span,
absl::Substitute("lambda_param_type_$0_at_$1", parametrics.size(),
Expand All @@ -299,22 +307,31 @@ absl::StatusOr<Lambda*> Parser::ParseLambda(Bindings& bindings) {
std::vector<Param*> params;
if (peek->kind() == TokenKind::kBar) {
XLS_ASSIGN_OR_RETURN(
params, ParseParamsInternal(bindings, TokenKind::kBar, TokenKind::kBar,
missing_annotation_generator));
params,
ParseParamsInternal(lambda_bindings, TokenKind::kBar, TokenKind::kBar,
missing_annotation_generator));
} else {
XLS_RETURN_IF_ERROR(DropTokenOrError(TokenKind::kDoubleBar));
}
for (auto p : params) {
std::optional<BoundNode> captured_node =
bindings.ResolveNode(p->identifier());
if (captured_node.has_value()) {
p->set_context_node(ToAstNode(*captured_node));
}
}

XLS_ASSIGN_OR_RETURN(bool dropped_arrow, TryDropToken(TokenKind::kArrow));
TypeAnnotation* return_type = nullptr;
if (dropped_arrow) {
XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(bindings));
XLS_ASSIGN_OR_RETURN(return_type, ParseTypeAnnotation(lambda_bindings));
} else {
XLS_ASSIGN_OR_RETURN(
return_type, missing_annotation_generator(Span(start_pos, GetPos())));
XLS_ASSIGN_OR_RETURN(return_type, missing_annotation_generator(
Span(start_pos, GetPos()), ""));
}

XLS_ASSIGN_OR_RETURN(StatementBlock * body, ParseBlockExpression(bindings));
XLS_ASSIGN_OR_RETURN(StatementBlock * body,
ParseBlockExpression(lambda_bindings));
Span sp = Span(start_pos, GetPos());
NameDef* fn_name_def = module_->Make<NameDef>(sp, "lambda_fn", nullptr);
Function* fn =
Expand Down Expand Up @@ -3917,7 +3934,8 @@ absl::StatusOr<Param*> Parser::ParseParam(
} else {
XLS_ASSIGN_OR_RETURN(bool peek_is_colon, PeekTokenIs(TokenKind::kColon));
if (!peek_is_colon && missing_annotation_generator) {
XLS_ASSIGN_OR_RETURN(type, missing_annotation_generator(name->span()));
XLS_ASSIGN_OR_RETURN(
type, missing_annotation_generator(name->span(), name->identifier()));
} else {
XLS_RETURN_IF_ERROR(
DropTokenOrError(TokenKind::kColon, /*start=*/nullptr,
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/frontend/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ class Parser : public TokenParser {
absl::StatusOr<Conditional*> ParseConditionalNode(
Bindings& bindings, ExprRestrictions restrictions, bool is_const = true);

using AnnotationGeneratorFn =
std::function<absl::StatusOr<TypeAnnotation*>(const Span&)>;
using AnnotationGeneratorFn = std::function<absl::StatusOr<TypeAnnotation*>(
const Span&, std::string_view)>;

// Parse a parameter. If `missing_annotation_generator` is provided, it will
// be called if the parameter is missing a type annotation; the returned type
Expand Down
3 changes: 2 additions & 1 deletion xls/dslx/type_system_v2/inference_table_converter_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,8 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
file_table_);
}
const int formal_param_count_without_self =
(function->params().size() - (function->IsMethod() ? 1 : 0));
(function->params().size() - (function->IsMethod() ? 1 : 0) -
function->GetNumCapturedParams());
if (invocation->args().size() != formal_param_count_without_self) {
// Note that the eventual unification of the signature would also catch
// this, but this redundant check ensures that an arg count mismatch error
Expand Down
14 changes: 13 additions & 1 deletion xls/dslx/type_system_v2/populate_table_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,19 @@ class PopulateInferenceTableVisitor : public PopulateTableVisitor,

absl::Status HandleParam(const Param* node) override {
VLOG(5) << "HandleParam: " << node->ToString();
XLS_RETURN_IF_ERROR(DefineTypeVariableForVariableOrConstant(node).status());
if (node->IsCaptured()) {
// If the parameter is captured from context, re-use the captured node
// type variable.
const NameRef* captured_type_variable =
*table_.GetTypeVariable(node->context_node());
XLS_RET_CHECK(captured_type_variable != nullptr);
XLS_RETURN_IF_ERROR(table_.SetTypeVariable(node, captured_type_variable));
XLS_RETURN_IF_ERROR(
table_.SetTypeVariable(node->name_def(), captured_type_variable));
} else {
XLS_RETURN_IF_ERROR(
DefineTypeVariableForVariableOrConstant(node).status());
}
return DefaultHandler(node);
}

Expand Down
3 changes: 3 additions & 0 deletions xls/dslx/type_system_v2/type_annotation_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,9 @@ FunctionTypeAnnotation* CreateFunctionTypeAnnotation(Module& module,
std::vector<const TypeAnnotation*> param_types;
param_types.reserve(function.params().size());
for (const Param* param : function.params()) {
if (param->IsCaptured()) {
continue;
}
param_types.push_back(param->type_annotation());
}
return module.Make<FunctionTypeAnnotation>(
Expand Down
4 changes: 2 additions & 2 deletions xls/dslx/type_system_v2/typecheck_module_v2_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9254,7 +9254,7 @@ TEST(TypecheckV2Test, LambdaWithContextCapture) {
R"(
fn main() -> u32 {
const X = u32:0;
let ARR = map(0..5, |i| -> u32 { X * i });
let ARR = map(0..5, |i, X| -> u32 { X * i });
ARR[0]
}
)",
Expand All @@ -9266,7 +9266,7 @@ TEST(TypecheckV2Test, LambdaWithContextParamsTypeMismatch) {
R"(
fn main() {
const X = false;
let ARR = map(0..5, |i| -> u32 { X * i });
let ARR = map(0..5, |i, X| -> u32 { X * i });
}
)",
TypecheckFails(HasSizeMismatch("uN[1]", "uN[32]")));
Expand Down