Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
978f75f
test: simple class
calebeden Jan 6, 2025
af28873
test: class keyword
calebeden Jan 6, 2025
e2fc8b3
feat(Lexer): class keyword
calebeden Jan 6, 2025
76b7348
refactor(Parser): prepare module-level parsing for classes
calebeden Jan 6, 2025
4a0a0ca
feat(Parser): classes
calebeden Jan 6, 2025
27eeec3
fix: test error message
calebeden Jan 6, 2025
02dc63c
fix: class field syntax
calebeden Jan 7, 2025
4ddaf1d
feat: class definitions
calebeden Jan 7, 2025
f3473fc
test: class definition without keyword
calebeden Jan 7, 2025
6542cca
test: constructor keyword
calebeden Jan 20, 2025
5ffd68a
feat(Lexer): constructor keyword
calebeden Jan 20, 2025
565b55f
test: impl block
calebeden Jan 20, 2025
323dc21
test: impl keyword
calebeden Jan 20, 2025
60dbfd9
feat(Lexer): impl keyword
calebeden Jan 20, 2025
45bcdaf
test: classes should not directly contain methods
calebeden Jan 20, 2025
e920577
feat(Parser): impl blocks
calebeden Jan 20, 2025
94273b6
feat(SemanticAnalyzer): impl blocks
calebeden Jan 20, 2025
c0bc6dd
feat(CCodeGenerator): impl blocks
calebeden Jan 20, 2025
fdf0e0d
test: static method
calebeden Jan 20, 2025
ac252c9
feat: static methods
calebeden Jan 20, 2025
fc035c4
test: constructor
calebeden Jan 20, 2025
fa435c9
feat: constructors
calebeden Jan 20, 2025
db215eb
test: object field access
calebeden Jan 20, 2025
a15c6ce
feat: object field access
calebeden Jan 20, 2025
625d263
test: instance method
calebeden Jan 20, 2025
58aa7f3
feat: instance methods
calebeden Jan 20, 2025
c1748fe
feat: change class fields to not use let statement syntax
calebeden Jan 20, 2025
c66bdba
test: nested object field access
calebeden Jan 20, 2025
6b915f6
fix: add assignment operator for class types
calebeden Jan 20, 2025
6387d02
test: object method access without intermediate from constructor
calebeden Jan 20, 2025
3dd5f47
fix: nested object field and method references
calebeden Jan 21, 2025
fba9f9f
refactor(tests): clean up tests for classes
calebeden Mar 23, 2026
45f4240
Merge branch 'main' into classes
calebeden Mar 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 218 additions & 24 deletions src/ast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,25 @@ Expression &FunctionCallExpression::getFunction() {
return *function;
}

void FunctionCallExpression::addFirstArgument(std::unique_ptr<Expression> argument) {
arguments.insert(arguments.begin(), std::move(argument));
}

void FunctionCallExpression::forEachArgument(
const std::function<void(Expression &)> &argumentHandler) {
for (auto &argument : arguments) {
argumentHandler(*argument);
}
}

void FunctionCallExpression::setVariant(FunctionVariant variant) {
this->variant = variant;
}

FunctionVariant FunctionCallExpression::getVariant() const {
return variant;
}

void FunctionCallExpression::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}
Expand Down Expand Up @@ -174,6 +186,14 @@ int BlockExpression::getSymbolType(std::string_view symbol) {
return std::get<0>(symbols[symbol]);
}

void BlockExpression::setSymbolType(std::string_view symbol, int typeID) {
if (symbols.find(symbol) == symbols.end()) {
std::cerr << "Symbol not found";
exit(EXIT_FAILURE);
}
std::get<0>(symbols[symbol]) = typeID;
}

SymbolSource BlockExpression::getSymbolSource(std::string_view symbol) {
if (symbols.find(symbol) == symbols.end()) {
return SymbolSource::Unknown;
Expand Down Expand Up @@ -266,6 +286,40 @@ IfElseExpression::IfElseExpression(std::unique_ptr<Expression> condition,
thenBlock(std::move(thenBlock)), elseExpression(std::move(elseExpression)) {
}

PathExpression::PathExpression(std::vector<std::unique_ptr<SymbolExpression>> path)
: Expression(Slice::merge(path[0]->getSlice(), path.back()->getSlice())),
path(std::move(path)) {
}

void PathExpression::forEachSymbol(
const std::function<void(SymbolExpression &)> &symbolHandler) {
for (auto &symbol : path) {
symbolHandler(*symbol);
}
}

void PathExpression::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

FieldAccessExpression::FieldAccessExpression(std::unique_ptr<Expression> object,
std::unique_ptr<Expression> field)
: Expression(Slice::merge(object->getSlice(), field->getSlice())),
object(std::move(object)), field(std::move(field)) {
}

Expression &FieldAccessExpression::getObject() {
return *object;
}

Expression &FieldAccessExpression::getField() {
return *field;
}

void FieldAccessExpression::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

WhileExpression::WhileExpression(const Keyword &whileKeyword,
std::unique_ptr<Expression> condition, std::unique_ptr<BlockExpression> body)
: Expression(Slice::merge(whileKeyword.s, body->getSlice())),
Expand Down Expand Up @@ -333,6 +387,12 @@ LetStatement::LetStatement(const Keyword &let, std::unique_ptr<Symbol> symbol,
equalSign(std::move(equalSign)), expression(std::move(expression)) {
}

LetStatement::LetStatement(std::unique_ptr<Symbol> symbol,
std::unique_ptr<Symbol> typeAnnotation, Punctuation *semicolon)
: Statement(Slice::merge(symbol->s, semicolon->s)), symbol(std::move(symbol)),
typeAnnotation(std::move(typeAnnotation)), isFieldDeclaration(true) {
}

LetStatement::LetStatement(std::unique_ptr<Symbol> symbol,
std::unique_ptr<Expression> expression)
: Statement(Slice("", "", 0, 0)), symbol(std::move(symbol)), typeAnnotation(nullptr),
Expand Down Expand Up @@ -363,22 +423,28 @@ int LetStatement::getSymbolTypeID() const {
return symbolTypeID;
}

bool LetStatement::getIsFieldDeclaration() const {
return isFieldDeclaration;
}

void LetStatement::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

Function::Function(
std::vector<std::pair<std::unique_ptr<Symbol>, std::unique_ptr<Symbol>>> parameters,
std::unique_ptr<Symbol> returnTypeAnnotation, std::unique_ptr<BlockExpression> body)
std::unique_ptr<Symbol> returnTypeAnnotation, std::unique_ptr<BlockExpression> body,
FunctionVariant variant)
: parameters(std::move(parameters)),
returnTypeAnnotation(std::move(returnTypeAnnotation)), body(std::move(body)) {
returnTypeAnnotation(std::move(returnTypeAnnotation)), body(std::move(body)),
variant(variant) {
}

Function::Function(
std::vector<std::pair<std::unique_ptr<Symbol>, std::unique_ptr<Symbol>>> parameters,
std::unique_ptr<BlockExpression> body)
std::unique_ptr<BlockExpression> body, FunctionVariant variant)
: parameters(std::move(parameters)), returnTypeAnnotation(nullptr),
body(std::move(body)) {
body(std::move(body)), variant(variant) {
}

void Function::forEachParameter(
Expand All @@ -404,27 +470,79 @@ void Function::setTypeID(int typeID) {
this->typeID = typeID;
}

void Function::setVariant(FunctionVariant variant) {
this->variant = variant;
}

FunctionVariant Function::getVariant() const {
return variant;
}

void Function::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

Type::Type(int id, int parentID, std::string_view name)
: id(id), parentID(parentID), name(name) {
Class::Class(std::vector<std::unique_ptr<LetStatement>> fieldDeclarations)
: fieldDeclarations(std::move(fieldDeclarations)) {
}

Class::Class(std::unique_ptr<BlockExpression> scope) : scope(std::move(scope)) {
}

void Class::forEachFieldDeclaration(
const std::function<void(LetStatement &)> &fieldHandler) {
for (auto &field : fieldDeclarations) {
fieldHandler(*field);
}
}

BlockExpression &Class::getScope() {
return *scope;
}

void Class::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

Impl::Impl(std::unordered_map<std::string_view, std::unique_ptr<Function>> methods)
: methods(std::move(methods)) {
}

void Impl::forEachMethod(
const std::function<void(std::string_view, Function &)> &methodHandler) {
for (auto &[name, method] : methods) {
methodHandler(name, *method);
}
}

Function *Impl::getMethod(std::string_view name) {
if (methods.find(name) == methods.end()) {
return nullptr;
}
return methods[name].get();
}

void Impl::accept(ASTVisitor &visitor) {
visitor.visit(*this);
}

Type::Type(int id, int parentID, std::string_view name, bool isClass)
: id(id), parentID(parentID), isClass(isClass), name(name) {
}

Module::Module(std::filesystem::path source) : source(std::move(source)) {
insertType("()");
insertType("!");
insertType("i8");
insertType("i16");
insertType("i32");
insertType("i64");
insertType("u8");
insertType("u16");
insertType("u32");
insertType("u64");
insertType("bool");
insertType("char");
insertType("()", false);
insertType("!", false);
insertType("i8", false);
insertType("i16", false);
insertType("i32", false);
insertType("i64", false);
insertType("u8", false);
insertType("u16", false);
insertType("u32", false);
insertType("u64", false);
insertType("bool", false);
insertType("char", false);
}

Module::Module(const Module &module)
Expand All @@ -445,25 +563,50 @@ void Module::forEachFunction(
}
}

void Module::addClass(std::unique_ptr<Symbol> name, std::unique_ptr<Class> cls,
bool isBuiltin) {
classes[name->s.contents] = {std::move(cls), isBuiltin};
}

void Module::addImpl(std::unique_ptr<Symbol> className, std::unique_ptr<Impl> impl,
bool isBuiltin) {
impls[className->s.contents] = {std::move(impl), isBuiltin};
}

void Module::forEachClass(
const std::function<void(std::string_view, Class &, bool)> &classHandler) {
for (auto &[name, cls] : classes) {
classHandler(name, *std::get<0>(cls), std::get<1>(cls));
}
}

void Module::forEachImpl(
const std::function<void(std::string_view, Impl &, bool)> &implHandler) {
for (auto &[name, impl] : impls) {
implHandler(name, *std::get<0>(impl), std::get<1>(impl));
}
}

Type Module::getType(std::string_view typeName) {
if (typeTableByName.find(typeName) == typeTableByName.end()) {
return Type(-1, -1, "");
return Type(-1, -1, "", false);
}
return typeTableByName.at(typeName);
}

Type Module::getType(int id) {
if (typeTableByID.find(id) == typeTableByID.end()) {
return Type(-1, -1, "");
return Type(-1, -1, "", false);
}
return typeTableByID.at(id);
}

void Module::insertType(std::string_view typeName) {
void Module::insertType(std::string_view typeName, bool isClass) {
if (typeTableByName.find(typeName) == typeTableByName.end()) {
typeTableByName.insert({typeName, Type(typeTableByName.size(), -1, typeName)});
typeTableByName.insert(
{typeName, Type(typeTableByName.size(), -1, typeName, isClass)});
typeTableByID.insert(
{typeTableByID.size(), Type(typeTableByID.size(), -1, typeName)});
{typeTableByID.size(), Type(typeTableByID.size(), -1, typeName, isClass)});
} else {
std::cerr << "Type already exists";
exit(EXIT_FAILURE);
Expand All @@ -490,7 +633,7 @@ Type Module::getCommonTypeAncestor(int type1, int type2) {
if (type2 == getType("!").id) {
return getType(type1);
}
return Type(-1, -1, "");
return Type(-1, -1, "", false);
}

void Module::addUnaryOperator(Operator::Type op, int operandType, int resultType) {
Expand Down Expand Up @@ -539,6 +682,33 @@ Function *Module::getFunction(std::string_view name) {
return std::get<0>(functions[name]).get();
}

Impl *Module::getImpl(std::string_view name) {
if (impls.find(name) == impls.end()) {
return nullptr;
}
return std::get<0>(impls[name]).get();
}

std::pair<std::string_view, Impl *> Module::getImpl(int typeID) {
for (auto &[name, impl] : impls) {
int implTypeID = getType(name).id;
if (implTypeID == typeID) {
return {name, std::get<0>(impl).get()};
}
}
return {{}, nullptr};
}

std::pair<std::string_view, Class *> Module::getClass(int typeID) {
for (auto &[name, cls] : classes) {
int implTypeID = getType(name).id;
if (implTypeID == typeID) {
return {name, std::get<0>(cls).get()};
}
}
return {{}, nullptr};
}

std::filesystem::path Module::getSource() {
return source;
}
Expand Down Expand Up @@ -619,6 +789,23 @@ void ASTPrinter::visit(ParenthesizedExpression &node) {
node.getExpression().accept(*this);
}

void ASTPrinter::visit(PathExpression &node) {
bool first = true;
node.forEachSymbol([this, &first](SymbolExpression &symbol) {
symbol.accept(*this);
if (!first) {
std::cerr << "::";
}
first = false;
});
}

void ASTPrinter::visit(FieldAccessExpression &node) {
node.getObject().accept(*this);
std::cerr << '.';
node.getField().accept(*this);
}

void ASTPrinter::visit(IfElseExpression &node) {
std::cerr << "if ";
node.getCondition().accept(*this);
Expand Down Expand Up @@ -670,6 +857,12 @@ void ASTPrinter::visit(LetStatement &node) {
void ASTPrinter::visit([[maybe_unused]] Function &node) {
}

void ASTPrinter::visit([[maybe_unused]] Class &node) {
}

void ASTPrinter::visit([[maybe_unused]] Impl &node) {
}

void ASTPrinter::visit(Module &node) {
node.forEachFunction(
[this](std::string_view name, Function &function, bool /*unused*/) {
Expand All @@ -678,4 +871,5 @@ void ASTPrinter::visit(Module &node) {
function.getBody().accept(*this);
std::cerr << '\n';
});
// TODO classes
}
Loading
Loading